package test.dr.evomodel.arg.operators;
import dr.evolution.io.Importer;
import dr.evolution.io.NewickImporter;
import dr.evomodel.arg.ARGLogger;
import dr.evomodel.arg.ARGModel;
import dr.evomodel.arg.ARGReassortmentNodeCountStatistic;
import dr.evomodel.arg.coalescent.ARGUniformPrior;
import dr.evomodelxml.tree.TreeModelParser;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.loggers.ArrayLogFormatter;
import dr.inference.loggers.MCLogger;
import dr.inference.loggers.TabDelimitedFormatter;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCOptions;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.operators.CoercionMode;
import dr.inference.operators.OperatorSchedule;
import dr.inference.operators.ScaleOperator;
import dr.inference.operators.SimpleOperatorSchedule;
import dr.inference.trace.ArrayTraceList;
import dr.inference.trace.Trace;
import dr.inference.trace.TraceCorrelation;
import dr.math.distributions.GammaDistribution;
import dr.math.distributions.PoissonDistribution;
import test.dr.inference.trace.TraceCorrelationAssert;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
/**
* @author Marc Suchard
*/
public class ARGAddRemoveOperatorTest extends TraceCorrelationAssert {
public ARGAddRemoveOperatorTest(String name) {
super(name);
}
public void setUp() throws Exception {
super.setUp();
NewickImporter importer = new NewickImporter(
"(((A:1.0,B:1.0):1.0,C:2.0):1.0,D:3.0);");
arg4 = new ARGModel(importer.importTree(null));
arg4.setupHeightBounds();
arg4.addLikelihoodCalculator(null);
arg4.addLikelihoodCalculator(null);
importer = new NewickImporter(
"((((A:1.0,B:1.0):1.0,C:2.0):1.0,D:3.0):1.0, E:4.0);");
ARGModel arg5 = new ARGModel(importer.importTree(null));
importer = new NewickImporter(
"(((((A:1.0,B:1.0):1.0,C:2.0):1.0,D:3.0):1.0, E:4.0),F:5.0);");
ARGModel arg6 = new ARGModel(importer.importTree(null));
}
// 4 taxa args
public void testFlatPrior4() throws IOException, Importer.ImportException {
flatPriorTester(arg4, 2000000, 1000, 2.0, 100.0, 0.5, 3);
}
private void flatPriorTester(ARGModel arg, int chainLength, int sampleTreeEvery,
double nodeCountSetting, double rootHeightAlpha, double rootHeightBeta, int maxCount)
throws IOException, Importer.ImportException {
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(chainLength);
// double nodeCountSetting = 2.0;
// double rootHeightAlpha = 100;
// double rootHeightBeta = 0.5;
OperatorSchedule schedule = getSchedule(arg);
ARGUniformPrior uniformPrior = new ARGUniformPrior(arg, maxCount, arg.getExternalNodeCount());
PoissonDistribution poisson = new PoissonDistribution(nodeCountSetting);
DistributionLikelihood nodeCountPrior = new DistributionLikelihood(poisson, 0.0);
ARGReassortmentNodeCountStatistic nodeCountStatistic = new ARGReassortmentNodeCountStatistic("nodeCount", arg);
nodeCountPrior.addData(nodeCountStatistic);
DistributionLikelihood rootPrior = new DistributionLikelihood(new GammaDistribution(rootHeightAlpha, rootHeightBeta), 0.0);
CompoundParameter rootHeight = (CompoundParameter) arg.createNodeHeightsParameter(true, false, false);
rootPrior.addData(rootHeight);
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(uniformPrior);
likelihoods.add(rootPrior);
likelihoods.add(nodeCountPrior);
CompoundLikelihood compoundLikelihood = new CompoundLikelihood(1, likelihoods);
compoundLikelihood.setId("likelihood1");
MCLogger[] loggers = new MCLogger[3];
loggers[0] = new MCLogger(new TabDelimitedFormatter(System.out), 10000, false);
loggers[0].add(compoundLikelihood);
loggers[0].add(arg);
File file = new File("test.args");
file.deleteOnExit();
FileOutputStream out = new FileOutputStream(file);
loggers[1] = new ARGLogger(arg, new TabDelimitedFormatter(out), sampleTreeEvery, "test");
ArrayLogFormatter formatter = new ArrayLogFormatter(false);
loggers[2] = new MCLogger(formatter, sampleTreeEvery, false);
loggers[2].add(arg);
arg.getRootHeightParameter().setId("root");
loggers[2].add(arg.getRootHeightParameter());
mcmc.setShowOperatorAnalysis(true);
mcmc.init(options, compoundLikelihood, schedule, loggers);
mcmc.run();
out.flush();
out.close();
List<Trace> traces = formatter.getTraces();
// Set<String> uniqueTrees = new HashSet<String>();
//
// NexusImporter importer = new NexusImporter(new FileReader(file));
// while (importer.hasTree()) {
// Tree t = importer.importNextTree();
// uniqueTrees.add(Tree.Utils.uniqueNewick(t, t.getRoot()));
// }
//
// TestCase.assertEquals(numTopologies, uniqueTrees.size()); List<Trace> traces = formatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("ARGTest", traces, 0);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
TraceCorrelation nodeCountStats = traceList.getCorrelationStatistics(1);
TraceCorrelation rootHeightStats = traceList.getCorrelationStatistics(4);
assertExpectation("nodeCount", nodeCountStats, poisson.truncatedMean(maxCount));
assertExpectation(TreeModelParser.ROOT_HEIGHT, rootHeightStats, rootHeightAlpha * rootHeightBeta);
}
public static OperatorSchedule getSchedule(ARGModel arg) {
CompoundParameter rootHeight = (CompoundParameter) arg.createNodeHeightsParameter(true, false, false);
CompoundParameter internalHeights = (CompoundParameter) arg.createNodeHeightsParameter(false, true, false);
//CompoundParameter allInternalNodeHeights = (CompoundParameter) arg.createNodeHeightsParameter(true, true, false);
// CompoundParameter rates = (CompoundParameter) arg.createNodeRatesParameter(false, true, true);
// ARGAddRemoveEventOperator operator1 = new ARGAddRemoveEventOperator(arg, 5, 0.5,
// CoercionMode.COERCION_ON, internalHeights, allInternalNodeHeights, rates, 0.9, null,-1);
ScaleOperator operator2 = new ScaleOperator(rootHeight, 0.75, CoercionMode.COERCION_ON, 5);
ScaleOperator operator3 = new ScaleOperator(internalHeights, 0.75, CoercionMode.COERCION_ON, 10);
OperatorSchedule schedule = new SimpleOperatorSchedule();
// schedule.addOperator(operator1);
schedule.addOperator(operator2);
schedule.addOperator(operator3);
return schedule;
// <scaleOperator id="rootOperator" scaleFactor="0.5"
// weight="10">
// <parameter idref="argModel.rootHeight" />
// </scaleOperator>
//
// <scaleOperator scaleFactor="0.95" weight="10">
// <parameter idref="argModel.internalNodeHeights" />
// </scaleOperator>
// <ARGEventOperator weight="5" addProbability="0.5"
// autoOptimize="false">
// <argTreeModel idref="argModel" />
// <internalNodes>
// <parameter idref="argModel.internalNodeHeights" />
// </internalNodes>
// <internalNodesPlusRoot>
// <parameter idref="argModel.allInternalNodeHeights" />
// </internalNodesPlusRoot>
// <nodeRates>
// <parameter idref="argModel.rates" />
// </nodeRates>
// </ARGEventOperator>
}
private ARGModel arg4;
}