/* * ARGSubtreeSlideOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /* * RGSubtreeSlideOperator.java * * (c) 2002-2005 BEAST Development Core Team * * This package may be distributed under the * Lesser Gnu Public Licence (LGPL) */ package dr.evomodel.arg.operators; import dr.evolution.tree.MutableTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.arg.ARGModel; import dr.evomodel.arg.ARGModel.Node; import dr.evomodel.operators.SubtreeSlideOperator; import dr.inference.operators.*; import dr.math.MathUtils; import dr.xml.*; import java.util.ArrayList; /** * Implements the subtree slide move. * * @author Alexei Drummond * @version $Id: ARGSubtreeSlideOperator.java,v 1.1.2.2 2006/11/06 01:38:30 msuchard Exp $ */ public class ARGSubtreeSlideOperator extends AbstractCoercableOperator { // SimpleMCMCOperator implements CoercableMCMCOperator { public static final String SUBTREE_SLIDE = "argSubtreeSlide"; // public static final String public static final String SWAP_RATES = "swapRates"; public static final String SWAP_TRAITS = "swapTraits"; public static final String DIRICHLET_BRANCHES = "branchesAreScaledDirichlet"; private ARGModel tree = null; private double size = 1.0; private boolean gaussian = false; private boolean swapRates; private boolean swapTraits; private boolean scaledDirichletBranches; // private int mode = CoercableMCMCOperator.DEFAULT; // CoercionMode model; public ARGSubtreeSlideOperator(ARGModel tree, int weight, double size, boolean gaussian, boolean swapRates, boolean swapTraits, boolean scaledDirichletBranches, CoercionMode mode) { super(mode); this.tree = tree; setWeight(weight); this.size = size; this.gaussian = gaussian; this.swapRates = swapRates; this.swapTraits = swapTraits; this.scaledDirichletBranches = scaledDirichletBranches; // this.mode = mode; } public void sanityCheck() { int len = tree.getNodeCount(); for (int i = 0; i < len; i++) { Node node = (Node) tree.getNode(i); if (node.bifurcation) { boolean equalChild = (node.leftChild == node.rightChild); if ((equalChild && node.leftChild != null)) { if (!node.leftChild.bifurcation && ((node.leftChild).leftParent == node)) ; else { System.err.println("Node " + (i + 1) + " is insane."); System.err.println(tree.toGraphString()); System.exit(-1); } } } else { if ((node.leftChild != node.rightChild)) { System.err.println("Node " + (i + 1) + " is insane."); System.err.println(tree.toGraphString()); System.exit(-1); } } } } /** * Do a probablistic subtree slide move. * * @return the log-transformed hastings ratio */ public double doOperation() { // System.err.println("Starting Subtree Slide Operation."); double logq = 0; double oldTreeHeight = tree.getNodeHeight(tree.getRoot()); NodeRef i, newParent, newChild; // 1. choose a random node avoiding root ArrayList<NodeRef> potentialSubtrees = new ArrayList<NodeRef>(); int numPotentialSubtrees = this.getSlideableSubtrees(tree, potentialSubtrees); // System.err.println("Slide:\n"+tree.toGraphString()); i = potentialSubtrees.get(MathUtils.nextInt(numPotentialSubtrees)); // logq = - Math.log(numPotentialSubtrees); NodeRef iP = tree.getParent(i); // TODO Start rewriting here. NodeRef CiP = getOtherChild(tree, iP, i); NodeRef PiP; if (tree.isBifurcation(iP)) PiP = tree.getParent(iP); else { PiP = tree.getParent(iP, MathUtils.nextInt(2)); logq -= Math.log(2); // TODO check if really necessary. } // 2. choose a delta to move double delta = getDelta(); double oldHeight = tree.getNodeHeight(iP); double newHeight = oldHeight + delta; //newHeight = tree.getNodeHeight(tree.getRoot()) + delta; // 3. if the move is up if (delta > 0) { // 3.1 if the topology will change if (PiP != null && tree.getNodeHeight(PiP) < newHeight) { // find new parent newParent = PiP; newChild = iP; while (tree.getNodeHeight(newParent) < newHeight) { newChild = newParent; if (tree.isBifurcation(newParent)) newParent = tree.getParent(newParent); else { newParent = tree.getParent(newParent, MathUtils.nextInt(2)); logq -= Math.log(2); // TODO check if correct. } if (newParent == null) break; } //System.err.println("No problem climbing"); // logq += Math.log(numPotentialSubtrees); tree.beginTreeEdit(); // 3.1.1 if creating a new root if (tree.isRoot(newChild)) { if (true) { tree.endTreeEdit(); try { tree.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException e) { e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. } throw new RuntimeException("Temporarily disable re-rooting"); } //Parameter rootParameter = ((Node)newChild).heightParameter; //Parameter otherParameter = ((Node)iP).heightParameter; //tree.swapHeightParameters(newChild,iP); tree.doubleRemoveChild(iP, CiP); // iP is always bifurcation tree.doubleRemoveChild(PiP, iP); // PiP can be reassortment tree.doubleAddChild(iP, newChild); // iP and newChild are always bifurcations, i still connected to iP if (tree.isBifurcation(PiP)) tree.addChild(PiP, CiP); else tree.addChild(PiP, CiP); tree.setRoot(iP); //System.err.println("Creating new root!"); } // 3.1.2 no new root else { boolean doubly = tree.isBifurcationDoublyLinked(newParent); tree.doubleRemoveChild(iP, CiP); // iP bifurcation tree.doubleRemoveChild(PiP, iP); // PiP can be reassortment tree.doubleRemoveChild(newParent, newChild); //tree.doubleAddChild(iP, newChild); tree.doubleAddChild(PiP, CiP); if (!doubly) { tree.doubleAddChild(iP, newChild); tree.doubleAddChild(newParent, iP); } else { tree.singleAddChild(newParent, iP); tree.singleAddChildWithOneParent(iP, newChild); tree.singleAddChild(newParent, newChild); } // System.err.println("No new root!"); } // System.err.println("i ="+((Node)i).number); // System.err.println("iP ="+((Node)iP).number); // if( newParent != null ) // System.err.println("newParent ="+((Node)newParent).number+" "+tree.getNodeHeight(newParent)); // else // System.err.println("newParent is above root"); // System.err.println("newChild ="+((Node)newChild).number+" "+tree.getNodeHeight(newChild)); tree.setNodeHeight(iP, newHeight); // System.err.println("iP height ="+tree.getNodeHeight(iP)); // System.err.println("Intermediate slide up:\n"+tree.toGraphString()); tree.endTreeEdit(); try { tree.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException ite) { throw new RuntimeException(ite.toString()); } // 3.1.3 count the hypothetical sources of this destination. int possibleSources = intersectingEdges(tree, newChild, iP, oldHeight, null); // System.err.println("possible sources = " + possibleSources); logq -= Math.log(possibleSources); } else { // 3.2 // just change the node height tree.setNodeHeight(iP, newHeight); logq = 0.0; // TODO check is losing -Log(2) is correct. } // System.err.println("Sanity check up-slide"); // sanityCheck(); } // 4 if we are sliding the subtree down. else { // logq = 0; // 4.0 is it a valid move? if (tree.getNodeHeight(i) > newHeight) { return Double.NEGATIVE_INFINITY; } // 4.1 will the move change the topology if (tree.getNodeHeight(CiP) > newHeight) { // System.err.println("Starting down-slide:\n"+tree.toGraphString()); ArrayList<NodeRef[]> newChildren = new ArrayList<NodeRef[]>(); int possibleDestinations = intersectingEdges(tree, CiP, iP, newHeight, newChildren); // if no valid destinations then return a failure if (newChildren.size() == 0) //{ return Double.NEGATIVE_INFINITY; } throw new RuntimeException("no valid destinations"); // pick a random parent/child destination edge uniformly from options int childIndex = MathUtils.nextInt(newChildren.size()); NodeRef[] draw = (NodeRef[]) newChildren.get(childIndex); newChild = draw[1]; newParent = draw[0]; // ((NodeRef[])newChildren.get(childIndex))[1]; //int choice = MathUtils.nextInt(2); NodeRef oops = null; // if( tree.isBifurcation(newChild) ) // newParent = tree.getParent(newChild); // else { // //newParent = tree.getParent(newChild,MathUtils.nextInt(2)); // newParent = tree.getParent(newChild,choice); // if( tree.getNodeHeight(newParent) < newHeight ) // newParent = tree.getParent(newChild,1-choice); // oops = tree.getParent(newChild,1-choice); // System.err.println("Grabbing random child."); // logq -= Math.log(2); // TODO check ratio // } // //NodeRef oops //newParent = // logq += Math.log(possibleDestinations); tree.beginTreeEdit(); // 4.1.1 if iP was root if (tree.isRoot(iP)) { // new root is CiP, but root cannot be a reassortment if (!tree.isBifurcation(CiP)) throw new RuntimeException("root cannot be a reassortment"); boolean doubly = tree.isBifurcationDoublyLinked(newParent); tree.doubleRemoveChild(iP, CiP); tree.doubleRemoveChild(newParent, newChild); // System.err.println("Down-slide disconnect:\n"+tree.toGraphString()); if (tree.isBifurcation(newChild)) tree.doubleAddChild(iP, newChild); else tree.singleAddChildWithOneParent(iP, newChild); if (!doubly) tree.doubleAddChild(newParent, iP); else { tree.singleAddChild(newParent, iP); tree.singleAddChildWithOneParent(newParent, newChild); } //tree.doubleAddChild(newParent, iP); //tree.swapHeightParameters(iP,CiP); // TODO tree.setRoot(CiP); // System.err.println("DOWN: Creating new root!"); // TODO still not tested } else { boolean doubly = tree.isBifurcationDoublyLinked(newParent); // or is it CiP = newParent? tree.doubleRemoveChild(iP, CiP); tree.doubleRemoveChild(PiP, iP); tree.doubleRemoveChild(newParent, newChild); // System.err.println("Down-slide disconnect:\n"+tree.toGraphString()); if (tree.isBifurcation(newChild)) tree.doubleAddChild(iP, newChild); else tree.singleAddChildWithOneParent(iP, newChild); tree.doubleAddChild(PiP, CiP); if (!doubly) tree.doubleAddChild(newParent, iP); // Only works if iP is bifurcation else { tree.singleAddChild(newParent, iP); tree.singleAddChildWithOneParent(newParent, newChild); } // System.err.println("DOWN: no new root!"); /* if( true ) { try { tree.endTreeEdit(); } catch(MutableTree.InvalidTreeException ite) { throw new RuntimeException(ite.toString()); } throw new OperatorFailedException(""); } */ } tree.setNodeHeight(iP, newHeight); // System.err.println("i = " + ((Node) i).number); // System.err.println("iP = " + ((Node) iP).number); // System.err.println("CiP = " + ((Node) CiP).number); // System.err.println("newChild = " + ((Node) newChild).number + " " // + tree.getNodeHeight(newChild)); // System.err.println("newParent = " + ((Node) newParent).number + " " // + tree.getNodeHeight(newParent)); // if (oops != null) // System.err.println("oops = " + ((Node) oops).number + " " // + tree.getNodeHeight(oops)); // System.err.println("newHeight = " + newHeight); // System.err.println("After slide down:\n"+tree.toGraphString()); tree.endTreeEdit(); try { tree.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException ite) { throw new RuntimeException(ite.toString()); } //logq = -Math.log((double) possibleDestinations); logq += Math.log((double) possibleDestinations); } else { try { tree.setNodeHeight(iP, newHeight); } catch (Exception e) { // System.err.println("iP =" + ((Node) iP).number); // System.err.println("newHeight =" + newHeight); //System.exit(-1); } logq = 0.0; } // System.err.println("After slide down:\n"+tree.toGraphString()); // System.err.println("Sanity check in down-slide."); // sanityCheck(); } /* if (swapRates) { NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); if (j != i) { double tmp = tree.getNodeRate(i); tree.setNodeRate(i, tree.getNodeRate(j)); tree.setNodeRate(j, tmp); } } if (swapTraits) { NodeRef j = tree.getNode(MathUtils.nextInt(tree.getNodeCount())); if (j != i) { double tmp = tree.getNodeTrait(i); tree.setNodeTrait(i, tree.getNodeTrait(j)); tree.setNodeTrait(j, tmp); } } */ // todo fix // Reject all trees in which the root is doubly- if (tree.isBifurcationDoublyLinked(tree.getRoot())) throw new RuntimeException("invalid slide"); // todo -- check all ARGTree.Roots if (!tree.validRoot()) throw new RuntimeException("Roots are invalid"); if (logq == Double.NEGATIVE_INFINITY) throw new RuntimeException("invalid slide"); // System.err.println("Ending Subtree Slide Operation."); //System.err.println("logq = "+logq); // logq = 0; if (scaledDirichletBranches) { if (oldTreeHeight != tree.getNodeHeight(tree.getRoot())) throw new RuntimeException("Temporarily disabled."); // TODO calculate Hastings ratio } return logq; } private double getDelta() { if (!gaussian) { return (MathUtils.nextDouble() * size) - (size / 2.0); } else { return MathUtils.nextGaussian() * size; } } private int getSlideableSubtrees(ARGModel tree, ArrayList<NodeRef> potentials) { int count = 0; for (int i = 0, n = tree.getNodeCount(); i < n; i++) { NodeRef node = tree.getNode(i); if (!tree.isRoot(node) && tree.isBifurcation(node) && tree.isBifurcation(tree.getParent(node))) { if (potentials != null) potentials.add(node); count++; } } return count; } private int intersectingEdges(ARGModel tree, NodeRef node, NodeRef parent, double height, ArrayList<NodeRef[]> directChildren) { // if( tree.isBifurcation(node) ) { //ree.getParent(node); tree.getMinParentNodeHeight(node); //NodeRef parent = tree.getParent(node); if (tree.getNodeHeight(parent) < height) return 0; if (tree.getNodeHeight(node) < height) { if (directChildren != null) { NodeRef[] addition = new NodeRef[2]; addition[0] = parent; addition[1] = node; directChildren.add(addition); } //directChildren.add(node); return 1; } int count = 0; //for (int i = 0; i < tree.getChildCount(node); i++) { count += intersectingEdges(tree, tree.getChild(node, 0), node, height, directChildren); if (tree.isBifurcation(node)) count += intersectingEdges(tree, tree.getChild(node, 1), node, height, directChildren); return count; } // Handle reassortment nodes // // Only the parental edge which routines flows down is important // NodeRef parent0 = tree.getParent(node,0); // NodeRef parent1 = tree.getParent(node,1); // double thisHeight = tree.getNodeHeight(node); // double height0 = tree.getNodeHeight(parent0); // double height1 = tree.getNodeHeight(parent1); // // if( (height0 < height) && (height1 < height) ) return 0; // int count = 0; // if( (thisHeight < height) && (height0 >= height) ) { // if( directChildren !=null ) { // NodeRef[] addition = new NodeRef[2]; // addition[0] = parent0; // addition[1] = node; // directChildren.add(addition); // } // //directChildren.add(node); // count += 1; // } // if( (thisHeight < height) && (height1 >= height) ) { // if( directChildren != null ) { // NodeRef[] addition = new NodeRef[2]; // addition[0] = parent1; // addition[1] = node; // directChildren.add(addition); // } // //directChildren.add(node); // count += 1; // } // if( count > 0 ) // return count; // // return intersectingEdges(tree, tree.getChild(node,0), height, directChildren); // } /** * @return the other child of the given parent. */ private NodeRef getOtherChild(Tree tree, NodeRef parent, NodeRef child) { if (tree.getChild(parent, 0) == child) { return tree.getChild(parent, 1); } else { return tree.getChild(parent, 0); } } public double getSize() { return size; } public void setSize(double size) { this.size = size; } public double getCoercableParameter() { return Math.log(getSize()); } public void setCoercableParameter(double value) { setSize(Math.exp(value)); } public double getRawParameter() { return getSize(); } // public int getMode() { // return mode; // } public double getTargetAcceptanceProbability() { return 0.234; } public String getPerformanceSuggestion() { double prob = MCMCOperator.Utils.getAcceptanceProbability(this); double targetProb = getTargetAcceptanceProbability(); double ws = OperatorUtils.optimizeWindowSize(getSize(), Double.MAX_VALUE, prob, targetProb); if (prob < getMinimumGoodAcceptanceLevel()) { return "Try decreasing size to about " + ws; } else if (prob > getMaximumGoodAcceptanceLevel()) { return "Try increasing size to about " + ws; } else return ""; } public String getOperatorName() { return SUBTREE_SLIDE; } public static dr.xml.XMLObjectParser PARSER = new dr.xml.AbstractXMLObjectParser() { public String getParserName() { return SUBTREE_SLIDE; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean swapRates = false; boolean swapTraits = false; boolean scaledDirichletBranches = false; CoercionMode mode = CoercionMode.parseMode(xo); if (xo.hasAttribute(SWAP_RATES)) { swapRates = xo.getBooleanAttribute(SWAP_RATES); } if (xo.hasAttribute(SWAP_TRAITS)) { swapTraits = xo.getBooleanAttribute(SWAP_TRAITS); } if (xo.hasAttribute(DIRICHLET_BRANCHES)) { scaledDirichletBranches = xo.getBooleanAttribute(DIRICHLET_BRANCHES); } ARGModel treeModel = (ARGModel) xo.getChild(ARGModel.class); int weight = xo.getIntegerAttribute("weight"); double size = xo.getDoubleAttribute("size"); boolean gaussian = xo.getBooleanAttribute("gaussian"); return new ARGSubtreeSlideOperator(treeModel, weight, size, gaussian, swapRates, swapTraits, scaledDirichletBranches, mode); } public String getParserDescription() { return "An operator that slides a subtree."; } public Class getReturnType() { return SubtreeSlideOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newIntegerRule("weight"), AttributeRule.newDoubleRule("size"), AttributeRule.newBooleanRule("gaussian"), AttributeRule.newBooleanRule(SWAP_RATES, true), AttributeRule.newBooleanRule(SWAP_TRAITS, true), AttributeRule.newBooleanRule(AUTO_OPTIMIZE, true), new ElementRule(ARGModel.class) }; }; }