/*
* GeneralizedSteppingStoneSamplingAnalysis.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.trace;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.util.*;
import dr.util.Attribute;
import dr.util.FileHelpers;
import dr.xml.*;
/**
* @author Guy Baele
*/
public class GeneralizedSteppingStoneSamplingAnalysis {
public static final String GENERALIZED_STEPPING_STONE_SAMPLING_ANALYSIS = "generalizedSteppingStoneSamplingAnalysis";
public static final String RESULT_FILE_NAME = "resultsFileName";
public static final String THETA_COLUMN = "thetaColumn";
public static final String SOURCE_COLUMN = "sourceColumn";
public static final String DESTINATION_COLUMN = "destinationColumn";
public static final String FORMAT = "%5.5g";
private final String sourceName, destinationName;
private final List<Double> thetaSample;
private final List<Double> sourceSample;
private final List<Double> destinationSample;
private boolean logBayesFactorCalculated = false;
private double logBayesFactor;
private List<Double> maxLogLikelihood;
private List<Double> orderedTheta;
private List<Double> mlContribution;
public GeneralizedSteppingStoneSamplingAnalysis(String sourceName, String destinationName, List<Double> thetaSample, List<Double> sourceSample, List<Double> destinationSample) {
this.sourceName = sourceName;
this.destinationName = destinationName;
this.thetaSample = thetaSample;
this.sourceSample = sourceSample;
this.destinationSample = destinationSample;
}
public double getLogBayesFactor() {
if (!logBayesFactorCalculated) {
calculateBF();
}
return logBayesFactor;
}
private void calculateBF() {
Map<Double, List<Double>> map = new HashMap<Double, List<Double>>();
Map<Double, List<Double>> testmap = new HashMap<Double, List<Double>>();
orderedTheta = new ArrayList<Double>();
//the log-likelihood*prior/refprior values are needed to calculate the marginal likelihood
for (int i = 0; i < sourceSample.size(); i++) {
if (!map.containsKey(thetaSample.get(i))) {
map.put(thetaSample.get(i), new ArrayList<Double>());
testmap.put(thetaSample.get(i), new ArrayList<Double>());
orderedTheta.add(thetaSample.get(i));
}
map.get(thetaSample.get(i)).add(sourceSample.get(i) - destinationSample.get(i));
testmap.get(thetaSample.get(i)).add(sourceSample.get(i));
}
Collections.sort(orderedTheta);
//a list with the maxima of the log-likelihood*prior/refprior values is constructed
System.out.println("Test source column:");
maxLogLikelihood = new ArrayList<Double>();
for (double t : orderedTheta) {
List<Double> values = map.get(t);
maxLogLikelihood.add(Collections.max(values));
System.out.println(Collections.max(testmap.get(t)));
}
System.out.println("Number of maximum loglikelihoods: " + maxLogLikelihood.size());
for (double ml : maxLogLikelihood) {
System.out.println(ml);
}
mlContribution = new ArrayList<Double>();
logBayesFactor = 0.0;
for (int i = 1; i < orderedTheta.size(); i++) {
double contribution = (orderedTheta.get(i) - orderedTheta.get(i-1)) * maxLogLikelihood.get(i-1);
logBayesFactor += contribution;
mlContribution.add(contribution);
}
//System.out.println(logBayesFactor);
for (int i = 1; i < orderedTheta.size(); i++) {
double internalSum = 0.0;
for (int j = 0; j < map.get(orderedTheta.get(i-1)).size(); j++) {
internalSum += Math.exp((orderedTheta.get(i) - orderedTheta.get(i-1)) * (map.get(orderedTheta.get(i-1)).get(j) - maxLogLikelihood.get(i-1)));
}
internalSum /= map.get(orderedTheta.get(i-1)).size();
//System.out.print(orderedTheta.get(i) + "-" + orderedTheta.get(i-1) + ": ");
//System.out.println(Math.log(internalSum));
mlContribution.set(i-1, mlContribution.get(i-1) + Math.log(internalSum));
logBayesFactor += Math.log(internalSum);
}
logBayesFactorCalculated = true;
}
public String toString() {
double bf = getLogBayesFactor();
StringBuffer sb = new StringBuffer();
sb.append("PathParameter\tMaxPathLikelihood\tMLContribution\n");
for (int i = 0; i < orderedTheta.size(); ++i) {
sb.append(String.format(FORMAT, orderedTheta.get(i)));
sb.append("\t");
sb.append(String.format(FORMAT, maxLogLikelihood.get(i)));
sb.append("\t");
if (i != (orderedTheta.size()-1)) {
sb.append(String.format(FORMAT, mlContribution.get(i)));
}
sb.append("\n");
}
sb.append("\nlog marginal likelihood (using generalized stepping stone sampling) from (" + sourceName + " - " + destinationName + ") = " + bf + "\n");
return sb.toString();
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return GENERALIZED_STEPPING_STONE_SAMPLING_ANALYSIS;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String fileName = xo.getStringAttribute(FileHelpers.FILE_NAME);
String resultFileName = null;
if (xo.hasAttribute(RESULT_FILE_NAME)) {
resultFileName = xo.getStringAttribute(RESULT_FILE_NAME);
}
StringTokenizer tokenFileName = new StringTokenizer(fileName);
int numberOfFiles = tokenFileName.countTokens();
System.out.println(numberOfFiles + " file(s) found with marginal likelihood samples");
try {
String sourceName = "", destinationName = "";
List sampleTheta = null;
List sampleSource = null;
List sampleDestination = null;
for (int j = 0; j < numberOfFiles; j++) {
File file = new File(tokenFileName.nextToken());
String name = file.getName();
String parent = file.getParent();
if (!file.isAbsolute()) {
parent = System.getProperty("user.dir");
}
file = new File(parent, name);
fileName = file.getAbsolutePath();
XMLObject cxo = xo.getChild(SOURCE_COLUMN);
sourceName = cxo.getStringAttribute(Attribute.NAME);
cxo = xo.getChild(DESTINATION_COLUMN);
destinationName = cxo.getStringAttribute(Attribute.NAME);
cxo = xo.getChild(THETA_COLUMN);
String thetaName = cxo.getStringAttribute(Attribute.NAME);
LogFileTraces traces = new LogFileTraces(fileName, file);
traces.loadTraces();
int burnin = 0;
traces.setBurnIn(burnin);
int traceIndexTheta = -1;
int traceIndexSource = -1;
int traceIndexDestination = -1;
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
if (traceName.trim().equals(thetaName)) {
traceIndexTheta = i;
}
if (traceName.trim().equals(sourceName)) {
traceIndexSource = i;
}
if (traceName.trim().equals(destinationName)) {
traceIndexDestination = i;
}
}
if (traceIndexTheta == -1) {
throw new XMLParseException("Column '" + thetaName + "' can not be found for " + getParserName() + " element.");
}
if (traceIndexSource == -1) {
throw new XMLParseException("Column '" + sourceName + "' can not be found for " + getParserName() + " element.");
}
if (traceIndexDestination == -1) {
throw new XMLParseException("Column '" + destinationName + "' can not be found for " + getParserName() + " element.");
}
if (sampleTheta == null && sampleSource == null && sampleDestination == null) {
sampleTheta = traces.getValues(traceIndexTheta);
sampleSource = traces.getValues(traceIndexSource);
sampleDestination = traces.getValues(traceIndexDestination);
} else {
sampleTheta.addAll(traces.getValues(traceIndexTheta));
sampleSource.addAll(traces.getValues(traceIndexSource));
sampleDestination.addAll(traces.getValues(traceIndexDestination));
}
}
GeneralizedSteppingStoneSamplingAnalysis analysis = new GeneralizedSteppingStoneSamplingAnalysis(sourceName, destinationName, sampleTheta, sampleSource, sampleDestination);
System.out.println(analysis.toString());
if (resultFileName != null) {
FileWriter fw = new FileWriter(resultFileName, true);
BufferedWriter bw = new BufferedWriter(fw);
bw.write(analysis.toString());
bw.flush();
bw.close();
}
return analysis;
} catch (FileNotFoundException fnfe) {
throw new XMLParseException("File '" + fileName + "' can not be opened for " + getParserName() + " element.");
} catch (java.io.IOException ioe) {
throw new XMLParseException(ioe.getMessage());
} catch (TraceException e) {
throw new XMLParseException(e.getMessage());
}
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Performs a trace analysis.";
}
public Class getReturnType() {
return GeneralizedSteppingStoneSamplingAnalysis.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new StringAttributeRule(FileHelpers.FILE_NAME,
"The traceName of a BEAST log file (can not include trees, which should be logged separately)"),
new StringAttributeRule(RESULT_FILE_NAME,
"The name of the output file to which the generalized stepping-stone sampling estimate will be written", true),
new ElementRule(THETA_COLUMN, new XMLSyntaxRule[]{
new StringAttributeRule(Attribute.NAME, "The column name")}),
new ElementRule(SOURCE_COLUMN, new XMLSyntaxRule[]{
new StringAttributeRule(Attribute.NAME, "The column name")}),
new ElementRule(DESTINATION_COLUMN, new XMLSyntaxRule[]{
new StringAttributeRule(Attribute.NAME, "The column name")})
};
};
}