/*
* RapidMiner
*
* Copyright (C) 2001-2014 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.operator.postprocessing;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.NominalMapping;
import com.rapidminer.operator.AbstractExampleSetProcessing;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ports.metadata.ExampleSetPrecondition;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.container.Tupel;
/**
* This operator will generate predictions of the second, .. n-th most probable
* class from the confidences attributes generated by applying a classification model.
*
* @author Sebastian Land
*/
public class GeneratePredictionRankingOperator extends AbstractExampleSetProcessing {
public static final String PARAMETER_NUMBER_OF_RANKS = "number_of_ranks";
public static final String PARAMETER_REMOVE_OLD_PREDICTIONS = "remove_old_predictions";
public GeneratePredictionRankingOperator(OperatorDescription description) {
super(description);
getExampleSetInputPort().addPrecondition(new ExampleSetPrecondition(getExampleSetInputPort(), Attributes.PREDICTION_NAME, Ontology.NOMINAL));
getExampleSetInputPort().addPrecondition(new ExampleSetPrecondition(getExampleSetInputPort(), Attributes.CONFIDENCE_NAME, Ontology.NUMERICAL));
}
@Override
public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
// searching confidence attributes
Attributes attributes = exampleSet.getAttributes();
Attribute predictedLabel = attributes.getPredictedLabel();
if (predictedLabel == null) {
throw new UserError(this, 107);
}
NominalMapping mapping = predictedLabel.getMapping();
int numberOfLabels = mapping.size();
Attribute[] confidences = new Attribute[numberOfLabels];
String[] labelValue = new String[numberOfLabels];
int i = 0;
for (String value: mapping.getValues()) {
labelValue[i] = value;
confidences[i] = attributes.getConfidence(value);
if (confidences[i] == null) {
throw new UserError(this, 154, value);
}
i++;
}
// generating new prediction attributes
int k = Math.min(numberOfLabels, getParameterAsInt(PARAMETER_NUMBER_OF_RANKS));
Attribute[] kthPredictions = new Attribute[k];
Attribute[] kthConfidences = new Attribute[k];
for (i = 0; i < k; i++) {
kthPredictions[i] = AttributeFactory.createAttribute(predictedLabel.getValueType());
kthPredictions[i].setName(predictedLabel.getName() + "_" + (i + 1));
kthPredictions[i].setMapping((NominalMapping) predictedLabel.getMapping().clone());
kthConfidences[i] = AttributeFactory.createAttribute(Ontology.REAL);
kthConfidences[i].setName(Attributes.CONFIDENCE_NAME + "_" + (i + 1));
attributes.addRegular(kthPredictions[i]);
attributes.addRegular(kthConfidences[i]);
attributes.setSpecialAttribute(kthPredictions[i], Attributes.PREDICTION_NAME + "_" + (i + 1));
attributes.setSpecialAttribute(kthConfidences[i], Attributes.CONFIDENCE_NAME + "_" + (i + 1));
}
exampleSet.getExampleTable().addAttributes(Arrays.asList(kthConfidences));
exampleSet.getExampleTable().addAttributes(Arrays.asList(kthPredictions));
// now setting values
for (Example example: exampleSet) {
ArrayList<Tupel<Double, Integer>> labelConfidences = new ArrayList<Tupel<Double,Integer>>(numberOfLabels);
for (i = 0; i < numberOfLabels; i++) {
labelConfidences.add(new Tupel<Double, Integer>(example.getValue(confidences[i]), i));
}
Collections.sort(labelConfidences);
for (i = 0; i < k; i++) {
Tupel<Double, Integer> tupel = labelConfidences.get(numberOfLabels - i - 1);
example.setValue(kthPredictions[i], tupel.getSecond()); // Can use index since mapping has been cloned from above
example.setValue(kthConfidences[i], tupel.getFirst());
}
}
// deleting old prediction / confidences
attributes.remove(predictedLabel);
if (getParameterAsBoolean(PARAMETER_REMOVE_OLD_PREDICTIONS)) {
for (i = 0; i < confidences.length; i++) {
attributes.remove(confidences[i]);
}
}
return exampleSet;
}
@Override
public boolean writesIntoExistingData() {
return false;
}
@Override
public List<ParameterType> getParameterTypes() {
List<ParameterType> types = super.getParameterTypes();
types.add(new ParameterTypeInt(PARAMETER_NUMBER_OF_RANKS, "This determines how many ranks will be considered. ", 2, Integer.MAX_VALUE, false));
types.add(new ParameterTypeBoolean(PARAMETER_REMOVE_OLD_PREDICTIONS, "This indicates if the old confidence attributes should be removed.", true, false));
return types;
}
}