package org.streaminer.stream.classifier; import org.streaminer.stream.learner.LearnerUtils; import org.streaminer.util.distance.Kernel; import org.streaminer.stream.data.Data; import org.streaminer.stream.model.HyperplaneModel; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Perceptron stream learner * * @author Helge Homburg, Christian Bockermann */ public class Perceptron extends AbstractClassifier<Data, String> { private static final long serialVersionUID = -3263838547557335984L; private static Logger LOG = LoggerFactory.getLogger(Perceptron.class); /* The learning rate gamma */ private Double learnRate; /* The label attribute */ private String labelAttribute; /* The default labels predicted by this model */ private List<String> labels = new ArrayList<String>(); /* The attribute which this learner acts upon */ private List<String> attributes = new ArrayList<String>(); private HyperplaneModel model; public Perceptron() { this(Kernel.INNER_PRODUCT, 0.05); } public Perceptron(int learnRate) { this(Kernel.INNER_PRODUCT, learnRate); } public Perceptron(int kernelType, double learnRate) { this.model = new HyperplaneModel(kernelType); this.model.initModel(new LinkedHashMap<String,Double>(), 0.0d); this.learnRate = learnRate; } /** * @return the labelAttribute */ public String getLabelAttribute() { return labelAttribute; } /** * @param labelAttribute the labelAttribute to set */ public void setLabelAttribute(String labelAttribute) { this.labelAttribute = labelAttribute; } /** * @return the learnRate */ public Double getLearnRate() { return learnRate; } /** * @param learnRate the learnRate to set */ public void setLearnRate(Double learnRate) { this.learnRate = learnRate; } @Override public void learn(Data item) { if (labelAttribute == null) labelAttribute = LearnerUtils.detectLabelAttribute(item); if (labelAttribute == null) { LOG.info("No label defined!"); return; } String label = null; if (item.get(labelAttribute) == null) { LOG.error("No label found for example!"); return; } else { label = item.get(labelAttribute).toString(); } int labelIndex = labels.indexOf(label); if (labelIndex < 0 && labels.size() < 2) { LOG.info("Adding label '{}'", label); labels.add(label); labelIndex = labels.indexOf(label); } if (labelIndex < 0) { LOG.error("My labels are {}, unknown label: {}", labels, label); if (labels.size() == 2) LOG.error("The perceptron algorithm only works for binary classification tasks!"); return; } Map<String,Double> example = LearnerUtils.getNumericVector(item); if (example.isEmpty()) { LOG.info("No numerical attributes found for learning! Ignoring example!"); return; } //---reading label // ---start computation Double prediction = model.predict(item); if (prediction != null && prediction.intValue() != labelIndex) { double direction = (labelIndex == 0) ? -1 : 1; // adjusting bias model.setBias(model.getBias() + (learnRate * direction)); // adjusting models weights Map<String,Double> weights = model.getWeights(); for (String attribute : attributes) { Double attributeValue = example.get(attribute); Double weight = weights.get(attribute); if (weight == null) weight = 0.0d; weight += learnRate * direction * attributeValue; weights.put(attribute, weight); } model.setWeights(weights); } } @Override public String predict(Data item) { if (labels.isEmpty()) { LOG.warn("No labels available, predicting '?'!"); return "?"; } if (labels.size() == 1) { LOG.warn("Only 1 label available, predicting '{}'!", labels.get(0)); return labels.get(0); } Double pred = model.predict(item); if (pred < 0.5) { return this.labels.get(0); } else { return this.labels.get(1); } } }