package org.streaminer.stream.classifier.tree; import org.streaminer.stream.data.Data; import org.streaminer.stream.learner.AbstractRegressor; import org.streaminer.stream.learner.LearnerUtils; import org.streaminer.stream.learner.Regressor; import java.util.Map; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * <p> * This class implements a regression tree. A regression trees is a model tree, i.e. * an induced decision tree with prediction models at its leafs. In case of the regression * tree these prediction models are simply regression models. * </p> * * @author Mattias Balke, Christian Bockermann <christian.bockermann@udo.edu> * */ public class RegressionTree extends AbstractRegressor<Data> { /** The unique class ID */ private static final long serialVersionUID = -2683782830606679008L; static final transient Logger log = LoggerFactory.getLogger(RegressionTree.class); /** The resulting RegressionTree which is incrementally updated */ final RegressionTreeModel tree; /** possible split point */ BTreeNode possibleSplitpoint; /** true if possible split point fulfils chernoff bound */ boolean splitPossible; /** epsilon, computed by chernoff bound */ double epsilon; /** value for delta error in chernoff bound */ final double delta; /** * * @param targetValue * @throws ClassNotFoundException * @throws IllegalAccessException * @throws InstantiationException */ public RegressionTree(Double delta, Regressor<Data> regression) throws Exception{ // FIXME warning: removed by balke: because it's never used // this.linearRegressionClass = linearRegressionClass; // this.parameters = parameters; this.tree = new RegressionTreeModel(regression); this.possibleSplitpoint = null; this.delta = delta; } @Override public RegressionTreeModel getModel() { return this.tree; } @Override public void learn(Data item) { LeafNode leaf = tree.getLeaf(item); leaf.getRegressionModel().learn( item ); leaf.updateBTrees(item); this.computeChernoffBound(leaf, item); if(this.checkSplitpoints(leaf, item) != null){ try{ this.doSplit(leaf); }catch(Exception e){} // can not occur, same exception would have been thrown in constructor } } protected BTreeNode checkSplitpoints(LeafNode leaf, Data item ){ BTreeNode possibleSplitpoint = null; for(String feature : LearnerUtils.getAttributes( item ) ){ BTreeNode root = leaf.getBTrees().get(feature); double preSd = root.getNodeInfo().getStandardDeviation(); // computeSD(root.getLeqElements() + root.getGreaterElements(), root.getLeqDeltaSum() + root.getGreaterDeltaSum(), root.getLeqSquaredDeltaSum() + root.getGreaterSquaredDeltaSum()); if( preSd > 0 && possibleSplitpoint != null ){ Double sdr = root.getStandardDeviationReduction(); if( possibleSplitpoint != null && sdr > possibleSplitpoint.getStandardDeviationReduction() ){ if( sdr - possibleSplitpoint.getStandardDeviationReduction() - epsilon >= 0 ){ } } computeSDRs(root, preSd); } } if( possibleSplitpoint != null ){ return possibleSplitpoint; } return null; } protected void computeSDRs(BTreeNode node, double preSd ){ //, double leqParentElements, double leqParentSum, double leqParentSquaredSum, double greaterParentElements, double greaterParentSum, double greaterParentSquaredSum, double preSd){ if(preSd > 0){ double sdr = node.getStandardDeviationReduction(); if (this.possibleSplitpoint != null) { if (sdr > this.possibleSplitpoint.getStandardDeviationReduction() ) { if ((sdr - this.possibleSplitpoint.getStandardDeviationReduction() - this.epsilon) >= 0) { this.splitPossible = true; } this.possibleSplitpoint = node; } else { if ((this.possibleSplitpoint.getStandardDeviationReduction() - sdr - this.epsilon) < 0) { this.splitPossible = false; } } }else{ this.possibleSplitpoint = node; } } } /* private double computeSD(double n, double deltaSum, double squaredDeltaSum){ double tmp = 1/n *(squaredDeltaSum - 1/n * Math.pow(deltaSum, 2)); if(tmp < 0){ return 0.0; } return Math.sqrt(tmp); } private double computeSDR(double preSd, double leqElements, double leqSd, double greaterElements, double greaterSd){ double n = leqElements + greaterElements; double leq = leqElements / n * leqSd; double greater = greaterElements / n * greaterSd; return preSd - (leq + greater); } */ protected void doSplit(LeafNode leaf) throws InstantiationException, IllegalAccessException, ClassNotFoundException{ System.out.println("perform split"); InnerNode splitPoint = new InnerNode(this.possibleSplitpoint.getName(), this.possibleSplitpoint.getValue(), leaf.getRegressionModel(), leaf.getN()); InnerNode parent = (InnerNode)leaf.getParent(); if(parent != null){ if(leaf.isRightChild()){ parent.setRightChild(splitPoint); }else{ parent.setLeftChild(splitPoint); } }else{ this.tree.setRoot(splitPoint); } } protected void computeChernoffBound(LeafNode leaf, Data item){ log.debug( "call predict: {}", item ); double prediction = (Double) leaf.getRegressionModel().predict( item ); this.epsilon = Math.sqrt(Math.abs(3*prediction/leaf.getN()*Math.log(2/this.delta))); } /** * @see stream.learner.AbstractRegressor#predict(java.lang.Object) */ @Override public Double predict(Data item) { return tree.predict( item ); } }