/* * GibbsSubtreeSwap.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /** * */ package dr.evomodel.operators; import dr.evolution.tree.MutableTree.InvalidTreeException; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.inference.model.Likelihood; import dr.inference.operators.SimpleMetropolizedGibbsOperator; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; /** * @author Sebastian Hoehna * */ // Cleaning out untouched stuff. Can be resurrected if needed @Deprecated public class GibbsSubtreeSwap extends SimpleMetropolizedGibbsOperator { private int MAX_DISTANCE = 10; private final TreeModel tree; private final int[] distances; private boolean pruned = true; /** * */ public GibbsSubtreeSwap(TreeModel tree, boolean pruned, double weight) { this.tree = tree; this.pruned = pruned; setWeight(weight); MAX_DISTANCE = tree.getNodeCount() / 10; MAX_DISTANCE = 4; distances = new int[tree.getNodeCount() / 2]; } /* * (non-Javadoc) * * @see dr.evomodel.operators.SimpleGibbsOperator#getStepCount() */ @Override public int getStepCount() { return 1; } public double doOperation(Likelihood likelihood) { if( pruned ) { return prunedWide(likelihood); } else { return wide(likelihood); } } /** * WARNING: Assumes strictly bifurcating tree. * * @throws InvalidTreeException */ public double wide(Likelihood likelihood) { final int nodeCount = tree.getNodeCount(); final NodeRef root = tree.getRoot(); NodeRef i; int indexI; do { indexI = MathUtils.nextInt(nodeCount); i = tree.getNode(indexI); } while( root == i || (tree.getParent(i) == root && tree.getNodeHeight(i) > tree .getNodeHeight(getOtherChild(tree, tree.getParent(i), i))) ); List<Integer> secondNodeIndices = new ArrayList<Integer>(); List<Double> probabilities = new ArrayList<Double>(); NodeRef j, jP; NodeRef iP = tree.getParent(i); double heightIP = tree.getNodeHeight(iP); double heightI = tree.getNodeHeight(i); double sum = 0.0; double backward = calculateTreeLikelihood(likelihood, tree); int offset = (int) -backward; backward = Math.exp(backward + offset); for(int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if( j != root ) { jP = tree.getParent(j); if( (iP != jP) && (tree.getNodeHeight(j) < heightIP) && (heightI < tree.getNodeHeight(jP)) ) { secondNodeIndices.add(n); swap(tree, i, j, iP, jP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); probabilities.add(prob); swap(tree, i, j, jP, iP); sum += prob; } } } if( sum <= 1E-100 ) { // hack // the proposals have such a small likelihood that they can be // neglected throw new RuntimeException( "Couldn't find another proposal with a decent likelihood."); } double ran = Math.random() * sum; int index = 0; while( ran > 0.0 ) { ran -= probabilities.get(index); index++; } index--; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); double heightJP = tree.getNodeHeight(jP); double heightJ = tree.getNodeHeight(j); // int distance = getNodeDistance(i, j); // distances[distance]++; // ******************************************* // assuming we would have chosen j first double sumForward2 = 0.0; NodeRef k, kP; for(int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if( k != root ) { kP = tree.getParent(k); if( (jP != kP) && (tree.getNodeHeight(k) < heightJP) && (heightJ < tree.getNodeHeight(kP)) ) { swap(tree, j, k, jP, kP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumForward2 += prob; swap(tree, j, k, kP, jP); } } } swap(tree, i, j, iP, jP); double forward = probabilities.get(index); iP = jP; heightIP = heightJP; double sumBackward = 0.0; for(int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if( j != root ) { jP = tree.getParent(j); if( (iP != jP) && (tree.getNodeHeight(j) < heightIP) && (heightI < tree.getNodeHeight(jP)) ) { swap(tree, i, j, iP, jP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumBackward += prob; swap(tree, i, j, jP, iP); } } } // ******************************************* // assuming we would have chosen j first double sumBackward2 = 0.0; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); heightJP = tree.getNodeHeight(jP); heightJ = tree.getNodeHeight(j); for(int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if( k != root ) { kP = tree.getParent(k); if( (jP != kP) && (tree.getNodeHeight(k) < heightJP) && (heightJ < tree.getNodeHeight(kP)) ) { swap(tree, j, k, jP, kP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumBackward2 += prob; swap(tree, j, k, kP, jP); } } } double forwardProb = (forward / sum) + (forward / sumForward2); double backwardProb = (backward / sumBackward) + (backward / sumBackward2); double hastingsRatio = Math.log(backwardProb / forwardProb); // throw new OperatorFailedException( // "Couldn't find valid wide move on this tree!"); return hastingsRatio; } /** * WARNING: Assumes strictly bifurcating tree. * * @throws InvalidTreeException */ public double prunedWide(Likelihood likelihood) { final int nodeCount = tree.getNodeCount(); final NodeRef root = tree.getRoot(); NodeRef i; int indexI; do { indexI = MathUtils.nextInt(nodeCount); i = tree.getNode(indexI); } while( root == i || (tree.getParent(i) == root && tree.getNodeHeight(i) > tree .getNodeHeight(getOtherChild(tree, tree.getParent(i), i))) ); List<Integer> secondNodeIndices = new ArrayList<Integer>(); List<Double> probabilities = new ArrayList<Double>(); NodeRef j, jP; NodeRef iP = tree.getParent(i); double heightIP = tree.getNodeHeight(iP); double heightI = tree.getNodeHeight(i); double sum = 0.0; double backward = calculateTreeLikelihood(likelihood, tree); int offset = (int) -backward; backward = Math.exp(backward + offset); for(int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if( j != root ) { jP = tree.getParent(j); if( (iP != jP) && (tree.getNodeHeight(j) < heightIP) && (heightI < tree.getNodeHeight(jP)) && getNodeDistance(iP, jP) <= MAX_DISTANCE ) { secondNodeIndices.add(n); swap(tree, i, j, iP, jP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); probabilities.add(prob); swap(tree, i, j, jP, iP); sum += prob; } } } if( sum <= 1E-100 ) { // hack // the proposals have such a small likelihood that they can be // neglected throw new RuntimeException( "Couldn't find another proposal with a decent likelihood."); } double ran = Math.random() * sum; int index = 0; while( ran > 0.0 ) { ran -= probabilities.get(index); index++; } index--; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); double heightJP = tree.getNodeHeight(jP); double heightJ = tree.getNodeHeight(j); // ******************************************* // assuming we would have chosen j first double sumForward2 = 0.0; NodeRef k, kP; for(int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if( k != root ) { kP = tree.getParent(k); if( (jP != kP) && (tree.getNodeHeight(k) < heightJP) && (heightJ < tree.getNodeHeight(kP)) && getNodeDistance(kP, jP) <= MAX_DISTANCE ) { swap(tree, j, k, jP, kP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumForward2 += prob; swap(tree, j, k, kP, jP); } } } swap(tree, i, j, iP, jP); double forward = probabilities.get(index); iP = jP; heightIP = heightJP; double sumBackward = 0.0; for(int n = 0; n < nodeCount; n++) { j = tree.getNode(n); if( j != root ) { jP = tree.getParent(j); if( (iP != jP) && (tree.getNodeHeight(j) < heightIP) && (heightI < tree.getNodeHeight(jP)) && getNodeDistance(iP, jP) <= MAX_DISTANCE ) { swap(tree, i, j, iP, jP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumBackward += prob; swap(tree, i, j, jP, iP); } } } // ******************************************* // assuming we would have chosen j first double sumBackward2 = 0.0; j = tree.getNode(secondNodeIndices.get(index)); jP = tree.getParent(j); heightJP = tree.getNodeHeight(jP); heightJ = tree.getNodeHeight(j); for(int n = 0; n < nodeCount; n++) { k = tree.getNode(n); if( k != root ) { kP = tree.getParent(k); if( (jP != kP) && (tree.getNodeHeight(k) < heightJP) && (heightJ < tree.getNodeHeight(kP)) && getNodeDistance(kP, jP) <= MAX_DISTANCE ) { swap(tree, j, k, jP, kP); double prob = Math.exp(calculateTreeLikelihood( likelihood, tree) + offset); sumBackward2 += prob; swap(tree, j, k, kP, jP); } } } double forwardProb = (forward / sum) + (forward / sumForward2); double backwardProb = (backward / sumBackward) + (backward / sumBackward2); double hastingsRatio = Math.log(backwardProb / forwardProb); // throw new OperatorFailedException( // "Couldn't find valid wide move on this tree!"); return hastingsRatio; } private double calculateTreeLikelihood(Likelihood likelihood, TreeModel tree) { return evaluate(likelihood, 1.0); // return 0.0; } /** * @param tree the tree * @param parent the parent * @param child the child that you want the sister of * @return the other child of the given parent. */ protected NodeRef getOtherChild(Tree tree, NodeRef parent, NodeRef child) { if( tree.getChild(parent, 0) == child ) { return tree.getChild(parent, 1); } else { return tree.getChild(parent, 0); } } /* exchange subtrees whose root are i and j */ private TreeModel swap(TreeModel tree, NodeRef i, NodeRef j, NodeRef iP, NodeRef jP) { tree.beginTreeEdit(); tree.removeChild(iP, i); tree.removeChild(jP, j); tree.addChild(jP, i); tree.addChild(iP, j); tree.endTreeEdit(); return tree; } private int getNodeDistance(NodeRef i, NodeRef j) { int count = 0; double heightI = tree.getNodeHeight(i); double heightJ = tree.getNodeHeight(j); while( i != j ) { count++; if( heightI < heightJ ) { i = tree.getParent(i); heightI = tree.getNodeHeight(i); } else { j = tree.getParent(j); heightJ = tree.getNodeHeight(j); } } return count; } public void printDistances() { System.out.println("Number of proposed trees in distances:"); for(int i = 0; i < distances.length; i++) { System.out.println(i + ")\t\t" + distances[i]); } } public String getOperatorName() { return "Gibbs Subtree Exchange"; } }