package org.seqcode.ml.classification; import java.util.Collections; import java.util.Enumeration; import java.util.Vector; import weka.classifiers.AbstractClassifier; import weka.classifiers.Classifier; import weka.classifiers.meta.Bagging; import weka.classifiers.trees.RandomTree; import weka.core.AdditionalMeasureProducer; import weka.core.Aggregateable; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.PartitionGenerator; import weka.core.Randomizable; import weka.core.RevisionUtils; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.SerializationHelper; /** <!-- globalinfo-start --> * Class for constructing a forest of random trees.<br/> * <br/> * For more information see: <br/> * <br/> * Leo Breiman (2001). Random Forests. Machine Learning. 45(1):5-32. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @article{Breiman2001, * author = {Leo Breiman}, * journal = {Machine Learning}, * number = {1}, * pages = {5-32}, * title = {Random Forests}, * volume = {45}, * year = {2001} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -I <number of trees> * Number of trees to build. * (default 100)</pre> * * <pre> -K <number of features> * Number of features to consider (<1=int(log_2(#predictors)+1)). * (default 0)</pre> * * <pre> -S * Seed for random number generator. * (default 1)</pre> * * <pre> -depth <num> * The maximum depth of the trees, 0 for unlimited. * (default 0)</pre> * * <pre> -O * Don't calculate the out of bag error.</pre> * * <pre> -print * Print the individual trees in the output</pre> * * <pre> -num-slots <num> * Number of execution slots. * (default 1 - i.e. no parallelism)</pre> * * <pre> -output-debug-info * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -do-not-check-capabilities * If set, classifier capabilities are not checked before classifier is built * (use with caution).</pre> * <!-- options-end --> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 11003 $ */ public class BaggedRandomForest extends AbstractClassifier implements OptionHandler, Randomizable, WeightedInstancesHandler, AdditionalMeasureProducer, TechnicalInformationHandler, PartitionGenerator, Aggregateable<BaggedRandomForest> { public double[] getAttributeWeights() { if (m_bagger == null) throw new Error("Random forest not built yet"); double[] sum = null; for (int i = 0; i < m_bagger.getNumIterations(); i++) { Classifier c = m_bagger.getClassifier(i); AttributeRandomTree t = (AttributeRandomTree) c; // System.out.println(t); int[] numInstances = t.nodeNumInstances(); if (sum == null) sum = new double[numInstances.length]; for (int j = 0; j < numInstances.length; j++) sum[j] += numInstances[j]; } double max = -1; for (int i = 0; i < sum.length; i++) if (sum[i] > max) max = sum[i]; for (int i = 0; i < sum.length; i++) sum[i] /= max; return sum; } public String[] getAttributes() { return ((AttributeRandomTree) m_bagger.getClassifier(0)).getAttributes(); } /** for serialization */ static final long serialVersionUID = 1116839470751428698L; /** Number of trees in forest. */ protected int m_numTrees = 100; /** The size of each bag sample, as a percentage of the training size */ protected int m_BagSizePercent = 100; /** Whether to calculate the out of bag error */ protected boolean m_CalcOutOfBag = false; /** * Number of features to consider in random feature selection. If less than 1 * will use int(log_2(M)+1) ) */ protected int m_numFeatures = 0; /** The random seed. */ protected int m_randomSeed = 1; /** Final number of features that were considered in last build. */ protected int m_KValue = 0; /** The bagger. */ protected AttributeBagging m_bagger = null; /** The maximum depth of the trees (0 = unlimited) */ protected int m_MaxDepth = 0; /** The number of threads to have executing at any one time */ protected int m_numExecutionSlots = 1; /** Print the individual trees in the output */ protected boolean m_printTrees = false; /** Don't calculate the out of bag error */ protected boolean m_dontCalculateOutOfBagError; /** * Returns a string describing classifier * * @return a description suitable for displaying in the explorer/experimenter * gui */ public String globalInfo() { return "Class for constructing a forest of random trees.\n\n" + "For more information see: \n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing detailed * information about the technical background of this class, e.g., paper * reference or book this class is based on. * * @return the technical information about this class */ @Override public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "Leo Breiman"); result.setValue(Field.YEAR, "2001"); result.setValue(Field.TITLE, "Random Forests"); result.setValue(Field.JOURNAL, "Machine Learning"); result.setValue(Field.VOLUME, "45"); result.setValue(Field.NUMBER, "1"); result.setValue(Field.PAGES, "5-32"); return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numTreesTipText() { return "The number of trees to be generated."; } /** * Get the value of numTrees. * * @return Value of numTrees. */ public int getNumTrees() { return m_numTrees; } /** * Set the value of numTrees. * * @param newNumTrees Value to assign to numTrees. */ public void setNumTrees(int newNumTrees) { m_numTrees = newNumTrees; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numFeaturesTipText() { return "The number of attributes to be used in random selection (see RandomTree)."; } /** * Get the number of features used in random selection. * * @return Value of numFeatures. */ public int getNumFeatures() { return m_numFeatures; } /** * Set the number of features to use in random selection. * * @param newNumFeatures Value to assign to numFeatures. */ public void setNumFeatures(int newNumFeatures) { m_numFeatures = newNumFeatures; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String seedTipText() { return "The random number seed to be used."; } /** * Set the seed for random number generation. * * @param seed the seed */ @Override public void setSeed(int seed) { m_randomSeed = seed; } /** * Gets the seed for the random number generations * * @return the seed for the random number generation */ @Override public int getSeed() { return m_randomSeed; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String maxDepthTipText() { return "The maximum depth of the trees, 0 for unlimited."; } /** * Get the maximum depth of trh tree, 0 for unlimited. * * @return the maximum depth. */ public int getMaxDepth() { return m_MaxDepth; } /** * Set the maximum depth of the tree, 0 for unlimited. * * @param value the maximum depth. */ public void setMaxDepth(int value) { m_MaxDepth = value; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String printTreesTipText() { return "Print the individual trees in the output"; } /** * Set whether to print the individual ensemble trees in the output * * @param print true if the individual trees are to be printed */ public void setPrintTrees(boolean print) { m_printTrees = print; } /** * Get whether to print the individual ensemble trees in the output * * @return true if the individual trees are to be printed */ public boolean getPrintTrees() { return m_printTrees; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String dontCalculateOutOfBagErrorTipText() { return "If true, then the out of bag error is not computed"; } /** * Set whether to turn off the calculation of out of bag error * * @param b true to turn off the calculation of out of bag error */ public void setDontCalculateOutOfBagError(boolean b) { m_dontCalculateOutOfBagError = b; } /** * Get whether to turn off the calculation of out of bag error * * @return true to turn off the calculation of out of bag error */ public boolean getDontCalculateOutOfBagError() { return m_dontCalculateOutOfBagError; } /** * Gets the out of bag error that was calculated as the classifier was built. * * @return the out of bag error */ public double measureOutOfBagError() { if (m_bagger != null && !m_dontCalculateOutOfBagError) { return m_bagger.measureOutOfBagError(); } else { return Double.NaN; } } /** * Set the number of execution slots (threads) to use for building the members * of the ensemble. * * @param numSlots the number of slots to use. */ public void setNumExecutionSlots(int numSlots) { m_numExecutionSlots = numSlots; } /** * Get the number of execution slots (threads) to use for building the members * of the ensemble. * * @return the number of slots to use */ public int getNumExecutionSlots() { return m_numExecutionSlots; } /** * Returns the tip text for this property * * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String numExecutionSlotsTipText() { return "The number of execution slots (threads) to use for " + "constructing the ensemble."; } /** * Returns an enumeration of the additional measure names. * * @return an enumeration of the measure names */ @Override public Enumeration<String> enumerateMeasures() { Vector<String> newVector = new Vector<String>(1); newVector.addElement("measureOutOfBagError"); return newVector.elements(); } /** * Returns the value of the named measure. * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ @Override public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) { return measureOutOfBagError(); } else { throw new IllegalArgumentException(additionalMeasureName + " not supported (RandomForest)"); } } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ @Override public Enumeration<Option> listOptions() { Vector<Option> newVector = new Vector<Option>(); newVector.addElement(new Option("\tNumber of trees to build.\n\t(default 100)", "I", 1, "-I <number of trees>")); newVector.addElement(new Option( "\tNumber of features to consider (<1=int(log_2(#predictors)+1)).\n\t(default 0)", "K", 1, "-K <number of features>")); newVector.addElement(new Option("\tSeed for random number generator.\n" + "\t(default 1)", "S", 1, "-S")); newVector.addElement(new Option( "\tThe maximum depth of the trees, 0 for unlimited.\n" + "\t(default 0)", "depth", 1, "-depth <num>")); newVector.addElement(new Option("\tDon't calculate the out of bag error.", "O", 0, "-O")); newVector.addElement(new Option( "\tPrint the individual trees in the output", "print", 0, "-print")); newVector.addElement(new Option("\tNumber of execution slots.\n" + "\t(default 1 - i.e. no parallelism)", "num-slots", 1, "-num-slots <num>")); newVector.addAll(Collections.list(super.listOptions())); return newVector.elements(); } /** * Gets the current settings of the forest. * * @return an array of strings suitable for passing to setOptions() */ @Override public String[] getOptions() { Vector<String> result = new Vector<String>(); result.add("-I"); result.add("" + getNumTrees()); result.add("-K"); result.add("" + getNumFeatures()); result.add("-S"); result.add("" + getSeed()); if (getMaxDepth() > 0) { result.add("-depth"); result.add("" + getMaxDepth()); } if (getDontCalculateOutOfBagError()) { result.add("-O"); } if (m_printTrees) { result.add("-print"); } result.add("-num-slots"); result.add("" + getNumExecutionSlots()); Collections.addAll(result, super.getOptions()); return result.toArray(new String[result.size()]); } /** * Parses a given list of options. * <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -I <number of trees> * Number of trees to build. * (default 100)</pre> * * <pre> -K <number of features> * Number of features to consider (<1=int(log_2(#predictors)+1)). * (default 0)</pre> * * <pre> -S * Seed for random number generator. * (default 1)</pre> * * <pre> -depth <num> * The maximum depth of the trees, 0 for unlimited. * (default 0)</pre> * * <pre> -O * Don't calculate the out of bag error.</pre> * * <pre> -print * Print the individual trees in the output</pre> * * <pre> -num-slots <num> * Number of execution slots. * (default 1 - i.e. no parallelism)</pre> * * <pre> -output-debug-info * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -do-not-check-capabilities * If set, classifier capabilities are not checked before classifier is built * (use with caution).</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ @Override public void setOptions(String[] options) throws Exception { String tmpStr; tmpStr = Utils.getOption('I', options); if (tmpStr.length() != 0) { m_numTrees = Integer.parseInt(tmpStr); } else { m_numTrees = 100; } m_CalcOutOfBag = Utils.getFlag('O', options); tmpStr = Utils.getOption('P', options); if (tmpStr.length() != 0) { m_BagSizePercent = Integer.parseInt(tmpStr); } else { m_BagSizePercent = 100; } tmpStr = Utils.getOption('K', options); if (tmpStr.length() != 0) { m_numFeatures = Integer.parseInt(tmpStr); } else { m_numFeatures = 0; } tmpStr = Utils.getOption('S', options); if (tmpStr.length() != 0) { setSeed(Integer.parseInt(tmpStr)); } else { setSeed(1); } tmpStr = Utils.getOption("depth", options); if (tmpStr.length() != 0) { setMaxDepth(Integer.parseInt(tmpStr)); } else { setMaxDepth(0); } setDontCalculateOutOfBagError(Utils.getFlag('O', options)); setPrintTrees(Utils.getFlag("print", options)); tmpStr = Utils.getOption("num-slots", options); if (tmpStr.length() > 0) { setNumExecutionSlots(Integer.parseInt(tmpStr)); } else { setNumExecutionSlots(1); } super.setOptions(options); Utils.checkForRemainingOptions(options); } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ @Override public Capabilities getCapabilities() { return new AttributeRandomTree().getCapabilities(); } /** * Builds a classifier for a set of instances. * * @param data the instances to train the classifier with * @throws Exception if something goes wrong */ @Override public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); m_bagger = new AttributeBagging(); // RandomTree implements WeightedInstancesHandler, so we can // represent copies using weights to achieve speed-up. m_bagger.setRepresentCopiesUsingWeights(true); AttributeRandomTree rTree = new AttributeRandomTree(); // set up the random tree options m_KValue = m_numFeatures; if (m_KValue < 1) { m_KValue = (int) Utils.log2(data.numAttributes() - 1) + 1; } rTree.setKValue(m_KValue); rTree.setMaxDepth(getMaxDepth()); rTree.setDoNotCheckCapabilities(true); // set up the bagger and build the forest m_bagger.setBagSizePercent(m_BagSizePercent); m_bagger.setCalcOutOfBag(m_CalcOutOfBag); m_bagger.setClassifier(rTree); m_bagger.setSeed(m_randomSeed); m_bagger.setNumIterations(m_numTrees); m_bagger.setNumExecutionSlots(m_numExecutionSlots); m_bagger.buildClassifier(data); } /** * Returns the class probability distribution for an instance. * * @param instance the instance to be classified * @return the distribution the forest generates for the instance * @throws Exception if computation fails */ @Override public double[] distributionForInstance(Instance instance) throws Exception { return m_bagger.distributionForInstance(instance); } /** * Outputs a description of this classifier. * * @return a string containing a description of the classifier */ @Override public String toString() { if (m_bagger == null) { return "Random forest not built yet"; } else { StringBuffer temp = new StringBuffer(); temp.append("Random forest of " + m_numTrees + " trees, each constructed while considering " + m_KValue + " random feature" + (m_KValue == 1 ? "" : "s") + ".\n" + (!getDontCalculateOutOfBagError() ? "Out of bag error: " + Utils.doubleToString(m_bagger.measureOutOfBagError(), 4) : "") + "\n" + (getMaxDepth() > 0 ? ("Max. depth of trees: " + getMaxDepth() + "\n") : ("")) + "\n"); if (m_printTrees) { temp.append(m_bagger.toString()); } return temp.toString(); } } /** * Builds the classifier to generate a partition. */ @Override public void generatePartition(Instances data) throws Exception { buildClassifier(data); } /** * Computes an array that indicates leaf membership */ @Override public double[] getMembershipValues(Instance inst) throws Exception { return m_bagger.getMembershipValues(inst); } /** * Returns the number of elements in the partition. */ @Override public int numElements() throws Exception { return m_bagger.numElements(); } /** * Returns the revision string. * * @return the revision */ @Override public String getRevision() { return RevisionUtils.extract("$Revision: 11003 $"); } /** * Main method for this class. * * @param argv the options */ public static void main(String[] argv) { runClassifier(new BaggedRandomForest(), argv); } @Override public BaggedRandomForest aggregate(BaggedRandomForest toAggregate) throws Exception { m_bagger.aggregate(toAggregate.m_bagger); return this; } @Override public void finalizeAggregation() throws Exception { m_bagger.finalizeAggregation(); } }