/*
* RapidMiner
*
* Copyright (C) 2001-2014 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.learner.meta;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.set.SplittedExampleSet;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.PredictionModel;
import com.rapidminer.operator.performance.EstimatedPerformance;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeList;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import com.rapidminer.tools.Tools;
import com.rapidminer.tools.math.optimization.ec.es.ESOptimization;
import com.rapidminer.tools.math.optimization.ec.es.Individual;
/**
* <p>This operator uses a set of class weights and also allows a weight for the fact
* that an example is not classified at all (marked as unknown). Based on the
* predictions of the model of the inner learner this operator optimized
* a set of thresholds regarding the defined weights.</p>
*
* <p>
* This operator might be very useful in cases where it is better to not classify
* an example then to classify it in a wrong way. This way, it is often possible
* to get very high accuracies for the remaining examples (which are actually
* classified) for the cost of having some examples which must still be manually
* classified.
* </p>
*
* @author Ingo Mierswa
*/
public class CostBasedThresholdLearner extends AbstractMetaLearner {
/** The parameter name for "The weights for all classes (first column: class names, second column: weight), empty: using 1 for all classes. The costs for not classifying at all are defined with class name '?'." */
public static final String PARAMETER_CLASS_WEIGHTS = "class_weights";
public static final String PARAMETER_ALLOW_UNKOWN_PREDICTIONS = "allow_unkown_predictions";
/** The parameter name for "Use this cost value for predicting an example as unknown (-1: use same costs as for correct class)." */
public static final String PARAMETER_PREDICT_UNKNOWN_COSTS = "predict_unknown_costs";
/** The parameter name for "Use this amount of input data for model learning and the rest for threshold optimization." */
public static final String PARAMETER_TRAINING_RATIO = "training_ratio";
/** The parameter name for "Defines the number of optimization iterations." */
public static final String PARAMETER_NUMBER_OF_ITERATIONS = "number_of_iterations";
public CostBasedThresholdLearner(OperatorDescription description) {
super(description);
}
public Model learn(ExampleSet exampleSet) throws OperatorException {
Attribute label = exampleSet.getAttributes().getLabel();
List<String[]> classWeights = getParameterList(PARAMETER_CLASS_WEIGHTS);
// some checks
if (!exampleSet.getAttributes().getLabel().isNominal()) {
throw new UserError(this, 101, getName(), label.getName());
}
if (classWeights.size() == 0) {
throw new UserError(this, 205, "class_weights");
}
// derive possible class weights
double unknownWeight = getParameterAsBoolean(PARAMETER_ALLOW_UNKOWN_PREDICTIONS) ? getParameterAsDouble(PARAMETER_PREDICT_UNKNOWN_COSTS) : -1d;
double[] weights = new double[label.getMapping().size()];
for (int i = 0; i < weights.length; i++) {
weights[i] = 1.0d;
}
Iterator<String[]> i = classWeights.iterator();
while (i.hasNext()) {
String[] classWeightArray = i.next();
String className = classWeightArray[0];
double classWeight = Double.valueOf(classWeightArray[1]);
int index = label.getMapping().getIndex(className);
if (index == -1) {
throw new UserError(this, 955, className);
} else {
weights[index] = classWeight;
}
}
// logging
List<String> weightList = new LinkedList<String>();
for (double d : weights)
weightList.add(Tools.formatIntegerIfPossible(d));
log("Used class weights --> " + weightList + ", unknown weight: " + Tools.formatIntegerIfPossible(unknownWeight));
return calculateThresholdModel(exampleSet, weights, unknownWeight);
}
private Model calculateThresholdModel(ExampleSet exampleSet, final double[] classWeights, final double unknownWeight) throws OperatorException {
SplittedExampleSet trainingSet = new SplittedExampleSet(exampleSet, getParameterAsDouble(PARAMETER_TRAINING_RATIO), SplittedExampleSet.STRATIFIED_SAMPLING, getParameterAsBoolean(RandomGenerator.PARAMETER_USE_LOCAL_RANDOM_SEED), getParameterAsInt(RandomGenerator.PARAMETER_LOCAL_RANDOM_SEED));
trainingSet.selectSingleSubset(0);
Model innerModel = applyInnerLearner(trainingSet);
trainingSet.selectSingleSubset(1);
final ExampleSet appliedTrainingSet = innerModel.apply(trainingSet);
final Attribute label = appliedTrainingSet.getAttributes().getLabel();
int numberOfGenerations = getParameterAsInt(PARAMETER_NUMBER_OF_ITERATIONS);
ESOptimization optimization = new ESOptimization(0.0d, 1.0d, 5, classWeights.length,
ESOptimization.INIT_TYPE_RANDOM,
numberOfGenerations, Math.max(1, numberOfGenerations / 10),
ESOptimization.TOURNAMENT_SELECTION,
0.4, true, ESOptimization.GAUSSIAN_MUTATION,
0.9, false, false,
RandomGenerator.getRandomGenerator(this),
this) {
@Override
public PerformanceVector evaluateIndividual(Individual individual) throws OperatorException {
double costs = 0.0d;
double[] thresholds = individual.getValues();
for (Example example : appliedTrainingSet) {
int predictionIndex = (int)example.getPredictedLabel();
String className = label.getMapping().mapIndex(predictionIndex);
double confidence = example.getConfidence(className);
// confident...
if (confidence > thresholds[predictionIndex]) {
// wrong -> malus
if (example.getLabel() != example.getPredictedLabel()) {
costs += classWeights[(int)example.getLabel()];
} else {
// correct -> bonus
//costs -= classWeights[(int)example.getLabel()];
}
// not so confident...
} else {
double usedWeight = unknownWeight;
if (unknownWeight < 0.0d) {
usedWeight = classWeights[(int)example.getLabel()];
}
// correct -> malus
if (example.getLabel() == example.getPredictedLabel()) {
costs += usedWeight;
} else {
// wrong -> bonus
//costs -= usedWeight;
}
}
}
PerformanceVector performanceVector = new PerformanceVector();
performanceVector.addCriterion(new EstimatedPerformance("Costs", costs, 1, true));
return performanceVector;
}
};
optimization.optimize();
PredictionModel.removePredictedLabel(appliedTrainingSet);
double[] bestValues = optimization.getBestValuesEver();
return new ThresholdModel(appliedTrainingSet, innerModel, bestValues);
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
ParameterType type = new ParameterTypeList(PARAMETER_CLASS_WEIGHTS, "The weights for all classes, empty: using 1 for all classes. The costs for not classifying at all are defined with class name '?'.",
new ParameterTypeString("class_name", "The name of the class."),
new ParameterTypeDouble("weight", "The weight for this class.", 0.0d, Double.POSITIVE_INFINITY, 1.0d));
type.setExpert(false);
types.add(type);
types.add(new ParameterTypeBoolean(PARAMETER_ALLOW_UNKOWN_PREDICTIONS, "This indicates if unkown predictions are allowed. If checked, the costs for unkown predictions must be specified.", false));
type = new ParameterTypeDouble(PARAMETER_PREDICT_UNKNOWN_COSTS, "Use this cost value for predicting an example as unknown.", 0d, Double.POSITIVE_INFINITY, 1d);
type.setExpert(false);
type.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_ALLOW_UNKOWN_PREDICTIONS, true, true));
types.add(type);
types.add(new ParameterTypeDouble(PARAMETER_TRAINING_RATIO, "Use this amount of input data for model learning and the rest for threshold optimization.", 0.0d, 1.0d, 0.7d));
types.add(new ParameterTypeInt(PARAMETER_NUMBER_OF_ITERATIONS, "Defines the number of optimization iterations.", 1, Integer.MAX_VALUE, 200));
types.addAll(RandomGenerator.getRandomGeneratorParameters(this));
return types;
}
@Override
public boolean supportsCapability(OperatorCapability capability) {
switch (capability) {
case NUMERICAL_LABEL:
case NO_LABEL:
case UPDATABLE:
case FORMULA_PROVIDER:
return false;
default:
return true;
}
}
}