package cc.mallet.classify.tui;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.logging.Logger;
import cc.mallet.classify.FeatureConstraintUtil;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
/**
* Create "feature constraints" from data for use in GE training.
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/
public class Vectors2FeatureConstraints {
private static Logger logger = MalletLogger.getLogger(Vectors2FeatureConstraints.class.getName());
public static CommandOption.File vectorsFile = new
CommandOption.File(Vectors2FeatureConstraints.class, "input", "FILENAME",
true, null, "Data file used to generate constraints.", null);
public static CommandOption.File constraintsFile = new
CommandOption.File(Vectors2FeatureConstraints.class, "output", "FILENAME",
true, null, "Output file for constraints.", null);
public static CommandOption.File featuresFile = new
CommandOption.File(Vectors2FeatureConstraints.class, "features-file", "FILENAME",
false, null, "File with list of features used to generate constraints.", null);
public static CommandOption.File ldaFile = new
CommandOption.File(Vectors2FeatureConstraints.class, "lda-file", "FILENAME",
false, null, "File with serialized LDA object (if using LDA feature constraint selection).", null);
public static CommandOption.Integer numConstraints = new
CommandOption.Integer(Vectors2FeatureConstraints.class, "num-constraints", "FILENAME",
true, 10, "Number of feature constraints.", null);
public static CommandOption.String featureSelection = new
CommandOption.String(Vectors2FeatureConstraints.class, "feature-selection", "STRING",
true, "infogain | lda", "Method used to choose feature constraints.", null);
public static CommandOption.String targets = new
CommandOption.String(Vectors2FeatureConstraints.class, "targets", "STRING",
true, "none | oracle | heuristic | voted", "Method used to estimate constraint targets.", null);
public static CommandOption.Double majorityProb = new
CommandOption.Double(Vectors2FeatureConstraints.class, "majority-prob", "DOUBLE",
false, 0.9, "Probability for majority labels when using heuristic target estimation.", null);
public static void main(String[] args) {
CommandOption.process(Vectors2FeatureConstraints.class, args);
InstanceList list = InstanceList.load(vectorsFile.value);
// Here we will assume that we use all labeled data available.
ArrayList<Integer> features = null;
HashMap<Integer,ArrayList<Integer>> featuresAndLabels = null;
// if a features file was specified, then load features from the file
if (featuresFile.wasInvoked()) {
if (fileContainsLabels(featuresFile.value)) {
// better error message from dfrankow@gmail.com
if (targets.value.equals("oracle")) {
throw new RuntimeException("with --targets oracle, features file must be unlabeled");
}
featuresAndLabels = readFeaturesAndLabelsFromFile(featuresFile.value, list.getDataAlphabet(), list.getTargetAlphabet());
}
else {
features = readFeaturesFromFile(featuresFile.value, list.getDataAlphabet());
}
}
// otherwise select features using specified method
else {
if (featureSelection.value.equals("infogain")) {
features = FeatureConstraintUtil.selectFeaturesByInfoGain(list,numConstraints.value);
}
else if (featureSelection.value.equals("lda")) {
try {
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(ldaFile.value));
ParallelTopicModel lda = (ParallelTopicModel)ois.readObject();
features = FeatureConstraintUtil.selectTopLDAFeatures(numConstraints.value, lda, list.getDataAlphabet());
}
catch (Exception e) {
e.printStackTrace();
}
}
else {
throw new RuntimeException("Unsupported value for feature selection: " + featureSelection.value);
}
}
// If the target method is oracle, then we do not need feature "labels".
HashMap<Integer,double[]> constraints = null;
if (targets.value.equals("none")) {
constraints = new HashMap<Integer,double[]>();
for (int fi : features) {
constraints.put(fi, null);
}
}
else if (targets.value.equals("oracle")) {
constraints = FeatureConstraintUtil.setTargetsUsingData(list, features);
}
else {
// For other methods, we need to get feature labels, as
// long as they haven't been already loaded from disk.
if (featuresAndLabels == null) {
featuresAndLabels = FeatureConstraintUtil.labelFeatures(list,features);
for (int fi : featuresAndLabels.keySet()) {
logger.info(list.getDataAlphabet().lookupObject(fi) + ": ");
for (int li : featuresAndLabels.get(fi)) {
logger.info(list.getTargetAlphabet().lookupObject(li) + " ");
}
}
}
if (targets.value.equals("heuristic")) {
constraints = FeatureConstraintUtil.setTargetsUsingHeuristic(featuresAndLabels,list.getTargetAlphabet().size(),majorityProb.value);
}
else if (targets.value.equals("voted")) {
constraints = FeatureConstraintUtil.setTargetsUsingFeatureVoting(featuresAndLabels,list);
}
else {
throw new RuntimeException("Unsupported value for targets: " + targets.value);
}
}
writeConstraints(constraints,constraintsFile.value,list.getDataAlphabet(),list.getTargetAlphabet());
}
private static boolean fileContainsLabels(File file) {
String line = "";
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
line = reader.readLine().trim();
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
String[] split = line.split("\\s+");
if (split.length == 1) {
return false;
}
return true;
}
private static ArrayList<Integer> readFeaturesFromFile(File file, Alphabet dataAlphabet) {
ArrayList<Integer> features = new ArrayList<Integer>();
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
line = line.trim();
int featureIndex = dataAlphabet.lookupIndex(line,false);
features.add(featureIndex);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return features;
}
public static HashMap<Integer,ArrayList<Integer>> readFeaturesAndLabelsFromFile(File file, Alphabet dataAlphabet, Alphabet targetAlphabet) {
HashMap<Integer,ArrayList<Integer>> featuresAndLabels = new HashMap<Integer,ArrayList<Integer>>();
try {
BufferedReader reader = new BufferedReader(new FileReader(file));
String line = reader.readLine();
while (line != null) {
line = line.trim();
String[] split = line.split("\\s+");
int featureIndex = dataAlphabet.lookupIndex(split[0],false);
// better error message from dfrankow@gmail.com
if (featureIndex == -1) {
throw new RuntimeException("Couldn't find feature '"
+ split[0] + "' in the data alphabet.");
}
ArrayList<Integer> labels = new ArrayList<Integer>();
for (int i = 1; i < split.length; i++) {
// TODO should these be label names?
int li = targetAlphabet.lookupIndex(split[i]);
labels.add(li);
logger.info("found label " + li);
}
featuresAndLabels.put(featureIndex,labels);
line = reader.readLine();
}
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
return featuresAndLabels;
}
private static void writeConstraints(HashMap<Integer,double[]> constraints, File constraintsFile, Alphabet dataAlphabet, Alphabet targetAlphabet) {
if (constraints.size() == 0) {
logger.warning("No constraints written!");
return;
}
try {
FileWriter writer = new FileWriter(constraintsFile);
for (int fi : constraints.keySet()) {
writer.write(dataAlphabet.lookupObject(fi) + " ");
double[] p = constraints.get(fi);
if (p != null) {
for (int li = 0; li < p.length; li++) {
writer.write(targetAlphabet.lookupObject(li) + ":" + p[li] + " ");
}
}
writer.write("\n");
}
writer.close();
}
catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
}
}