/* * TransmissionSubtreeSlideB.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase.operators; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.epidemiology.casetocase.AbstractCase; import dr.evomodel.epidemiology.casetocase.BranchMapModel; import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood; import dr.evomodel.operators.AbstractTreeOperator; import dr.evomodel.tree.TreeModel; import dr.inference.operators.*; import dr.math.MathUtils; import dr.xml.*; import java.util.ArrayList; import java.util.List; /** * Implements the subtree slide move where the slide moves an entire subtree of the transmission tree. * * @author Matthew Hall * */ public class TransmissionSubtreeSlideB extends AbstractTreeOperator implements CoercableMCMCOperator { private CaseToCaseTreeLikelihood c2cLikelihood; private TreeModel tree; private double size = 1.0; private boolean gaussian = false; private final boolean swapInRandomRate; private final boolean swapInRandomTrait; private final boolean resampleInfectionTimes; private CoercionMode mode = CoercionMode.DEFAULT; private static final boolean DEBUG = false; public static final String TRANSMISSION_SUBTREE_SLIDE_B = "transmissionSubtreeSlideB"; public static final String SWAP_RATES = "swapInRandomRate"; public static final String SWAP_TRAITS = "swapInRandomTrait"; public static final String TARGET_ACCEPTANCE = "targetAcceptance"; public TransmissionSubtreeSlideB(CaseToCaseTreeLikelihood c2cLikelihood, double weight, double size, boolean gaussian, boolean swapRates, boolean swapTraits, CoercionMode mode, boolean resampleInfectionTimes) { this.c2cLikelihood = c2cLikelihood; tree = c2cLikelihood.getTreeModel(); setWeight(weight); if (size == 0.0) { double b = 0.0; for (int k = 0; k < tree.getNodeCount(); ++k) { b += tree.getBranchLength(tree.getNode(k)); } size = b / (2 * tree.getNodeCount()); } this.size = size; this.gaussian = gaussian; this.swapInRandomRate = swapRates; this.swapInRandomTrait = swapTraits; this.resampleInfectionTimes = resampleInfectionTimes; this.mode = mode; } /** * Do a probablistic subtree slide move. * * @return the log-transformed hastings ratio */ public double doOperation() { if(DEBUG){ c2cLikelihood.outputTreeToFile("beforeTSSB.nex", false); } BranchMapModel branchMap = c2cLikelihood.getBranchMap(); double logq; NodeRef i; // 1. choose a random eligible node avoiding root do { i = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); } while (!eligibleForMove(i, tree, branchMap)); final NodeRef iP = tree.getParent(i); final NodeRef CiP = getOtherChild(tree, iP, i); final NodeRef PiP = tree.getParent(iP); // 2. choose a delta to move final double delta = getDelta(); final double oldHeight = tree.getNodeHeight(iP); final double newHeight = oldHeight + delta; AbstractCase iCase = branchMap.get(i.getNumber()); AbstractCase iPCase = branchMap.get(iP.getNumber()); AbstractCase CiPCase = branchMap.get(CiP.getNumber()); AbstractCase PiPCase = null; if(PiP!=null){ PiPCase = branchMap.get(PiP.getNumber()); } if(resampleInfectionTimes) { // what happens on i's branch (there has always been a change) iCase.setInfectionBranchPosition(MathUtils.nextDouble()); // what happens between PiP and CiP if (PiPCase == null || CiPCase != PiPCase) { CiPCase.setInfectionBranchPosition(MathUtils.nextDouble()); } } // 3. if the move is down if (delta > 0) { // 3.1 if the topology will change if (PiP != null && tree.getNodeHeight(PiP) < newHeight) { // find new parent NodeRef newParent = PiP; NodeRef newChild = iP; while (tree.getNodeHeight(newParent) < newHeight) { newChild = newParent; newParent = tree.getParent(newParent); if (newParent == null) break; } tree.beginTreeEdit(); // 3.1.1 if creating a new root if (tree.isRoot(newChild)) { tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); tree.addChild(iP, newChild); tree.addChild(PiP, CiP); tree.setRoot(iP); //System.err.println("Creating new root!"); if (tree.hasNodeTraits()) { // ********************************************** // swap traits and rates so that root keeps it trait and rate values // ********************************************** tree.swapAllTraits(newChild, iP); } if (tree.hasRates()) { final double rootNodeRate = tree.getNodeRate(newChild); tree.setNodeRate(newChild, tree.getNodeRate(iP)); tree.setNodeRate(iP, rootNodeRate); } // ********************************************** } // 3.1.2 no new root else { tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); tree.removeChild(newParent, newChild); tree.addChild(iP, newChild); tree.addChild(PiP, CiP); tree.addChild(newParent, iP); //System.err.println("No new root!"); } tree.setNodeHeight(iP, newHeight); tree.endTreeEdit(); // 3.1.3 count the hypothetical sources of this destination. final int possibleSources = intersectingEdges(tree, newChild, oldHeight, null); logq = -Math.log(possibleSources); // Randomly assign iP the partition of either its parent or the child that is not i, and adjust q // appropriately if(PiPCase != CiPCase){ logq += Math.log(0.5); } AbstractCase newiPCase; AbstractCase newChildCase = branchMap.get(newChild.getNumber()); if(newParent != null && branchMap.get(newParent.getNumber())!=branchMap.get(newChild.getNumber())){ if(MathUtils.nextInt(2)==0){ newiPCase = branchMap.get(newParent.getNumber()); } else { newiPCase = newChildCase; } if(resampleInfectionTimes) { //whichever we picked for iP, it's the new child's case whose infection branch is modified // (even if this infection branch is iP's branch) newChildCase.setInfectionBranchPosition(MathUtils.nextDouble()); } logq += Math.log(2); } else { newiPCase = newChildCase; // if iP is now the root its case infection time has changed. If not, then there is no infection // on this branch if(resampleInfectionTimes) { if (newParent == null) { newChildCase.setInfectionBranchPosition(MathUtils.nextDouble()); } } } branchMap.set(iP.getNumber(), newiPCase, true); } else { // just change the node height // todo you could actually randomise whether the subtree containing iP is changed here tree.setNodeHeight(iP, newHeight); logq = 0.0; } } // 4 if we are sliding the subtree up. else { // 4.0 is it a valid move? if (tree.getNodeHeight(i) > newHeight) { return Double.NEGATIVE_INFINITY; } // 4.1 will the move change the topology? if (tree.getNodeHeight(CiP) > newHeight) { List<NodeRef> newChildren = new ArrayList<NodeRef>(); final int possibleDestinations = intersectingEdges(tree, CiP, newHeight, newChildren); // if no valid destinations then return a failure if (newChildren.size() == 0) { return Double.NEGATIVE_INFINITY; } // pick a random parent/child destination edge uniformly from options final int childIndex = MathUtils.nextInt(newChildren.size()); NodeRef newChild = newChildren.get(childIndex); NodeRef newParent = tree.getParent(newChild); tree.beginTreeEdit(); // 4.1.1 if iP was root if (tree.isRoot(iP)) { // new root is CiP tree.removeChild(iP, CiP); tree.removeChild(newParent, newChild); tree.addChild(iP, newChild); tree.addChild(newParent, iP); tree.setRoot(CiP); if (tree.hasNodeTraits()) { // ********************************************** // swap traits and rates, so that root keeps it trait and rate values // ********************************************** tree.swapAllTraits(iP, CiP); } if (tree.hasRates()) { final double rootNodeRate = tree.getNodeRate(iP); tree.setNodeRate(iP, tree.getNodeRate(CiP)); tree.setNodeRate(CiP, rootNodeRate); } // ********************************************** //System.err.println("DOWN: Creating new root!"); } else { tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); tree.removeChild(newParent, newChild); tree.addChild(iP, newChild); tree.addChild(PiP, CiP); tree.addChild(newParent, iP); //System.err.println("DOWN: no new root!"); } tree.setNodeHeight(iP, newHeight); tree.endTreeEdit(); logq = Math.log(possibleDestinations); // Randomly assign iP the partition of either its parent or the child that is not i, and adjust q // appropriately if(PiP!=null && PiPCase != CiPCase){ logq += Math.log(0.5); } AbstractCase newiPCase; AbstractCase newChildCase = branchMap.get(newChild.getNumber()); if(branchMap.get(newParent.getNumber())!=branchMap.get(newChild.getNumber())){ if(MathUtils.nextInt(2)==0){ newiPCase = branchMap.get(newParent.getNumber()); } else { newiPCase = newChildCase; } if(resampleInfectionTimes) { //whichever we picked for iP, it's the new child's case whose infection branch is modified // (even if this infection branch is iP's branch) newChildCase.setInfectionBranchPosition(MathUtils.nextDouble()); } logq += Math.log(2); } else { //upward, so don't have to worry about newParent being the root if the topology changed newiPCase = newChildCase; } branchMap.set(iP.getNumber(), newiPCase, true); } else { tree.setNodeHeight(iP, newHeight); logq = 0.0; } } if (swapInRandomRate) { final NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); if (j != i) { final double tmp = tree.getNodeRate(i); tree.setNodeRate(i, tree.getNodeRate(j)); tree.setNodeRate(j, tmp); } } if (swapInRandomTrait) { final NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); if (j != i) { tree.swapAllTraits(i, j); // final double tmp = tree.getNodeTrait(i, TRAIT); // tree.setNodeTrait(i, TRAIT, tree.getNodeTrait(j, TRAIT)); // tree.setNodeTrait(j, TRAIT, tmp); } } // just return -Inf: will reject move // if (logq == Double.NEGATIVE_INFINITY) throw new OperatorFailedException("invalid slide"); if (DEBUG) { c2cLikelihood.getTreeModel().checkPartitions(); c2cLikelihood.outputTreeToFile("afterTSSB.nex", false); } return logq; } private double getDelta() { if (!gaussian) { return (MathUtils.nextDouble() * size) - (size / 2.0); } else { return MathUtils.nextGaussian() * size; } } private boolean eligibleForMove(NodeRef node, TreeModel tree, BranchMapModel branchMap){ // to be eligible for this move, the node's parent must exist and be in a different partition to itself. This // forces the parent to be in the same partition as either its grandchild or its child. return (!tree.isRoot(node) && branchMap.get(tree.getParent(node).getNumber()) !=branchMap.get(node.getNumber())); } //intersectingEdges is the same as in normal STS, since there's no additional restriction in this case on where // nodes can go, and the move does not modify eligibility for itself private int intersectingEdges(Tree tree, NodeRef node, double height, List<NodeRef> directChildren) { final NodeRef parent = tree.getParent(node); if (tree.getNodeHeight(parent) < height) return 0; if (tree.getNodeHeight(node) < height) { if (directChildren != null) directChildren.add(node); return 1; } int count = 0; for (int i = 0; i < tree.getChildCount(node); i++) { count += intersectingEdges(tree, tree.getChild(node, i), height, directChildren); } return count; } public double getSize() { return size; } public void setSize(double size) { this.size = size; } public double getCoercableParameter() { return Math.log(getSize()); } public void setCoercableParameter(double value) { setSize(Math.exp(value)); } public double getRawParameter() { return getSize(); } public CoercionMode getMode() { return mode; } public String getPerformanceSuggestion() { return "not implemented"; } public String getOperatorName() { return TRANSMISSION_SUBTREE_SLIDE_B + " (" + tree.getId() + ")"; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public static final String RESAMPLE_INFECTION_TIMES = "resampleInfectionTimes"; public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean swapRates = xo.getAttribute(SWAP_RATES, false); boolean swapTraits = xo.getAttribute(SWAP_TRAITS, false); CoercionMode mode = CoercionMode.DEFAULT; if (xo.hasAttribute(CoercableMCMCOperator.AUTO_OPTIMIZE)) { if (xo.getBooleanAttribute(CoercableMCMCOperator.AUTO_OPTIMIZE)) { mode = CoercionMode.COERCION_ON; } else { mode = CoercionMode.COERCION_OFF; } } CaseToCaseTreeLikelihood c2cL = (CaseToCaseTreeLikelihood)xo.getChild(CaseToCaseTreeLikelihood.class); final double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); final double targetAcceptance = xo.getAttribute(TARGET_ACCEPTANCE, 0.234); final double size = xo.getAttribute("size", 1.0); if (Double.isInfinite(size) || size <= 0.0) { throw new XMLParseException("size attribute must be positive and not infinite. was " + size + " for tree " + c2cL.getTreeModel().getId() ); } boolean resampleInfectionTimes = false; if(xo.hasAttribute(RESAMPLE_INFECTION_TIMES)) { resampleInfectionTimes = xo.getBooleanAttribute(RESAMPLE_INFECTION_TIMES); } final boolean gaussian = xo.getBooleanAttribute("gaussian"); TransmissionSubtreeSlideB operator = new TransmissionSubtreeSlideB(c2cL, weight, size, gaussian, swapRates, swapTraits, mode, resampleInfectionTimes); operator.setTargetAcceptanceProbability(targetAcceptance); return operator; } public String getParserDescription() { return "An operator that slides a phylogenetic subtree and a transmission subtree simultaneously."; } public Class getReturnType() { return TransmissionSubtreeSlideA.class; } public String getParserName() { return TRANSMISSION_SUBTREE_SLIDE_B; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), // Make size optional. If not given or equals zero, size is set to half of average tree branch length. AttributeRule.newDoubleRule("size", true), AttributeRule.newDoubleRule(TARGET_ACCEPTANCE, true), AttributeRule.newBooleanRule("gaussian"), AttributeRule.newBooleanRule(SWAP_RATES, true), AttributeRule.newBooleanRule(SWAP_TRAITS, true), AttributeRule.newBooleanRule(CoercableMCMCOperator.AUTO_OPTIMIZE, true), AttributeRule.newBooleanRule(RESAMPLE_INFECTION_TIMES, true), new ElementRule(CaseToCaseTreeLikelihood.class) }; }; }