/* * StructuredColourSampler.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evolution.colouring; import dr.evolution.alignment.Alignment; import dr.evolution.coalescent.structure.MetaPopulation; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.math.MathUtils; import java.text.DecimalFormat; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.TreeMap; /** * @author Gerton Lunter * @version $Id: StructuredColourSampler.java,v 1.11 2006/09/11 09:33:01 gerton Exp $ * <p/> * Samples tree colourings from a proposal distribution based on * Felsenstein mutation process; biased at the nodes for coalescent rates; * and biased at edges using a mean-field approximation of the branch * colourings. This should improve performance for unequal population sizes */ public class StructuredColourSampler implements ColourSampler { static final int maxIterations = 1000; static final double tinyTime = 1.0e-6; // to deal with discontinuous demograhpic models // debugging flags static final boolean debugMessages = false; static final boolean debugMeanColours = false; static final boolean debugNodePartials = false; static final boolean debugSampleLikelihoods = false; static final boolean debugRejectionSampler = false; static final boolean debugProposalProbabilityCalculator = false; double _totalIntegratedRate; static final DecimalFormat df = new DecimalFormat("###.####"); // tuning parameters static final double propAffected = 0.0; /* proportion of lineages deemed affected by conditioning on a single one */ private boolean useNodeBias = false; private boolean useBranchBias = false; private boolean useSecondColourIteration = true; public StructuredColourSampler(Alignment tipColours, Tree tree, boolean nodeBias, boolean branchBias, boolean secondIteration) { if (tipColours.getSiteCount() != 1) { throw new IllegalArgumentException("Tip colour alignment must consist of a single column!"); } nodeColours = new int[tree.getNodeCount()]; colourCount = tipColours.getDataType().getStateCount(); leafColourCounts = new int[colourCount]; // initialize external node colours for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); int colour = tipColours.getState(tipColours.getTaxonIndex(tree.getTaxonId(i)), 0); nodeColours[node.getNumber()] = colour; leafColourCounts[colour]++; } useNodeBias = nodeBias; useBranchBias = branchBias; useSecondColourIteration = secondIteration; initialize(tree); // uses only tree constants (number of tips, etc) } public StructuredColourSampler(TaxonList[] tipColours, Tree tree, boolean nodeBias, boolean branchBias, boolean secondIteration) { nodeColours = new int[tree.getNodeCount()]; colourCount = tipColours.length + 1; leafColourCounts = new int[colourCount]; // initialize external node colours for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); int colour = 0; for (int j = 0; j < tipColours.length; j++) { if (tipColours[j].getTaxonIndex(tree.getTaxonId(i)) != -1) { colour = j + 1; } } nodeColours[node.getNumber()] = colour; leafColourCounts[colour]++; } useNodeBias = nodeBias; useBranchBias = branchBias; useSecondColourIteration = secondIteration; initialize(tree); // uses only tree constants (number of tips etc) } public int[] getLeafColourCounts() { return leafColourCounts; } private void initialize(Tree tree) { nodePartials = new double[tree.getNodeCount()][colourCount]; meanColourCounts = new double[tree.getNodeCount()][colourCount]; nodeColoursEM = new int[tree.getNodeCount()][]; nodePartialsEM = new double[tree.getNodeCount()][][]; equilibriumColours = new double[colourCount]; } private void computeIntervals(Tree tree, MetaPopulation mp) { // dumb implementation using a sorted map - slowish but simple TreeMap<Double, ArrayList<NodeRef>> intervals = new TreeMap<Double, ArrayList<NodeRef>>(); int numnodes = tree.getNodeCount(); for (int i = 0; i < numnodes; i++) { NodeRef node = tree.getNode(i); Double height = new Double(tree.getNodeHeight(node)); if (intervals.containsKey(height)) { (intervals.get(height)).add(node); } else { ArrayList<NodeRef> list = new ArrayList<NodeRef>(1); list.add(node); intervals.put(height, list); } } node2Interval = new int[numnodes]; interval2Height = new double[intervals.size()]; avgN0 = new double[intervals.size()]; avgN1 = new double[intervals.size()]; Iterator<Double> iter = intervals.keySet().iterator(); int interval = 0; while (iter.hasNext()) { Double height = iter.next(); interval2Height[interval] = height.doubleValue(); List<NodeRef> nodes = intervals.get(height); for (int i = 0; i < nodes.size(); i++) { node2Interval[(nodes.get(i)).getNumber()] = interval; } // now initialize the effective (i.e. harmonic average) N's if (interval > 0) { double prevtime = interval2Height[interval - 1]; double curtime = height.doubleValue(); avgN0[interval - 1] = (curtime - prevtime) / mp.getIntegral(prevtime, curtime, 0); avgN1[interval - 1] = (curtime - prevtime) / mp.getIntegral(prevtime, curtime, 1); } interval += 1; } numIntervals = interval; } /** * Main entry point. Colours the tree probabilistically with the given migration rates * * @param colourChangeMatrix the colour change rate parameters */ // public TreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, double[] N) { public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation mp) { //double[] N = mp.getPopulationSizes(0); populateEquilibriumColourArray(colourChangeMatrix); // Build array of node indices, arranged by height computeIntervals(tree, mp); DefaultTreeColouring colouring = new DefaultTreeColouring(2, tree); // Calculate root partials, and those of all other nodes logNodePartialsRescaling = 0.0; prune(tree, tree.getRoot(), colourChangeMatrix); // fill meanColourCounts array calculateMeanColourCounts(tree, colourChangeMatrix); // Prune again, but now using mean colour counts for each interval logNodePartialsRescaling = 0.0; double[] rootPartials = pruneEM(tree, tree.getRoot(), colourChangeMatrix, mp); // Re-calculate meanColourCounts, and do a final pruning step if (useSecondColourIteration) { calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix); logNodePartialsRescaling = 0.0; rootPartials = pruneEM(tree, tree.getRoot(), colourChangeMatrix, mp); } // Sampling is conditional on data; so normalize by the probability of the // data under the proposal distribution double normalization = 0.0; for (int i = 0; i < colourCount; i++) { normalization += equilibriumColours[i] * rootPartials[i]; } double logNormalization = Math.log(normalization) + logNodePartialsRescaling; double logP = sampleEM(tree, tree.getRoot(), colourChangeMatrix, mp, colouring) - logNormalization; colouring.setLogProbabilityDensity(logP); //JFrame frame = new JFrame(); //ColouredTreeComponent comp = new ColouredTreeComponent() if (debugProposalProbabilityCalculator) { double logP2 = getProposalProbability(colouring, tree, colourChangeMatrix, mp); System.out.println("Sampling proposal probability " + logP + ", re-calculated as " + logP2 + " (norm=" + logNormalization + ")"); } return colouring; } /** * Returns proposal probability density of a tree colouring * * @param treeColouring * @param tree * @param colourChangeMatrix * @param mp * @return probability density */ public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation mp) { //double[] N = mp.getPopulationSizes(0); populateEquilibriumColourArray(colourChangeMatrix); // Build array of node indices, arranged by height computeIntervals(tree, mp); // Calculate root partials, and those of all other nodes logNodePartialsRescaling = 0.0; prune(tree, tree.getRoot(), colourChangeMatrix); // fill meanColourCounts array calculateMeanColourCounts(tree, colourChangeMatrix); // Prune again, but now using mean colour counts for each interval logNodePartialsRescaling = 0.0; double[] rootPartials = pruneEM(tree, tree.getRoot(), colourChangeMatrix, mp); // Re-calculate meanColourCounts, and do a final pruning step if (useSecondColourIteration) { calculateMeanColourCountsEM(tree, tree.getRoot(), colourChangeMatrix); logNodePartialsRescaling = 0.0; rootPartials = pruneEM(tree, tree.getRoot(), colourChangeMatrix, mp); } // Sampling is conditional on data; so normalize by the probability of the // data under the proposal distribution double normalization = 0.0; for (int i = 0; i < colourCount; i++) { normalization += equilibriumColours[i] * rootPartials[i]; } double logP = calculateEMProposal(tree, tree.getRoot(), colourChangeMatrix, mp, treeColouring); return logP - Math.log(normalization) - logNodePartialsRescaling; } /** * @param node */ private int getColour(NodeRef node) { return nodeColours[node.getNumber()]; } /** * @param node */ private void setColour(NodeRef node, int colour) { if (colour >= 0 && colour < colourCount) { nodeColours[node.getNumber()] = colour; } else { throw new IllegalArgumentException("colour value " + colour + " + is outside of range of colours, [0, " + Integer.toString(colourCount - 1) + "]"); } } /** * ************************************************************************************** * <p/> * Probability- and sampling-related code follows */ void populateEquilibriumColourArray(ColourChangeMatrix colourChangeMatrix) { // fill equilibrium colour array for (int clr = 0; clr < colourCount; clr++) { equilibriumColours[clr] = colourChangeMatrix.getEquilibrium(clr); } } /* Returns probability of colours at a node, based on pruning output */ double[] getMeanColours(int nodeNum, ColourChangeMatrix colourChangeMatrix) { double c[] = new double[colourCount]; double norm = 0.0; for (int i = 0; i < colourCount; i++) { c[i] = nodePartials[nodeNum][i] * equilibriumColours[i]; norm += c[i]; } for (int i = 0; i < colourCount; i++) { c[i] /= norm; } return c; } /* Returns probability of colours at an node-based intreval, based on pruning output */ double[] getMeanColoursEM(int nodeNum, int relativeInterval, ColourChangeMatrix colourChangeMatrix) { double c[] = new double[colourCount]; double norm = 0.0; for (int i = 0; i < colourCount; i++) { c[i] = nodePartialsEM[nodeNum][relativeInterval][i] * equilibriumColours[i]; norm += c[i]; } for (int i = 0; i < colourCount; i++) { c[i] /= norm; } return c; } // Recursively fills meanColourCounts with derived counts for intervals; // utility function for computeMeanColourCounts below void fillMeanColourCounts(Tree tree, NodeRef node, ColourChangeMatrix colourChangeMatrix) { if (!tree.isRoot(node)) { NodeRef parent = tree.getParent(node); int parentNum = parent.getNumber(); int nodeNum = node.getNumber(); double[] parentColours = getMeanColours(parentNum, colourChangeMatrix); double[] childColours = getMeanColours(nodeNum, colourChangeMatrix); for (int i = 0; i < colourCount; i++) { // calculate expected colour at parent and child end of branch double meanColour = (parentColours[i] + childColours[i]) / 2.0; meanColourCounts[node2Interval[nodeNum]][i] += meanColour; meanColourCounts[node2Interval[parentNum]][i] -= meanColour; } } if (!tree.isExternal(node)) { // Note: assuming binary tree! NodeRef leftChild = tree.getChild(node, 0); NodeRef rightChild = tree.getChild(node, 1); fillMeanColourCounts(tree, leftChild, colourChangeMatrix); fillMeanColourCounts(tree, rightChild, colourChangeMatrix); } } // Recursively fills meanColourCounts with derived counts for intervals; // utility function for computeMeanColourCountsEM below void calculateMeanColourCountsEM(Tree tree, NodeRef node, ColourChangeMatrix colourChangeMatrix) { if (tree.isRoot(node)) { // empty the meanColourCounts array for (int clr = 0; clr < colourCount; clr++) { for (int i = 0; i < meanColourCounts.length; i++) { meanColourCounts[i][clr] = 0.0; } } } else { NodeRef parent = tree.getParent(node); int parentNum = parent.getNumber(); int nodeNum = node.getNumber(); int bottomInterval = node2Interval[nodeNum]; int topInterval = node2Interval[parentNum]; double[] belowColours = getMeanColoursEM(nodeNum, 0, colourChangeMatrix); double[] aboveColours; for (int relInterval = 0; relInterval < topInterval - bottomInterval; relInterval++) { int aboveInterval = relInterval + 1; if (aboveInterval + bottomInterval < topInterval) { aboveColours = getMeanColoursEM(nodeNum, aboveInterval, colourChangeMatrix); } else { if (!tree.isRoot(parent)) { aboveColours = getMeanColoursEM(parentNum, 0, colourChangeMatrix); } else { aboveColours = belowColours; } } for (int i = 0; i < colourCount; i++) { double meanColour = (aboveColours[i] + belowColours[i]) / 2.0; meanColourCounts[relInterval + bottomInterval][i] += meanColour; } belowColours = aboveColours; } } if (!tree.isExternal(node)) { // Note: assuming binary tree! NodeRef leftChild = tree.getChild(node, 0); NodeRef rightChild = tree.getChild(node, 1); calculateMeanColourCountsEM(tree, leftChild, colourChangeMatrix); calculateMeanColourCountsEM(tree, rightChild, colourChangeMatrix); } if (tree.isRoot(node) && debugMeanColours) { for (int i = 0; i < numIntervals; i++) { System.out.println("EM Interval " + i + " height=" + interval2Height[i] + " "); double sum = 0.0; for (int clr = 0; clr < colourCount; clr++) { System.out.print("c=" + clr + " mean=" + meanColourCounts[i][clr] + " "); sum += meanColourCounts[i][clr]; } System.out.println(" sum=" + sum); } } } // Computes the meanColourCounts array for the intervals void calculateMeanColourCounts(Tree tree, ColourChangeMatrix colourChangeMatrix) { // empty the meanColourCounts array for (int clr = 0; clr < colourCount; clr++) { for (int i = 0; i < meanColourCounts.length; i++) { meanColourCounts[i][clr] = 0.0; } } // populate meanColourCounts array, so that sum of counts before and including an interval, // give average colour counts for that interval. fillMeanColourCounts(tree, tree.getRoot(), colourChangeMatrix); // Sum out, so that actual count gives average colour count on interval for (int clr = 0; clr < colourCount; clr++) { double sum = 0.0; for (int i = 0; i < meanColourCounts.length; i++) { sum += meanColourCounts[i][clr]; meanColourCounts[i][clr] = sum; } } if (debugMeanColours) { for (int i = 0; i < numIntervals; i++) { System.out.println("Interval " + i + " height=" + interval2Height[i] + " "); double sum = 0.0; for (int clr = 0; clr < colourCount; clr++) { System.out.print("c=" + clr + " mean=" + meanColourCounts[i][clr] + " "); sum += meanColourCounts[i][clr]; } System.out.println(" sum=" + sum); } } } /** * Calculate probability of data at descendants from node, given a color at the node ('partials'), * by a Felsenstein-like pruning algorithm. (First step in the color sampling algorithm) * Side effect: updates nodePartials[] for this node and all its descendants. * * @param node * @return the partials of this node */ private double[] prune(Tree tree, NodeRef node, ColourChangeMatrix mm) { double[] p = new double[colourCount]; if (tree.isExternal(node)) { p[getColour(node)] = 1.0; } else { // Note: assuming binary tree! NodeRef leftChild = tree.getChild(node, 0); NodeRef rightChild = tree.getChild(node, 1); double[] left = prune(tree, leftChild, mm); double[] right = prune(tree, rightChild, mm); double nodeHeight = tree.getNodeHeight(node); double leftTime = nodeHeight - tree.getNodeHeight(tree.getChild(node, 0)); double rightTime = nodeHeight - tree.getNodeHeight(tree.getChild(node, 1)); double maxp = 0.0; for (int i = 0; i < colourCount; i++) { double leftSum = 0.0; double rightSum = 0.0; // looping over colours in left and right children for (int j = 0; j < colourCount; j++) { // forwardTimeEvolution conditions on the parent state i, i.e. time // runs in the natural direction (forward from parent to child) leftSum += mm.forwardTimeEvolution(i, j, leftTime) * left[j]; rightSum += mm.forwardTimeEvolution(i, j, rightTime) * right[j]; } p[i] = leftSum * rightSum; // This version does not condition on the formal variable, because // that tends to bias towards the small population, whereas no correction // for bias towards the large population on branches is made at this stage if (p[i] > maxp) { maxp = p[i]; } } // take care of underflow if (maxp < 1.0e-100) { for (int i = 0; i < colourCount; i++) { p[i] *= 1.0e+100; } logNodePartialsRescaling -= Math.log(1.0e+100); } } nodePartials[node.getNumber()] = p; if (debugNodePartials) { prettyPrint("Node " + node.getNumber() + " prune=", p); } return p; } /* Calculates the result of multiplying the vector vec with the matrix Exp[{{-a,b},{c,-d}}] */ static double[] matrixEvolve(double[] mx, int current) { double a = mx[0]; double b = mx[1]; double c = mx[2]; double d = mx[3]; /* double[] vec = new double[2]; vec[current] = 1.0; double S = Math.sqrt( (a-d)*(a-d) + 4*b*c ); double c00,c01,c10,c11; if (S < 1.0e-5 ) { c00 = Math.exp( -a ); c11 = Math.exp( -d ); c01 = b*c00; c10 = c*c11; } else { double T = Math.exp( -(a+d+S)/2.0 ); double U = Math.exp( -(a+d-S)/2.0 ); c00 = ( (d-a+S)*U - (d-a-S)*T ) / (2*S); c01 = ( b*(U-T) ) / S; c10 = ( c*(U-T) ) / S; c11 = ( (a-d+S)*U - (a-d-S)*T ) / (2*S); } double v0 = vec[0] * c00 + vec[1] * c10; vec[1] = vec[0] * c01 + vec[1] * c11; vec[0] = v0; return vec; */ double S = Math.sqrt((a - d) * (a - d) + 4 * b * c); if (S < 1.0e-5) { if (current == 0) { a = Math.exp(-a); return new double[]{a, b * a}; } else { d = Math.exp(-d); return new double[]{c * d, d}; } } else { double T = Math.exp(-(a + d + S) / 2.0); double U = Math.exp(-(a + d - S) / 2.0); if (current == 0) { return new double[]{((d - a + S) * U - (d - a - S) * T) / (2 * S), (b * (U - T)) / S}; } else { return new double[]{(c * (U - T)) / S, ((a - d + S) * U - (a - d - S) * T) / (2 * S)}; } } } /* Calculates the result of multiplying the matrix Exp[{{-a,b},{c,-d}}] with vec */ static void matrixPullBack(double[] mx, double[] vec) { double a = mx[0]; double b = mx[1]; double c = mx[2]; double d = mx[3]; double S = Math.sqrt((a - d) * (a - d) + 4 * b * c); double c00, c01, c10, c11; if (S < 1.0e-5) { c00 = Math.exp(-a); c11 = Math.exp(-d); c01 = b * c00; c10 = c * c11; } else { double T = Math.exp(-(a + d + S) / 2.0); double U = Math.exp(-(a + d - S) / 2.0); c00 = ((d - a + S) * U - (d - a - S) * T) / (2 * S); c01 = (b * (U - T)) / S; c10 = (c * (U - T)) / S; c11 = ((a - d + S) * U - (a - d - S) * T) / (2 * S); } double v0 = vec[0] * c00 + vec[1] * c01; vec[1] = vec[0] * c10 + vec[1] * c11; vec[0] = v0; } double[] calculateMatrixElts(int curinterval, NodeRef affectedNode, Tree tree, double time, double N0, double N1, ColourChangeMatrix mm) { double meancolour0 = meanColourCounts[curinterval][0]; double meancolour1 = meanColourCounts[curinterval][1]; double lineages = meancolour0 + meancolour1; // get coalescent rate for colour 0 double coalescent0 = ((propAffected * (lineages - 1)) + (1.0 - propAffected) * (meancolour0 - 1)) / (2.0 * N0) * time; if (coalescent0 < 0.0) { coalescent0 = 0.0; } // same for colour 1 double coalescent1 = ((propAffected * (lineages - 1)) + (1.0 - propAffected) * (meancolour1 - 1)) / (2.0 * N1) * time; if (coalescent1 < 0.0) { coalescent1 = 0.0; } if (!useBranchBias) { coalescent0 = 0.0; coalescent1 = 0.0; } // partial conditioning on survival; take away coalescences that are // common to both colours double commonCoalescentRate = Math.min(coalescent0, coalescent1); /* debug */ //commonCoalescentRate = 0.0; coalescent0 -= commonCoalescentRate; coalescent1 -= commonCoalescentRate; // get forward rates double fwrate01 = mm.getForwardRate(0, 1) * time; double fwrate10 = mm.getForwardRate(1, 0) * time; // next, evolve 'with leakage'. Matrix: // 0->1 : rate01; diagonal 0 : -rate01-coalescent0 // 1->0 : rate10; diagonal 1 : -rate10-coalescent1 return new double[]{fwrate01 + coalescent0, fwrate01, fwrate10, fwrate10 + coalescent1}; } /* Calculates partial branch probabilities with mean-field coalescent correction */ /* Inputs are the partials at the child node; results on branch are stored in nodePartialsEM[][][] */ /* This version is for two colours only */ double[] pruneBranchEM(ColourChangeMatrix mm, double[] inState, NodeRef parent, NodeRef child, Tree tree, MetaPopulation mp) { // Get interval above parent (highest) and above child (lowest) int topInterval = node2Interval[parent.getNumber()]; int bottomInterval = node2Interval[child.getNumber()]; // Make a result array and a current state double[][] partials = new double[topInterval - bottomInterval][2]; double[] state = inState.clone(); int curinterval = bottomInterval; while (curinterval != topInterval) { // Store current state partials[curinterval - bottomInterval][0] = state[0]; partials[curinterval - bottomInterval][1] = state[1]; // Calculate partials at top of current interval double time = interval2Height[curinterval + 1] - interval2Height[curinterval]; //double effectiveN0 = time / mp.getIntegral( lowertime, uppertime, 0 ); //double effectiveN1 = time / mp.getIntegral( lowertime, uppertime, 1 ); double[] mxElts = calculateMatrixElts(curinterval, child, tree, time, avgN0[curinterval], avgN1[curinterval], mm); matrixPullBack(mxElts, state); // update for next interval curinterval += 1; } // Store branch partials (including child's, excluding parent's) at child nodePartialsEM[child.getNumber()] = partials; // Return parent's partials return state; } /** * Calculate probability of data at descendants from node, given a color at the node ('partials'), * by a Felsenstein-like pruning algorithm. (First step in the color sampling algorithm) * Side effect: updates nodePartials[] and nodePartialsEM[][] for this node and all its descendants. * This version uses mean-field coalescence correction and formal variable correction */ private double[] pruneEM(Tree tree, NodeRef node, ColourChangeMatrix mm, MetaPopulation mp) { double[] p = new double[colourCount]; if (tree.isExternal(node)) { p[getColour(node)] = 1.0; } else { // Note: assuming binary tree! NodeRef leftChild = tree.getChild(node, 0); NodeRef rightChild = tree.getChild(node, 1); // Obtain partials at the child nodes double[] left = pruneEM(tree, leftChild, mm, mp); double[] right = pruneEM(tree, rightChild, mm, mp); // Calculate partials at parent node // (Side effect: calculates partials for each interval along branch in nodePartialsEM) double[] leftBranchPartials = pruneBranchEM(mm, left, node, leftChild, tree, mp); double[] rightBranchPartials = pruneBranchEM(mm, right, node, rightChild, tree, mp); double maxp = 0.0; for (int i = 0; i < colourCount; i++) { p[i] = leftBranchPartials[i] * rightBranchPartials[i]; // Condition on the formal variable if (useNodeBias) { p[i] *= mm.getEquilibrium(i) / mp.getDemographic(tree.getNodeHeight(node) - tinyTime, i); } if (p[i] > maxp) { maxp = p[i]; } } // take care of underflow if (maxp < 1.0e-100) { for (int i = 0; i < colourCount; i++) { p[i] *= 1.0e+100; } logNodePartialsRescaling -= Math.log(1.0e+100); } } // Store results for the sampling step nodePartials[node.getNumber()] = p; if (debugNodePartials) { prettyPrint("Node " + node.getNumber() + " pruneEM=", p); } return p; } /** * Samples internal node colours (from root to tips) and events * Precondition: parent node has been sampled, descendants have not * Requires the results from Felsenstein Backwards pruning, in nodePartials(EM)[] (see pruneEM()) * Side effect: updates nodeColours[] */ private double sampleEM(Tree tree, NodeRef node, ColourChangeMatrix mm, MetaPopulation mp, DefaultTreeColouring colouring) { double[] forward; double[] posterior; int colour; DefaultBranchColouring history = null; double logLikelihood = 0.0; if (tree.isRoot(node)) { _totalIntegratedRate = 0.0; /* for debugging */ forward = mm.getEquilibrium(); double[] backward = nodePartials[node.getNumber()]; posterior = new double[colourCount]; double max = -1.0; double min = 1.0; for (int i = 0; i < forward.length; i++) { posterior[i] = forward[i] * backward[i]; max = Math.max(max, posterior[i]); min = Math.min(min, posterior[i]); } if (debugMessages && min < 0.0) { System.out.println("Aargh, negative probabilities " + min + " " + max); } if (debugMessages && max < 1.0e-200) { System.out.println("Hmm, very small numbers indeed " + max); } // Sample a colour colour = MathUtils.randomChoicePDF(posterior); // Add likelihood of (unconditional) probability logLikelihood += Math.log(forward[colour]); } else { int nodeNum = node.getNumber(); double[][] backward = nodePartialsEM[nodeNum]; int bottomInterval = node2Interval[nodeNum]; // Allocate array for interval node colours nodeColoursEM[nodeNum] = new int[backward.length]; // Start from parent, and sample colours for interval nodes colour = getColour(tree.getParent(node)); // Allocate event history history = new DefaultBranchColouring(colour, colour); // Allocate posterior probabilities posterior = new double[colourCount]; // Go down the branch (from parent to child) for (int relinterval = backward.length - 1; relinterval >= 0; relinterval--) { int interval = relinterval + bottomInterval; double childHeight = interval2Height[interval]; double uppertime = interval2Height[interval + 1]; double time = uppertime - childHeight; double[] mxElts = calculateMatrixElts(interval, node, tree, time, avgN0[interval], avgN1[interval], mm); forward = matrixEvolve(mxElts, colour); // Calculate posterior for interval node for (int i = 0; i < colourCount; i++) { posterior[i] = forward[i] * backward[relinterval][i]; } // Sample a colour int childColour = MathUtils.randomChoicePDF(posterior); // Store colour nodeColoursEM[nodeNum][relinterval] = childColour; // Sample events conditional on parent and child colour //System.out.println("node="+node.getNumber()+" parent="+colour+" child="+childColour+" interval="+relinterval+" time="+time); logLikelihood += sampleConditionalBranchColouringEM(node, colour, childColour, time, childHeight, mxElts, history); // update colour colour = childColour; } // Store history for branch at child node colouring.setBranchColouring(node, history); } // store colour in node setColour(node, colour); // add contribution of formal variable - except if this is a leaf, // since there is no corresponding coalescent event if (!tree.isExternal(node) && useNodeBias) { double childTime = tree.getNodeHeight(node); logLikelihood += Math.log(mm.getEquilibrium(colour) / mp.getDemographic(childTime - tinyTime, colour)); } // Recursively sample down the tree for (int i = 0; i < tree.getChildCount(node); i++) { NodeRef child = tree.getChild(node, i); logLikelihood += sampleEM(tree, child, mm, mp, colouring); } if (debugSampleLikelihoods && tree.isRoot(node)) { System.out.println("Sampled likelihood " + logLikelihood); System.out.println("Integrated exit rate " + _totalIntegratedRate); System.out.println("sampleEMProposal node=" + node.getNumber() + " logL=" + logLikelihood); } //System.out.println("sampleEM node="+node.getNumber()+"\tlogL="+logLikelihood); return logLikelihood; } /** * This samples events along an interval, and adds these events to a BranchColouring. * Also returns the (unconditional) log probability density for these events */ private double sampleConditionalBranchColouringEM( NodeRef node, int parentColour, int childColour, double time, double childHeight, double[] matrixElements, DefaultBranchColouring initialBranchColouring) { DefaultBranchColouring history = new DefaultBranchColouring(parentColour, childColour); int iterations = 0; int currentColour; double currentHeight, dt; double logLikelihood; boolean reject, firstEvent; double _sumRate = 0.0; String _msg = ""; // Reject until we get the child colour do { history.clear(); currentColour = parentColour; currentHeight = time; logLikelihood = 0.0; reject = false; firstEvent = true; if (debugRejectionSampler) { _sumRate = 0.0; _msg = "sample (iter=" + iterations + ") from " + parentColour + " to " + childColour + " at " + df.format(childHeight) + " over " + df.format(time) + " rate=" + df.format(matrixElements[3 * parentColour] / time) + "\n"; } // Sample events until we reach the child do { // Sample a waiting time double totalRate, migrationRate; if (currentColour == 0) { totalRate = matrixElements[0] / time; migrationRate = matrixElements[1] / time; } else { totalRate = matrixElements[3] / time; migrationRate = matrixElements[2] / time; } double U; do { U = MathUtils.nextDouble(); } while (U == 0.0); // Neat trick (Rasmus Nielsen): // If colours of parent and child differ, condition on at least 1 event if (firstEvent && (parentColour != childColour)) { double minU = Math.exp(-totalRate * time); U = minU + U * (1.0 - minU); } // Calculate the waiting time, and update currentHeight dt = -Math.log(U) / totalRate; currentHeight -= dt; if (currentHeight > 0) { // Not yet reached the child. Sample an event. // Neat trick II: condition on first event being a migration event // Also, don't bother calculating a random sample if coalescent events // have zero rate if (debugRejectionSampler) { _sumRate += totalRate * dt; _msg = _msg + "rate=" + df.format(totalRate) + "\ttime=" + df.format(dt) + "\tintrate=" + df.format(totalRate * dt) + "\n"; } if (firstEvent || (migrationRate == totalRate) || (MathUtils.nextDouble() < migrationRate / totalRate)) { // migration event //System.out.println("Event height="+(currentHeight + childHeight)); currentColour = 1 - currentColour; // Add it to the list history.addEvent(currentColour, currentHeight + childHeight); // Update likelihood logLikelihood += (-totalRate * dt) + Math.log(migrationRate); // ... replaces this to keep in log space // likelihood *= Math.exp( -totalRate * dt ) * migrationRate; } else { // coalescent event reject = true; } } else { if (debugRejectionSampler) { _sumRate += totalRate * (currentHeight + dt); _msg = _msg + "rate=" + df.format(totalRate) + "\ttime=" + df.format(currentHeight + dt) + "\tintrate=" + df.format(totalRate * (currentHeight + dt)) + "\n"; } // Update likelihood, no event until child //System.out.println(" totalRate="+totalRate+" dt="+(currentHeight+dt)+" time="+time); logLikelihood += -totalRate * (currentHeight + dt); // ... replaces this to keep in log space //likelihood *= Math.exp( -totalRate * (currentHeight + dt) ); } firstEvent = false; } while ((!reject) && (currentHeight > 0)); iterations += 1; if (currentColour != childColour) { reject = true; } } while (reject && (iterations < maxIterations)); if (debugRejectionSampler) { _totalIntegratedRate += _sumRate; System.out.print(_msg); } if (reject) { // Extreme migration rates may cause difficulty for the rejection sampler // Print a warning and add a bogus event somewhat near where you'd want it. if (currentColour != childColour) { double previousEventHeight = currentHeight + dt; double finalEventHeight = 0.01 * previousEventHeight; history.addEvent(childColour, finalEventHeight + childHeight); if (debugMessages) { System.out.println("dr.evolution.colouring.BranchColourSampler: failed to generate sample after " + maxIterations + " trials."); System.out.println(": height=" + time); System.out.println(": childColour=" + childColour); System.out.println(": childHeight=" + childHeight); System.out.println(": migration rate 0->1 = " + matrixElements[1]); System.out.println(": migration rate 1->0 = " + matrixElements[2]); System.out.println(": total rate 0 = " + matrixElements[0]); System.out.println(": total rate 1 = " + matrixElements[3]); } } } // Add this interval's history to the previously sampled history initialBranchColouring.addHistory(history); //System.out.println("sampleConditionalBranchColouringEM child="+node.getNumber()+"\tchildH="+childHeight+"\tlogL="+Math.log(likelihood)+"\tevents="+history.getNumEvents()+"\tfinal clr="+childColour); return logLikelihood; } /** * Calculates proposal density for given colouring * Requires the results from Felsenstein Backwards pruning, in nodePartials(EM)[] (see pruneEM()) */ private double calculateEMProposal(Tree tree, NodeRef node, ColourChangeMatrix mm, MetaPopulation mp, TreeColouring colouring) { double[] forward; int colour; BranchColouring history = null; double logLikelihood = 0.0; if (tree.isRoot(node)) { forward = mm.getEquilibrium(); // Get root colour colour = colouring.getNodeColour(node); // Add likelihood of (unconditional) probability logLikelihood += Math.log(forward[colour]); } else { int nodeNum = node.getNumber(); double[][] backward = nodePartialsEM[nodeNum]; int bottomInterval = node2Interval[nodeNum]; // Get branch colouring history = colouring.getBranchColouring(node); // Go down the branch (from parent to child) for (int relinterval = backward.length - 1; relinterval >= 0; relinterval--) { int interval = relinterval + bottomInterval; double childHeight = interval2Height[interval]; double uppertime = interval2Height[interval + 1]; double time = uppertime - childHeight; double[] mxElts = calculateMatrixElts(interval, node, tree, time, avgN0[interval], avgN1[interval], mm); logLikelihood += calculateConditionalBranchColouringEM(node, time, childHeight, mxElts, history); } colour = colouring.getNodeColour(node); } if (!tree.isExternal(node) && useNodeBias) { double childTime = tree.getNodeHeight(node); logLikelihood += Math.log(mm.getEquilibrium(colour) / mp.getDemographic(childTime - tinyTime, colour)); } // Recursively sample down the tree for (int i = 0; i < tree.getChildCount(node); i++) { NodeRef child = tree.getChild(node, i); logLikelihood += calculateEMProposal(tree, child, mm, mp, colouring); } //System.out.println("calculateEMProposal node="+node.getNumber()+"\tlogL="+logLikelihood); return logLikelihood; } /** * This samples events along an interval, and adds these events to a BranchColouring. * Also returns the (unconditional) log probability density for these events */ private double calculateConditionalBranchColouringEM( NodeRef node, double time, double childHeight, double[] matrixElements, BranchColouring branchColouring) { double currentHeight = time + childHeight; int nextEvent = branchColouring.getNextForwardEvent(currentHeight); int currentColour = branchColouring.getForwardColourBelow(nextEvent - 1); double logLikelihood = 0.0; int numEvents = 0; while (currentHeight > childHeight) { double nextEventHeight; //System.out.println("nextEvent="+nextEvent+" total="+branchColouring.getNumEvents()); if (nextEvent == branchColouring.getNumEvents() + 1) { nextEventHeight = childHeight - 1.0; // past child } else { nextEventHeight = branchColouring.getForwardTime(nextEvent); //System.out.println("Event height="+nextEventHeight); } double dt = currentHeight - nextEventHeight; double totalRate, migrationRate; if (currentColour == 0) { totalRate = matrixElements[0] / time; migrationRate = matrixElements[1] / time; } else { totalRate = matrixElements[3] / time; migrationRate = matrixElements[2] / time; } if (nextEventHeight < childHeight) { // no event - include exit probability, and done. dt = currentHeight - childHeight; //System.out.println(" totalRate="+totalRate+" dt="+dt+" time="+time); logLikelihood += -totalRate * dt; // ... replaces this to keep in log space //likelihood *= Math.exp( -totalRate * dt ); } else { // proper migration event logLikelihood += (-totalRate * dt) + Math.log(migrationRate); // ... replaces this to keep in log space //likelihood *= Math.exp( -totalRate * dt ) * migrationRate; currentColour = branchColouring.getForwardColourBelow(nextEvent); numEvents += 1; } currentHeight = nextEventHeight; nextEvent += 1; } //System.out.println("calculateConditionalBranchColouring child="+node.getNumber()+"\tchildH="+childHeight+"\tlogL="+Math.log(likelihood)+"\tevents="+numEvents+"\tfinal clr="+currentColour); return logLikelihood; } private void prettyPrint(String label, double[] vec) { System.out.print(label + "= ("); for (double x : vec) { System.out.print(x + ", "); } System.out.println(")"); } // Testing code static void testMatrix(double[] m, double[] r) { if (Math.abs(matrixEvolve(m, 0)[0] - r[0]) > 1.0e-6) { throw new Error("1"); } if (Math.abs(matrixEvolve(m, 0)[1] - r[1]) > 1.0e-6) { throw new Error("2"); } if (Math.abs(matrixEvolve(m, 1)[0] - r[2]) > 1.0e-6) { throw new Error("3"); } if (Math.abs(matrixEvolve(m, 1)[1] - r[3]) > 1.0e-6) { throw new Error("4"); } double vec0[] = {1, 0}; double vec1[] = {0, 1}; matrixPullBack(m, vec0); matrixPullBack(m, vec1); if (Math.abs(vec0[0] - r[0]) > 1.0e-6) { throw new Error("5"); } if (Math.abs(vec0[1] - r[2]) > 1.0e-6) { throw new Error("7"); } if (Math.abs(vec1[0] - r[1]) > 1.0e-6) { throw new Error("6"); } if (Math.abs(vec1[1] - r[3]) > 1.0e-6) { throw new Error("8"); } } public static void main(String[] args) { double m[] = {5.0, 3.0, 2.0, 3.0}; double result[] = {0.0811818, 0.145616, 0.097077, 0.178259}; testMatrix(m, result); System.out.println("First matrix OK"); double m2[] = {1.0, 1.0, 0.0, 1.0}; double result2[] = {0.367879, 0.367879, 0, 0.367879}; testMatrix(m2, result2); System.out.println("Second matrix OK"); double m3[] = {1.0, 0.0, 1.0, 1.0}; double result3[] = {0.367879, 0.0, 0.367879, 0.367879}; testMatrix(m3, result3); System.out.println("Third matrix OK"); } private final int colourCount; private final int[] nodeColours; private final int[] leafColourCounts; private double[][] meanColourCounts; // [i][c] = approximate number of lineages coloured c in interval i private int[][] nodeColoursEM; // [n][r] = colour of lineage above node n and interval r (relative) private double[][] nodePartials; // [node][colour] = p(descendant data | node colour) private double[][][] nodePartialsEM; // [node][rel.interval][colour] = p(desc | node clr) private double logNodePartialsRescaling; // rescaling factor to prevent underflow private double[] equilibriumColours; private int[] node2Interval; // linear index for interval corresponding to (child) node private double[] interval2Height; private double[] avgN0; private double[] avgN1; private int numIntervals; }