/*
* RapidMiner
*
* Copyright (C) 2001-2014 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.functions;
import com.rapidminer.datatable.DataTable;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimplePredictionModel;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.kernels.DotKernel;
import com.rapidminer.tools.math.kernels.Kernel;
/**
* This model is a separating hyperplane for two classes.
*
* @author Sebastian Land
*/
public class HyperplaneModel extends SimplePredictionModel {
private static final long serialVersionUID = -4990692589416639697L;
private String[] coefficientNames;
private double[] coefficients;
private double intercept;
private String classNegative;
private String classPositive;
private Kernel kernel;
public HyperplaneModel(ExampleSet exampleSet) {
this(exampleSet, null, null);
}
public HyperplaneModel(ExampleSet exampleSet, String classNegative, String classPositive) {
this(exampleSet, classNegative, classPositive, new DotKernel());
}
public HyperplaneModel(ExampleSet exampleSet, String classNegative, String classPositive, Kernel kernel) {
super(exampleSet);
this.coefficientNames = com.rapidminer.example.Tools.getRegularAttributeNames(exampleSet);
this.classNegative = classNegative;
this.classPositive = classPositive;
this.kernel = kernel;
}
@Override
public double predict(Example example) throws OperatorException {
int i = 0;
double distance = intercept;
// using kernel for distance calculation
double[] values = new double[example.getAttributes().size()];
for (Attribute currentAttribute : example.getAttributes()) {
values[i] = example.getValue(currentAttribute);
i++;
}
distance += kernel.calculateDistance(values, coefficients);
if (getLabel().isNominal()) {
int positiveMapping = getLabel().getMapping().mapString(classPositive);
int negativeMapping = getLabel().getMapping().mapString(classNegative);
boolean isApplying = example.getAttributes().getPredictedLabel() != null;
if (isApplying) {
example.setConfidence(classPositive, 1.0d / (1.0d + java.lang.Math.exp(-distance)));
example.setConfidence(classNegative, 1.0d / (1.0d + java.lang.Math.exp(distance)));
}
if (distance < 0) {
return negativeMapping;
} else {
return positiveMapping;
}
} else {
return distance;
}
}
public void init(double[] coefficients, double intercept) {
this.coefficients = coefficients;
this.intercept = intercept;
}
public double[] getCoefficients() {
return coefficients;
}
public double getIntercept() {
return intercept;
}
public void setCoefficients(double[] coefficients) {
this.coefficients = coefficients;
}
public void setIntercept(double intercept) {
this.intercept = intercept;
}
@Override
public String toString() {
StringBuffer buffer = new StringBuffer();
if ((classPositive != null) && (classNegative != null))
buffer.append("Hyperplane seperating " + classPositive + " and " + classNegative + "." + Tools.getLineSeparator());
else
buffer.append("Hyperplane for linear regression." + Tools.getLineSeparator());
buffer.append("Intercept: ");
buffer.append(Double.toString(intercept));
buffer.append(Tools.getLineSeparator());
buffer.append("Coefficients: " + Tools.getLineSeparator());
int counter = 0;
for (double value : coefficients) {
buffer.append("w(" + coefficientNames[counter] + ") = " + Tools.formatIntegerIfPossible(value, 3) + Tools.getLineSeparator());
counter++;
}
buffer.append(Tools.getLineSeparator());
return buffer.toString();
}
public DataTable createWeightsTable() {
SimpleDataTable weightTable = new SimpleDataTable("Hyperplane Model Weights", new String[] { "Attribute", "Weight" } );
for (int j = 0; j < this.coefficientNames.length; j++) {
int nameIndex = weightTable.mapString(0, this.coefficientNames[j]);
weightTable.add(new SimpleDataTableRow(new double[] { nameIndex, this.coefficients[j]}));
}
return weightTable;
}
}