/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.fst.semi_supervised.pr;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.logging.Logger;
import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable.ByGradientValue;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
/**
* M-step/M-projection for PR.
*
* @author Kedar Bellare
* @author Gregory Druck
*/
public class CRFOptimizableByKL implements Serializable, ByGradientValue {
private static Logger logger = MalletLogger.getLogger(CRFOptimizableByKL.class.getName());
private static final long serialVersionUID = 1L;
protected int cachedValueWeightsStamp;
protected int cachedGradientWeightsStamp;
protected int numParameters;
protected int numThreads;
protected double weight;
protected double gaussianPriorVariance = 1.0;
protected double cachedValue = -123456789;
protected double[] cachedGradient;
protected List<double[]> initialProbList, finalProbList;
protected List<double[][][]> transitionProbList;
protected InstanceList trainingSet;
protected CRF crf;
protected CRF.Factors constraints, expectations;
protected ThreadPoolExecutor executor;
protected PRAuxiliaryModel auxModel;
public CRFOptimizableByKL(CRF crf, InstanceList trainingSet,
PRAuxiliaryModel auxModel, double[][][][] cachedDots, int numThreads, double weight) {
this.crf = crf;
this.trainingSet = trainingSet;
this.numParameters = crf.getParameters().getNumFactors();
this.cachedGradient = new double[numParameters];
this.cachedValueWeightsStamp = -1;
this.cachedGradientWeightsStamp = -1;
assert(weight > 0);
this.weight = weight;
gatherConstraints(auxModel, cachedDots);
this.numThreads = numThreads;
this.executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(numThreads);
}
private double[] toProbabilities(double weights[]) {
double probs[] = new double[weights.length];
for (int i = 0; i < weights.length; i++)
probs[i] = Math.exp(weights[i]);
// TODO this shouldn't be necessary
MatrixOps.normalize(probs);
return probs;
}
private void toProbabilities(double weights[][][]) {
for (int i = 0; i < weights.length; i++)
for (int j = 0; j < weights[i].length; j++)
for (int k = 0; k < weights[i][j].length; k++)
weights[i][j][k] = Math.exp(weights[i][j][k]);
}
@SuppressWarnings("unchecked")
protected void gatherConstraints(
PRAuxiliaryModel auxModel, double[][][][] cachedDots) {
initialProbList = new ArrayList<double[]>();
finalProbList = new ArrayList<double[]>();
transitionProbList = new ArrayList<double[][][]>();
constraints = new CRF.Factors(crf.getParameters());
expectations = new CRF.Factors(crf.getParameters());
constraints.zero();
for (int ii = 0; ii < trainingSet.size(); ii++) {
Instance inst = trainingSet.get(ii);
Sequence input = (Sequence) inst.getData();
SumLatticePR geLatt =
new SumLatticePR(crf, ii, input, null, auxModel, cachedDots[ii], false, null, null, true);
double gammas[][] = geLatt.getGammas();
double initialProbs[] = toProbabilities(gammas[0]);
initialProbList.add(initialProbs);
double finalProbs[] = toProbabilities(gammas[gammas.length - 1]);
finalProbList.add(finalProbs);
double transitionProbs[][][] = geLatt.getXis();
toProbabilities(transitionProbs);
transitionProbList.add(transitionProbs);
new SumLatticeKL(crf, input, initialProbs,
finalProbs, transitionProbs, null, constraints.new Incrementor());
}
}
@SuppressWarnings("unchecked")
protected double getExpectationValue() {
expectations.zero();
// updating tasks
ArrayList<Callable<Double>> tasks = new ArrayList<Callable<Double>>();
int increment = trainingSet.size() / numThreads;
int start = 0;
int end = increment;
for (int taskIndex = 0; taskIndex < numThreads; taskIndex++) {
// same structure, but with zero values
CRF.Factors exCopy = new CRF.Factors(expectations);
tasks.add(new ExpectationTask(start,end,exCopy));
start = end;
if (taskIndex == numThreads - 2) {
end = trainingSet.size();
}
else {
end = start + increment;
}
}
double value = 0;
try {
List<Future<Double>> results = executor.invokeAll(tasks);
// compute value
for (Future<Double> f : results) {
try {
value += f.get();
} catch (ExecutionException ee) {
ee.printStackTrace();
}
}
} catch (InterruptedException ie) {
ie.printStackTrace();
}
// combine results
for (Callable<Double> task : tasks) {
this.expectations.plusEquals(((ExpectationTask)task).getExpectationsCopy(), 1);
}
return value;
}
public double getValue() {
if (crf.getWeightsValueChangeStamp() != cachedValueWeightsStamp) {
cachedValueWeightsStamp = crf.getWeightsValueChangeStamp();
long startingTime = System.currentTimeMillis();
cachedValue = getExpectationValue();
// Incorporate prior on parameters
double priorValue = crf.getParameters().gaussianPrior(gaussianPriorVariance);
cachedValue += priorValue;
logger.info("Gaussian prior = " + priorValue);
cachedValue *= weight;
assert (!(Double.isNaN(cachedValue) || Double.isInfinite(cachedValue))) : "Label likelihood is NaN/Infinite";
logger.info("getValue() (loglikelihood, optimizable by klDiv) = "+ cachedValue);
long endingTime = System.currentTimeMillis();
logger.fine("Inference milliseconds = " + (endingTime - startingTime));
}
return cachedValue;
}
public void getValueGradient(double[] buffer) {
if (cachedGradientWeightsStamp != crf.getWeightsValueChangeStamp()) {
cachedGradientWeightsStamp = crf.getWeightsValueChangeStamp();
getValue();
expectations.plusEquals(constraints, -1.0);
expectations.plusEqualsGaussianPriorGradient(crf.getParameters(), -gaussianPriorVariance);
expectations.assertNotNaNOrInfinite();
expectations.getParameters(cachedGradient);
MatrixOps.timesEquals(cachedGradient, -weight);
}
System.arraycopy(cachedGradient, 0, buffer, 0, cachedGradient.length);
}
public int getNumParameters() {
return numParameters;
}
public void getParameters(double[] buffer) {
crf.getParameters().getParameters(buffer);
}
public double getParameter(int index) {
return crf.getParameters().getParameter(index);
}
public void setParameters(double[] buff) {
crf.getParameters().setParameters(buff);
crf.weightsValueChanged();
}
public void setParameter(int index, double value) {
crf.getParameters().setParameter(index, value);
crf.weightsValueChanged();
}
public void setGaussianPriorVariance(double value) {
gaussianPriorVariance = value;
}
public void shutdown() {
executor.shutdown();
}
private class ExpectationTask implements Callable<Double> {
private int start;
private int end;
private CRF.Factors expectationsCopy;
public ExpectationTask(int start, int end, CRF.Factors exCopy) {
this.start = start;
this.end = end;
this.expectationsCopy = exCopy;
}
public CRF.Factors getExpectationsCopy() {
return expectationsCopy;
}
public Double call() throws Exception {
double value = 0;
for (int ii = start; ii < end; ii++) {
Instance inst = trainingSet.get(ii);
Sequence input = (Sequence) inst.getData();
double initProbs[] = initialProbList.get(ii);
double finalProbs[] = finalProbList.get(ii);
double transProbs[][][] = transitionProbList.get(ii);
double[][][] cachedDots = new double[input.size()][crf.numStates()][crf.numStates()];
for (int j = 0; j < input.size(); j++) {
for (int k = 0; k < crf.numStates(); k++) {
for (int l = 0; l < crf.numStates(); l++) {
cachedDots[j][k][l] = Transducer.IMPOSSIBLE_WEIGHT;
}
}
}
double labeledWeight = new SumLatticeKL(crf, input, initProbs,
finalProbs, transProbs, cachedDots, null).getTotalWeight();
value += labeledWeight;
//double unlabeledWeight = new SumLatticeDefault(crf, input,
// expectationsCopy.new Incrementor()).getTotalWeight();
double unlabeledWeight = new SumLatticeDefaultCachedDot(crf, input, null,
cachedDots, expectationsCopy.new Incrementor(), false, null).getTotalWeight();
value -= unlabeledWeight;
}
return value;
}
}
}