/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.util.List;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.regex.Pattern;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import cc.mallet.grmm.types.*;
import gnu.trove.THashMap;
import bsh.Interpreter;
import bsh.EvalError;
/**
* $Id: ModelReader.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class ModelReader {
private static THashMap allClasses;
static {
allClasses = new THashMap ();
// add new classes here
allClasses.put ("potts", PottsTableFactor.class);
allClasses.put ("unary", BoltzmannUnaryFactor.class);
allClasses.put ("binaryunary", BinaryUnaryFactor.class);
allClasses.put ("binarypair", BoltzmannPairFactor.class);
allClasses.put ("uniform", UniformFactor.class);
allClasses.put ("normal", UniNormalFactor.class);
allClasses.put ("beta", BetaFactor.class);
}
private THashMap name2var = new THashMap ();
public static Assignment readFromMatrix (VarSet vars, Reader in) throws IOException
{
Variable[] varr = vars.toVariableArray ();
Interpreter interpreter = new Interpreter ();
BufferedReader bIn = new BufferedReader (in);
Assignment assn = new Assignment ();
String line;
while ((line = bIn.readLine ()) != null) {
String[] fields = line.split ("\\s+");
Object[] vals = new Object [fields.length];
for (int i = 0; i < fields.length; i++) {
try {
vals[i] = interpreter.eval (fields[i]);
} catch (EvalError e) {
throw new RuntimeException ("Error reading line: "+line, e);
}
}
assn.addRow (varr, vals);
}
return assn;
}
public FactorGraph readModel (BufferedReader in) throws IOException
{
List factors = new ArrayList ();
String line;
while ((line = in.readLine ()) != null) {
try {
if (Pattern.matches ("^\\s*$", line)) { continue; }
String[] fields = line.split ("\\s+");
if (fields[0].equalsIgnoreCase ("VAR")) {
// a variable declaration
handleVariableDecl (fields);
} else {
// a factor line
Factor factor = factorFromLine (fields);
factors.add (factor);
}
} catch (Exception e) {
throw new RuntimeException ("Error reading line:\n"+line, e);
}
}
FactorGraph fg = new FactorGraph ();
for (Iterator it = factors.iterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
fg.multiplyBy (factor);
}
return fg;
}
private void handleVariableDecl (String[] fields)
{
int colonIdx = findColon (fields);
if (fields.length != colonIdx + 2) throw new IllegalArgumentException ("Invalid syntax");
String numOutsString = fields[colonIdx+1];
int numOutcomes;
if (numOutsString.equalsIgnoreCase ("continuous")) {
numOutcomes = Variable.CONTINUOUS;
} else {
numOutcomes = Integer.parseInt (numOutsString);
}
for (int i = 0; i < colonIdx; i++) {
String name = fields[i];
Variable var = new Variable (numOutcomes);
var.setLabel (name);
name2var.put (name, var);
}
}
private int findColon (String[] fields)
{
for (int i = 0; i < fields.length; i++) {
if (fields[i].equals (":")) {
return i;
}
}
throw new IllegalArgumentException ("Invalid syntax.");
}
private Factor factorFromLine (String[] fields)
{
int idx = findTwiddle (fields);
return constructFactor (fields, idx);
}
private int findTwiddle (String[] fields)
{
for (int i = 0; i < fields.length; i++) {
if (fields[i].equals ("~")) {
return i;
}
}
return -1;
}
private Factor constructFactor (String[] fields, int idx)
{
Class factorClass = determineFactorClass (fields, idx);
Object[] args = determineFactorArgs (fields, idx);
Constructor factorCtor = findCtor (factorClass, args);
Factor factor;
try {
factor = (Factor) factorCtor.newInstance (args);
} catch (InstantiationException e) {
throw new RuntimeException (e);
} catch (IllegalAccessException e) {
throw new RuntimeException (e);
} catch (InvocationTargetException e) {
throw new RuntimeException (e);
}
return factor;
}
private Constructor findCtor (Class factorClass, Object[] args)
{
Class[] argClass = new Class[args.length];
for (int i = 0; i < args.length; i++) {
argClass[i] = args[i].getClass ();
// special case
if (argClass[i] == Double.class) { argClass[i] = double.class; }
}
try {
return factorClass.getDeclaredConstructor (argClass);
} catch (NoSuchMethodException e) {
StringBuffer buf = new StringBuffer("Invalid argments for factor "+factorClass+"\n");
buf.append ("Args were:\n");
for (int i = 0; i < args.length; i++) {
buf.append(args[i]);
buf.append(" ");
}
buf.append("\n");
for (int i = 0; i < args.length; i++) {
buf.append(args[i].getClass());
buf.append(" ");
}
buf.append("\n");
throw new RuntimeException (buf.toString());
}
}
private Class determineFactorClass (String[] fields, int twiddleIdx)
{
String factorName = fields [twiddleIdx + 1].toLowerCase ();
Class theClass = (Class) allClasses.get (factorName);
if (theClass != null) {
return theClass;
} else {
throw new RuntimeException ("Could not determine factor class from "+factorName);
}
}
private Object[] determineFactorArgs (String[] fields, int twiddleIdx)
{
List args = new ArrayList (fields.length);
for (int i = 0; i < twiddleIdx; i++) {
args.add (varFromName (fields[i], true));
}
for (int i = twiddleIdx+2; i < fields.length; i++) {
args.add (varFromName (fields[i], false));
}
return args.toArray ();
}
private static Pattern nbrRegex = Pattern.compile ("[+-]?\\d+(?:\\.\\d+)?(E[+-]\\d+)?");
private Object varFromName (String name, boolean preTwiddle)
{
if (nbrRegex.matcher(name).matches ()) {
return new Double (Double.parseDouble (name));
} else if (name2var.contains (name)) {
return name2var.get (name);
} else {
Variable var = (preTwiddle) ? new Variable (2) : new Variable (Variable.CONTINUOUS);
var.setLabel (name);
name2var.put (name, var);
return var;
}
}
}