/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
/**
* Classification methods of BalancedWinnow algorithm.
*
* @see BalancedWinnowTrainer
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class BalancedWinnow extends Classifier implements Serializable
{
double [][] m_weights;
/**
* Passes along data pipe and weights from
* {@link #BalancedWinnowTrainer BalancedWinnowTrainer}
* @param dataPipe needed for dictionary, labels, feature vectors, etc
* @param weights weights calculated during training phase
*/
public BalancedWinnow (Pipe dataPipe, double [][] weights)
{
super (dataPipe);
m_weights = new double[weights.length][weights[0].length];
for (int i = 0; i < weights.length; i++)
for (int j = 0; j < weights[0].length; j++)
m_weights[i][j] = weights[i][j];
}
/**
* @return a copy of the weight vectors
*/
public double[][] getWeights()
{
int numCols = m_weights[0].length;
double[][] ret = new double[m_weights.length][numCols];
for (int i = 0; i < ret.length; i++)
System.arraycopy(m_weights[i], 0, ret[i], 0, numCols);
return ret;
}
/**
* Classifies an instance using BalancedWinnow's weights
*
* <p>Returns a Classification containing the normalized
* dot products between class weight vectors and the instance
* feature vector.
*
* <p>One can obtain the confidence of the classification by
* calculating weight(j')/weight(j), where j' is the
* highest weight prediction and j is the 2nd-highest.
* Another possibility is to calculate
* <br><tt><center>e^{dot(w_j', x} / sum_j[e^{dot(w_j, x)}]</center></tt>
*/
public Classification classify (Instance instance)
{
int numClasses = getLabelAlphabet().size();
int numFeats = getAlphabet().size();
double[] scores = new double[numClasses];
FeatureVector fv = (FeatureVector) instance.getData ();
// Make sure the feature vector's feature dictionary matches
// what we are expecting from our data pipe (and thus our notion
// of feature probabilities.
assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
int fvisize = fv.numLocations();
// Take dot products
double sum = 0;
for (int ci = 0; ci < numClasses; ci++) {
for (int fvi = 0; fvi < fvisize; fvi++) {
int fi = fv.indexAtLocation (fvi);
double vi = fv.valueAtLocation(fvi);
if ( m_weights[ci].length > fi ) {
scores[ci] += vi * m_weights[ci][fi];
sum += vi * m_weights[ci][fi];
}
}
scores[ci] += m_weights[ci][numFeats];
sum += m_weights[ci][numFeats];
}
MatrixOps.timesEquals(scores, 1.0 / sum);
// Create and return a Classification object
return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores));
}
// Serialization
// serialVersionUID is overriden to prevent innocuous changes in this
// class from making the serialization mechanism think the external
// format has changed.
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
private void writeObject(ObjectOutputStream out) throws IOException
{
out.writeInt(CURRENT_SERIAL_VERSION);
out.writeObject(getInstancePipe());
// write weight vector for each class
out.writeObject(m_weights);
}
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
if (version != CURRENT_SERIAL_VERSION)
throw new ClassNotFoundException("Mismatched BalancedWinnow versions: wanted " +
CURRENT_SERIAL_VERSION + ", got " +
version);
instancePipe = (Pipe) in.readObject();
m_weights = (double[][]) in.readObject();
}
}