/*
* RapidMiner
*
* Copyright (C) 2001-2014 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.features.selection;
import java.util.Iterator;
import java.util.List;
import java.util.Vector;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Statistics;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
/**
* <p>
* Removes (un-) correlated features due to the selected filter relation. The
* procedure is quadratic in number of attributes. In order to get more stable
* results, the original, random, and reverse order of attributes is available.
* </p>
*
* <p>
* Please note that this operator might fail in some cases when the attributes
* should be filtered out, e.g. it might not be able to remove for example all
* negative correlated features. The reason for this behaviour seems to be that
* for the complete m x m - matrix of correlations (for m attributes) the
* correlations will not be recalculated and hence not checked if one of the
* attributes of the current pair was already marked for removal. That means for
* three attributes a1, a2, and a3 that it might be that a2 was already ruled
* out by the negative correlation with a1 and is now not able to rule out a3
* any longer.
* </p>
*
* <p>
* The used correlation function is the Pearson correlation.
* </p>
*
* @author Daniel Hakenjos, Ingo Mierswa
*/
public class RemoveCorrelatedFeatures extends AbstractFeatureSelection {
/**
* The parameter name for "Use this correlation for the filter
* relation."
*/
public static final String PARAMETER_CORRELATION = "correlation";
/**
* The parameter name for "Removes one of two features if their
* correlation fulfill this relation."
*/
public static final String PARAMETER_FILTER_RELATION = "filter_relation";
/**
* The parameter name for "The algorithm takes this attribute order to
* calculate correlation and filter."
*/
public static final String PARAMETER_ATTRIBUTE_ORDER = "attribute_order";
/**
* The parameter name for "Indicates if the absolute value of the
* correlations should be used for comparison."
*/
public static final String PARAMETER_USE_ABSOLUTE_CORRELATION = "use_absolute_correlation";
public static final OperatorVersion VERSION_DETERMINISTIC_RANDOM_NUMBERS = new OperatorVersion(5, 2, 1);
private static final String[] ORDER = new String[] { "original", "random", "reverse" };
private static final int ORDER_ORIGINAL = 0;
private static final int ORDER_RANDOM = 1;
private static final int ORDER_REVERSE = 2;
private static final String[] FILTER_RELATIONS = new String[] { "greater", "greater equals", "equals", "less equals", "less" };
private static final int GREATER = 0;
private static final int GREATER_EQUALS = 1;
private static final int EQUALS = 2;
private static final int LESS_EQUALS = 3;
private static final int LESS = 4;
/** The number of removed features (for logging as value, see constructor.) */
private int removedFeatures = 0;
public RemoveCorrelatedFeatures(OperatorDescription description) {
super(description);
addValue(new ValueDouble("features_removed", "Number of removed features") {
@Override
public double getDoubleValue() {
return removedFeatures;
}
});
}
@Override
public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
exampleSet.recalculateAllAttributeStatistics();
double[] means = new double[exampleSet.getAttributes().size()];
double[] deviations = new double[exampleSet.getAttributes().size()];
boolean[] removeFeature = new boolean[exampleSet.getAttributes().size()];
int[] attributeIndex = new int[exampleSet.getAttributes().size()];
int index = 0;
for (Attribute attribute : exampleSet.getAttributes()) {
means[index] = exampleSet.getStatistics(attribute, Statistics.AVERAGE);
deviations[index] = Math.sqrt(exampleSet.getStatistics(attribute, Statistics.VARIANCE));
removeFeature[index] = false;
attributeIndex[index] = index;
index++;
}
double[][] samples = new double[exampleSet.size()][exampleSet.getAttributes().size()];
int counter = 0;
for (Example example : exampleSet) {
int d = 0;
for (Attribute attribute : example.getAttributes()) {
samples[counter][d++] = example.getValue(attribute);
}
counter++;
}
// attribute order
int order = getParameterAsInt(PARAMETER_ATTRIBUTE_ORDER);
if (order == ORDER_ORIGINAL) {
for (int i = 0; i < exampleSet.getAttributes().size(); i++)
attributeIndex[i] = i;
} else if (order == ORDER_RANDOM) {
// random attributes
Vector<Integer> vector = new Vector<Integer>();
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
vector.add(i);
}
int vindex;
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
if (getCompatibilityLevel().isAtMost(VERSION_DETERMINISTIC_RANDOM_NUMBERS)) {
vindex = (int) Math.floor(Math.random() * vector.size());
attributeIndex[i] = vector.remove(vindex).intValue();
} else {
vindex = randomGenerator.nextInt(vector.size());
attributeIndex[i] = vector.remove(vindex);
}
}
} else if (order == ORDER_REVERSE) {
for (int i = 0; i < exampleSet.getAttributes().size(); i++) {
attributeIndex[i] = exampleSet.getAttributes().size() - 1 - i;
}
}
// absolute value
boolean absolute = getParameterAsBoolean(PARAMETER_USE_ABSOLUTE_CORRELATION);
// filter relation
int relation = getParameterAsInt(PARAMETER_FILTER_RELATION);
// filtering
double threshold = getParameterAsDouble(PARAMETER_CORRELATION);
if (absolute && threshold < 0.0d) {
threshold = Math.abs(threshold);
logWarning("Correlation value is lower zero. Setting to absolute: " + threshold);
}
Attribute[] allAttributes = exampleSet.getAttributes().createRegularAttributeArray();
for (int i = 0; i < exampleSet.getAttributes().size() - 1; i++) {
if (removeFeature[attributeIndex[i]] == true) {
continue;
}
for (int j = i + 1; j < exampleSet.getAttributes().size(); j++) {
if (removeFeature[attributeIndex[j]] == true) {
continue;
}
double correlation = getCorrelation(samples, means, deviations, attributeIndex[i], attributeIndex[j]);
if (absolute) {
correlation = Math.abs(correlation);
}
if (fulfillRelation(correlation, threshold, relation)) {
removeFeature[attributeIndex[j]] = true;
String first = allAttributes[attributeIndex[i]].getName();
String second = allAttributes[attributeIndex[j]].getName();
log("Removed Attribute : " + second + " --> correlation(" + first + "," + second + ")=" + correlation);
}
}
}
// actual removal (and counter)
this.removedFeatures = 0;
index = 0;
Iterator<Attribute> iterator = exampleSet.getAttributes().iterator();
while (iterator.hasNext()) {
iterator.next();
if (removeFeature[index]) {
iterator.remove();
this.removedFeatures++;
}
index++;
}
log("Removed " + this.removedFeatures + "features." + Tools.getLineSeparator() + "ExampleSet has now " + exampleSet.getAttributes().size() + " features.");
return exampleSet;
}
private boolean fulfillRelation(double correlation, double threshold, int relation) {
switch (relation) {
case GREATER:
return (correlation > threshold);
case GREATER_EQUALS:
return (correlation >= threshold);
case EQUALS:
return (correlation == threshold);
case LESS_EQUALS:
return (correlation <= threshold);
case LESS:
return (correlation < threshold);
}
return false;
}
/**
* Calculates the correlation between the two features
*
* @param att1
* index of feature 1
* @param att2
* index of feature 2
* @return the correlation in (-1.0,1.0)
*/
private double getCorrelation(double[][] samples, double[] means, double[] deviations, int att1, int att2) {
// calculate covariance
double covariance = 0.0d;
for (int j = 0; j < samples.length; j++) {
covariance += (samples[j][att1] - means[att1]) * (samples[j][att2] - means[att2]);
}
covariance = covariance / (samples.length - 1);
// calculate correlation
double correlation = 0.0d;
correlation = (deviations[att1] * deviations[att2]);
if (correlation == 0.0d) {
correlation = covariance;
} else {
correlation = covariance / correlation;
}
return correlation;
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeDouble(PARAMETER_CORRELATION, "Use this correlation for the filter relation.", -1.0d, 1.0d, 0.95d);
type.setExpert(false);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_FILTER_RELATION, "Removes one of two features if their correlation fulfill this relation.", FILTER_RELATIONS, 0);
types.add(type);
type = new ParameterTypeCategory(PARAMETER_ATTRIBUTE_ORDER, "The algorithm takes this attribute order to calculate correlation and filter.", ORDER, 0);
types.add(type);
type = new ParameterTypeBoolean(PARAMETER_USE_ABSOLUTE_CORRELATION, "Indicates if the absolute value of the correlations should be used for comparison.", true);
types.add(type);
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
@Override
public OperatorVersion[] getIncompatibleVersionChanges() {
OperatorVersion[] incompatibleVersionChanges = super.getIncompatibleVersionChanges();
OperatorVersion[] newIncompatibleVersionChanges = new OperatorVersion[incompatibleVersionChanges.length + 1];
for (int i = 0; i < incompatibleVersionChanges.length; ++i) {
newIncompatibleVersionChanges[i] = incompatibleVersionChanges[i];
}
newIncompatibleVersionChanges[newIncompatibleVersionChanges.length-1] = VERSION_DETERMINISTIC_RANDOM_NUMBERS;
return newIncompatibleVersionChanges;
}
}