/* * CompleteHistorySimulator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.beagle.tools; import dr.evolution.tree.*; import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.substmodel.FrequencyModel; import dr.evolution.alignment.SimpleAlignment; import dr.evolution.datatype.Codons; import dr.evolution.datatype.DataType; import dr.evolution.sequence.Sequence; import dr.evomodel.branchratemodel.BranchRateModel; import dr.inference.markovjumps.MarkovJumpsRegisterAcceptor; import dr.inference.markovjumps.MarkovJumpsType; import dr.inference.markovjumps.StateHistory; import dr.inference.model.Parameter; import dr.math.MathUtils; import java.text.NumberFormat; import java.util.*; import java.util.logging.Logger; /** * Simulates a complete transition history and alignment of sequences given a tree, substitution model and * branch rate model. This code duplicates portions of dr.app.seqgen.SequenceSimulator. However, SequenceSimulator * is wed to dr.evomodel.substmodel.SubstitutionModel which does not emit the infinitesimal generator. * * @author Marc A. Suchard * @author remco@cs.waikato.ac.nz */ public class CompleteHistorySimulator extends SimpleAlignment implements MarkovJumpsRegisterAcceptor, TreeTraitProvider { /** * number of replications */ protected int nReplications; /** * tree used for generating samples * */ protected Tree tree; /** * site model used for generating samples * */ protected GammaSiteRateModel siteModel; /** * branch rate model used for generating samples * */ protected BranchRateModel branchRateModel; /** * nr of categories in site model * */ int categoryCount; /** * nr of states in site model * */ int stateCount; /** * an array used to hold infinitesimal generator */ // protected double[] lambda; // protected double[][] probabilities; private boolean branchSpecificLambda = false; private Parameter branchVariableParameter = null; private Parameter branchPossibleValuesParameter = null; private DataType dataType; protected List<double[]> registers; protected List<String> jumpTags; protected List<MarkovJumpsType> jumpTypes; protected List<double[][]> realizedJumps; protected int nJumpProcesses = 0; protected boolean sumAcrossSites; private final Map<String, Integer> idMap = new HashMap<String, Integer>(); private boolean saveAlignment = false; private Map<Integer,Sequence> alignmentTraitList; private boolean alignmentOnly = false; /** * Constructor * * @param tree * @param siteModel * @param branchRateModel * @param nReplications: nr of samples to generate */ // public CompleteHistorySimulator(Tree tree, GammaSiteRateModel siteModel, BranchRateModel branchRateModel, // int nReplications) { // this(tree, siteModel, branchRateModel, nReplications, false); // } // // public CompleteHistorySimulator(Tree tree, GammaSiteRateModel siteModel, BranchRateModel branchRateModel, // int nReplications, boolean sumAcrossSites) { // this(tree, siteModel, branchRateModel, nReplications, sumAcrossSites, null, null); // // } public CompleteHistorySimulator(Tree tree, GammaSiteRateModel siteModel, BranchRateModel branchRateModel, int nReplications, boolean sumAcrossSites, Parameter branchVariableParameter, Parameter branchPossibleValuesParameter) { this.tree = tree; this.siteModel = siteModel; this.branchRateModel = branchRateModel; this.nReplications = nReplications; stateCount = this.siteModel.getSubstitutionModel().getDataType().getStateCount(); categoryCount = this.siteModel.getCategoryCount(); // Codon models give exception when put inside report and when count statistics are done on them dataType = siteModel.getSubstitutionModel().getDataType(); // if (dataType instanceof Codons && !alignmentOnly) { // System.out.println("Codon models give exception when put inside report and when count statistics are done on them. " // + "You can supress this by setting alignmentOnly to true."); // } this.sumAcrossSites = sumAcrossSites; List<String> taxaIds = new ArrayList<String>(); for (int i = 0; i < tree.getTaxonCount(); i++) { taxaIds.add(tree.getTaxon(i).getId()); } int k = 1; for (String taxaId : taxaIds) { idMap.put(taxaId, k); k += 1; } format = NumberFormat.getNumberInstance(Locale.ENGLISH); format.setMaximumFractionDigits(3); if (branchVariableParameter != null && branchPossibleValuesParameter != null) { if (branchVariableParameter.getDimension() != 1) { throw new RuntimeException("branchVariableParameter has the wrong dimension; should be 1"); } if (branchPossibleValuesParameter.getDimension() != tree.getNodeCount()) { throw new RuntimeException("branchPossibleValuesParameter has the wrong dimension; should be " + tree.getNodeCount()); } branchSpecificLambda = true; this.branchPossibleValuesParameter = branchPossibleValuesParameter; this.branchVariableParameter = branchVariableParameter; StringBuilder sb = new StringBuilder(); sb.append("Doing a complete history simulation using branch-specific variables\n\tReplacing variable '"); sb.append(branchVariableParameter.getId()); sb.append("' with values from '"); sb.append(branchPossibleValuesParameter.getId()); sb.append("'"); Logger.getLogger("dr.app.beagle.tools").info(sb.toString()); } alignmentTraitList = new HashMap<Integer,Sequence>(tree.getNodeCount()); } /** * Convert integer representation of sequence into a Sequence * * @param seq integer representation of the sequence * @param node used to determine taxon for sequence * @return Sequence */ Sequence intArray2Sequence(int[] seq, NodeRef node) { String sSeq = ""; // DataType dataType = siteModel.getSubstitutionModel().getDataType(); for (int i = 0; i < nReplications; i++) { if (dataType instanceof Codons) { String s = dataType.getTriplet(seq[i]); sSeq += s; } else { String c = dataType.getCode(seq[i]); sSeq += c; } } return new Sequence(tree.getNodeTaxon(node), sSeq); } public void addAlignmentTrait() { saveAlignment = true; final String tag = "alignment"; treeTraits.addTrait(new TreeTrait.S() { public String getTraitName() { return tag; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return alignmentTraitList.get(node.getNumber()).getSequenceString(); } }); } public void addRegister(Parameter addRegisterParameter, MarkovJumpsType type, boolean scaleByTime) { if (registers == null) { registers = new ArrayList<double[]>(); } if (jumpTags == null) { jumpTags = new ArrayList<String>(); } if (jumpTypes == null) { jumpTypes = new ArrayList<MarkovJumpsType>(); } if (realizedJumps == null) { realizedJumps = new ArrayList<double[][]>(); } final String tag = addRegisterParameter.getId(); registers.add(addRegisterParameter.getParameterValues()); jumpTags.add(tag); jumpTypes.add(type); realizedJumps.add(new double[tree.getNodeCount()][nReplications]); final int r = nJumpProcesses; treeTraits.addTrait(new TreeTrait.S() { public String getTraitName() { return tag; } public Intent getIntent() { return Intent.NODE; } public String getTrait(Tree tree, NodeRef node) { return formattedValue(tree, node, r); } }); nJumpProcesses++; // scaleByTime is currently ignored } public double[] getMarkovJumpsForNodeAndRegister(Tree tree, NodeRef node, int whichRegister) { if (this.tree != tree) { throw new RuntimeException("Wrong tree!"); } return realizedJumps.get(whichRegister)[node.getNumber()]; } public int getNumberOfJumpProcess() { return nJumpProcesses; } // public double[][] getMarkovJumpsForNode(Tree tree, NodeRef node) { // double[][] rtn = new double[nJumpProcesses][]; // for(int r = 0; r < nJumpProcesses; r++) { // rtn[r] = getMarkovJumpsForNodeAndRegister(tree, node, r); // } // return rtn; // } protected Helper treeTraits = new Helper(); public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } private String formattedValue(Tree tree, NodeRef node, int jump) { StringBuffer sb = new StringBuffer(); double[] values = getMarkovJumpsForNodeAndRegister(tree, node, jump); if (sumAcrossSites) { double total = 0; for (double x : values) { total += x; } sb.append(total); } else { sb.append("{"); for (int i = 0; i < values.length; i++) { if (i > 0) { sb.append(","); } sb.append(values[i]); } sb.append("}"); } return sb.toString(); } private NumberFormat format; public String toString() { StringBuffer sb = new StringBuffer(); if (alignmentOnly) { this.setReportCountStatistics(false); sb.append(super.toString()); sb.append("\n"); } else { // alignment output sb.append("alignment\n"); sb.append(super.toString()); sb.append("\n"); // tree output sb.append("tree\n"); TreeUtils.newick(tree, tree.getRoot(), true, TreeUtils.BranchLengthType.LENGTHS_AS_TIME, format, null, (nJumpProcesses > 0 || saveAlignment ? new TreeTraitProvider[] { this } : null), idMap, sb); sb.append("\n"); } return sb.toString(); }// END: toString /** * perform the actual sequence generation * * @return alignment containing randomly generated sequences for the nodes in the * leaves of the tree */ public void simulate() { double[] lambda = new double[stateCount * stateCount]; if (!branchSpecificLambda) { siteModel.getSubstitutionModel().getInfinitesimalMatrix(lambda); // Assumes a single generator for whole tree } NodeRef root = tree.getRoot(); double[] categoryProbs = siteModel.getCategoryProportions(); int[] category = new int[nReplications]; for (int i = 0; i < nReplications; i++) { category[i] = MathUtils.randomChoicePDF(categoryProbs); } FrequencyModel frequencyModel = siteModel.getSubstitutionModel().getFrequencyModel(); int[] seq = new int[nReplications]; for (int i = 0; i < nReplications; i++) { seq[i] = MathUtils.randomChoicePDF(frequencyModel.getFrequencies()); } setDataType(siteModel.getSubstitutionModel().getDataType()); traverse(root, seq, category, this, lambda); } /** * recursively walk through the tree top down, and add sequence to alignment whenever * a leave node is reached. * * @param node reference to the current node, for which we visit all children * @param parentSequence randomly generated sequence of the parent node * @param category array of categories for each of the sites * @param alignment */ private void traverse(NodeRef node, int[] parentSequence, int[] category, SimpleAlignment alignment, double[] lambda) { if (saveAlignment) { alignmentTraitList.put(node.getNumber(),intArray2Sequence(parentSequence,node)); } for (int iChild = 0; iChild < tree.getChildCount(node); iChild++) { NodeRef child = tree.getChild(node, iChild); int[] seq = new int[nReplications]; StateHistory[] histories = new StateHistory[nReplications]; if (branchSpecificLambda) { final double branchValue = branchPossibleValuesParameter.getParameterValue(child.getNumber()); branchVariableParameter.setParameterValue(0, branchValue); // System.err.println("trying value = " + branchValue + " for " + child.getNumber()); siteModel.getSubstitutionModel().getInfinitesimalMatrix(lambda); } for (int i = 0; i < nReplications; i++) { histories[i] = simulateAlongBranch(tree, child, category[i], parentSequence[i], lambda); seq[i] = histories[i].getEndingState(); } processHistory(child, histories); if (tree.getChildCount(child) == 0) { alignment.addSequence(intArray2Sequence(seq, child)); } traverse(tree.getChild(node, iChild), seq, category, alignment, lambda); } } protected void processHistory(NodeRef node, StateHistory[] histories) { for (int jump = 0; jump < nJumpProcesses; jump++) { double[] register = registers.get(jump); MarkovJumpsType type = jumpTypes.get(jump); double[] realizedJump = realizedJumps.get(jump)[node.getNumber()]; for (int i = 0; i < nReplications; i++) { if (type == MarkovJumpsType.COUNTS) { realizedJump[i] = histories[i].getTotalRegisteredCounts(register); } else if (type == MarkovJumpsType.REWARDS) { realizedJump[i] = histories[i].getTotalReward(register); } else { throw new IllegalAccessError("Unknown MarkovJumps type"); } } } } private StateHistory simulateAlongBranch(Tree tree, NodeRef node, int rateCategory, int startingState, double[] lambda) { NodeRef parent = tree.getParent(node); final double branchRate = branchRateModel.getBranchRate(tree, node); // Get the operational time of the branch final double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node)); if (branchTime < 0.0) { throw new RuntimeException("Negative branch length: " + branchTime); } double branchLength = siteModel.getRateForCategory(rateCategory) * branchTime; return StateHistory.simulateUnconditionalOnEndingState(0.0, startingState, branchLength, lambda, stateCount); } public void setAlignmentOnly() { alignmentOnly = true; } }//END: class