/* * ImportanceSubtreeSwap.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /** * */ package dr.evomodel.operators; import dr.evolution.tree.MutableTree.InvalidTreeException; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.ConditionalCladeFrequency; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.operators.ImportanceSubtreeSwapParser; import dr.inference.operators.*; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; /** * @author Sebastian Hoehna * <p/> * This class implements a subtree swap operator. The first subtree is * chosen randomly and the second one is chosen according to the * importance of the new tree. The importance are calculated by the * multiplied clade probabilities. */ // Cleaning out untouched stuff. Can be resurrected if needed @Deprecated public class ImportanceSubtreeSwap extends AbstractTreeOperator { public final int SAMPLE_EVERY = 10; private final TreeModel tree; private final int samples; private int sampleCount = 0; private boolean burnin = false; private final ConditionalCladeFrequency probabilityEstimater; private final OperatorSchedule schedule; /** * */ public ImportanceSubtreeSwap(TreeModel tree, double weight, int samples, int epsilon) { this.tree = tree; setWeight(weight); this.samples = samples; sampleCount = 0; probabilityEstimater = new ConditionalCladeFrequency(tree, epsilon); schedule = getOperatorSchedule(tree); } /** * */ public ImportanceSubtreeSwap(TreeModel tree, double weight, int samples) { this.tree = tree; setWeight(weight); this.samples = samples; sampleCount = 0; double epsilon = 1 - Math.pow(0.5, 1.0 / samples); probabilityEstimater = new ConditionalCladeFrequency(tree, epsilon); schedule = getOperatorSchedule(tree); } private OperatorSchedule getOperatorSchedule(TreeModel treeModel) { ExchangeOperator narrowExchange = new ExchangeOperator( ExchangeOperator.NARROW, treeModel, 10); ExchangeOperator wideExchange = new ExchangeOperator( ExchangeOperator.WIDE, treeModel, 3); SubtreeSlideOperator subtreeSlide = new SubtreeSlideOperator(treeModel, 10.0, 1.0, true, false, false, false, CoercionMode.COERCION_ON); NNI nni = new NNI(treeModel, 10.0); WilsonBalding wilsonBalding = new WilsonBalding(treeModel, 3.0); FNPR fnpr = new FNPR(treeModel, 5.0); OperatorSchedule schedule = new SimpleOperatorSchedule(); schedule.addOperator(narrowExchange); schedule.addOperator(wideExchange); schedule.addOperator(subtreeSlide); schedule.addOperator(nni); schedule.addOperator(wilsonBalding); schedule.addOperator(fnpr); return schedule; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#doOperation() */ @Override public double doOperation() { if (!burnin) { if (sampleCount < samples * SAMPLE_EVERY) { sampleCount++; if (sampleCount % SAMPLE_EVERY == 0) { probabilityEstimater.addTree(tree); } setAcceptCount(0); setRejectCount(0); setTransitions(0); return doUnguidedOperation(); } else { return importanceExchange(); } } else { return doUnguidedOperation(); } } private double doUnguidedOperation() { int index = schedule.getNextOperatorIndex(); SimpleMCMCOperator operator = (SimpleMCMCOperator) schedule.getOperator(index); return operator.doOperation(); } /** * WARNING: Assumes strictly bifurcating tree. * * @throws InvalidTreeException */ private double importanceExchange() { final int nodeCount = tree.getNodeCount(); final NodeRef root = tree.getRoot(); NodeRef i; int indexI; int indexJ; do { indexI = MathUtils.nextInt(nodeCount); i = tree.getNode(indexI); } while (root == i || (tree.getParent(i) == root && tree.getNodeHeight(i) > tree.getNodeHeight(getOtherChild(tree, tree.getParent(i), i)))); List<Integer> secondNodeIndices = new ArrayList<Integer>(); List<Double> probabilities = new ArrayList<Double>(); NodeRef j, iP, jP; iP = tree.getParent(i); double sum = 0.0; double backward = calculateTreeProbability(tree); int offset = (int) -backward; backward = Math.exp(backward + offset); tree.beginTreeEdit(); for (int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if (j != root) { jP = tree.getParent(j); if ((iP != jP) && (i != jP) && (j != iP) && (tree.getNodeHeight(j) < tree.getNodeHeight(iP)) && (tree.getNodeHeight(i) < tree.getNodeHeight(jP))) { secondNodeIndices.add(n); swap(tree, tree.getNode(indexI), tree.getNode(n)); double prob = Math.exp(calculateTreeProbability(tree) + offset); probabilities.add(prob); swap(tree, tree.getNode(indexI), tree.getNode(n)); sum += prob; } } } double ran = Math.random() * sum; int index = 0; while (ran > 0.0) { ran -= probabilities.get(index); index++; } index--; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); // ******************************************* // assuming we would have chosen j first double sumForward2 = 0.0; NodeRef k, kP; indexJ = secondNodeIndices.get(index); for (int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if (k != root) { kP = tree.getParent(k); if ((jP != kP) && (j != kP) && (k != jP) && (tree.getNodeHeight(k) < tree.getNodeHeight(jP)) && (tree.getNodeHeight(j) < tree.getNodeHeight(kP))) { swap(tree, tree.getNode(indexJ), tree.getNode(n)); double prob = Math.exp(calculateTreeProbability(tree) + offset); sumForward2 += prob; swap(tree, tree.getNode(indexJ), tree.getNode(n)); } } } swap(tree, i, j); double forward = probabilities.get(index); iP = tree.getParent(i); double sumBackward = 0.0; for (int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if (j != root) { jP = tree.getParent(j); if ((iP != jP) && (i != jP) && (j != iP) && (tree.getNodeHeight(j) < tree.getNodeHeight(iP)) && (tree.getNodeHeight(i) < tree.getNodeHeight(jP))) { swap(tree, tree.getNode(indexI), tree.getNode(n)); double prob = Math.exp(calculateTreeProbability(tree) + offset); sumBackward += prob; swap(tree, tree.getNode(indexI), tree.getNode(n)); } } } // ******************************************* // assuming we would have chosen j first double sumBackward2 = 0.0; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); for (int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if (k != root) { kP = tree.getParent(k); if ((jP != kP) && (j != kP) && (k != jP) && (tree.getNodeHeight(k) < tree.getNodeHeight(jP)) && (tree.getNodeHeight(j) < tree.getNodeHeight(kP))) { swap(tree, tree.getNode(indexJ), tree.getNode(n)); double prob = Math.exp(calculateTreeProbability(tree) + offset); sumBackward2 += prob; swap(tree, tree.getNode(indexJ), tree.getNode(n)); } } } tree.endTreeEdit(); // AR - not sure whether this check is necessary try { tree.checkTreeIsValid(); } catch (InvalidTreeException e) { throw new RuntimeException(e.getMessage()); } double forwardProb = (forward / sum) + (forward / sumForward2); double backwardProb = (backward / sumBackward) + (backward / sumBackward2); double hastingsRatio = Math.log(backwardProb / forwardProb); // throw new OperatorFailedException( // "Couldn't find valid wide move on this tree!"); return hastingsRatio; } /* exchange subtrees whose root are i and j */ private void swap(TreeModel tree, NodeRef i, NodeRef j) { NodeRef iP = tree.getParent(i); NodeRef jP = tree.getParent(j); tree.removeChild(iP, i); tree.removeChild(jP, j); tree.addChild(jP, i); tree.addChild(iP, j); } private double calculateTreeProbability(Tree tree) { // return calculateTreeProbabilityMult(tree); // return calculateTreeProbabilityLog(tree); return probabilityEstimater.getTreeProbability(tree); // return 0.0; } public void setBurnin(boolean burnin) { this.burnin = burnin; } /* * (non-Javadoc) * * @see dr.inference.operators.SimpleMCMCOperator#getOperatorName() */ @Override public String getOperatorName() { return ImportanceSubtreeSwapParser.IMPORTANCE_SUBTREE_SWAP; } /* * (non-Javadoc) * * @see dr.inference.operators.MCMCOperator#getPerformanceSuggestion() */ public String getPerformanceSuggestion() { // TODO Auto-generated method stub return ""; } }