package cc.mallet.classify;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
* Classifer for an ensemble of classifers, combined with learned weights.
* The procedure is to obtain the score from each classifier (typically p(y|x)),
* perform the weighted sum of these scores, then exponentiate the summed
* score for each class, and re-normalize the resulting per-class scores.
* In other words, the scores of the ensemble classifiers are treated as
* input features in a Maximum Entropy classifier.
* @author <a href="mailto:mccallum@cs.umass.edu">Andrew McCallum</a>
*/
public class ClassifierEnsemble extends Classifier
{
Classifier[] ensemble;
double[] weights;
public ClassifierEnsemble (Classifier[] classifiers, double[] weights)
{
this.ensemble = new Classifier[classifiers.length];
for (int i = 0; i < classifiers.length; i++) {
if (i > 0 && ensemble[i-1].getLabelAlphabet() != classifiers[i].getLabelAlphabet())
throw new IllegalStateException("LabelAlphabet's do not match.");
ensemble[i] = classifiers[i];
}
System.arraycopy (classifiers, 0, ensemble, 0, classifiers.length);
this.weights = (double[]) weights.clone();
}
public Classification classify (Instance instance)
{
int numLabels = ensemble[0].getLabelAlphabet().size();
double[] scores = new double[numLabels];
// Run each classifier on the instance, summing each one's per-class score, with a weight
for (int i = 0; i < ensemble.length; i++) {
Classification c = ensemble[i].classify(instance);
c.getLabelVector().addTo(scores, weights[i]);
}
// Exponentiate and normalize scores
expNormalize (scores);
return new Classification (instance, this, new LabelVector (ensemble[0].getLabelAlphabet(), scores));
}
private static void expNormalize (double[] a)
{
double max = MatrixOps.max (a);
double sum = 0;
for (int i = 0; i < a.length; i++) {
assert(!Double.isNaN(a[i]));
a[i] = Math.exp (a[i] - max);
sum += a[i];
}
for (int i = 0; i < a.length; i++) {
a[i] /= sum;
}
}
}