/* * ARGCoalescentLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.arg.coalescent; import dr.evolution.tree.NodeRef; import dr.evomodel.arg.ARGModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.xml.*; import java.util.ArrayList; import java.util.logging.Logger; public class ARGCoalescentLikelihood extends VeryOldCoalescentLikelihood { public static final String ARG_COALESCENT_MODEL = "argCoalescentLikelihood"; public static final String RECOMBINATION_RATE = "recombinationRate"; public static final String POPULATION_SIZE = "populationSize"; public static final String ARG_MODEL = "argModel"; public static final String MAX_REASSORTMENTS = "maxReassortments"; public static final String ANCESTRAL_RESTRICTION = "ancestralRestriction"; public static final int RECOMBINATION = 3; private Parameter popSize; private Parameter recomRate; protected ARGModel arg; private int taxaNumber; protected int maxReassortments; private boolean ancestralRestriction = false; private ArrayList<CoalescentInterval> intervals; private ArrayList<CoalescentInterval> storedIntervals; public ARGCoalescentLikelihood(String name, ARGModel arg, int max) { super(name); this.arg = arg; intervals = new ArrayList<CoalescentInterval>(); taxaNumber = arg.getExternalNodeCount(); this.maxReassortments = max; } public ARGCoalescentLikelihood(Parameter popSize, Parameter recomRate, ARGModel arg, boolean setupIntervals, int maxReassort, boolean ancestralRestriction) { super(ARG_COALESCENT_MODEL); this.popSize = popSize; this.recomRate = recomRate; this.arg = arg; this.ancestralRestriction = ancestralRestriction; addVariable(popSize); addVariable(recomRate); addModel(arg); intervals = new ArrayList<CoalescentInterval>(arg.getNodeCount()); intervalsKnown = false; likelihoodKnown = false; if (setupIntervals) { intervalsKnown = true; calculateIntervals(); } taxaNumber = arg.getExternalNodeCount(); this.maxReassortments = maxReassort; } public void calculateIntervals() { intervals.clear(); intervals.ensureCapacity(arg.getNodeCount()); NodeRef x; for (int i = 0; i < arg.getInternalNodeCount(); i++) { x = arg.getInternalNode(i); if (arg.isReassortment(x)) { intervals.add(new CoalescentInterval(arg.getNodeHeight(x), RECOMBINATION)); } else { intervals.add(new CoalescentInterval(arg.getNodeHeight(x), COALESCENT)); } } for (int i = 0; i < arg.getExternalNodeCount(); i++) { x = arg.getExternalNode(i); if (arg.getNodeHeight(x) > 0.0) { intervals.add(new CoalescentInterval(arg.getNodeHeight(x), NEW_SAMPLE)); } } dr.util.HeapSort.sort(intervals); double a = 0, b = 0; for (int i = 0; i < intervals.size(); i++) { b = intervals.get(i).length; intervals.get(i).length = intervals.get(i).length - a; a = b; } intervalsKnown = true; } public void handleModelChangedEvent(Model model, Object object, int index) { if (model == arg) { intervalsKnown = false; } likelihoodKnown = false; } public void handleParameterChangedEvent(Parameter parameter, int index) { likelihoodKnown = false; } public void storeState() { storedIntervals = new ArrayList<CoalescentInterval>(intervals.size()); for (CoalescentInterval interval : intervals) { storedIntervals.add(interval.clone()); } intervalsKnown = likelihoodKnown = false; storedIntervalsKnown = intervalsKnown; storedLikelihoodKnown = likelihoodKnown; storedLogLikelihood = logLikelihood; } public void restoreState() { intervals = storedIntervals; storedIntervals.clear(); intervalsKnown = storedIntervalsKnown; likelihoodKnown = storedLikelihoodKnown; logLikelihood = storedLogLikelihood; intervalsKnown = likelihoodKnown = false; // if (!intervalsKnown) { // likelihoodKnown = false; // } } public boolean currentARGValid(boolean allowDoubleParents) { if (!intervalsKnown) { calculateIntervals(); } int taxa = taxaNumber; for (CoalescentInterval x : intervals) { if (taxa == 1) return false; if (x.type == COALESCENT) taxa--; else if (x.type == RECOMBINATION) taxa++; else throw new RuntimeException("Not implemented yet"); } if (!allowDoubleParents) { for (int i = 0, n = arg.getNodeCount(); i < n; i++) { NodeRef x = arg.getNode(i); if (arg.isReassortment(x) && arg.getParent(x, 0) == arg.getParent(x, 1)) { return false; } } } return true; } public double getLogLikelihood() { if (likelihoodKnown) return logLikelihood; if (!intervalsKnown) calculateIntervals(); likelihoodKnown = true; if (arg.getReassortmentNodeCount() > maxReassortments){ logLikelihood = Double.NEGATIVE_INFINITY; }else if(ancestralRestriction && !arg.isAncestral()){ logLikelihood = Double.NEGATIVE_INFINITY; }else{ logLikelihood = calculateLogLikelihood( popSize.getParameterValue(0), recomRate.getParameterValue(0)); } return logLikelihood; } private double chooseTwo(int n) { return n * (n - 1) / 2.0; } private double calculateLogLikelihood(double pSize, double rRate) { double logLike = 0.0; int numberOfTaxa = taxaNumber; for (CoalescentInterval interval : intervals) { if (numberOfTaxa == 1) return Double.NEGATIVE_INFINITY; double rate = (double) numberOfTaxa * (numberOfTaxa - 1 + rRate) / (2.0 * pSize); logLike += Math.log(rate) - rate * interval.length; if (interval.type == COALESCENT) { logLike += Math.log((double) (numberOfTaxa - 1) / (numberOfTaxa - 1 + rRate)) - Math.log(chooseTwo(numberOfTaxa)); numberOfTaxa--; } else if (interval.type == RECOMBINATION) { logLike += Math.log(rRate / (numberOfTaxa - 1 + rRate)) - Math.log((double) numberOfTaxa); numberOfTaxa++; } else { throw new RuntimeException("Not implemented yet"); } } assert numberOfTaxa == 1; return logLike; } private class CoalescentInterval implements Comparable<CoalescentInterval>, Cloneable { public int type; public double length; public CoalescentInterval(double length, int type) { this.length = length; this.type = type; } public int compareTo(CoalescentInterval a) { if (a.length > this.length) { return -1; } else if (a.length == this.length) { Logger.getLogger("dr.evomodel.coalescent").severe( "The current ARG Model has 2 internal nodes " + "at the same height"); return 0; } return 1; } public String toString() { if (type == 0) { return "(" + length + ", Coalescent)"; } return "(" + length + ", Recombination)"; } public CoalescentInterval clone() { return new CoalescentInterval(length, type); } } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserDescription() { return "A coalescent likelihood for an ARG model"; } public Class getReturnType() { return ARGCoalescentLikelihood.class; } public String getParserName() { return ARG_COALESCENT_MODEL; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(POPULATION_SIZE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(RECOMBINATION_RATE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(ARG_MODEL, new XMLSyntaxRule[]{new ElementRule(ARGModel.class)}), AttributeRule.newBooleanRule(ANCESTRAL_RESTRICTION,true), }; public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(RECOMBINATION_RATE); Parameter rRate = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(POPULATION_SIZE); Parameter pSize = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(ARG_MODEL); ARGModel argModel = (ARGModel) cxo.getChild(ARGModel.class); int maxreassort = Integer.MAX_VALUE; if (xo.hasAttribute(MAX_REASSORTMENTS)) { maxreassort = xo.getIntegerAttribute(MAX_REASSORTMENTS); } boolean ancestral = false; if(xo.hasAttribute(ANCESTRAL_RESTRICTION)) ancestral = xo.getBooleanAttribute(ANCESTRAL_RESTRICTION); return new ARGCoalescentLikelihood(pSize, rRate, argModel, false, maxreassort, ancestral); } }; public String toString() { return getClass().getSimpleName() + " " + super.toString(); } }