/* * BeagleBranchLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.branchmodel.lineagespecific; import java.util.*; import beagle.Beagle; import beagle.BeagleFactory; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.branchmodel.HomogeneousBranchModel; import dr.evomodel.siteratemodel.GammaSiteRateModel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.nucleotide.HKY; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.treedatalikelihood.BufferIndexHelper; import dr.evomodel.treelikelihood.BeagleTreeLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; import dr.evomodel.treelikelihood.SubstitutionModelDelegate; import dr.app.beagle.tools.BeagleSequenceSimulator; import dr.app.beagle.tools.Partition; import dr.evolution.alignment.Alignment; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.datatype.Nucleotides; import dr.evolution.io.NewickImporter; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.StrictClockBranchRates; import dr.evomodel.tree.TreeModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.NumberColumn; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.math.MathUtils; @SuppressWarnings("serial") public class BeagleBranchLikelihood implements Likelihood { private static final boolean DEBUG = true; // Constructor fields private PatternList patternList; private TreeModel treeModel; private BranchModel branchModel; private SiteRateModel siteRateModel; private FrequencyModel freqModel; private BranchRateModel branchRateModel; // Likelihood fields private String id = null; private boolean used = true; // Beagle fields private Beagle beagle; private BufferIndexHelper matrixBufferHelper; private BufferIndexHelper partialBufferHelper; private SubstitutionModelDelegate substitutionModelDelegate; int nodeCount; boolean[] updateNode; public BeagleBranchLikelihood(PatternList patternList, // TreeModel treeModel, // BranchModel branchModel, // SiteRateModel siteRateModel, // FrequencyModel freqModel, // BranchRateModel branchRateModel // ) { this.patternList = patternList; this.treeModel = treeModel; this.branchModel = branchModel; this.siteRateModel = siteRateModel; this.freqModel = freqModel; this.branchRateModel = branchRateModel; this.loadBeagleInstance(); }// END: Constructor public double getBranchLogLikelihood(int branchIndex) { int count = 1; double[] loglikelihood = new double[count]; // gamma category rates double[] categoryRates = siteRateModel.getCategoryRates(); beagle.setCategoryRates(categoryRates); double[] categoryWeights = this.siteRateModel.getCategoryProportions(); beagle.setCategoryWeights(0, categoryWeights); double[] frequencies = substitutionModelDelegate .getRootStateFrequencies(); beagle.setStateFrequencies(0, frequencies); substitutionModelDelegate.updateSubstitutionModels(beagle); setTipPartials(); // flags to keep track of updated transition matrix buffers updateNode = new boolean[nodeCount]; Arrays.fill(updateNode, true); // // Do not update root node // int rootNum = treeModel.getRoot().getNumber(); // updateNode[rootNum] = false; int nodeNum = branchIndex; NodeRef node = treeModel.getNode(nodeNum); // traverse down that node populating buffers and calculating partials traverse(treeModel, node); // traverse down the parent node populating buffers and calculating partials NodeRef parent = treeModel.getParent(node); traverse(treeModel, parent); int parentNum = treeModel.getParent(node).getNumber(); // beagle.calculateEdgeLogLikelihoods(new int[] { parentNum }, // parentBufferIndices // new int[] { nodeNum }, // int[] childBufferIndices // new int[] { 0 }, // int[] probabilityIndices // null, // firstDerivativeIndices // null, // secondDerivativeIndices // new int[] { 0 }, // int[] categoryWeightsIndices // new int[] { 0 }, // int[] stateFrequenciesIndices // new int[] { Beagle.NONE }, // cumulativeScaleIndices // 1, // count // loglikelihood, // // null, // outSumFirstDerivative, // // null // outSumSecondDerivative // // ); return loglikelihood[0]; }// END: getLogLikelihood private boolean traverse(TreeModel treeModel, NodeRef node) { boolean update = false; // if (!treeModel.isRoot(node)) { int nodeNum = node.getNumber(); NodeRef parentNode = treeModel.getParent(node); // populate buffer if (parentNode != null && updateNode[nodeNum]) { double branchRate = branchRateModel.getBranchRate(treeModel, node); double parentHeight = treeModel.getNodeHeight(parentNode); double nodeHeight = treeModel.getNodeHeight(node); double branchLength = branchRate * (parentHeight - nodeHeight); substitutionModelDelegate.flipMatrixBuffer(nodeNum); // matrixBufferHelper.flipOffset(nodeNum); substitutionModelDelegate.updateTransitionMatrices(// beagle, // new int[] { nodeNum }, // new double[] { branchLength }, // 1 // ); if (DEBUG) { System.out.println("At branch " + nodeNum); System.out.println(" Length " + branchLength + ": node " + nodeNum + ", height=" + nodeHeight + " parent " + parentNode); System.out.println(" Populating transition matrix buffer"); }// END: DEBUG check updateNode[nodeNum] = false; update = true; }// END: parent check // update the partial likelihoods if (!treeModel.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = treeModel.getChild(node, 0); boolean update1 = traverse(treeModel, child1); NodeRef child2 = treeModel.getChild(node, 1); boolean update2 = traverse(treeModel, child2); // If either child node was updated then update this node too if (update1 || update2) { int[] operations = new int[Beagle.OPERATION_TUPLE_SIZE]; partialBufferHelper.flipOffset(nodeNum); // destinationPartials operations[Beagle.OPERATION_TUPLE_SIZE - 7] = partialBufferHelper .getOffsetIndex(nodeNum); // destinationScaleWrite operations[Beagle.OPERATION_TUPLE_SIZE - 6] = Beagle.NONE; // destinationScaleRead operations[Beagle.OPERATION_TUPLE_SIZE - 5] = Beagle.NONE; // source node 1 operations[Beagle.OPERATION_TUPLE_SIZE - 4] = partialBufferHelper .getOffsetIndex(child1.getNumber()); // source matrix 1 operations[Beagle.OPERATION_TUPLE_SIZE - 3] = substitutionModelDelegate .getMatrixIndex(child1.getNumber()); // source node 2 operations[Beagle.OPERATION_TUPLE_SIZE - 2] = partialBufferHelper .getOffsetIndex(child2.getNumber()); // source matrix 2 operations[Beagle.OPERATION_TUPLE_SIZE - 1] = substitutionModelDelegate .getMatrixIndex(child2.getNumber()); beagle.updatePartials(operations, 1, Beagle.NONE); if (DEBUG) { System.out.println("At branch " + nodeNum); System.out.println(" Child nodes updated"); System.out.println(" Populating partial buffer"); }// END: DEBUG check updateNode[nodeNum] = false; update = true; }// END: children updated check }// END: external branch check // }//END: root check return update; }// END: traverse // ////////////// // ---PUBLIC---// // ////////////// @Override public double getLogLikelihood() { double loglikelihood = 0; // TODO return loglikelihood; }// END: getLogLikelihood // /////////////// // ---PRIVATE---// // /////////////// private void populateTransitionBuffers() { for (NodeRef node : treeModel.getNodes()) { // if (!treeModel.isRoot(node)) { int nodeNum = node.getNumber(); matrixBufferHelper.flipOffset(nodeNum); int branchIndex = nodeNum; int[] childBufferIndices = new int[] { branchIndex }; double branchRate = branchRateModel.getBranchRate(treeModel, node); double branchLength = treeModel.getBranchLength(node); double branchTime = branchLength * branchRate;// * siteRate; substitutionModelDelegate.updateTransitionMatrices(beagle, // childBufferIndices, // new double[] { branchTime }, 1 // ); // }// END: root check }// END: nodes loop }//END: populateTransitionBuffers private void setTipPartials() { int patternCount = patternList.getPatternCount(); int tipCount = treeModel.getTaxonCount(); for (int i = 0; i < tipCount; i++) { String id = treeModel.getTaxonId(i); int sequenceIndex = patternList.getTaxonIndex(id); int[] states = new int[patternCount]; for (int j = 0; j < patternCount; j++) { states[j] = patternList.getPatternState(sequenceIndex, j); } beagle.setTipStates(i, states); }// END: i loop }// END: setTipStates public void finalizeBeagle() throws Throwable { beagle.finalize(); }// END: finalizeBeagle private void loadBeagleInstance() { this.substitutionModelDelegate = new SubstitutionModelDelegate( treeModel, branchModel); DataType dataType = freqModel.getDataType(); int partitionSiteCount = patternList.getPatternCount(); nodeCount = treeModel.getNodeCount(); this.matrixBufferHelper = new BufferIndexHelper(nodeCount, 0); int tipCount = treeModel.getExternalNodeCount(); int internalNodeCount = treeModel.getInternalNodeCount(); partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount); BufferIndexHelper scaleBufferHelper = new BufferIndexHelper( internalNodeCount + 1, 0); int compactPartialsCount = tipCount; int stateCount = dataType.getStateCount(); int patternCount = partitionSiteCount; int siteRateCategoryCount = siteRateModel.getCategoryCount(); int[] resourceList = new int[] { 0 }; long preferenceFlags = 0; long requirementFlags = 0; beagle = BeagleFactory.loadBeagleInstance(tipCount, // partialBufferHelper.getBufferCount(), // compactPartialsCount, // stateCount, // patternCount, // substitutionModelDelegate.getEigenBufferCount(), // substitutionModelDelegate.getMatrixBufferCount(), // siteRateCategoryCount, // scaleBufferHelper.getBufferCount(), // resourceList, // preferenceFlags, // requirementFlags// ); }// END: loadBeagleInstance // ///////////////// // ---INHERITED---// // ///////////////// @Override public LogColumn[] getColumns() { return new dr.inference.loggers.LogColumn[] { new LikelihoodColumn( getId() == null ? "likelihood" : getId()) }; } @Override public String getId() { return this.id; } @Override public void setId(String id) { this.id = id; } @Override public Model getModel() { // TODO Auto-generated method stub return null; } @Override public void makeDirty() { // TODO Auto-generated method stub } @Override public String prettyName() { return Abstract.getPrettyName(this); } @Override public Set<Likelihood> getLikelihoodSet() { return new HashSet<Likelihood>(Arrays.asList(this)); } @Override public boolean isUsed() { return used; } @Override public void setUsed() { used = true; } @Override public boolean evaluateEarly() { return false; } // /////////////////////// // ---PRIVATE CLASSES---// // /////////////////////// private class LikelihoodColumn extends NumberColumn { public LikelihoodColumn(String label) { super(label); }// END: Constructor public double getDoubleValue() { return getLogLikelihood(); } }// END: LikelihoodColumn class // //////////// // ---TEST---// // //////////// public static void main(String[] args) { try { MathUtils.setSeed(666); int sequenceLength = 1000; ArrayList<Partition> partitionsList = new ArrayList<Partition>(); // create tree NewickImporter importer = new NewickImporter( "((SimSeq1:22.0,SimSeq2:22.0):12.0,(SimSeq3:23.1,SimSeq4:23.1):10.899999999999999);"); Tree tree = importer.importTree(null); TreeModel treeModel = new TreeModel(tree); // create Frequency Model Parameter freqs = new Parameter.Default(new double[] { 0.25, 0.25, 0.25, 0.25 }); FrequencyModel freqModel = new FrequencyModel(Nucleotides.INSTANCE, freqs); // create branch model Parameter kappa1 = new Parameter.Default(1, 1); HKY hky1 = new HKY(kappa1, freqModel); BranchModel homogeneousBranchModel = new HomogeneousBranchModel( hky1); List<SubstitutionModel> substitutionModels = new ArrayList<SubstitutionModel>(); substitutionModels.add(hky1); List<FrequencyModel> freqModels = new ArrayList<FrequencyModel>(); freqModels.add(freqModel); // create branch rate model Parameter rate = new Parameter.Default(1, 1.000); BranchRateModel branchRateModel = new StrictClockBranchRates(rate); // create site model GammaSiteRateModel siteRateModel = new GammaSiteRateModel( "siteModel"); // create partition Partition partition1 = new Partition(treeModel, // homogeneousBranchModel,// siteRateModel, // branchRateModel, // freqModel, // 0, // from sequenceLength - 1, // to 1 // every ); partitionsList.add(partition1); // feed to sequence simulator and generate data BeagleSequenceSimulator simulator = new BeagleSequenceSimulator( partitionsList); Alignment alignment = simulator.simulate(false, false); System.out.println(alignment); BeagleTreeLikelihood btl = new BeagleTreeLikelihood(alignment, treeModel, homogeneousBranchModel, siteRateModel, branchRateModel, null, false, PartialsRescalingScheme.DEFAULT, true); System.out.println("BTL(homogeneous) = " + btl.getLogLikelihood()); BeagleBranchLikelihood bbl = new BeagleBranchLikelihood(alignment, treeModel, homogeneousBranchModel, siteRateModel, freqModel, branchRateModel); int branchIndex = 4; System.out.println(bbl.getBranchLogLikelihood(branchIndex)); bbl.finalizeBeagle(); } catch (Exception e) { e.printStackTrace(); System.exit(-1); } catch (Throwable e) { e.printStackTrace(); System.exit(-1); }// END: try-catch block }// END: main }// END: class