/*
* GaussianProcessSkytrackLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.coalescent;
//import com.lowagie.text.Paragraph;
//import com.sun.servicetag.SystemEnvironment;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.coalescent.GaussianProcessSkytrackLikelihoodParser;
import dr.inference.loggers.LogColumn;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MathUtils;
//import dr.util.ComparableDouble;
//import dr.util.HeapSort;
import no.uib.cipr.matrix.*;
//import sun.font.TrueTypeFont;
import java.util.ArrayList;
import java.util.List;
//import dr.evolution.tree.TreeTrait;
/**
* @author Vladimir Minin
* @author Marc Suchard
* @author Julia Palacios
*/
//For implementation,
public class GaussianProcessSkytrackLikelihood extends OldAbstractCoalescentLikelihood {
// protected Parameter groupSizeParameter;
public static final double LOG_TWO_TIMES_PI = 1.837877;
protected Parameter precisionParameter;
protected Parameter lambda_boundParameter;
// protected Parameter numGridPoints;
protected Parameter lambdaParameter; //prior for lambda_bound, will be used in operators only
protected Parameter betaParameter;
protected Parameter alphaParameter;
protected Parameter GPtype;
protected Parameter GPcounts;
protected Parameter coalfactor;
protected Parameter popSizeParameter; //before called GPvalues
protected Parameter changePoints;
protected Parameter Tmrca;
// protected Parameter popValue;
protected Parameter CoalCounts;
protected Parameter numPoints;
// protected double [] GPchangePoints;
// protected double [] storedGPchangePoints;
protected double [] GPcoalfactor;
protected double [] storedGPcoalfactor;
protected double [] GPCoalInterval;
protected double [] storedGPCoalInterval;
protected double [] backupIntervals;
// protected double [] storedcoalfactor;
// protected int [] GPcounts; //It changes values, no need to storage
// protected int [] storedGPcounts;
protected int [] CoalPosIndicator;
protected int [] storedCoalPosIndicator;
protected double [] CoalTime;
protected double [] storedCoalTime;
protected int numintervals;
protected int numcoalpoints;
protected double constlik;
protected double storedconstlik;
protected double logGPLikelihood;
// protected double storedLogGPLikelihood;
protected SymmTridiagMatrix weightMatrix; //this now changes in dimension, no need to storage
// protected MatrixParameter dMatrix;
protected boolean rescaleByRootHeight;
private static List<Tree> wrapTree(Tree tree) {
List<Tree> treeList = new ArrayList<Tree>();
treeList.add(tree);
return treeList;
}
public GaussianProcessSkytrackLikelihood(Tree tree,
Parameter precParameter,
boolean rescaleByRootHeight, Parameter lambda_bound, Parameter lambda_parameter, Parameter popParameter, Parameter alpha_parameter, Parameter beta_parameter, Parameter change_points, Parameter GPtype, Parameter GPcounts, Parameter coalfactor, Parameter CoalCounts, Parameter numPoints, Parameter Tmrca) {
this(wrapTree(tree), precParameter, rescaleByRootHeight, lambda_bound, lambda_parameter, popParameter, alpha_parameter, beta_parameter, change_points,GPtype,GPcounts,coalfactor,CoalCounts, numPoints, Tmrca);
}
public GaussianProcessSkytrackLikelihood(String name) {
super(name);
}
public GaussianProcessSkytrackLikelihood(List<Tree> treeList,
Parameter precParameter,
boolean rescaleByRootHeight, Parameter lambda_bound, Parameter lambda_parameter, Parameter popParameter, Parameter alpha_parameter, Parameter beta_parameter, Parameter change_points, Parameter GPtype, Parameter GPcounts, Parameter coalfactor, Parameter CoalCounts, Parameter numPoints, Parameter Tmrca) {
super(GaussianProcessSkytrackLikelihoodParser.SKYTRACK_LIKELIHOOD);
this.popSizeParameter = popParameter;
this.Tmrca=Tmrca;
// this.popValue=popValues;
this.changePoints=change_points;
this.numPoints=numPoints;
// this.groupSizeParameter = groupParameter;
this.precisionParameter = precParameter;
this.lambdaParameter = lambda_parameter;
this.betaParameter = beta_parameter;
this.alphaParameter=alpha_parameter;
// this.dMatrix = dMatrix;
this.rescaleByRootHeight = rescaleByRootHeight;
this.lambda_boundParameter= lambda_bound;
this.GPcounts=GPcounts;
this.GPtype=GPtype;
this.coalfactor=coalfactor;
this.CoalCounts=CoalCounts;
// addVariable(GPvalues);
addVariable(precisionParameter);
addVariable(popSizeParameter);
addVariable(changePoints);
addVariable(numPoints);
// addVariable(popValue);
addVariable(GPcounts);
// addVariable(GPcoalfactor);
addVariable(GPtype);
addVariable(coalfactor);
addVariable(lambda_boundParameter);
addVariable(CoalCounts);
// addVariable(lambdaParameter);
// addVariable(lambda_boundParameter);
// if (betaParameter != null) {
// addVariable(betaParameter);
// }
setTree(treeList);
wrapSetupIntervals();
// intervalCount = the size for constant vectors
// int fieldLength = getCorrectFieldLength();
numintervals= getIntervalCount();
numcoalpoints=getCorrectFieldLength();
GPcoalfactor = new double[numintervals];
backupIntervals=new double[numintervals];
GPCoalInterval=new double[numcoalpoints];
storedGPCoalInterval=new double[numcoalpoints];
CoalPosIndicator= new int[numcoalpoints];
storedCoalPosIndicator=new int[numcoalpoints];
CoalTime=new double[numcoalpoints];
storedCoalTime=new double[numcoalpoints];
storedGPcoalfactor = new double[numintervals];
GPcounts.setDimension(numintervals);
CoalCounts.setDimension(numcoalpoints);
// storedGPcounts= new int[numintervals];
GPtype.setDimension(numcoalpoints);
numPoints.setParameterValue(0,numcoalpoints);
// storedGPtype = new int[numcoalpoints];
popSizeParameter.setDimension(numcoalpoints);
// NEED TO MOVE PopValue ---delete popValue also
// int gridpoint= (int) numGridPoints.getParameterValue(0);
// popValue.setDimension(gridpoint);
// System.err.println("sets dimension");
changePoints.setDimension(numcoalpoints);
coalfactor.setDimension(numcoalpoints);
// storedcoalfactor= new double[numcoalpoints];
initializationReport();
setupSufficientStatistics();
setupGPvalues();
// System.err.println("initial GP likelihood +priors"+getLogLikelihood());
// System.err.println("like"+intervalsKnown);
//
// System.err.println(getLogLikelihood());
}
// Methods that override existent methods
private boolean flagForJulia = false;
/**
* Demonstration of how to mark stuff as dirty
* @param model
* @param object
* @param index
*/
protected void handleModelChangedEvent(Model model, Object object, int index) {
super.handleModelChangedEvent(model, object, index); // Call super, since it may do something important
if (model == tree) {
// treeModel has changed; treeModel calls pushTreeChangedEvent that ultimately gets passed to here
if (object instanceof TreeModel.TreeChangedEvent) {
TreeModel.TreeChangedEvent tce = (TreeModel.TreeChangedEvent) object;
// tce tells much about what type of event happened. In general, one does not care.
// System.err.println("Change in tree detected, flag true");
flagForJulia = true; // flag set, so lazy work can occur elsewhere.
} else {
if (object instanceof Parameter){
// System.err.println("changing heights");
flagForJulia=true;
}
else {
throw new IllegalArgumentException("Not sure what type of model changed event occurred: " + object.getClass().toString());
}
}
}
}
public LogColumn[] getColumns() {
// Add more LogColumn to the array if there are more things to log
return new LogColumn[]{
new VariableLengthColumn("changePoints", changePoints), new VariableLengthColumn("Gvalues",popSizeParameter)
};
}
private class VariableLengthColumn extends LogColumn.Abstract {
private final Parameter param;
public VariableLengthColumn(String label, Parameter param) {
super(label);
this.param = param;
}
protected String getFormattedValue() {
return convertToDelimited(param.getParameterValues());
}
// TODO The following functionality is generic and should be moved somewhere else and made static
private static final String OPEN = "{";
private static final String CLOSE = "}";
private static final String DELIMIT = ",";
private String convertToDelimited(double[] x) {
StringBuilder sb = new StringBuilder(OPEN);
final int dim = x.length;
for (int i = 0; i < dim; ++i) {
sb.append(Double.toString(x[i]));
if (i < dim - 1) {
sb.append(DELIMIT);
}
}
sb.append(CLOSE);
return sb.toString();
}
}
protected void setTree(List<Tree> treeList) {
if (treeList.size() != 1) {
throw new RuntimeException("GP-based method only implemented for one tree");
}
this.tree = treeList.get(0);
this.treesSet = null;
if (tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
}
protected void wrapSetupIntervals() {
setupIntervals();
intervalsKnown = true;
}
//This is actually the Augmented loglikelihood for fixed genealogy. For sequence data directly
// this becomes the coalescent point process prior on an augmented "tree"
public double calculateLogLikelihood(Parameter Gfunction, Parameter latentCounts, Parameter eventType, Parameter upper_Bound, double [] Gfactor) {
double upperBound = upper_Bound.getParameterValue(0);
logGPLikelihood=-upperBound*getConstlik();
for (int i=0; i<latentCounts.getSize(); i++){
if (Gfactor[i]>0) {
if (latentCounts.getParameterValue(i)<0){System.err.println("WARNING");}
logGPLikelihood+=latentCounts.getParameterValue(i)*Math.log(upperBound*Gfactor[i]);
}
}
double[] currentGfunction = Gfunction.getParameterValues();
for (int i=0; i<Gfunction.getSize();i++){
logGPLikelihood+= -Math.log(1+Math.exp(-eventType.getParameterValue(i)*currentGfunction[i]));
}
return logGPLikelihood;
}
public double getConstlik(){
return constlik;
}
//For fixed genealogy this contains the Augmented likelihood, the GP prior and prior on a the upper bound
public double getLogLikelihood() {
if (!likelihoodKnown) {
if(flagForJulia) {
System.err.println("recalculating intervals and counts");
wrapSetupIntervals();
recomputeValues();
flagForJulia=false;
}
logLikelihood =
calculateLogLikelihood(popSizeParameter,GPcounts,GPtype,lambda_boundParameter,GPcoalfactor)+calculateLogGP()
+getLogPriorLambda(lambdaParameter.getParameterValue(0),0.01,lambda_boundParameter.getParameterValue(0));
likelihoodKnown = true;
}
return logLikelihood;
}
protected SymmTridiagMatrix getQmatrix(double precision, double[] x ) {
SymmTridiagMatrix res;
double trick=0.00000000001;
double[] offdiag = new double[x.length - 1];
double[] diag = new double[x.length];
for (int i = 0; i < x.length - 1; i++) {
offdiag[i] = precision*(-1.0 / (x[i+1]-x[i]));
if (i< x.length-2){
diag[i+1]= -offdiag[i]+precision*(1.0/(x[i+2]-x[i+1])+trick);
}
}
// Diffuse prior correction - intrinsic
//Take care of the endpoints
diag[0] = -offdiag[0]+precision*trick;
diag[x.length - 1] = -offdiag[x.length - 2]+precision*(trick);
res = new SymmTridiagMatrix(diag, offdiag);
return res;
}
//Calculates prior on g function
protected double calculateLogGP() {
SymmTridiagMatrix currentQ = getQmatrix(precisionParameter.getParameterValue(0), changePoints.getParameterValues());
double currentLike;
DenseVector diagonal1 = new DenseVector(popSizeParameter.getSize());
DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues());
currentQ.mult(currentGamma, diagonal1);
currentLike = -0.5 * logGeneralizedDeterminant(currentQ) - 0.5 * currentGamma.dot(diagonal1) - 0.5 * (popSizeParameter.getSize() - 1) * LOG_TWO_TIMES_PI;
// System.err.println("the GP prior on f"+currentLike+" determinant"+logGeneralizedDeterminant(currentQ)+" size"+changePoints.getSize());
return currentLike;
}
// Calculates logprior on Upper Bound
private double getLogPriorLambda(double lambdaMean, double epsilon, double lambdaValue){
double res;
if (lambdaValue < lambdaMean) {res=epsilon*(1/lambdaMean);}
else {res=Math.log(1-epsilon)*(1/lambdaMean)*Math.exp(-(1/lambdaMean)*(lambdaValue-lambdaMean)); }
return res;
}
//log pseudo-determinant
public static double logGeneralizedDeterminant(SymmTridiagMatrix X) {
//Set up the eigenvalue solver
SymmTridiagEVD eigen = new SymmTridiagEVD(X.numRows(), false);
//Solve for the eigenvalues
try {
eigen.factor(X);
} catch (NotConvergedException e) {
throw new RuntimeException("Not converged error in generalized determinate calculation.\n" + e.getMessage());
}
//Get the eigenvalues
double[] x = eigen.getEigenvalues();
double a = 0;
for (double d : x) {
if (d > 0.00001)
a += Math.log(d);
}
return a;
}
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type){
likelihoodKnown = false;
}
protected void restoreState() {
super.restoreState();
System.arraycopy(storedGPcoalfactor, 0, GPcoalfactor, 0, storedGPcoalfactor.length);
System.arraycopy(storedCoalTime,0,CoalTime,0,storedCoalTime.length);
System.arraycopy(storedGPCoalInterval,0,GPCoalInterval,0,storedGPCoalInterval.length);
System.arraycopy(storedCoalPosIndicator,0,CoalPosIndicator,0,storedCoalPosIndicator.length);
constlik=storedconstlik;
}
protected void storeState() {
super.storeState();
System.arraycopy(GPcoalfactor, 0, storedGPcoalfactor, 0, GPcoalfactor.length);
System.arraycopy(CoalTime,0,storedCoalTime,0,CoalTime.length);
System.arraycopy(GPCoalInterval, 0, storedGPCoalInterval, 0, GPCoalInterval.length);
System.arraycopy(CoalPosIndicator, 0, storedCoalPosIndicator,0,CoalPosIndicator.length);
storedconstlik=constlik;
}
public String toString() {
return getId() + "(" + Double.toString(getLogLikelihood()) + ")";
}
public void initializationReport() {
System.out.println("Creating a GP based estimation of effective population trajectories:");
System.out.println("\tIf you publish results using this model, please reference: Minin, Palacios, Suchard (XXXX), AAA");
}
public static void checkTree(TreeModel treeModel) {
// todo Should only be run if there exists a zero-length interval - I don't actually understand why this is here
// TreeModel treeModel = (TreeModel) tree;
for (int i = 0; i < treeModel.getInternalNodeCount(); i++) {
NodeRef node = treeModel.getInternalNode(i);
if (node != treeModel.getRoot()) {
double parentHeight = treeModel.getNodeHeight(treeModel.getParent(node));
double childHeight0 = treeModel.getNodeHeight(treeModel.getChild(node, 0));
double childHeight1 = treeModel.getNodeHeight(treeModel.getChild(node, 1));
double maxChild = childHeight0;
if (childHeight1 > maxChild)
maxChild = childHeight1;
double newHeight = maxChild + MathUtils.nextDouble() * (parentHeight - maxChild);
treeModel.setNodeHeight(node, newHeight);
}
}
treeModel.pushTreeChangedEvent();
}
protected void recomputeValues() {
double length = 0.0;
double prevLength=0.0;
double count=0.0;
int countcoal = 0;
int pointer=0;
int pointer2=0;
constlik= 0;
for (int i = 0; i < getIntervalCount(); i++) {
length += getInterval(i);
count=0.0;
for (int j=pointer;j<changePoints.getSize();j++){
if (changePoints.getParameterValue(j)<=length){
pointer++;
count++;
}
}
GPcounts.setParameterValue(i,count);
GPcoalfactor[i] =getLineageCount(i)*(getLineageCount(i)-1.0) / 2.0;
constlik+=GPcoalfactor[i]*getInterval(i);
if (getIntervalType(i) == CoalescentEventType.COALESCENT) {
CoalPosIndicator[countcoal]=i;
count=0;
for (int s=pointer2;s<changePoints.getSize();s++){
if (changePoints.getParameterValue(s)<=length){
pointer2++;
count++;
} else
s=changePoints.getSize();
}
CoalCounts.setParameterValue(countcoal,count-1);
CoalTime[countcoal]=length;
GPCoalInterval[countcoal]=length-prevLength;
coalfactor.setParameterValue(countcoal,getLineageCount(i)*(getLineageCount(i)-1)/2.0);
countcoal++;
prevLength=length;
}
}
// TODO:This code should not be run in production, it is just for testing
int sumcoal=0;
int sumlat=0;
for (int j=0;j<changePoints.getSize();j++){
if( GPtype.getParameterValue(j)==1) {sumcoal++;}
}
for (int j=0;j<CoalCounts.getSize();j++){
sumlat+=CoalCounts.getParameterValue(j);
}
if (sumcoal!=CoalCounts.getSize()){ System.err.println("WARNING CONSISTENCY 1"); }
if (sumlat!=(changePoints.getSize()-CoalCounts.getSize())){ System.err.println("WARNING CONSISTENCY 2:"+sumlat+"and changePts size"+changePoints.getSize());}
Tmrca.setParameterValue(0,CoalTime[countcoal-1]);
}
protected void setupSufficientStatistics() {
double length = 0.0;
double prevLength=0.0;
int countcoal = 0;
constlik= 0;
for (int i = 0; i < getIntervalCount(); i++) {
length += getInterval(i);
GPcounts.setParameterValue(i,0.0);
GPcoalfactor[i] =getLineageCount(i)*(getLineageCount(i)-1.0) / 2.0;
constlik+=GPcoalfactor[i]*getInterval(i);
if (getIntervalType(i) == CoalescentEventType.COALESCENT) {
GPcounts.setParameterValue(i,1.0);
GPtype.setParameterValue(countcoal,1.0);
CoalPosIndicator[countcoal]=i;
changePoints.setParameterValue(countcoal,length);
CoalCounts.setParameterValue(countcoal,0.0);
CoalTime[countcoal]=length;
GPCoalInterval[countcoal]=length-prevLength;
coalfactor.setParameterValue(countcoal,getLineageCount(i)*(getLineageCount(i)-1)/2.0);
countcoal++;
prevLength=length;
}
}
Tmrca.setParameterValue(0,CoalTime[countcoal-1]);
}
protected int getCorrectFieldLength() {
return tree.getExternalNodeCount() - 1;
}
protected void setupQmatrix(double precision) {
double trick=0.000001;
double[] offdiag = new double[changePoints.getSize() - 1];
double[] diag = new double[changePoints.getSize()];
for (int i = 0; i < changePoints.getSize() - 1; i++) {
offdiag[i] = precision*(-1.0 / (changePoints.getParameterValue(i+1)-changePoints.getParameterValue(i)));
if (i<getCorrectFieldLength()-2){
diag[i+1]= -offdiag[i]+precision*(1.0/(changePoints.getParameterValue(i+2)-changePoints.getParameterValue(i+1))+trick);
}
}
// Diffuse prior correction - intrinsic
//Take care of the endpoints
diag[0] = -offdiag[0]+precision*trick;
diag[getCorrectFieldLength() - 1] = -offdiag[getCorrectFieldLength() - 2]+precision*(trick);
weightMatrix = new SymmTridiagMatrix(diag, offdiag);
}
protected void setupGPvalues() {
setupQmatrix(precisionParameter.getParameterValue(0));
int length = getCorrectFieldLength();
DenseVector StandNorm = new DenseVector(length);
DenseVector MultiNorm = new DenseVector(length);
for (int i=0; i<length;i++){
StandNorm.set(i,MathUtils.nextGaussian());
}
UpperSPDBandMatrix Qcurrent = new UpperSPDBandMatrix(weightMatrix, 1);
BandCholesky U = new BandCholesky(length,1,true);
U.factor(Qcurrent);
UpperTriangBandMatrix CholeskyUpper = U.getU();
CholeskyUpper.solve(StandNorm,MultiNorm);
for (int i=0; i<length;i++){
popSizeParameter.setParameterValue(i,1.0);
}
}
public Parameter getPrecisionParameter() {
return precisionParameter;
}
public Parameter getPopSizeParameter() {
return popSizeParameter;
}
public Parameter getNumPoints() {
return numPoints;
}
public Parameter getLambdaParameter() {
return lambdaParameter;
}
public Parameter getLambdaBoundParameter() {
return lambda_boundParameter;
}
public Parameter getChangePoints() {
return changePoints;
}
public double getAlphaParameter(){
return alphaParameter.getParameterValue(0);
}
public double getBetaParameter(){
return betaParameter.getParameterValue(0);
}
public double [] getGPcoalfactor(){
return GPcoalfactor;
}
public Parameter getcoalfactor(){
return coalfactor;
}
public Parameter getCoalCounts(){
return CoalCounts;
}
public Parameter getGPtype(){
return GPtype;
}
public Parameter getGPcounts(){
return GPcounts;
}
public SymmTridiagMatrix getWeightMatrix() {
return weightMatrix.copy();
}
// Methods needed for GP-based
public double [] getGPCoalInterval(){
return GPCoalInterval;
}
public double [] getCoalTime(){
return CoalTime;
}
public double getGPCoalInterval(int j){
return GPCoalInterval[j];
}
public int [] getCoalPosIndicator() {
return CoalPosIndicator;
}
}