/* * ARGUniformPrior.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.arg.coalescent; import dr.evomodel.arg.ARGModel; import dr.xml.*; import org.apache.commons.math.util.MathUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.logging.Logger; public class ARGUniformPrior extends ARGCoalescentLikelihood { public static final String ARG_UNIFORM_PRIOR = "argUniformPrior"; public static final String INITIAL_CALCULATIONS = "initialCalculations"; public static final int INITIAL_DEFAULT = 5; /** * First index represents <code>(number of taxa - 3)</code>. Second index * represents <code>(number of reassortments)</code>. */ public static final double[][] logARGCoalescentCount = { //3 external taxa {1.0986122886681098, 4.276666119016055, 8.265650165580329, 12.882968485504067, 18.022948777876554, 23.61061585799083, 29.589035299467483, 35.9136593725818, 42.54885778516801, 49.465568765313016, 56.63967062533004, 64.05083469593846, 71.6817017141028, 79.51727811078219, 87.54448452157624}, //4 external taxa {2.8903717578961645, 7.049254841255837, 11.72325121817378, 16.888557143580204, 22.489284935076803, 28.475109395164285, 34.804239049061785, 41.44233437710982, 48.360992220084476, 55.53645069653148, 62.948589014628716, 70.58017373916618, 78.41629051730933, 86.44391155847957, 94.65156231759461}, //5 external taxa {5.192956850890211, 10.127430784020902, 15.386536070471918, 21.031563309246433, 27.041414773420907, 33.38480151281157, 40.03192406503731, 46.95658863790172, 54.13620464676437, 61.55131458837549, 69.18508096220405, 77.02283620054189, 85.0517120025277, 93.26034012092553, 101.63861177667185}, //6 external taxa {7.90100705199242, 13.480736877978641, 19.249258325378765, 25.32054548939157, 31.698755104859107, 38.36733349428991, 45.30602732354885, 52.49523955146093, 59.917151710215464, 67.55588200276209, 75.39734855785784, 83.42905687418532, 91.63988547749693, 100.01989325102055, 108.56015364817871}, //7 external taxa {10.945529489715843, 17.07892753271249, 23.29838718386135, 29.753407176990198, 36.467041602426036, 43.43431706668038, 50.64269194407989, 58.077999881992824, 65.7264083635128, 73.57505740705963, 81.61221669682189, 89.8272615722345, 98.21058447998666, 106.75348902699808, 115.44808566836372}, //8 external taxa {14.277733999891046, 20.896472983408266, 27.5204956568849, 34.32463786712059, 41.34646380715327, 48.590254195514014, 56.04978852563933, 63.71543385947847, 71.57673212169905, 79.62341044573522, 87.84577046007783, 96.23481579273648, 104.78226604090887, 113.48052233027693, 122.32261419599219}, //9 external taxa {17.861252938347157, 24.912242385415205, 31.903176098451674, 39.02769936297709, 46.334764858303096, 53.836163974100344, 61.53092519032464, 69.4132445398445, 77.47555198269822, 85.70981250210274, 94.10810078400579, 102.66285174696894, 111.36695890554311, 120.21379947792728, 129.1972243144967}, //10 external taxa {21.667915428117478, 29.108649135506738, 36.43531279311537, 43.855945609774885, 51.42860726365353, 59.171269494486445, 67.08736948530374, 75.17436983012065, 83.4271948608094, 91.83976176721744, 100.40570480277214, 109.11872661806457, 117.97276669219758, 126.96207643311605, 136.08124559689557}, //11 external taxa {25.67524861334995, 33.470895149684544, 41.10703544874183, 48.802991754828916, 56.62422391351782, 64.59385162827832, 72.71905241455046, 81.00010003398279, 89.43409011267886, 98.0166551565224, 106.74281045596386, 115.60738869034444, 124.6052669677749, 133.73148419256228, 142.98129861277982}, //12 external taxa {29.864903355376374, 37.986383730127116, 45.90958634624165, 53.86285360119818, 61.91774245806648, 70.10170706136994, 78.42512661349667, 86.890712820353, 95.49746436335495, 104.24253036074631, 113.12215615381358, 122.13218720871606, 131.26834550980237, 140.52638287915215, 149.90216524590267}, //13 external taxa {34.22161218206597, 42.644274889635966, 50.83517511899286, 59.02998615127507, 67.30534956176908, 75.69240403974487, 84.20429128794189, 92.84585184006139, 101.61776251108103, 110.51852076251443, 119.54547309663305, 128.69537701990959, 137.96471924891867, 147.34989978600277, 156.84733922685047}, //14 external taxa {38.732471688582805, 47.43514810013058, 55.87684474986188, 64.2992772236583, 72.78337269607047, 81.36342882144557, 90.05498796388106, 98.86476107289788, 107.79491286633757, 116.84514548111035, 126.01379567653075, 135.29844575074105, 144.69627736969085, 154.20428208422834, 163.8193886161766}, //15 external taxa {43.38643203874033, 52.35074398686485, 61.028354989069626, 69.6660218145462, 78.34831802900398, 87.11227051188368, 95.97552137406072, 104.94643415554035, 114.02849935687857, 123.22250024459373, 132.52766774855348, 141.9423328912873, 151.46431112672838, 161.0911356940919, 170.82020228928107} }; private ArrayList<Double> argNumber; public ARGUniformPrior(ARGModel arg, int max, int initial) { super(ARG_UNIFORM_PRIOR, arg, max); addModel(arg); argNumber = new ArrayList<Double>(15); if (arg.getExternalNodeCount() - 3 < logARGCoalescentCount.length) { Logger.getLogger("dr.evomodel").info("Creating ARGUniformPrior using stored arg counts"); for (int i = 0, n = arg.getExternalNodeCount() - 3; i < logARGCoalescentCount[n].length; i++) argNumber.add(logARGCoalescentCount[n][i]); } else { Logger.getLogger("dr.evomodel").info("Creating ARGUniformPrior by calculating arg counts"); for (int i = 0, n = arg.getExternalNodeCount(); i < initial; i++) { argNumber.add(logNumberARGS(n, i)); } } } public double getLogARGNumber(int i) { if (i >= argNumber.size()) { argNumber.add(logNumberARGS(arg.getExternalNodeCount(), i)); } return argNumber.get(i); } public double getLogLikelihood() { if (likelihoodKnown) { return logLikelihood; } likelihoodKnown = true; logLikelihood = calculateLogLikelihood(); if (arg.getReassortmentNodeCount() > maxReassortments) logLikelihood = Double.NEGATIVE_INFINITY; else logLikelihood = calculateLogLikelihood(); if (!currentARGValid(true)) { logLikelihood = Double.NEGATIVE_INFINITY; } return logLikelihood; } public double calculateLogLikelihood() { double treeHeight = arg.getNodeHeight(arg.getRoot()); int internalNodes = arg.getInternalNodeCount() - 1; double logLike = logFactorial(internalNodes) - (double) internalNodes * Math.log(treeHeight) - getLogARGNumber(arg.getReassortmentNodeCount()); assert !Double.isInfinite(logLike) && !Double.isNaN(logLike); return logLike; } private double logFactorial(int n) { double rValue = 0; for (int i = n; i > 0; i--) { rValue += Math.log(i); } return rValue; } private int numberARGS(int taxa, int argNumber) { int x = taxa; int n = 2 * argNumber + taxa - 1; return shurikoRecursion(x, n); } private int shurikoRecursion(int x, int n) { int a = 0; if (x == 0) { a = 0; } else if (x == 1) { if (n == 0) { a = 1; } else { a = 0; } } else if (n == 0) { if (x == 1) { a = 1; } else { a = 0; } } else if (x == n + 1) { a = x * (x - 1) / 2 * shurikoRecursion(x - 1, n - 1); } else { a = x * shurikoRecursion(x + 1, n - 1) + x * (x - 1) / 2 * shurikoRecursion(x - 1, n - 1); } return a; } public static double logNumberARGS(int start, int reassortments) { Logger.getLogger("dr.evomodel").warning("Calculating ARG count for " + reassortments + " reassortments. This may take awhile"); if (reassortments == 0) { double a = 0; for (int i = start; i > 2; i--) { a += Math.log(i * (i - 1) / 2.0); } return a; } int[] max = new int[start - 3 + reassortments * 2]; int[] x = new int[max.length]; int i = 0; while (i < reassortments) { x[i] = max[i] = 1; i++; } while (i < max.length) { x[i] = max[i] = -1; i++; } double before = 100; double approx = 0; while (x[0] != -9 && !stopCombination(x, start)) { if (testCombination(x, start)) { before = approx; int[] y = generateValues(x, start); approx += reduceThenDivide(y, generateValues(max, start)); before = approx - before; } nextCombination(x); } approx = Math.log(approx); int[] y = new int[max.length + 2]; for (i = 0; i < max.length; i++) y[i] = max[i]; y[y.length - 2] = y[y.length - 1] = -1; max = generateValues(y, start); for (int k = 0; k < y.length; k++) approx += Math.log(max[k]); return approx; } private static double reduceThenDivide(int[] top, int[] bottom) { if (false) { for (int i = 0; i < top.length; i++) { for (int j = 0; j < bottom.length; j++) { int gcd = MathUtils.gcd(top[i], bottom[j]); if (gcd > 1) { top[i] = top[i] / gcd; bottom[j] = bottom[j] / gcd; } } } } Arrays.sort(top); Arrays.sort(bottom); double a = 1; for (int i = 0; i < top.length; i++) a *= (double) top[i] / bottom[i]; return a; } private static int[] generateValues(int[] x, int start) { int[] y = new int[x.length]; for (int i = 0; i < x.length; i++) { if (x[i] == 1) y[i] = start; else y[i] = start * (start - 1) / 2; start += x[i]; } return y; } private static boolean testCombination(int[] x, int start) { for (int i = 0; i < x.length; i++) { start += x[i]; if (start == 1) return false; } return true; } private static boolean stopCombination(int[] x, int start) { for (int i = 0; i < x.length; i++) { if (x[i] == -1) { start--; if (start == 1) { return true; } } else { break; } } return false; } private static void nextCombination(int[] x) { if (x[x.length - 1] == -1) { int i = x.length - 1; while (i > -1) { if (x[i] == 1) { x[i] = -1; x[i + 1] = 1; return; } else i--; } } else { int endOnes = 0; int i = x.length - 1; while (x[i] == 1) { endOnes++; i--; } int nextOne = -1; while (i > -1) { if (x[i] == 1) { nextOne = i; break; } else i--; } if (nextOne == -1) { x[0] = -9; return; } x[nextOne] = -1; x[nextOne + 1] = 1; for (i = 0; i < endOnes; i++) x[i + nextOne + 2] = 1; i = nextOne + 2 + endOnes; while (i < x.length) { x[i] = -1; i++; } } } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserDescription() { return "A uniform prior for an ARG model"; } public Class getReturnType() { return ARGUniformPrior.class; } public String getParserName() { return ARG_UNIFORM_PRIOR; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ new ElementRule(ARGModel.class), }; public Object parseXMLObject(XMLObject xo) throws XMLParseException { ARGModel argModel = (ARGModel) xo.getChild(ARGModel.class); int max = Integer.MAX_VALUE; if (xo.hasAttribute(MAX_REASSORTMENTS)) { max = xo.getIntegerAttribute(MAX_REASSORTMENTS); } int initial = INITIAL_DEFAULT; if (xo.hasAttribute(INITIAL_CALCULATIONS)) { initial = xo.getIntegerAttribute(INITIAL_CALCULATIONS); } return new ARGUniformPrior(argModel, max, initial); } }; }