/* * ARGSwapOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.arg.operators; import dr.evolution.tree.MutableTree; import dr.evolution.tree.NodeRef; import dr.evomodel.arg.ARGModel; import dr.evomodel.arg.ARGModel.Node; import dr.inference.operators.SimpleMCMCOperator; import dr.math.MathUtils; import dr.xml.*; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.logging.Logger; /** * This method moves the arg model around. Use of both the * reassortment and bifurcation modes, as well as the event operator, * satisfies irreducibility. * * @author ebloomqu * @author Marc A. Suchard */ public class ARGSwapOperator extends SimpleMCMCOperator { public static final String ARG_SWAP_OPERATOR = "argSwapOperator"; public static final String SWAP_TYPE = "type"; public static final String BIFURCATION_SWAP = "bifurcationSwap"; public static final String REASSORTMENT_SWAP = "reassortmentSwap"; public static final String DUAL_SWAP = "dualSwap"; public static final String FULL_SWAP = "fullSwap"; public static final String NARROW_SWAP = "narrowSwap"; private ARGModel arg; private String mode; public ARGSwapOperator(ARGModel arg, String mode, int weight) { this.arg = arg; this.mode = mode; setWeight(weight); } public double doOperation() { if (mode.equals(NARROW_SWAP)) { return narrowSwap(); } if ((mode.equals(REASSORTMENT_SWAP) || mode.equals(DUAL_SWAP)) && arg.getReassortmentNodeCount() == 0) { return 0.0; } ArrayList<NodeRef> bifurcationNodes = new ArrayList<NodeRef>(arg.getNodeCount()); ArrayList<NodeRef> reassortmentNodes = new ArrayList<NodeRef>(arg.getNodeCount()); setupBifurcationNodes(bifurcationNodes); setupReassortmentNodes(reassortmentNodes); if (mode.equals(BIFURCATION_SWAP)) { return bifurcationSwap(bifurcationNodes.get(MathUtils.nextInt(bifurcationNodes.size()))); } else if (mode.equals(REASSORTMENT_SWAP)) { return reassortmentSwap(reassortmentNodes.get(MathUtils.nextInt(reassortmentNodes.size()))); } else if (mode.equals(DUAL_SWAP)) { reassortmentSwap(reassortmentNodes.get(MathUtils.nextInt(reassortmentNodes.size()))); return bifurcationSwap(bifurcationNodes.get(MathUtils.nextInt(bifurcationNodes.size()))); } bifurcationNodes.addAll(reassortmentNodes); Collections.sort(bifurcationNodes, NodeSorter); for (NodeRef x : bifurcationNodes) { if (arg.isBifurcation(x)) bifurcationSwap(x); else reassortmentSwap(x); } return 0; } private double narrowSwap() { ArrayList<NarrowSwap> possibleSwaps = new ArrayList<NarrowSwap>(arg.getNodeCount()); findAllNarrowSwaps(possibleSwaps); int possibleSwapsBefore = possibleSwaps.size(); if (possibleSwapsBefore == 0) return 0; doNarrowSwap(possibleSwaps.get(MathUtils.nextInt(possibleSwaps.size()))); possibleSwaps.clear(); findAllNarrowSwaps(possibleSwaps); return Math.log((double) possibleSwapsBefore / possibleSwaps.size()); } public int findAllNarrowSwaps(ArrayList<NarrowSwap> moves) { for (int i = 0, n = arg.getInternalNodeCount(); i < n; i++) { Node x = (Node) arg.getInternalNode(i); if (x.bifurcation && !x.isRoot() && x.leftParent.bifurcation) { NarrowSwap a = new NarrowSwap(x.leftChild, x, x.leftParent); NarrowSwap b = new NarrowSwap(x.rightChild, x, x.leftParent); if (a.isValid()) moves.add(a); if (b.isValid()) moves.add(b); } } return moves.size(); } private void doNarrowSwap(NarrowSwap swap) { arg.beginTreeEdit(); String before = arg.toARGSummary(); if (swap.c == swap.pb) { Node c = (Node) swap.c; Node p = (Node) swap.p; Node gp = (Node) swap.gp; if (c.leftParent == p) { c.leftParent = gp; c.rightParent = p; } else { c.leftParent = p; c.rightParent = gp; } } else if (arg.getChild(swap.p, 0) == arg.getChild(swap.p, 1)) { Node p = (Node) swap.p; Node c = (Node) swap.c; if (MathUtils.nextBoolean()) p.leftChild = c.leftParent = null; else p.rightChild = c.rightParent = null; arg.removeChild(swap.gp, swap.pb); arg.singleAddChild(swap.gp, swap.c); arg.singleAddChild(swap.p, swap.pb); } else { arg.removeChild(swap.gp, swap.pb); arg.removeChild(swap.p, swap.c); arg.singleAddChild(swap.gp, swap.c); arg.singleAddChild(swap.p, swap.pb); } assert nodeCheck() : swap + " " + before + " " + arg.toARGSummary(); arg.pushTreeChangedEvent(swap.gp); arg.pushTreeChangedEvent(swap.p); arg.endTreeEdit(); try { arg.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException ite) { System.out.println(swap); System.out.println(before); System.err.println(ite.getMessage()); System.exit(-1); } catch (NullPointerException e) { System.out.println(swap); System.out.println(before); System.err.println(e.getMessage()); System.exit(-1); } } private class NarrowSwap { public NodeRef c; public NodeRef p; public NodeRef gp; public NodeRef pb; public NarrowSwap(NodeRef child, NodeRef parent, NodeRef gParent) { c = child; p = parent; gp = gParent; pb = arg.getOtherChild(gParent, parent); } public boolean isValid() { // if (arg.getNodeHeight(pb) < arg.getNodeHeight(p)) // return true; // return false; return (arg.getNodeHeight(pb) < arg.getNodeHeight(p)); } public String toString() { return "Child: " + c.toString() + ", Parent: " + p.toString() + ", G-parent: " + gp.toString() + ", P-brother: " + pb.toString(); } } private double bifurcationSwap(NodeRef x) { Node startNode = (Node) x; // Node keepChild = startNode.leftChild; Node moveChild = startNode.rightChild; if (MathUtils.nextBoolean()) { // keepChild = moveChild; moveChild = startNode.leftChild; } ArrayList<NodeRef> possibleNodes = new ArrayList<NodeRef>(arg.getNodeCount()); findNodesAtHeight(possibleNodes, startNode.getHeight()); assert !possibleNodes.contains(startNode); assert possibleNodes.size() > 0; Node swapNode = (Node) possibleNodes.get(MathUtils.nextInt(possibleNodes.size())); Node swapNodeParent = swapNode.leftParent; arg.beginTreeEdit(); String before = arg.toARGSummary(); if (swapNode.bifurcation) { swapNodeParent = swapNode.leftParent; arg.singleRemoveChild(startNode, moveChild); if (swapNodeParent.bifurcation) { arg.singleRemoveChild(swapNodeParent, swapNode); arg.singleAddChild(swapNodeParent, moveChild); } else { arg.doubleRemoveChild(swapNodeParent, swapNode); arg.doubleAddChild(swapNodeParent, moveChild); } arg.singleAddChild(startNode, swapNode); } else { boolean leftSide = true; boolean[] sideOk = {swapNode.leftParent.getHeight() > startNode.getHeight(), swapNode.rightParent.getHeight() > startNode.getHeight()}; if (sideOk[0] && sideOk[1]) { if (MathUtils.nextBoolean()) { swapNodeParent = swapNode.rightParent; leftSide = false; } } else if (sideOk[1]) { swapNodeParent = swapNode.rightParent; leftSide = false; } //Double linked parents if (swapNode.leftParent == swapNode.rightParent) { arg.singleRemoveChild(startNode, moveChild); if (leftSide) { swapNode.leftParent = null; swapNodeParent.leftChild = null; } else { swapNode.rightParent = null; swapNodeParent.rightChild = null; } arg.singleAddChild(startNode, swapNode); arg.singleAddChild(swapNodeParent, moveChild); } else if (swapNode.leftParent == startNode || swapNode.rightParent == startNode) { arg.singleRemoveChild(startNode, moveChild); if (swapNodeParent.bifurcation) { arg.singleRemoveChild(swapNodeParent, swapNode); arg.singleAddChild(swapNodeParent, moveChild); } else { arg.doubleRemoveChild(swapNodeParent, swapNode); arg.doubleAddChild(swapNodeParent, moveChild); } if (startNode.leftChild == null) startNode.leftChild = swapNode; else startNode.rightChild = swapNode; if (swapNode.leftParent == null) swapNode.leftParent = startNode; else swapNode.rightParent = startNode; } else { arg.singleRemoveChild(startNode, moveChild); if (swapNodeParent.bifurcation) { arg.singleRemoveChild(swapNodeParent, swapNode); arg.singleAddChild(swapNodeParent, moveChild); } else { arg.doubleRemoveChild(swapNodeParent, swapNode); arg.doubleAddChild(swapNodeParent, moveChild); } arg.singleAddChild(startNode, swapNode); } } arg.pushTreeChangedEvent(); // TODO Send only changed nodes // arg.pushTreeChangedEvent(startNode); // arg.pushTreeChangedEvent(swapNodeParent); assert nodeCheck(); arg.endTreeEdit(); try { arg.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException ite) { System.out.println(before); System.err.println(ite.getMessage()); System.exit(-1); } return 0; } private double reassortmentSwap(NodeRef x) { Node startNode = (Node) x; Node startChild = startNode.leftChild; ArrayList<NodeRef> possibleNodes = new ArrayList<NodeRef>(arg.getNodeCount()); findNodesAtHeight(possibleNodes, startNode.getHeight()); assert !possibleNodes.contains(startNode); assert possibleNodes.size() > 0; Node swapNode = (Node) possibleNodes.get(MathUtils.nextInt(possibleNodes.size())); Node swapParent; arg.beginTreeEdit(); if (swapNode.bifurcation) { swapParent = swapNode.leftParent; arg.doubleRemoveChild(startNode, startChild); if (swapParent.bifurcation) arg.singleRemoveChild(swapParent, swapNode); else arg.doubleRemoveChild(swapParent, swapNode); arg.doubleAddChild(startNode, swapNode); if (startChild.bifurcation) { startChild.leftParent = swapParent; startChild.rightParent = swapParent; } else { if (startChild.leftParent == null) { startChild.leftParent = swapParent; } else { startChild.rightParent = swapParent; } } if (!swapParent.bifurcation) { swapParent.leftChild = startChild; swapParent.rightChild = startChild; } else { if (swapParent.leftChild == null) { swapParent.leftChild = startChild; } else { swapParent.rightChild = startChild; } } } else { boolean leftSide = true; boolean[] sideOk = {swapNode.leftParent.getHeight() > startNode.getHeight(), swapNode.rightParent.getHeight() > startNode.getHeight()}; swapParent = swapNode.leftParent; if (sideOk[0] && sideOk[1]) { if (MathUtils.nextBoolean()) { leftSide = false; swapParent = swapNode.rightParent; } } else if (sideOk[1]) { leftSide = false; swapParent = swapNode.rightParent; } if (swapNode.leftParent == swapNode.rightParent) { arg.doubleRemoveChild(startNode, startChild); if (leftSide) { swapParent.leftChild = swapNode.leftParent = null; swapParent.leftChild = startChild; swapNode.leftParent = startNode; } else { swapParent.rightChild = swapNode.rightParent = null; swapParent.rightChild = startChild; swapNode.rightParent = startNode; } startNode.leftChild = startNode.rightChild = swapNode; if (startChild.bifurcation) { startChild.leftParent = startChild.rightParent = swapParent; } else { if (startChild.leftParent == null) startChild.leftParent = swapParent; else startChild.rightParent = swapParent; } } else { arg.doubleRemoveChild(startNode, startChild); if (swapParent.bifurcation) arg.singleRemoveChild(swapParent, swapNode); else arg.doubleRemoveChild(swapParent, swapNode); startNode.leftChild = startNode.rightChild = swapNode; if (leftSide) swapNode.leftParent = startNode; else swapNode.rightParent = startNode; if (swapParent.bifurcation) { if (swapParent.leftChild == null) swapParent.leftChild = startChild; else swapParent.rightChild = startChild; } else { swapParent.leftChild = swapParent.rightChild = startChild; } if (startChild.bifurcation) { startChild.leftParent = startChild.rightParent = swapParent; } else { if (startChild.leftParent == null) startChild.leftParent = swapParent; else startChild.rightParent = swapParent; } } } arg.pushTreeChangedEvent(); // TODO Limit tree hit arg.endTreeEdit(); try { arg.checkTreeIsValid(); } catch (MutableTree.InvalidTreeException ite) { System.err.println(ite.getMessage()); System.exit(-1); } return 0; } private void setupBifurcationNodes(ArrayList<NodeRef> list) { for (int i = 0, n = arg.getNodeCount(); i < n; i++) { NodeRef x = arg.getNode(i); if (arg.isInternal(x) && arg.isBifurcation(x) && !arg.isRoot(x)) { list.add(x); } } } private void setupReassortmentNodes(ArrayList<NodeRef> list) { for (int i = 0, n = arg.getNodeCount(); i < n; i++) { NodeRef x = arg.getNode(i); if (arg.isReassortment(x)) { list.add(x); } } } private void findNodesAtHeight(ArrayList<NodeRef> x, double height) { for (int i = 0, n = arg.getNodeCount(); i < n; i++) { Node test = (Node) arg.getNode(i); if (test.getHeight() < height) { if (test.bifurcation) { if (test.leftParent.getHeight() > height) { x.add(test); } } else { if (test.leftParent.getHeight() > height) { x.add(test); } if (test.rightParent.getHeight() > height) { x.add(test); } } } } } public String getOperatorName() { return mode; } public String getPerformanceSuggestion() { return ""; } private Comparator<NodeRef> NodeSorter = new Comparator<NodeRef>() { public int compare(NodeRef o1, NodeRef o2) { double[] heights = {arg.getNodeHeight(o1), arg.getNodeHeight(o2)}; if (heights[0] < heights[1]) { return -1; } else if (heights[0] > heights[1]) { return 1; } return 0; } }; public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserDescription() { return "Swaps nodes on a tree"; } public Class getReturnType() { return ARGSwapOperator.class; } private String[] validFormats = {BIFURCATION_SWAP, REASSORTMENT_SWAP, DUAL_SWAP, FULL_SWAP, NARROW_SWAP}; private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ AttributeRule.newIntegerRule(WEIGHT), new StringAttributeRule(SWAP_TYPE, "The mode of the operator", validFormats, false), new ElementRule(ARGModel.class), }; public XMLSyntaxRule[] getSyntaxRules() { return rules; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { int weight = xo.getIntegerAttribute(WEIGHT); String mode = xo.getStringAttribute(SWAP_TYPE); Logger.getLogger("dr.evomodel").info("Creating ARGSwapOperator: " + mode); ARGModel arg = (ARGModel) xo.getChild(ARGModel.class); return new ARGSwapOperator(arg, mode, weight); } public String getParserName() { return ARG_SWAP_OPERATOR; } }; public boolean nodeCheck() { for (int i = 0, n = arg.getNodeCount(); i < n; i++) { Node x = (Node) arg.getNode(i); if (x.leftParent != x.rightParent && x.leftChild != x.rightChild) { return false; } if (x.leftParent != null) { if (x.leftParent.leftChild.getNumber() != i && x.leftParent.rightChild.getNumber() != i) return false; } if (x.rightParent != null) { if (x.rightParent.leftChild.getNumber() != i && x.rightParent.rightChild.getNumber() != i) return false; } if (x.leftChild != null) { if (x.leftChild.leftParent.getNumber() != i && x.leftChild.rightParent.getNumber() != i) return false; } if (x.rightChild != null) { if (x.rightChild.leftParent.getNumber() != i && x.rightChild.rightParent.getNumber() != i) return false; } } return true; } }