/* * BasicColourSampler.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evolution.colouring; import dr.evolution.alignment.Alignment; import dr.evolution.coalescent.structure.MetaPopulation; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.math.MathUtils; /** * @author Alexei Drummond * @author Gerton Lunter * @author Andrew Rambaut * <p/> * This is the old version. It samples like a substitution model, except that nodes are biased * towards populations with low Ne, since coalescences are more likely to occur there. * <p/> * It seems to work less well than a straight substitution-like model (such as Greg Ewing uses). * The reason is that although the bias is correct, along branches the bias works the other way. * By incorporating bias at the nodes, the problem along branches gets worse, and this seems to * affect the acceptance probabilities more than having the right bias at the nodes helps. So, * although the code still accepts the population sizes, it ignores it now. * @version $Id: BasicColourSampler.java,v 1.16 2006/09/11 09:33:01 gerton Exp $ */ public class BasicColourSampler implements ColourSampler { static final int maxIterations = 1000; public BasicColourSampler(Alignment tipColours, Tree tree) { if (tipColours.getSiteCount() != 1) { throw new IllegalArgumentException("Tip colour alignment must consist of a single column!"); } nodeColours = new int[tree.getNodeCount()]; colourCount = tipColours.getDataType().getStateCount(); leafColourCounts = new int[colourCount]; // initialize external node colours for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); int colour = tipColours.getState(tipColours.getTaxonIndex(tree.getTaxonId(i)), 0); nodeColours[node.getNumber()] = colour; leafColourCounts[colour]++; } nodePartials = new double[tree.getNodeCount()][colourCount]; } public BasicColourSampler(TaxonList[] tipColours, Tree tree) { nodeColours = new int[tree.getNodeCount()]; colourCount = tipColours.length + 1; leafColourCounts = new int[colourCount]; // initialize external node colours for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); int colour = 0; for (int j = 0; j < tipColours.length; j++) { if (tipColours[j].getTaxonIndex(tree.getTaxonId(i)) != -1) { colour = j + 1; } } nodeColours[node.getNumber()] = colour; leafColourCounts[colour]++; } nodePartials = new double[tree.getNodeCount()][colourCount]; } public int[] getLeafColourCounts() { return leafColourCounts; } /** * Colours the tree probabilistically with the given migration rates * * @param colourChangeMatrix the colour change rate parameters */ public DefaultTreeColouring sampleTreeColouring(Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation mp) { DefaultTreeColouring colouring = new DefaultTreeColouring(2, tree); double[] N = mp.getPopulationSizes(0); double[] rootPartials = prune(tree, tree.getRoot(), colourChangeMatrix, N); // Sampling is conditional on data; so normalize by the probability of the // data under the proposal distribution double normalization = 0.0; for (int i = 0; i < rootPartials.length; i++) { normalization += colourChangeMatrix.getEquilibrium(i) * rootPartials[i]; } sampleInternalNodes(tree, tree.getRoot(), colourChangeMatrix); sampleBranchColourings(colouring, tree, tree.getRoot(), colourChangeMatrix); double logP = calculateLogProbabilityDensity(colouring, tree, tree.getRoot(), colourChangeMatrix, N) - Math.log(normalization); colouring.setLogProbabilityDensity(logP); return colouring; } /** * @param node */ private final int getColour(NodeRef node) { return nodeColours[node.getNumber()]; } /** * @param node */ private final void setColour(NodeRef node, int colour) { if (colour >= 0 && colour < colourCount) { nodeColours[node.getNumber()] = colour; } else { throw new IllegalArgumentException("colour value " + colour + " + is outside of range of colours, [0, " + Integer.toString(colourCount - 1) + "]"); } } /***************************************************************************************** * * Probability- and sampling-related code follows * */ /** * Calculate probability of data at descendants from node, given a color at the node ('partials'), * by a Felsenstein-like pruning algorithm. (First step in the color sampling algorithm) * Side effect: updates nodePartials[] for this node and all its descendants. * * @param node * @return the partials of this node */ private final double[] prune(Tree tree, NodeRef node, ColourChangeMatrix mm, double[] N) { double[] p = new double[colourCount]; if (tree.isExternal(node)) { p[getColour(node)] = 1.0; return p; } // Note: assuming binary tree! NodeRef leftChild = tree.getChild(node, 0); NodeRef rightChild = tree.getChild(node, 1); double[] left = prune(tree, leftChild, mm, N); double[] right = prune(tree, rightChild, mm, N); double nodeHeight = tree.getNodeHeight(node); double leftTime = nodeHeight - tree.getNodeHeight(tree.getChild(node, 0)); double rightTime = nodeHeight - tree.getNodeHeight(tree.getChild(node, 1)); for (int i = 0; i < p.length; i++) { double leftSum = 0.0; double rightSum = 0.0; // looping over colours in left and right children for (int j = 0; j < left.length; j++) { // forwardTimeEvolution conditions on the parent state i, i.e. time // runs in the natural direction (forward from parent to child) leftSum += left[j] * mm.forwardTimeEvolution(i, j, leftTime); rightSum += right[j] * mm.forwardTimeEvolution(i, j, rightTime); } p[i] = leftSum * rightSum; // Condition on the formal variable // (Removed - this didn't work; without it the sampler is robust) /* if (N != null) { p[i] /= N[i]; } */ } nodePartials[node.getNumber()] = p; return p; } /** * Samples internal node colours (from root to tips) * Requires the results from Felsenstein Backwards pruning, in nodePartials[] (see prune()) * Side effect: updates nodeColours[] */ private final void sampleInternalNodes(Tree tree, NodeRef node, ColourChangeMatrix mm) { double[] backward = nodePartials[node.getNumber()]; double[] forward; if (tree.isRoot(node)) { forward = mm.getEquilibrium(); } else { NodeRef parent = tree.getParent(node); int parentColour = getColour(parent); double time = tree.getNodeHeight(parent) - tree.getNodeHeight(node); forward = new double[backward.length]; for (int i = 0; i < backward.length; i++) { forward[i] = mm.forwardTimeEvolution(parentColour, i, time); } } // Calculate the (unnormalized) probability for each colour, given the data and the // parent colour for (int i = 0; i < backward.length; i++) { forward[i] *= backward[i]; } // Sample a colour int colour = MathUtils.randomChoicePDF(forward); setColour(node, colour); // Recursively sample down the tree for (int i = 0; i < tree.getChildCount(node); i++) { NodeRef child = tree.getChild(node, i); if (!tree.isExternal(child)) { sampleInternalNodes(tree, child, mm); } } } /** * Samples the colours on a tree branch, between node and its parent, conditional on the colour at these nodes. * * @param node the node above which to sample changes */ private void sampleBranchColourings(DefaultTreeColouring colouring, Tree tree, NodeRef node, ColourChangeMatrix mm) { if (!tree.isRoot(node)) { NodeRef parent = tree.getParent(node); int parentColour = getColour(parent); int childColour = getColour(node); double parentHeight = tree.getNodeHeight(parent); double childHeight = tree.getNodeHeight(node); // Sample migration events on this branch, as a list of ColourChange-s DefaultBranchColouring history = sampleConditionalBranchColouring(parentColour, parentHeight, childColour, childHeight, mm); // Assign these migrations to the branch (attached to the child) colouring.setBranchColouring(node, history); } for (int i = 0; i < tree.getChildCount(node); i++) { sampleBranchColourings(colouring, tree, tree.getChild(node, i), mm); } } private DefaultBranchColouring sampleConditionalBranchColouring(int parentColour, double parentHeight, int childColour, double childHeight, ColourChangeMatrix mm) { DefaultBranchColouring history = new DefaultBranchColouring(parentColour, childColour); int currentColour; double currentHeight; int iterationsLeft = maxIterations; double time; // Reject until we get the child colour do { history.clear(); currentColour = parentColour; currentHeight = parentHeight; // Sample events until we reach the child do { // Sample a waiting time double totalRate = -mm.getForwardRate(currentColour, currentColour); double U; do { U = MathUtils.nextDouble(); } while (U == 0.0); // Neat trick (Rasmus Nielsen): // If colours of parent and child differ, sample conditioning on at least 1 event if ((parentColour != childColour) && (history.getNumEvents() == 0)) { double minU = Math.exp(-totalRate * (parentHeight - childHeight)); U = minU + U * (1.0 - minU); } // Calculate the waiting time, and update currentHeight time = -Math.log(U) / totalRate; currentHeight -= time; if (currentHeight > childHeight) { // Not yet reached the child. "Sample" an event currentColour = 1 - currentColour; // Add it to the list history.addEvent(currentColour, currentHeight); } } while (currentHeight > childHeight); iterationsLeft -= 1; } while ((currentColour != childColour) && (iterationsLeft > 0)); if (currentColour != childColour) { // Extreme migration rates may cause difficulty for the rejection sampler // Print a warning and add a bogus event somewhat near where you'd want it. double previousEventHeight = currentHeight + time; double finalEventHeight = childHeight + 0.01 * (previousEventHeight - childHeight); history.addEvent(childColour, finalEventHeight); System.out.println("dr.evolution.colouring.BranchColourSampler: failed to generate sample after " + maxIterations + " trials."); System.out.println(": parentColour=" + parentColour); System.out.println(": parentHeight=" + parentHeight); System.out.println(": childColour=" + childColour); System.out.println(": childHeight=" + childHeight); System.out.println(": migration rate 0->1 = " + mm.getForwardRate(0, 1)); System.out.println(": migration rate 1->0 = " + mm.getForwardRate(1, 0)); } return history; } /** * Calculates log probability density of the proposal colouring of the tree on the branch leading to this node, * and everything descending from it. */ private final double calculateLogProbabilityDensity(TreeColouring colouring, Tree tree, NodeRef node, ColourChangeMatrix mm, double[] N) { double p = 1.0; if (tree.isRoot(node)) { p = mm.getEquilibrium(colouring.getNodeColour(node)); } else { NodeRef parent = tree.getParent(node); BranchColouring history = colouring.getBranchColouring(node); // note - it is attached to the child node int fromColour = colouring.getNodeColour(parent); double fromHeight = tree.getNodeHeight(parent); // Loop over all events, forward in time (i.e. down the tree) for (int i = 1; i <= history.getNumEvents(); i++) { // get colour below this node. int toColour = history.getForwardColourBelow(i); // get new height double toHeight = history.getForwardTime(i); // factor in the exit probability p *= Math.exp(-(fromHeight - toHeight) * (-mm.getForwardRate(fromColour, fromColour))); // and the event itself p *= mm.getForwardRate(fromColour, toColour); fromHeight = toHeight; fromColour = toColour; } // Include the exit probability on the branch from the last migration event to the child. double toHeight = tree.getNodeHeight(node); p *= Math.exp(-(fromHeight - toHeight) * (-mm.getForwardRate(fromColour, fromColour))); // Include the contribution of the formal variable (if this is an internal node) // (Removed) /* if (!tree.isExternal(node) && N != null) { p /= N[fromColour]; } */ } double logP = Math.log(p); for (int i = 0; i < tree.getChildCount(node); i++) { logP += calculateLogProbabilityDensity(colouring, tree, tree.getChild(node, i), mm, N); } return logP; } // // Calculates the Lebesgue measure of the space of migration events for the number of migration events // on each branch as specified by TreeColouring. // public static final double calculateLogNormalization(TreeColouring colouring, Tree tree, NodeRef node) { final double arbitraryScaleFactor = 1.0; double logn = 0.0; if (!tree.isRoot(node)) { double norm = 1.0; double t = tree.getNodeHeight(tree.getParent(node)) - tree.getNodeHeight(node); int events = colouring.getBranchColouring(node).getNumEvents(); for (int i = 1; i <= events; i++) { norm *= t / i; } logn = arbitraryScaleFactor * Math.log(norm); } for (int i = 0; i < tree.getChildCount(node); i++) { logn += calculateLogNormalization(colouring, tree, tree.getChild(node, i)); } return logn; } public double getProposalProbability(TreeColouring treeColouring, Tree tree, ColourChangeMatrix colourChangeMatrix, MetaPopulation mp) { throw new IllegalArgumentException("Not implemented for BasicColourSampler; you can only use <ColouredOperator>s"); } private final int colourCount; private int[] nodeColours; private double[][] nodePartials; private final int[] leafColourCounts; }