package test.dr.calibration;
import dr.evolution.coalescent.DemographicFunction;
import dr.evolution.tree.SimpleNode;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxa;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.ConstantPopulationModel;
import dr.evomodel.coalescent.DemographicModel;
import dr.evomodel.operators.SubtreeSlideOperator;
import dr.evomodel.speciation.BirthDeathGernhard08Model;
import dr.evomodel.speciation.SpeciationLikelihood;
import dr.evomodel.speciation.SpeciationModel;
import dr.evomodel.tree.*;
import dr.evomodelxml.coalescent.ConstantPopulationModelParser;
import dr.evomodelxml.tree.TreeModelParser;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.loggers.ArrayLogFormatter;
import dr.inference.loggers.MCLogger;
import dr.inference.mcmc.MCMC;
import dr.inference.mcmc.MCMCOptions;
import dr.inference.model.BooleanLikelihood;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.*;
import dr.inference.trace.ArrayTraceList;
import dr.inference.trace.Trace;
import dr.inference.trace.TraceCorrelation;
import dr.inferencexml.model.CompoundLikelihoodParser;
import dr.math.distributions.LogNormalDistribution;
import dr.util.NumberFormatter;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
/**
* @auther Walter Xie
*/
public class TestCalibratedYuleModel {
protected static final String TL = "TL";
protected static final String TREE_HEIGHT = TreeModel.TREE_MODEL + "." + TreeModelParser.ROOT_HEIGHT;
// private final int treeSize;
private final BufferedWriter out;
Taxa taxa;
private static final double M = 1 + Math.log(10);
public TestCalibratedYuleModel(int treeSize, double S, int chainLength, BufferedWriter out) throws Exception {
// this.treeSize = treeSize;
this.out = out;
// out.write(Integer.toString(treeSize) + "\t");
TreeModel treeModel = createTreeModel(treeSize);
Parameter brParameter = new Parameter.Default("birthRate", 2.0, 0.0, 100.0);
OperatorSchedule schedule = new SimpleOperatorSchedule();
MCMCOperator operator = new SubtreeSlideOperator(treeModel, 10, 1, true, false, false, false, CoercionMode.COERCION_ON);
schedule.addOperator(operator);
operator = new ScaleOperator(brParameter, 0.5);
operator.setWeight(1.0);
schedule.addOperator(operator);
System.out.println("treeModel = " + TreeUtils.newickNoLengths(treeModel));
// out.write("\t");
// out.write("treeModel = \t");
out.write(TreeUtils.newickNoLengths(treeModel));
out.write("\t");
yuleTester(treeModel, schedule, brParameter, S, chainLength);
}
protected TreeModel createTreeModel(int treeSize) throws Exception {
taxa = new Taxa();
for (int i = 0; i < treeSize; i++) {
taxa.addTaxon(new Taxon("T" + Integer.toString(i)));
}
//System.out.println("taxaSubSet_size = " + taxaSubSet.getTaxonCount());
Parameter popSize = new Parameter.Default(treeSize);
popSize.setId(ConstantPopulationModelParser.POPULATION_SIZE);
ConstantPopulationModel startingTree = new ConstantPopulationModel(popSize, Units.Type.YEARS);
Tree tree = calibration(taxa, startingTree);
return new TreeModel(tree);//treeModel
}
private Tree calibration(final TaxonList taxa, DemographicModel demoModel) throws Exception {
dr.evolution.coalescent.CoalescentSimulator simulator = new dr.evolution.coalescent.CoalescentSimulator();
DemographicFunction demoFunction = demoModel.getDemographicFunction();
SimpleNode[] firstHalfNodes = new SimpleNode[taxa.getTaxonCount() / 2];
SimpleNode[] secondHalfNodes = new SimpleNode[taxa.getTaxonCount() - taxa.getTaxonCount() / 2];
for (int i = 0; i < firstHalfNodes.length; i++) {
firstHalfNodes[i] = new SimpleNode();
firstHalfNodes[i].setTaxon(taxa.getTaxon(i));
}
for (int i = 0; i < secondHalfNodes.length; i++) {
secondHalfNodes[i] = new SimpleNode();
secondHalfNodes[i].setTaxon(taxa.getTaxon(i + taxa.getTaxonCount() / 2));
}
SimpleNode firstHalfRootNode = simulator.simulateCoalescent(firstHalfNodes, demoFunction);
SimpleNode[] restNodes = simulator.simulateCoalescent(secondHalfNodes, demoFunction, 0, firstHalfRootNode.getHeight());
SimpleNode[] together = new SimpleNode[restNodes.length + 1];
for (int i = 0; i < restNodes.length; i++) {
together[i + 1] = restNodes[i];
}
together[0] = firstHalfRootNode;
SimpleNode root = simulator.simulateCoalescent(together, demoFunction);
Tree tree = new SimpleTree(root);
return tree;
}
private void yuleTester(TreeModel treeModel, OperatorSchedule schedule, Parameter brParameter, double S, int chainLength)
throws IOException, TreeUtils.MissingTaxonException {
MCMC mcmc = new MCMC("mcmc1");
MCMCOptions options = new MCMCOptions(chainLength);
TreeLengthStatistic tls = new TreeLengthStatistic(TL, treeModel);
TreeHeightStatistic rootHeight = new TreeHeightStatistic(TREE_HEIGHT, treeModel);
SpeciationModel speciationModel = new BirthDeathGernhard08Model("yule", brParameter, null, null,
BirthDeathGernhard08Model.TreeType.UNSCALED, Units.Type.SUBSTITUTIONS, false);
Likelihood speciationLikelihood = new SpeciationLikelihood(treeModel, speciationModel, "yule.like");
Taxa halfTaxa = new Taxa();
for (int i = 0; i < taxa.getTaxonCount() / 2; i++) {
halfTaxa.addTaxon(new Taxon("T" + Integer.toString(i)));
}
TMRCAStatistic tmrca = new TMRCAStatistic("tmrca(halfTaxa)", treeModel, halfTaxa, false, false);
DistributionLikelihood logNormalLikelihood = new DistributionLikelihood(
new LogNormalDistribution(M, S), 0); // meanInRealSpace="false"
logNormalLikelihood.addData(tmrca);
MonophylyStatistic monophylyStatistic = new MonophylyStatistic("monophyly(halfTaxa)", treeModel, halfTaxa, null);
BooleanLikelihood booleanLikelihood = new BooleanLikelihood();
booleanLikelihood.addData(monophylyStatistic);
//CompoundLikelihood
List<Likelihood> likelihoods = new ArrayList<Likelihood>();
likelihoods.add(speciationLikelihood);
likelihoods.add(logNormalLikelihood);
likelihoods.add(booleanLikelihood);
Likelihood prior = new CompoundLikelihood(0, likelihoods);
prior.setId(CompoundLikelihoodParser.PRIOR);
ArrayLogFormatter logformatter = new ArrayLogFormatter(false);
MCLogger[] loggers = new MCLogger[1];
loggers[0] = new MCLogger(logformatter, (int)(options.getChainLength() / 10000), false);
loggers[0].add(speciationLikelihood);
loggers[0].add(rootHeight);
loggers[0].add(tmrca);
loggers[0].add(tls);
loggers[0].add(brParameter);
mcmc.setShowOperatorAnalysis(false);
mcmc.init(options, prior, schedule, loggers);
mcmc.run();
List<Trace> traces = logformatter.getTraces();
ArrayTraceList traceList = new ArrayTraceList("yuleModelTest", traces, 1000);
for (int i = 1; i < traces.size(); i++) {
traceList.analyseTrace(i);
}
NumberFormatter formatter = new NumberFormatter(8);
TraceCorrelation tlStats = traceList.getCorrelationStatistics(traceList.getTraceIndex(TL));
TraceCorrelation treeHeightStats = traceList.getCorrelationStatistics(traceList.getTraceIndex("tmrca(halfTaxa)"));
// out.write("tmrcaHeight = \t");
out.write(formatter.format(treeHeightStats.getMean()));
out.write("\t");
double expectedNodeHeight = Math.pow(Math.E, (M + (Math.pow(S, 2) / 2)));
// out.write("expectation = \t");
out.write(formatter.format(expectedNodeHeight));
out.write("\t");
double error = Math.abs((treeHeightStats.getMean() - expectedNodeHeight) / expectedNodeHeight);
NumberFormat percentFormatter = NumberFormat.getNumberInstance();
percentFormatter.setMinimumFractionDigits(5);
percentFormatter.setMinimumFractionDigits(5);
// out.write("error = \t");
out.write(percentFormatter.format(error));
out.write("\t");
// out.write("tl.ess = \t");
out.write(Double.toString(tlStats.getESS()));
System.out.println("tmrcaHeight = " + formatter.format(treeHeightStats.getMean())
+ "; expectation = " + formatter.format(expectedNodeHeight)
+ "; error = " + percentFormatter.format(error)
+ "; tl.ess = " + tlStats.getESS());
}
public static void main(String[] args) {
try {
System.out.println("M = " + M);
BufferedWriter out = new BufferedWriter(new FileWriter("TestCalibratedYuleModel.txt"));
out.write("M = \t" + M);
out.newLine();
out.write("chainLeng\ttreeSize\tS\ttreeModel\ttmrcaHeight\texpectation\terror\ttl.ess");
out.newLine();
int[] taxaSchedule = new int[]{4, 8, 16, 32, 48, 64};
int[] chainLengthSchedule = new int[]{2000000, 2000000, 4000000, 8000000, 10000000, 10000000};
if (taxaSchedule.length != chainLengthSchedule.length) throw new Exception ();
double[] S_Schedule = new double[]{0.05, 0.1, 0.2, 0.4};
for (double S : S_Schedule) {
for (int i = 0; i < taxaSchedule.length; i++) {
int n = taxaSchedule[i];
int chainLength = chainLengthSchedule[i];
System.out.print("chainLeng = " + chainLengthSchedule[i] + "\tS = " + S + "\ttreeSize = " + n + "\t");
out.write(chainLengthSchedule[i] + "\t" + S + "\t" + n + "\t");
TestCalibratedYuleModel testCalibratedYuleModel = new TestCalibratedYuleModel(n, S, chainLength, out);
out.newLine();
}
}
out.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}