/* * Partition.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.beagle.tools; import java.util.LinkedHashMap; import java.util.Map; import dr.evomodel.treedatalikelihood.BufferIndexHelper; import org.apache.commons.math.random.MersenneTwister; import beagle.Beagle; import beagle.BeagleFactory; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.treelikelihood.SubstitutionModelDelegate; import dr.app.bss.Utils; import dr.evolution.datatype.Codons; import dr.evolution.datatype.DataType; import dr.evolution.sequence.Sequence; import dr.evolution.tree.NodeRef; import dr.evolution.util.Taxon; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.math.MathUtils; /** * @author Filip Bielejec * @version $Id$ */ public class Partition { private static final boolean DEBUG = false; // Constructor fields public int from; public int to; public int every; private BranchModel branchModel; private TreeModel treeModel; private GammaSiteRateModel siteRateModel; private BranchRateModel branchRateModel; private FrequencyModel freqModel; // Buffer helpers private BufferIndexHelper partialBufferHelper; private BufferIndexHelper scaleBufferHelper; private BufferIndexHelper matrixBufferHelper; // Beagle stuff private Beagle beagle; private SubstitutionModelDelegate substitutionModelDelegate; // int fields private Integer partitionNumber; private int partitionSiteCount; private int nodeCount; private int tipCount; private int internalNodeCount; private int stateCount; private int compactPartialsCount; private int patternCount; private int siteRateCategoryCount; // Sequence fields private LinkedHashMap<Taxon, int[]> alignmentMap; // private LinkedHashMap<NodeRef, int[]> sequencesMap = new LinkedHashMap<NodeRef, int[]>(); private DataType dataType; private boolean hasRootSequence = false; private Sequence rootSequence = null; private boolean outputAncestralSequences = false; // Random number generation private MersenneTwister random; // Annotating trees // private boolean annotateTree = true; public Partition(TreeModel treeModel, // BranchModel branchModel, // GammaSiteRateModel siteModel, // BranchRateModel branchRateModel, // FrequencyModel freqModel, // int from, // int to, // int every // ) { this.treeModel = treeModel; this.siteRateModel = siteModel; this.freqModel = freqModel; this.branchModel = branchModel; this.branchRateModel = branchRateModel; this.from = from; this.to = to; this.every = every; dataType = freqModel.getDataType(); partitionSiteCount = getPartitionSiteCount(); setBufferHelpers(); setSubstitutionModelDelegate(); loadBeagleInstance(); alignmentMap = new LinkedHashMap<Taxon, int[]>(); random = new MersenneTwister(MathUtils.nextLong()); }// END: Constructor private void setSubstitutionModelDelegate() { substitutionModelDelegate = new SubstitutionModelDelegate(treeModel, branchModel); }// END: setSubstitutionModelDelegate private void setBufferHelpers() { nodeCount = treeModel.getNodeCount(); matrixBufferHelper = new BufferIndexHelper(nodeCount, 0); tipCount = treeModel.getExternalNodeCount(); internalNodeCount = treeModel.getInternalNodeCount(); partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount); scaleBufferHelper = new BufferIndexHelper(internalNodeCount + 1, 0); }// END: setBufferHelpers public void loadBeagleInstance() { compactPartialsCount = tipCount; stateCount = dataType.getStateCount(); patternCount = partitionSiteCount; siteRateCategoryCount = siteRateModel.getCategoryCount(); int[] resourceList = new int[] { 0 }; long preferenceFlags = 0; long requirementFlags = 0; beagle = BeagleFactory.loadBeagleInstance(tipCount, // partialBufferHelper.getBufferCount(), // compactPartialsCount, // stateCount, // patternCount, // substitutionModelDelegate.getEigenBufferCount(), // substitutionModelDelegate.getMatrixBufferCount(), // siteRateCategoryCount, // scaleBufferHelper.getBufferCount(), // resourceList, // preferenceFlags, // requirementFlags); }// END: loadBeagleInstance public void simulatePartition() { try { NodeRef root = treeModel.getRoot(); // gamma category rates double[] categoryRates = siteRateModel.getCategoryRates(); beagle.setCategoryRates(categoryRates); // probabilities for gamma category rates double[] categoryProbs = siteRateModel.getCategoryProportions(); // beagle.setCategoryWeights(0, categoryProbs); // Utils.printArray(categoryRates); // Utils.printArray(categoryProbs); int[] category = new int[partitionSiteCount]; for (int i = 0; i < partitionSiteCount; i++) { category[i] = randomChoicePDF(categoryProbs, partitionNumber, "categories"); } // category = new int[] {1, 0, 0, 0, 0, 1, 0, 1, 0, 0 }; if(DEBUG){ System.out.println("category for each site:"); Utils.printArray(category); }//END: DEBUG int[] parentSequence = new int[partitionSiteCount]; // set ancestral sequence for partition if it exists if (hasRootSequence) { if (rootSequence.getLength() == partitionSiteCount) { parentSequence = sequence2intArray(rootSequence); } else if (dataType instanceof Codons && rootSequence.getLength() == 3 * partitionSiteCount) { parentSequence = sequence2intArray(rootSequence); } else { throw new RuntimeException("Ancestral sequence length of " + rootSequence.getLength() + " does not match partition site count of " + partitionSiteCount + "."); } } else { double[] frequencies = freqModel.getFrequencies(); for (int i = 0; i < partitionSiteCount; i++) { parentSequence[i] = randomChoicePDF(frequencies, partitionNumber, "root"); } }// END:ancestralSequence check if (DEBUG) { synchronized (this) { System.out.println(); System.out.println("root Sequence:"); Utils.printArray(parentSequence); } }//END: DEBUG substitutionModelDelegate.updateSubstitutionModels(beagle); traverse(root, parentSequence, category); if (DEBUG) { synchronized (this) { System.out.println("Simulated alignment:"); printSequences(); } }//END: DEBUG beagle.finalize(); } catch (Exception e) { e.printStackTrace(); } catch (Throwable e) { System.err.println("BeagleException: " + e.getMessage()); System.exit(-1); } }// END: simulatePartition private void traverse(NodeRef node, // int[] parentSequence, // int[] category // ) { if (DEBUG) { synchronized (this) { System.out.println(); System.out.println("I'm at: " + node.toString()); System.out.println(); } }//END: DEBUG for (int iChild = 0; iChild < treeModel.getChildCount(node); iChild++) { NodeRef child = treeModel.getChild(node, iChild); int[] partitionSequence = new int[partitionSiteCount]; double[] cProb = new double[stateCount]; double[][] probabilities = getTransitionProbabilities(child); if (DEBUG) { synchronized (this) { System.out.println("Going to child " + iChild + ": " + child.toString()); System.out.println("Child finite transition probs matrix:"); Utils.print2DArray(probabilities, stateCount); System.out.println(); } }// END: DEBUG for (int i = 0; i < partitionSiteCount; i++) { System.arraycopy(probabilities[category[i]], parentSequence[i] * stateCount, cProb, 0, stateCount); if (DEBUG) { synchronized (this) { System.out.println("site:" + i); System.out.println("site probs:"); Utils.printArray(cProb); } }// END: DEBUG partitionSequence[i] = randomChoicePDF(cProb, partitionNumber, "seq"); }// END: i loop if (DEBUG) { synchronized (this) { // partitionSequence = new int[]{1, 3, 2, 3, 0, 1, 0, 1, 0, 2, 2, 0, 1, 3, 3, 3, 0, 1, 2, 1, 3, 1, 1, 1, 1, 3, 0, 0, 3, 2, 3, 2, 3, 2, 1, 2, 1, 3, 2, 3, 3, 0, 2, 2, 3, 2, 3, 2, 3, 1, 2, 0, 2, 1, 3, 2, 3, 1, 1, 1, 1, 0, 2, 3, 1, 0, 2, 1, 2, 1, 3, 0, 0, 0, 0, 0, 2, 0, 2, 3, 1, 0, 1, 3, 0, 2, 1, 2, 1, 3, 0, 0, 3, 2, 2, 0, 1, 0, 0, 3 }; System.out.println("Simulated sequence:"); Utils.printArray(partitionSequence); } }// END: if DEBUG // if(annotateTree) { // // sequencesMap.put(child, partitionSequence); // // } if (treeModel.getChildCount(child) == 0) { Taxon taxon = treeModel.getNodeTaxon(child); alignmentMap.put(taxon, partitionSequence); if (DEBUG) { synchronized (this) { System.out.println("Simulated sequence (translated):"); System.out.println(Utils.intArray2Sequence(taxon, partitionSequence, BeagleSequenceSimulator.gapFlag, dataType).getSequenceString()); } }// END: DEBUG } else { if(outputAncestralSequences) { alignmentMap.put(new Taxon("internalNodeHeight" + treeModel.getNodeHeight(child)), partitionSequence); } } // END: tip node check traverse(treeModel.getChild(node, iChild), partitionSequence, category); }// END: child nodes loop }// END: traverse private double[][] getTransitionProbabilities(NodeRef node // ) { double[][] probabilities = new double[siteRateCategoryCount][stateCount * stateCount]; int nodeNum = node.getNumber(); matrixBufferHelper.flipOffset(nodeNum); int branchIndex = nodeNum; double branchRate = branchRateModel.getBranchRate(treeModel, node); double branchLength = treeModel.getBranchLength(node); double branchTime = branchLength * branchRate;// * siteRate; if (DEBUG) { synchronized (this) { System.out.println("Branch length: " + branchLength + " branch rate: " + branchRate + " branch time: " + branchTime);// + " site rate: " + siteRate); } }// END: DEBUG int count = 1; substitutionModelDelegate.updateTransitionMatrices(beagle, new int[] { branchIndex }, new double[] { branchTime }, count); double transitionMatrix[] = new double[siteRateCategoryCount * stateCount * stateCount]; beagle.getTransitionMatrix(branchIndex, // transitionMatrix // ); for (int siteRateCat = 0; siteRateCat < siteRateCategoryCount; siteRateCat++) { System.arraycopy(transitionMatrix, siteRateCat * stateCount * stateCount, probabilities[siteRateCat], 0, stateCount * stateCount); }// END: i loop return probabilities; }// END: getTransitionProbabilities // /////////////////////////// // ---START: EXPERIMENTAL---// // /////////////////////////// // void traverse(NodeRef node, // // int [] parentSequence, // // int [] category // // ) { // // for (int iChild = 0; iChild < treeModel.getChildCount(node); iChild++) { // // NodeRef child = treeModel.getChild(node, iChild); // // for (int i = 0; i < categoryCount; i++) { // getTransitionProbabilities(child, i, probabilities[i]); // } // //// if(DEBUG){ //// Utils.print2DArray(probabilities); //// } // // int [] seq = new int[partitionSiteCount]; // double [] cProb = new double[stateCount]; // // for (int i = 0; i < partitionSiteCount; i++) { // System.arraycopy(probabilities[category[i]], parentSequence[i] * stateCount, cProb, 0, stateCount); // seq[i] = MathUtils.randomChoicePDF(cProb); // } // // if (treeModel.getChildCount(child) == 0) { // Taxon taxon = treeModel.getNodeTaxon(child); // sequenceList.put(taxon, seq); // } // // traverse(treeModel.getChild(node, iChild), seq, category); // // }//END: child loop // } // traverse // // void getTransitionProbabilities( //NodeRef node, int rateCategory, double[] probs) { // // NodeRef parent = treeModel.getParent(node); // // double branchRate = branchRateModel.getBranchRate(treeModel, node); // // // Get the operational time of the branch // double branchTime = branchRate * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(node)); // // if (branchTime < 0.0) { // throw new RuntimeException("Negative branch length: " + branchTime); // } // // double branchLength = siteModel.getRateForCategory(rateCategory) * branchTime; // // siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probs); // } // getTransitionProbabilities // ///////////////////////// // ---END: EXPERIMENTAL---// // ///////////////////////// private int[] sequence2intArray(Sequence sequence) { int array[] = new int[partitionSiteCount]; if (dataType instanceof Codons) { int k = 0; for (int i = 0; i < partitionSiteCount; i++) { array[i] = ((Codons) dataType).getState(sequence.getChar(k), sequence.getChar(k + 1), sequence.getChar(k + 2)); k += 3; }// END: replications loop } else { for (int i = 0; i < partitionSiteCount; i++) { array[i] = dataType.getState(sequence.getChar(i)); }// END: replications loop }// END: dataType check return array; }// END: sequence2intArray private int randomChoicePDF(double[] pdf, int partitionNumber, String error) { int samplePos = -Integer.MAX_VALUE; double cumProb = 0.0; double u = random.nextDouble(); for (int i = 0; i < pdf.length; i++) { cumProb += pdf[i]; if (u < cumProb) { samplePos = i; break; } } return samplePos; }// END: randomChoicePDF // ///////////// // --SETTERS--// // ///////////// public void setPartitionNumber(Integer partitionNumber) { this.partitionNumber = partitionNumber; } public void setRootSequence(Sequence rootSequence) { this.rootSequence = rootSequence; this.hasRootSequence = true; }// END: setAncestralSequence public void setOutputAncestralSequences(boolean outputAncestralSequences) { this.outputAncestralSequences = outputAncestralSequences; } // ///////////// // --GETTERS--// // ///////////// // public boolean isOutputAncestralSequences() { // return outputAncestralSequences; // } public TreeModel getTreeModel() { return treeModel; } public int getPartitionSiteCount() { return ((to - from) / every) + 1; }// END: getPartitionSiteCount public BranchModel getBranchModel() { return this.branchModel; }// END: getBranchModelic public FrequencyModel getFreqModel() { return freqModel; }// END: getFreqModel public Integer getPartitionNumber() { return partitionNumber; }// END: getPartitionNumber public DataType getDataType() { return dataType; }// END: getDataType public Map<Taxon, int[]> getTaxonSequencesMap() { return alignmentMap; }// END: getSequenceList // public LinkedHashMap<NodeRef, int[]> getSequenceMap() { // return sequencesMap; // } public Sequence getRootSequence() { return rootSequence; } // /////////////// // --DEBUGGING--// // /////////////// public void printSequences() { System.out.println("partition " + partitionNumber); Utils.printMap(alignmentMap); }// END: printSequences }// END: class