/* * BeastCheckpointer.java * * Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.checkpoint; import dr.evolution.tree.NodeRef; import dr.evomodel.tree.TreeModel; import dr.evomodel.tree.TreeParameterModel; import dr.inference.state.StateSaverChainListener; import dr.inference.state.Factory; import dr.inference.state.StateLoader; import dr.inference.state.StateSaver; import dr.inference.markovchain.MarkovChain; import dr.inference.markovchain.MarkovChainListener; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.operators.CoercableMCMCOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.OperatorSchedule; import dr.math.MathUtils; import java.io.*; import java.text.SimpleDateFormat; import java.util.*; /** * A state loader / saver * @author Andrew Rambaut * @author Guy Baele */ public class BeastCheckpointer implements StateLoader, StateSaver { private static final boolean DEBUG = false; // A debugging flag to do a check that the state gives the same likelihood after loading private static final boolean CHECK_LOAD_STATE = false; public final static String LOAD_STATE_FILE = "load.state.file"; public final static String SAVE_STATE_FILE = "save.state.file"; public final static String SAVE_STATE_AT = "save.state.at"; public final static String SAVE_STATE_EVERY = "save.state.every"; private final String loadStateFileName; private final String saveStateFileName; public BeastCheckpointer() { loadStateFileName = System.getProperty(LOAD_STATE_FILE, null); saveStateFileName = System.getProperty(SAVE_STATE_FILE, null); final List<MarkovChainListener> listeners = new ArrayList<MarkovChainListener>(); if (System.getProperty(SAVE_STATE_AT) != null) { final long saveStateAt = Long.parseLong(System.getProperty(SAVE_STATE_AT)); listeners.add(new StateSaverChainListener(BeastCheckpointer.this, saveStateAt,false)); } if (System.getProperty(SAVE_STATE_EVERY) != null) { final long saveStateEvery = Long.parseLong(System.getProperty(SAVE_STATE_EVERY)); listeners.add(new StateSaverChainListener(BeastCheckpointer.this, saveStateEvery,true)); } Factory.INSTANCE = new Factory() { @Override public StateLoader getInitialStateLoader() { if (loadStateFileName == null) { return null; } else { return getStateLoaderObject(); } } @Override public MarkovChainListener[] getStateSaverChainListeners() { return listeners.toArray(new MarkovChainListener[0]); } }; } private BeastCheckpointer getStateLoaderObject() { return this; } @Override public boolean saveState(MarkovChain markovChain, long state, double lnL) { String timeStamp = new SimpleDateFormat("yyyy.MM.dd.HH.mm.ss").format(Calendar.getInstance().getTime()); String fileName = (this.saveStateFileName != null ? this.saveStateFileName : "beast_state_" + timeStamp); return writeStateToFile(new File(fileName), state, lnL, markovChain); } @Override public long loadState(MarkovChain markovChain, double[] savedLnL) { return readStateFromFile(new File(loadStateFileName), markovChain, savedLnL); } @Override public void checkLoadState(double savedLnL, double lnL) { if (CHECK_LOAD_STATE) { //first perform a simple check for equality of two doubles //when this test fails, go over the digits if (lnL != savedLnL) { //15 is the floor value for the number of decimal digits when representing a double //checking for 15 identical digits below String originalString = Double.toString(savedLnL); String restoredString = Double.toString(lnL); System.out.println(lnL + " " + originalString); System.out.println(savedLnL + " " + restoredString); //assume values will be nearly identical int digits = 0; for (int i = 0; i < Math.max(originalString.length(), restoredString.length()); i++) { if (originalString.charAt(i) == restoredString.charAt(i)) { if (!(originalString.charAt(i) == '-' || originalString.charAt(i) == '.')) { digits++; } } else { break; } } //System.err.println("digits = " + digits); if (digits < 15) { throw new RuntimeException("Dumped lnL does not match loaded state: stored lnL: " + savedLnL + ", recomputed lnL: " + lnL + " (difference " + (savedLnL - lnL) + ")"); } } else { System.out.println("IDENTICAL LIKELIHOODS"); System.out.println("lnL" + " = " + lnL); System.out.println("savedLnL[0]" + " = " + savedLnL); } } } private boolean writeStateToFile(File file, long state, double lnL, MarkovChain markovChain) { OperatorSchedule operatorSchedule = markovChain.getSchedule(); OutputStream fileOut = null; try { fileOut = new FileOutputStream(file); PrintStream out = new PrintStream(fileOut); ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>(); int[] rngState = MathUtils.getRandomState(); out.print("rng"); for (int i = 0; i < rngState.length; i++) { out.print("\t"); out.print(rngState[i]); } out.println(); out.print("state\t"); out.println(state); out.print("lnL\t"); out.println(lnL); for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) { out.print("parameter"); out.print("\t"); out.print(parameter.getParameterName()); out.print("\t"); out.print(parameter.getDimension()); for (int dim = 0; dim < parameter.getDimension(); dim++) { out.print("\t"); out.print(parameter.getParameterValue(dim)); } out.println(); } for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) { MCMCOperator operator = operatorSchedule.getOperator(i); out.print("operator"); out.print("\t"); out.print(operator.getOperatorName()); out.print("\t"); out.print(operator.getAcceptCount()); out.print("\t"); out.print(operator.getRejectCount()); if (operator instanceof CoercableMCMCOperator) { out.print("\t"); out.print(((CoercableMCMCOperator)operator).getCoercableParameter()); } out.println(); } //check up front if there are any TreeParameterModel objects for (Model model : Model.CONNECTED_MODEL_SET) { if (model instanceof TreeParameterModel) { //System.out.println("\nDetected TreeParameterModel: " + ((TreeParameterModel) model).toString()); traitModels.add((TreeParameterModel) model); } } for (Model model : Model.CONNECTED_MODEL_SET) { if (model instanceof TreeModel) { out.print("tree"); out.print("\t"); out.println(model.getModelName()); //replace Newick format by printing general graph structure //out.println(((TreeModel) model).getNewick()); out.println("#node height taxon"); int nodeCount = ((TreeModel) model).getNodeCount(); out.println(nodeCount); for (int i = 0; i < nodeCount; i++) { out.print(((TreeModel) model).getNode(i).getNumber()); out.print("\t"); out.print(((TreeModel) model).getNodeHeight(((TreeModel) model).getNode(i))); if (((TreeModel) model).isExternal(((TreeModel) model).getNode(i))) { out.print("\t"); out.print(((TreeModel) model).getNodeTaxon(((TreeModel) model).getNode(i)).getId()); } out.println(); } out.println("#edges"); out.println("#child-node parent-node L/R-child traits"); out.println(nodeCount); for (int i = 0; i < nodeCount; i++) { NodeRef parent = ((TreeModel) model).getParent(((TreeModel) model).getNode(i)); if (parent != null) { out.print(((TreeModel) model).getNode(i).getNumber()); out.print("\t"); out.print(((TreeModel) model).getParent(((TreeModel) model).getNode(i)).getNumber()); out.print("\t"); if ((((TreeModel) model).getChild(parent, 0) == ((TreeModel) model).getNode(i))) { //left child out.print(0); } else if ((((TreeModel) model).getChild(parent, 1) == ((TreeModel) model).getNode(i))) { //right child out.print(1); } else { throw new RuntimeException("Operation currently only supported for nodes with 2 children."); } for (TreeParameterModel tpm : traitModels) { out.print("\t"); out.print(tpm.getNodeValue((TreeModel)model, ((TreeModel) model).getNode(i))); } out.println(); } } } } out.close(); fileOut.close(); } catch (IOException ioe) { System.err.println("Unable to write file: " + ioe.getMessage()); return false; } if (DEBUG) { for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) { System.err.println(likelihood.getId() + ": " + likelihood.getLogLikelihood()); } } return true; } private long readStateFromFile(File file, MarkovChain markovChain, double[] lnL) { OperatorSchedule operatorSchedule = markovChain.getSchedule(); long state = -1; ArrayList<TreeParameterModel> traitModels = new ArrayList<TreeParameterModel>(); try { FileReader fileIn = new FileReader(file); BufferedReader in = new BufferedReader(fileIn); int[] rngState = null; String line = in.readLine(); String[] fields = line.split("\t"); if (fields[0].equals("rng")) { // if there is a random number generator state present then load it... try { rngState = new int[fields.length - 1]; for (int i = 0; i < rngState.length; i++) { rngState[i] = Integer.parseInt(fields[i + 1]); } } catch (NumberFormatException nfe) { throw new RuntimeException("Unable to read state number from state file"); } line = in.readLine(); fields = line.split("\t"); } try { if (!fields[0].equals("state")) { throw new RuntimeException("Unable to read state number from state file"); } state = Long.parseLong(fields[1]); } catch (NumberFormatException nfe) { throw new RuntimeException("Unable to read state number from state file"); } line = in.readLine(); fields = line.split("\t"); try { if (!fields[0].equals("lnL")) { throw new RuntimeException("Unable to read lnL from state file"); } if (lnL != null) { lnL[0] = Double.parseDouble(fields[1]); } } catch (NumberFormatException nfe) { throw new RuntimeException("Unable to read lnL from state file"); } for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) { line = in.readLine(); fields = line.split("\t"); //if (!fields[0].equals(parameter.getParameterName())) { // System.err.println("Unable to match state parameter: " + fields[0] + ", expecting " + parameter.getParameterName()); //} int dimension = Integer.parseInt(fields[2]); if (dimension != parameter.getDimension()) { System.err.println("Unable to match state parameter dimension: " + dimension + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName()); System.err.print("Read from file: "); for (int i = 0; i < fields.length; i++) { System.err.print(fields[i] + "\t"); } System.err.println(); } if (fields[1].equals("branchRates.categories.rootNodeNumber")) { // System.out.println("eek"); double value = Double.parseDouble(fields[3]); parameter.setParameterValue(0, value); if (DEBUG) { System.out.println("restoring " + fields[1] + " with value " + value); } } else { if (DEBUG) { System.out.print("restoring " + fields[1] + " with values "); } for (int dim = 0; dim < parameter.getDimension(); dim++) { parameter.setParameterValue(dim, Double.parseDouble(fields[dim + 3])); if (DEBUG) { System.out.print(Double.parseDouble(fields[dim + 3]) + " "); } } if (DEBUG) { System.out.println(); } } } for (int i = 0; i < operatorSchedule.getOperatorCount(); i++) { MCMCOperator operator = operatorSchedule.getOperator(i); line = in.readLine(); fields = line.split("\t"); if (!fields[1].equals(operator.getOperatorName())) { throw new RuntimeException("Unable to match operator: " + fields[1]); } if (fields.length < 4) { throw new RuntimeException("Operator missing values: " + fields[1]); } operator.setAcceptCount(Integer.parseInt(fields[2])); operator.setRejectCount(Integer.parseInt(fields[3])); if (operator instanceof CoercableMCMCOperator) { if (fields.length != 5) { throw new RuntimeException("Coercable operator missing parameter: " + fields[1]); } ((CoercableMCMCOperator)operator).setCoercableParameter(Double.parseDouble(fields[4])); } } // load the tree models last as we get the node heights from the tree (not the parameters which // which may not be associated with the right node Set<String> expectedTreeModelNames = new HashSet<String>(); for (Model model : Model.CONNECTED_MODEL_SET) { if (model instanceof TreeModel) { if (DEBUG) { System.out.println("model " + model.getModelName()); } expectedTreeModelNames.add(model.getModelName()); if (DEBUG) { for (String s : expectedTreeModelNames) { System.out.println(s); } } } if (model instanceof TreeParameterModel) { traitModels.add((TreeParameterModel)model); } } line = in.readLine(); fields = line.split("\t"); // Read in all (possibly more than one) trees while (fields[0].equals("tree")) { for (Model model : Model.CONNECTED_MODEL_SET) { if (model instanceof TreeModel && fields[1].equals(model.getModelName())) { line = in.readLine(); line = in.readLine(); fields = line.split("\t"); //read number of nodes int nodeCount = Integer.parseInt(fields[0]); double[] nodeHeights = new double[nodeCount]; for (int i = 0; i < nodeCount; i++) { line = in.readLine(); fields = line.split("\t"); nodeHeights[i] = Double.parseDouble(fields[1]); } //on to reading edge information line = in.readLine(); line = in.readLine(); line = in.readLine(); fields = line.split("\t"); int edgeCount = Integer.parseInt(fields[0]); //create data matrix of doubles to store information from list of TreeParameterModels double[][] traitValues = new double[traitModels.size()][edgeCount]; //create array to store whether a node is left or right child of its parent //can be important for certain tree transition kernels int[] childOrder = new int[edgeCount]; for (int i = 0; i < childOrder.length; i++) { childOrder[i] = -1; } int[] parents = new int[edgeCount]; for (int i = 0; i < edgeCount; i++){ parents[i] = -1; } for (int i = 0; i < edgeCount; i++) { line = in.readLine(); if (line != null) { fields = line.split("\t"); parents[Integer.parseInt(fields[0])] = Integer.parseInt(fields[1]); childOrder[i] = Integer.parseInt(fields[2]); for (int j = 0; j < traitModels.size(); j++) { traitValues[j][i] = Double.parseDouble(fields[3+j]); } } } //perform magic with the acquired information //adopt the loaded tree structure; this does not yet copy the traits on the branches ((TreeModel) model).beginTreeEdit(); ((TreeModel) model).adoptTreeStructure(parents, nodeHeights, childOrder); ((TreeModel) model).endTreeEdit(); expectedTreeModelNames.remove(model.getModelName()); } } line = in.readLine(); if (line != null) { fields = line.split("\t"); } } if (expectedTreeModelNames.size() > 0) { StringBuilder sb = new StringBuilder(); for (String notFoundName : expectedTreeModelNames) { sb.append("Expecting, but unable to match state parameter:" + notFoundName + "\n"); } throw new RuntimeException(sb.toString()); } if (DEBUG) { System.out.println("\nDouble checking:"); for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) { if (parameter.getParameterName().equals("branchRates.categories.rootNodeNumber")) { System.out.println(parameter.getParameterName() + ": " + parameter.getParameterValue(0)); } } } if (rngState != null) { MathUtils.setRandomState(rngState); } in.close(); fileIn.close(); // This shouldn't be necessary and if it is then it might be hiding a bug... // for (Likelihood likelihood : Likelihood.CONNECTED_LIKELIHOOD_SET) { // likelihood.makeDirty(); // } } catch (IOException ioe) { throw new RuntimeException("Unable to read file: " + ioe.getMessage()); } return state; } }