/* * TransmissionWilsonBaldingA.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.epidemiology.casetocase.operators; import dr.evolution.tree.NodeRef; import dr.evomodel.epidemiology.casetocase.AbstractCase; import dr.evomodel.epidemiology.casetocase.BranchMapModel; import dr.evomodel.epidemiology.casetocase.CaseToCaseTreeLikelihood; import dr.evomodel.epidemiology.casetocase.PartitionedTreeModel; import dr.evomodel.operators.AbstractTreeOperator; import dr.evomodel.tree.TreeModel; import dr.inference.operators.MCMCOperator; import dr.math.MathUtils; import dr.xml.*; import java.util.ArrayList; import java.util.HashSet; /** * Implements the Wilson-Balding branch swapping move if it does not change the transmission tree. * * @author Matthew Hall */ public class TransmissionWilsonBaldingA extends AbstractTreeOperator { private final CaseToCaseTreeLikelihood c2cLikelihood; public static final String TRANSMISSION_WILSON_BALDING_A = "transmissionWilsonBaldingA"; private double logq; private static final boolean DEBUG = false; private final int tipCount; private final boolean resampleInfectionTimes; public TransmissionWilsonBaldingA(CaseToCaseTreeLikelihood c2cLikelihood, double weight, boolean resampleInfectionTimes) { this.c2cLikelihood = c2cLikelihood; setWeight(weight); tipCount = c2cLikelihood.getTreeModel().getExternalNodeCount(); this.resampleInfectionTimes = resampleInfectionTimes; } public double doOperation() { proposeTree(); if (c2cLikelihood.getTreeModel().getExternalNodeCount() != tipCount) { int newCount = c2cLikelihood.getTreeModel().getExternalNodeCount(); throw new RuntimeException("Lost some tips in modified SPR! (" + tipCount + "-> " + newCount + ")"); } return logq; } public void proposeTree() { PartitionedTreeModel tree = c2cLikelihood.getTreeModel(); BranchMapModel branchMap = c2cLikelihood.getBranchMap(); NodeRef i; double oldMinAge, newMinAge, newRange, oldRange, newAge, q; // choose a random node avoiding root, and nodes that are ineligible for this move because they have nowhere to // go ArrayList<NodeRef> eligibleNodes = getEligibleNodes(tree, branchMap); i = eligibleNodes.get(MathUtils.nextInt(eligibleNodes.size())); double eligibleNodeCount = eligibleNodes.size(); final NodeRef iP = tree.getParent(i); Integer[] sameElements = tree.samePartitionElement(iP); HashSet<Integer> possibleDestinations = new HashSet<Integer>(); // we can insert the node above OR BELOW any node in the same partition for (Integer sameElement : sameElements) { possibleDestinations.add(sameElement); if (!tree.isExternal(tree.getNode(sameElement))) { possibleDestinations.add(tree.getChild(tree.getNode(sameElement), 0).getNumber()); possibleDestinations.add(tree.getChild(tree.getNode(sameElement), 1).getNumber()); } } Integer[] pd = possibleDestinations.toArray(new Integer[possibleDestinations.size()]); NodeRef j = tree.getNode(pd[MathUtils.nextInt(pd.length)]); NodeRef jP = tree.getParent(j); while ((jP != null && (tree.getNodeHeight(jP) <= tree.getNodeHeight(i))) || (i == j)) { j = tree.getNode(pd[MathUtils.nextInt(pd.length)]); jP = tree.getParent(j); } if (iP == tree.getRoot() || j == tree.getRoot()) { throw new RuntimeException("Root changes not allowed!"); } if (jP == iP || j == iP || jP == i) throw new RuntimeException("Move failed"); final NodeRef CiP = getOtherChild(tree, iP, i); NodeRef PiP = tree.getParent(iP); if(resampleInfectionTimes) { AbstractCase iCase = branchMap.get(i.getNumber()); AbstractCase iPCase = branchMap.get(iP.getNumber()); AbstractCase CiPCase = branchMap.get(CiP.getNumber()); AbstractCase PiPCase = null; if(PiP!=null){ PiPCase = branchMap.get(PiP.getNumber()); } // what happens on i's branch if (iCase != iPCase) { iCase.setInfectionBranchPosition(MathUtils.nextDouble()); } // what happens between PiP and CiP if (PiPCase == null || CiPCase != PiPCase) { CiPCase.setInfectionBranchPosition(MathUtils.nextDouble()); } // what happens between k and j AbstractCase jCase = branchMap.get(j.getNumber()); AbstractCase kCase = branchMap.get(jP.getNumber()); if(iPCase != jCase && iPCase != kCase){ throw new RuntimeException("TWBA misbehaving."); } jCase.setInfectionBranchPosition(MathUtils.nextDouble()); } newMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(j)); newRange = tree.getNodeHeight(jP) - newMinAge; newAge = newMinAge + (MathUtils.nextDouble() * newRange); oldMinAge = Math.max(tree.getNodeHeight(i), tree.getNodeHeight(CiP)); oldRange = tree.getNodeHeight(PiP) - oldMinAge; q = newRange / Math.abs(oldRange); tree.beginTreeEdit(); if (j == tree.getRoot()) { // 1. remove edges <iP, CiP> tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(PiP, CiP); // iP is the new root tree.setRoot(iP); } else if (iP == tree.getRoot()) { // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP> tree.removeChild(jP, j); tree.removeChild(iP, CiP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(jP, iP); //CiP is the new root tree.setRoot(CiP); } else { // 1. remove edges <k, j>, <iP, CiP>, <PiP, iP> tree.removeChild(jP, j); tree.removeChild(iP, CiP); tree.removeChild(PiP, iP); // 2. add edges <k, iP>, <iP, j>, <PiP, CiP> tree.addChild(iP, j); tree.addChild(jP, iP); tree.addChild(PiP, CiP); } tree.setNodeHeight(iP, newAge); tree.endTreeEdit(); if(DEBUG){ c2cLikelihood.getTreeModel().checkPartitions(); } logq = Math.log(q); double reverseEligibleNodeCount = getEligibleNodes(tree, branchMap).size(); logq += Math.log(eligibleNodeCount/reverseEligibleNodeCount); } public String getPerformanceSuggestion() { return "Not implemented"; } private boolean eligibleForMove(NodeRef node, TreeModel tree, BranchMapModel branchMap){ // to be eligible for this move, the node's parent and grandparent, or parent and other child, must be in the // same partition (so removing the parent has no effect on the transmission tree) return (!tree.isRoot(node) && ((tree.getParent(tree.getParent(node))!=null && branchMap.get(tree.getParent(node).getNumber()) ==branchMap.get(tree.getParent(tree.getParent(node)).getNumber())) || branchMap.get(tree.getParent(node).getNumber())==branchMap.get(getOtherChild(tree, tree.getParent(node), node).getNumber()))); } private ArrayList<NodeRef> getEligibleNodes(TreeModel tree, BranchMapModel branchMap){ ArrayList<NodeRef> out = new ArrayList<NodeRef>(); for(NodeRef node : tree.getNodes()){ if(eligibleForMove(node, tree, branchMap)){ out.add(node); } } return out; } public String getOperatorName() { return TRANSMISSION_WILSON_BALDING_A + " (" + c2cLikelihood.getTreeModel().getId() +")"; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public static final String RESAMPLE_INFECTION_TIMES = "resampleInfectionTimes"; public String getParserName() { return TRANSMISSION_WILSON_BALDING_A; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { final CaseToCaseTreeLikelihood c2cL = (CaseToCaseTreeLikelihood) xo.getChild(CaseToCaseTreeLikelihood.class); final double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); boolean resampleInfectionTimes = false; if(xo.hasAttribute(RESAMPLE_INFECTION_TIMES)) { resampleInfectionTimes = xo.getBooleanAttribute(RESAMPLE_INFECTION_TIMES); } return new TransmissionWilsonBaldingA(c2cL, weight, resampleInfectionTimes); } // ************************************************************************ // AbstractXMLObjectParser implementation // ************************************************************************ public String getParserDescription(){ return "This element represents a Wilson-Balding move operator, such that the transplantation of the " + "subtree does not affect the topology of the transmission tree."; } public Class getReturnType(){ return TransmissionWilsonBaldingA.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules;{ rules = new XMLSyntaxRule[]{ AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), AttributeRule.newBooleanRule(RESAMPLE_INFECTION_TIMES, true), new ElementRule(CaseToCaseTreeLikelihood.class) }; } }; }