/*
* TraceAnalysis.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.trace;
import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.List;
/**
* @author Alexei Drummond
* @version $Id: TraceAnalysis.java,v 1.23 2005/05/24 20:26:00 rambaut Exp $
*/
public class TraceAnalysis {
/**
* @param fileName the name of the log file to analyze
* @param burnin the state to discard up to
* @return an array og analyses of the statistics in a log file.
* @throws java.io.IOException if general error reading file
* @throws TraceException if trace file in wrong format or corrupted
*/
public static LogFileTraces analyzeLogFile(String fileName, int burnin) throws java.io.IOException, TraceException {
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
traces.loadTraces();
traces.setBurnIn(burnin);
for (int i = 0; i < traces.getTraceCount(); i++) {
traces.analyseTrace(i);
}
return traces;
}
public static TraceList report(String fileName) throws java.io.IOException, TraceException {
return report(fileName, -1, null);
}
public static TraceList report(String fileName, int burnin, String likelihoodName) throws java.io.IOException, TraceException {
return report(fileName, burnin, likelihoodName, true);
}
public static TraceList report(String fileName, int inBurnin, String likelihoodName, boolean withStdError)
throws java.io.IOException, TraceException {
// int fieldWidth = 14;
// int firstField = 25;
// NumberFormatter formatter = new NumberFormatter(4);
// formatter.setPadding(true);
// formatter.setFieldWidth(fieldWidth);
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
// if (traces == null) {
// throw new TraceException("Trace file is empty.");
// }
traces.loadTraces();
// traces.addTrace("R0", traces.getTraceIndex("bdss.psi"));
int burnin = inBurnin;
if (burnin == -1) {
burnin = (int) (traces.getMaxState() / 10);
}
traces.setBurnIn(burnin);
// System.out.println();
System.out.println("burnIn <= " + burnin + ", maxState = " + traces.getMaxState());
// System.out.println();
System.out.print("statistic");
String[] names;
if (!withStdError)
names = new String[]{"mean", "hpdLower", "hpdUpper", "ESS"};
else
names = new String[]{"mean", "stdErr", "median", "hpdLower", "hpdUpper", "ESS", "50hpdLower", "50hpdUpper"};
for (String name : names) {
System.out.print("\t" + name);
}
System.out.println();
int warning = 0;
for (int i = 0; i < traces.getTraceCount(); i++) {
traces.analyseTrace(i);
TraceCorrelation distribution = traces.getCorrelationStatistics(i);
double ess = distribution.getESS();
System.out.print(traces.getTraceName(i));
System.out.print("\t" + formattedNumber(distribution.getMean()));
if (withStdError) {
System.out.print("\t" + formattedNumber(distribution.getStdError()));
System.out.print("\t" + formattedNumber(distribution.getMedian()));
}
System.out.print("\t" + formattedNumber(distribution.getLowerHPD()));
System.out.print("\t" + formattedNumber(distribution.getUpperHPD()));
System.out.print("\t" + formattedNumber(ess));
if (withStdError) {
System.out.print("\t" + formattedNumber(distribution.getHpdLowerCustom()));
System.out.print("\t" + formattedNumber(distribution.getHpdUpperCustom()));
}
if (ess < 100) {
warning += 1;
System.out.println("\t" + "*");
} else {
System.out.println("\t");
}
}
System.out.println();
if (warning > 0) {
System.out.println(" * WARNING: The results of this MCMC analysis may be invalid as ");
System.out.println(" one or more statistics had very low effective sample sizes (ESS)");
}
if (likelihoodName != null) {
System.out.println();
int traceIndex = -1;
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
if (traceName.equals(likelihoodName)) {
traceIndex = i;
break;
}
}
if (traceIndex == -1) {
throw new TraceException("Column '" + likelihoodName +
"' can not be found for marginal likelihood analysis.");
}
String analysisType = "aicm";
int bootstrapLength = 1000;
List<Double> sample = traces.getValues(traceIndex);
MarginalLikelihoodAnalysis analysis = new MarginalLikelihoodAnalysis(sample,
traces.getTraceName(traceIndex), burnin, analysisType, bootstrapLength);
System.out.println(analysis.toString());
}
System.out.flush();
return traces;
}
public static void reportTrace(String fileName, int inBurnin, String traceName) throws IOException, TraceException {
File file = new File(fileName);
LogFileTraces traces = new LogFileTraces(fileName, file);
traces.loadTraces();
int burnin = inBurnin;
if (burnin == -1) {
burnin = (int) (traces.getMaxState() / 10);
}
traces.setBurnIn(burnin);
// System.out.println();
// System.out.println("burnIn <= " + burnin + ", maxState = " + traces.getMaxState());
// System.out.println();
// System.out.print("statistic");
// String[] names = new String[]{"mean", "stdErr", "median", "hpdLower", "hpdUpper", "50hpdLower", "50hpdUpper"};//, "ESS"};
//
// for (String name : names) {
// System.out.print("\t" + name);
// }
// System.out.println();
int id = traces.getTraceIndex(traceName);
traces.analyseTrace(id);
TraceCorrelation distribution = traces.getCorrelationStatistics(id);
double ess = distribution.getESS();
// System.out.print(traces.getTraceName(id) + "\t");
System.out.print(formattedNumber(distribution.getMean()) + "\t");
System.out.print(formattedNumber(distribution.getStdError()) + "\t");
System.out.print(formattedNumber(distribution.getMedian()) + "\t");
System.out.print(formattedNumber(distribution.getLowerHPD()) + "\t");
System.out.print(formattedNumber(distribution.getUpperHPD()) + "\t");
System.out.print(formattedNumber(distribution.getHpdLowerCustom()) + "\t");
System.out.print(formattedNumber(distribution.getHpdUpperCustom()) + "\t");
System.out.println();
// System.out.print(SummaryStatisticsPanel.formattedNumber(ess));
}
/**
* @param burnin the number of states of burnin or if -1 then use 10%
* @param filename the file name of the log file to report on
* @param drawHeader if true then draw header
* @param stdErr if true then report the standard deviation of the mean
* @param hpds if true then report 95% hpd upper and lower
* @param individualESSs minimum number of ESS with which to throw warning
* @param likelihoodName column name
* @return the traces loaded from given file to create this short report
* @throws java.io.IOException if general error reading file
* @throws TraceException if trace file in wrong format or corrupted
*/
public static TraceList shortReport(String filename,
final int burnin, boolean drawHeader,
boolean hpds, boolean individualESSs, boolean stdErr,
String likelihoodName) throws java.io.IOException, TraceException {
TraceList traces = analyzeLogFile(filename, burnin);
long maxState = traces.getMaxState();
double minESS = Double.MAX_VALUE;
if (drawHeader) {
System.out.print("file\t");
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
System.out.print(traceName + "\t");
if (stdErr)
System.out.print(traceName + " stdErr\t");
if (hpds) {
System.out.print(traceName + " hpdLower\t");
System.out.print(traceName + " hpdUpper\t");
}
if (individualESSs) {
System.out.print(traceName + " ESS\t");
}
}
System.out.print("minESS\t");
if (likelihoodName != null) {
System.out.print("marginal likelihood\t");
System.out.print("stdErr\t");
}
System.out.println("chainLength");
}
System.out.print(filename + "\t");
for (int i = 0; i < traces.getTraceCount(); i++) {
//TraceDistribution distribution = traces.getDistributionStatistics(i);
TraceCorrelation distribution = traces.getCorrelationStatistics(i);
System.out.print(distribution.getMean() + "\t");
if (stdErr)
System.out.print(distribution.getStdErrorOfMean() + "\t");
if (hpds) {
System.out.print(distribution.getLowerHPD() + "\t");
System.out.print(distribution.getUpperHPD() + "\t");
}
if (individualESSs) {
System.out.print(distribution.getESS() + "\t");
}
double ess = distribution.getESS();
if (ess < minESS) {
minESS = ess;
}
}
System.out.print(minESS + "\t");
if (likelihoodName != null) {
int traceIndex = -1;
for (int i = 0; i < traces.getTraceCount(); i++) {
String traceName = traces.getTraceName(i);
if (traceName.equals(likelihoodName)) {
traceIndex = i;
break;
}
}
if (traceIndex == -1) {
throw new TraceException("Column '" + likelihoodName + "' can not be found in file " + filename + ".");
}
String analysisType = "aicm";
int bootstrapLength = 1000;
List<Double> sample = traces.getValues(traceIndex);
MarginalLikelihoodAnalysis analysis = new MarginalLikelihoodAnalysis(sample,
traces.getTraceName(traceIndex), burnin, analysisType, bootstrapLength);
System.out.print(analysis.getLogMarginalLikelihood() + "\t");
System.out.print(analysis.getBootstrappedSE() + "\t");
}
System.out.println(maxState);
return traces;
}
public static String formattedNumber(double value) {
DecimalFormat formatter = new DecimalFormat("0.####E0");
DecimalFormat formatter2 = new DecimalFormat("####0.####");
if (value > 0 && (Math.abs(value) < 0.01 || Math.abs(value) >= 100000.0)) {
return formatter.format(value);
} else return formatter2.format(value);
}
static final String[] colNamesNumeric = {"mean", "stderr_of_mean", "stdev", "variance", "median", "min", "max",
"quantile1", "quantile3", "95_hpd_lower", "95_hpd_upper", "ACT", "ESS", "num_samples", "geometric_mean"};
static final String[] colNamesCategorical = {"mode", "mode_frequency", "mode_probability",
"unique_values", "95_credible_set"};
/**
* Output a tab-delimited result of the full statistic summary in a string,
* given a list of <code>TraceList</code> (log or combined trace).
* The rows are traces, columns are statistics.
* The left section of statistics is for numbers, the right for categorical values,
* if null or NA, then return empty string in that particular row and column.
*
* @param traceLists
* @return
*/
public static String getStatisticSummary(List<TraceList> traceLists) {
StringBuffer buffer = new StringBuffer();
String[] colNames = colNamesNumeric;
if (TraceTypeUtils.anyCategorical(traceLists, null)) {
colNames = new String[colNamesNumeric.length + colNamesCategorical.length];
System.arraycopy(colNamesNumeric, 0, colNames, 0, colNamesNumeric.length);
System.arraycopy(colNamesCategorical, 0, colNames, colNamesNumeric.length, colNamesCategorical.length);
}
for (int i = 0; i < traceLists.size(); i++) {
TraceList tl = traceLists.get(i);
// trace list name
String prefix = "";
// add prefix to multi-log
if (traceLists.size() > 1) {
prefix = tl.getName() + ".";
// rm all spaces
prefix = prefix.replaceAll("\\s+", "");
// file extension
if (prefix.contains(".txt") || prefix.contains(".log"))
prefix = prefix.replaceAll("\\.txt|\\.log", "");
}
// write column names
if (i == 0) {
for (String colName : colNames) {
buffer.append("\t");
buffer.append(colName);
}
buffer.append("\n");
}
// main
for (int row = 0; row < tl.getTraceCount(); row++) {
// row name
buffer.append(prefix + tl.getTrace(row).getName());
TraceCorrelation tc = tl.getCorrelationStatistics(row);
// stats
for (int col = 0; col < colNames.length; col++) {
buffer.append("\t");
String stats = getStatistic(col, tc);
buffer.append(stats);
}
buffer.append("\n");
}
}
return buffer.toString();
}
private static String getStatistic(int i, TraceCorrelation tc) {
if (tc == null)
return "";
if (tc.getTraceType().isContinuous() && i >= colNamesNumeric.length)
return "";
if (tc.getTraceType().isCategorical() && i < colNamesNumeric.length)
return "";
Object value = null;
switch (i) { // i is the index of colNamesNumeric + colNamesCategorical
case 0:
value = tc.getMean();
break;
case 1:
value = tc.getStdErrorOfMean();
break;
case 2:
value = tc.getStdError();
break;
case 3:
value = tc.getVariance();
break;
case 4:
value = tc.getMedian();
break;
case 5:
value = tc.getMinimum();
break;
case 6:
value = tc.getMaximum();
break;
case 7:
value = tc.getQ1();
break;
case 8:
value = tc.getQ3();
break;
case 9:
value = tc.getLowerHPD();
break;
case 10:
value = tc.getUpperHPD();
break;
case 11:
value = tc.getACT();
break;
case 12:
value = tc.getESS();
break;
case 13:
value = tc.getSize();
break;
case 14:
if (!tc.hasGeometricMean()) return "";
value = tc.getGeometricMean();
break;
//+++++ categorical +++++
case 15:
value = tc.getMode().toString();
break;
case 16:
value = tc.getFrequencyOfMode();
break;
case 17:
value = tc.getProbabilityOfMode();
break;
case 18:
value = tc.printUniqueValues();
break;
case 19:
value = tc.printCredibleSet();
break;
}
if (value == null)
return "";
else if (value instanceof Double && Double.isNaN(((Double) value).doubleValue()))
return "";
else
return value.toString();
}
}