package test.dr.evomodel.treelikelihood;
import dr.evolution.alignment.SimpleAlignment;
import dr.evolution.alignment.SitePatterns;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.sequence.Sequence;
import dr.evolution.util.Date;
import dr.evolution.util.Taxon;
import dr.evolution.util.Units;
import dr.oldevomodel.sitemodel.GammaSiteModel;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.substmodel.HKY;
import dr.evomodel.tree.TreeModel;
import dr.oldevomodel.treelikelihood.TreeLikelihood;
import dr.oldevomodelxml.sitemodel.GammaSiteModelParser;
import dr.oldevomodelxml.substmodel.HKYParser;
import dr.inference.model.Parameter;
import test.dr.inference.trace.TraceCorrelationAssert;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
/**
* @author Marc A. Suchard
*/
public class SequenceLikelihoodTest extends TraceCorrelationAssert {
protected NumberFormat format = NumberFormat.getNumberInstance(Locale.ENGLISH);
protected TreeModel treeModel;
protected static double tolerance = 1E-8;
public SequenceLikelihoodTest(String name) {
super(name);
}
private void recursivelyAddCharacter(String[] sequences, List<Integer> pattern,
DataType dataType) {
final int nTaxa = sequences.length;
if (pattern.size() == nTaxa) {
// Add pattern
for (int i = 0; i < nTaxa; i++) {
sequences[i] = sequences[i] + dataType.getCode(pattern.get(i));
}
} else {
// Continue recursion
final int stateCount = dataType.getStateCount();
for (int i = 0; i < stateCount; i++) {
List<Integer> newPattern = new ArrayList<Integer>();
newPattern.addAll(pattern);
newPattern.add(i);
recursivelyAddCharacter(sequences, newPattern, dataType);
}
}
}
public void testNull() {
// Do nothing; completely abstract JUnitTests are not allowed?
}
private String[] createAllUniquePatterns(int nTaxa, DataType dataType) {
String[] result = new String[nTaxa];
for (int i = 0; i < nTaxa; i++) {
result[i] = "";
}
List<Integer> pattern = new ArrayList<Integer>();
recursivelyAddCharacter(result, pattern, dataType);
return result;
}
protected void createAlignmentWithAllUniquePatterns(Object[][] taxa_sequence, DataType dataType) {
alignment = new SimpleAlignment();
alignment.setDataType(dataType);
int nTaxa = taxa_sequence[0].length;
String[] allUniquePatterns = createAllUniquePatterns(nTaxa, dataType);
taxa_sequence[1] = allUniquePatterns;
taxa = new Taxon[nTaxa]; // 6, 17
System.out.println("Taxon len = " + taxa_sequence[0].length);
System.out.println("Alignment len = " + taxa_sequence[1].length);
if (taxa_sequence.length > 2) System.out.println("Date len = " + taxa_sequence[2].length);
for (int i=0; i < taxa_sequence[0].length; i++) {
taxa[i] = new Taxon(taxa_sequence[0][i].toString());
if (taxa_sequence.length > 2) {
Date date = new Date((Double) taxa_sequence[2][i], Units.Type.YEARS, (Boolean) taxa_sequence[3][0]);
taxa[i].setDate(date);
}
//taxonList.addTaxon(taxon);
Sequence sequence = new Sequence(taxa_sequence[1][i].toString());
sequence.setTaxon(taxa[i]);
sequence.setDataType(dataType);
alignment.addSequence(sequence);
}
System.out.println("Sequence pattern count = " + alignment.getPatternCount());
}
protected double[] computeSitePatternLikelihoods(SitePatterns patterns) {
// Sub model
Parameter freqs = new Parameter.Default(alignment.getStateFrequencies());
Parameter kappa = new Parameter.Default(HKYParser.KAPPA, 29.739445, 0, 100);
FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs);
HKY hky = new HKY(kappa, f);
//siteModel
GammaSiteModel siteModel = new GammaSiteModel(hky);
Parameter mu = new Parameter.Default(GammaSiteModelParser.MUTATION_RATE, 1.0, 0, Double.POSITIVE_INFINITY);
siteModel.setMutationRateParameter(mu);
//treeLikelihood
TreeLikelihood treeLikelihood = new TreeLikelihood(patterns, treeModel, siteModel, null, null,
false, false, true, false, false);
return treeLikelihood.getPatternLogLikelihoods();
}
protected double computeSumOfPatterns(SitePatterns patterns) {
double[] patternLogLikelihoods = computeSitePatternLikelihoods(patterns);
double total = 0;
for (double x: patternLogLikelihoods) {
total += Math.exp(x);
}
return total;
}
}