/* * EBSPAnalysis.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.coalescent; import dr.evolution.io.Importer; import dr.evolution.io.NexusImporter; import dr.evolution.io.TreeImporter; import dr.evolution.tree.Tree; import dr.inference.trace.LogFileTraces; import dr.inference.trace.TraceException; import dr.stats.DiscreteStatistics; import dr.util.HeapSort; import dr.util.TabularData; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.io.PrintWriter; import java.util.Arrays; /** * @author Joseph Heled */ public class EBSPAnalysis extends TabularData { private final double[] xPoints; private final double[] means; private final double[] medians; private final double[][] hpdLower; private final double[][] hpdHigh; private final double[] HPDLevels; // each bin covers xPoints[-1]/coalBins.length private int[] coalBins; private final boolean quantiles; public EBSPAnalysis(File log, File[] treeFiles, VariableDemographicModel.Type modelType, String firstColumnName, String firstIndicatorColumnName, String rootHeightColumnName, int coalPointBins, double burnIn, double[] inHPDLevels, boolean quantiles, boolean logSpace, boolean mid, int restrictToNchanges, PrintWriter allDemoWriter) throws IOException, Importer.ImportException, TraceException { LogFileTraces ltraces = new LogFileTraces(log.getCanonicalPath(), log); ltraces.loadTraces(); ltraces.setBurnIn(0); final int runLengthIncludingBurnin = ltraces.getStateCount(); int intBurnIn = (int) Math.floor(burnIn < 1 ? runLengthIncludingBurnin * burnIn : burnIn); final int nStates = runLengthIncludingBurnin - intBurnIn; //intBurnIn *= ltraces.getStepSize(); ltraces.setBurnIn(intBurnIn * ltraces.getStepSize()); assert ltraces.getStateCount() == nStates; this.quantiles = quantiles; HPDLevels = (inHPDLevels != null) ? inHPDLevels : new double[]{0.95}; int populationFirstColumn = -1; int indicatorsFirstColumn = -1; int rootHeightColumn = -1; for (int n = 0; n < ltraces.getTraceCount(); ++n) { final String traceName = ltraces.getTraceName(n); if (traceName.equals(firstColumnName)) { populationFirstColumn = n; } else if (traceName.equals(firstIndicatorColumnName)) { indicatorsFirstColumn = n; } else if (rootHeightColumnName != null && traceName.equals(rootHeightColumnName)) { rootHeightColumn = n; } } if (populationFirstColumn < 0 || indicatorsFirstColumn < 0) { throw new TraceException("incorrect trace column names: unable to find populations/indicators"); } double binSize = 0; if (coalPointBins > 0) { if (rootHeightColumn < 0) { throw new TraceException("incorrect tree height column"); } double hSum = -0; double[] h = new double[1]; for (int ns = 0; ns < nStates; ++ns) { ltraces.getStateValues(ns, h, rootHeightColumn); hSum += h[0]; } binSize = hSum / (nStates * coalPointBins); coalBins = new int[coalPointBins]; Arrays.fill(coalBins, 0); } TreeImporter[] treeImporters = new TreeImporter[treeFiles.length]; final boolean isStepWise = modelType == VariableDemographicModel.Type.STEPWISE; int nIndicators = 0; for (int k = 0; k < treeFiles.length; ++k) { // System.err.println("burnin " + treeFiles[k] + "(" + k + ")"); treeImporters[k] = new NexusImporter(new FileReader(treeFiles[k])); assert intBurnIn > 0; for (int z = 0; z < intBurnIn - 1; ++z) { treeImporters[k].importNextTree(); } nIndicators += treeImporters[k].importNextTree().getExternalNodeCount() - 1; } if (isStepWise) { nIndicators -= 1; } final int nXaxisPoints = nIndicators + (isStepWise ? 1 : 0) + 1; xPoints = new double[nXaxisPoints]; Arrays.fill(xPoints, 0.0); int nDataPoints = 0; VDdemographicFunction[] allDemog = new VDdemographicFunction[nStates]; { double[] indicators = new double[nIndicators]; double[] pop = new double[nIndicators + 1]; Tree[] tt = new Tree[treeFiles.length]; boolean match = true; for (int ns = 0; ns < nStates; ++ns) { ltraces.getStateValues(ns, indicators, indicatorsFirstColumn); ltraces.getStateValues(ns, pop, populationFirstColumn); if (match) { for (int nt = 0; nt < tt.length; ++nt) { tt[nt] = treeImporters[nt].importNextTree(); if( tt[nt] == null ) { throw new TraceException("All NEXUS tree files should contain the same number of states"); } } } //Get tree state number final String name1 = tt[0].getId(); final int state1 = Integer.parseInt(name1.substring(name1.indexOf('_') + 1, name1.length())); for (int j = 1; j < tt.length; ++j) { final String name2 = tt[j].getId(); int state2 = Integer.parseInt(name1.substring(name2.indexOf('_') + 1, name2.length())); if (state1 != state2) { throw new TraceException("NEXUS tree files have different rates or corrupted!!!!"); } } if ((ns + intBurnIn) * ltraces.getStepSize() == state1) { //Check if log state matches tree state match = true; final VDdemographicFunction demoFunction = new VDdemographicFunction(tt, modelType, indicators, pop, logSpace, mid); if (restrictToNchanges >= 0 && demoFunction.numberOfChanges() != restrictToNchanges) { continue; } double[] xs = demoFunction.allTimePoints(); for (int k = 0; k < xs.length; ++k) { xPoints[k + 1] += xs[k]; } if (coalPointBins > 0) { for (double x : xs) { coalBins[Math.min((int) (x / binSize), coalBins.length - 1)]++; } } allDemog[nDataPoints] = demoFunction; ++nDataPoints; demoFunction.freeze(); } else { match = false; } } for (int k = 0; k < xPoints.length; ++k) { xPoints[k] /= nStates; } if (nStates != nDataPoints) { //Warning if log file and tree files // have different rates System.err.println("Different Rates is \"main\" and \"tree\" log files"); } if (nDataPoints < 10) { //Warning if number of states is not sufficient // enough to do the analysis System.err.println("Warning!!! Not Sufficient number of data points"); } } double[] popValues = new double[nDataPoints]; means = new double[nXaxisPoints]; medians = new double[nXaxisPoints]; hpdLower = new double[HPDLevels.length][]; hpdHigh = new double[HPDLevels.length][]; for (int i = 0; i < HPDLevels.length; ++i) { hpdLower[i] = new double[nXaxisPoints]; hpdHigh[i] = new double[nXaxisPoints]; } for (int nx = 0; nx < xPoints.length; ++nx) { final double x = xPoints[nx]; for (int ns = 0; ns < nDataPoints; ++ns) { popValues[ns] = allDemog[ns].getDemographic(x); } int[] indices = new int[popValues.length]; HeapSort.sort(popValues, indices); means[nx] = DiscreteStatistics.mean(popValues); for (int i = 0; i < HPDLevels.length; ++i) { if (quantiles) { hpdLower[i][nx] = DiscreteStatistics.quantile((1 - HPDLevels[i]) / 2, popValues, indices); hpdHigh[i][nx] = DiscreteStatistics.quantile((1 + HPDLevels[i]) / 2, popValues, indices); } else { final double[] hpd = DiscreteStatistics.HPDInterval(HPDLevels[i], popValues, indices); hpdLower[i][nx] = hpd[0]; hpdHigh[i][nx] = hpd[1]; } } medians[nx] = DiscreteStatistics.median(popValues, indices); } if( allDemoWriter != null ) { for(double xPoint : xPoints) { allDemoWriter.print(xPoint); allDemoWriter.append(' '); } for (int ns = 0; ns < nDataPoints; ++ns) { allDemoWriter.println(); for(double xPoint : xPoints) { allDemoWriter.print(allDemog[ns].getDemographic(xPoint)); allDemoWriter.append(' '); } } allDemoWriter.close(); } } private final String[] columnNames = {"time", "mean", "median"}; public int nColumns() { return columnNames.length + 2 * HPDLevels.length + (coalBins != null ? 1 : 0); } public String columnName(int nColumn) { final int fixed = columnNames.length; if (nColumn < fixed) { return columnNames[nColumn]; } nColumn -= fixed; if (nColumn < 2 * HPDLevels.length) { final double p = HPDLevels[nColumn / 2]; final String s = (nColumn % 2 == 0) ? "lower" : "upper"; return (quantiles ? "cpd " : "hpd ") + s + " " + Math.round(p * 100); } assert (nColumn - 2 * HPDLevels.length) == 0; return "bins"; } public int nRows() { return Math.max(xPoints.length, (coalBins != null ? coalBins.length : 0)); } public Object data(int nRow, int nColumn) { switch (nColumn) { case 0: { if (nRow < xPoints.length) { return xPoints[nRow]; } break; } case 1: { if (nRow < means.length) { return means[nRow]; } break; } case 2: { if (nRow < medians.length) { return medians[nRow]; } break; } default: { final int j = nColumn - columnNames.length; if (j < 2 * HPDLevels.length) { if (nRow < xPoints.length) { final int k = j / 2; if (0 <= k && k < HPDLevels.length) { if (j % 2 == 0) { return hpdLower[k][nRow]; } else { return hpdHigh[k][nRow]; } } } } else { if (nRow < coalBins.length) { return coalBins[nRow]; } } break; } } return ""; } }