/*
* GaussianProcessMultilocusSkytrackLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.coalescent;
//import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
//import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.evomodelxml.coalescent.GaussianProcessSkytrackLikelihoodParser;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
//import no.uib.cipr.matrix.DenseVector;
//import no.uib.cipr.matrix.SymmTridiagMatrix;
import java.util.ArrayList;
//import java.util.Arrays;
import java.util.List;
/**
* @author Julia Palacios
* @author Marc A. Suchard
* @author Vladimir Minin
* @author Michael Karcher
*/
public class GaussianProcessMultilocusSkytrackLikelihood extends GaussianProcessSkytrackLikelihood implements MultiLociTreeSet {
// private double cutOff;
// private int numGridPoints;
// protected int oldFieldLength;
// number of coalescent events which occur in an interval with constant population size
// protected double[] numCoalEvents;
// protected double[] storedNumCoalEvents;
// protected double[] gridPoints;
// protected double theLastTime;
// protected double diagonalValue;
// sortedPoints[i][0] is the time of the i-th grid point or sampling or coalescent event
// sortedPoints[i][1] is 0 if the i-th point is a grid point, 1 if it's a sampling point, and 2 if it's a coalescent point
// sortedPoints[i][2] is the number of lineages present in the interval starting at time sortedPoints[i][0]
// protected Parameter phiParameter;
// protected SymmTridiagMatrix precMatrix;
// protected SymmTridiagMatrix storedPrecMatrix;
public GaussianProcessMultilocusSkytrackLikelihood(List<Tree> treeList,
Parameter precParameter,
boolean rescalebyRootHeight,
Parameter lambda_bound,
Parameter lambda_parameter,
Parameter popParameter,
Parameter alpha_parameter,
Parameter beta_parameter,
Parameter change_points,
Parameter GPtype,
Parameter GPcounts,
Parameter coalfactor,
Parameter CoalCounts,
Parameter numPoints,
Parameter Tmrca) {
super(GaussianProcessSkytrackLikelihoodParser.SKYTRACK_LIKELIHOOD);
this.popSizeParameter = popParameter;
this.Tmrca = Tmrca;
this.changePoints=change_points;
this.numPoints=numPoints;
this.precisionParameter = precParameter;
this.lambdaParameter = lambda_parameter;
this.betaParameter = beta_parameter;
this.alphaParameter=alpha_parameter;
this.rescaleByRootHeight=rescalebyRootHeight;
this.lambda_boundParameter=lambda_bound;
this.GPcounts=GPcounts;
this.GPtype=GPtype;
this.coalfactor=coalfactor;
this.CoalCounts=CoalCounts;
// System.out.println("numGridPoints: " + numGridPoints);
// setupGridPoints();
addVariable(popSizeParameter);
addVariable(precisionParameter);
addVariable(changePoints);
addVariable(numPoints);
addVariable(GPcounts);
addVariable(GPtype);
addVariable(coalfactor);
addVariable(lambda_boundParameter);
addVariable(CoalCounts);
setTree(treeList);
numintervals=getIntervalCount(); //Not sure, check. I want the number of intervals defined by either
// sampling times or coalescent times of all trees
System.err.println("by getIntervalCount, I get:"+getIntervalCount());
numcoalpoints = getCorrectFieldLength();
System.err.println("by getCorrectField"+getCorrectFieldLength());
GPcoalfactor = new double[numintervals];
backupIntervals=new double[numintervals];
GPCoalInterval=new double[numcoalpoints];
storedGPCoalInterval=new double[numcoalpoints];
CoalPosIndicator= new int[numcoalpoints];
storedCoalPosIndicator=new int[numcoalpoints];
CoalTime=new double[numcoalpoints];
storedCoalTime=new double[numcoalpoints];
storedGPcoalfactor = new double[numintervals];
GPcounts.setDimension(numintervals);
CoalCounts.setDimension(numcoalpoints);
GPtype.setDimension(numcoalpoints);
numPoints.setParameterValue(0,numcoalpoints);
popSizeParameter.setDimension(numcoalpoints);
changePoints.setDimension(numcoalpoints);
coalfactor.setDimension(numcoalpoints);
initializationReport();
// wrapSetupIntervals();
// coalescentIntervals = new double[oldFieldLength];
// storedCoalescentIntervals = new double[oldFieldLength];
// sufficientStatistics = new double[fieldLength];
// storedSufficientStatistics = new double[fieldLength];
// numCoalEvents = new double[fieldLength];
// storedNumCoalEvents = new double[fieldLength];
// setupGMRFWeights();
setupSufficientStatistics();
// addStatistic(new DeltaStatistic());
/* Force all entries in groupSizeParameter = 1 for compatibility with Tracer */
}
// OK
protected void setTree(List<Tree> treeList) {
treesSet = this;
this.treeList = treeList;
makeTreeIntervalList(treeList, true);
numTrees = treeList.size();
}
// OK
private void makeTreeIntervalList(List<Tree> treeList, boolean add) {
if (intervalsList == null) {
intervalsList = new ArrayList<TreeIntervals>();
} else {
intervalsList.clear();
}
for (Tree tree : treeList) {
numIntervals+=tree.getExternalNodeCount();
intervalsList.add(new TreeIntervals(tree));
if (add && tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
}
}
protected int getCorrectFieldLength() {
return numIntervals-treeList.size();
// TODO add correction when not all samples are gathered at the same time
}
//
// protected int getCorrectOldFieldLength() {
// int tips = 0;
// for (Tree tree : treeList) {
// tips += tree.getExternalNodeCount();
// }
// return tips - treeList.size();
// }
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model instanceof TreeModel) {
TreeModel treeModel = (TreeModel) model;
int tn = treeList.indexOf(treeModel);
if (tn >= 0) {
// intervalsList.get(tn).setIntervalsUnknown(); // TODO Why is this slower (?) than remaking whole list?
makeTreeIntervalList(treeList, false);
intervalsKnown = false;
likelihoodKnown = false;
} else {
throw new RuntimeException("Unknown tree modified in GMRFMultilocusSkyrideLikelihood");
}
} else {
throw new RuntimeException("Unknown object modified in GMRFMultilocusSkyrideLikelihood");
}
}
public void initializationReport() {
System.out.println("Creating a GP based estimation of effective population size trajectories for multiple loci:");
System.out.println("\tIf you publish results using this model, please reference: ");
System.out.println("\t\tPalacios, Minin and Suchard (XXXX)");
}
//
// public void wrapSetupIntervals() {
// // Do nothing
// }
int numTrees;
int numIntervals;
//
// protected void setupGridPoints() {
// if (gridPoints == null) {
// gridPoints = new double[numGridPoints];
// } else {
// Arrays.fill(gridPoints, 0);
// }
//
// for (int pt = 0; pt < numGridPoints; pt++) {
// gridPoints[pt] = (pt + 1) * (cutOff / numGridPoints);
// }
// }
protected void setupSufficientStatistics() {
// //numCoalEvents = new double[fieldLength];
// //sufficientStatistics = new double[fieldLength];
//
// Arrays.fill(numCoalEvents, 0);
// Arrays.fill(sufficientStatistics, 0);
//
// //index of smallest grid point greater than at least one sampling/coalescent time in current tree
// int minGridIndex;
// //index of greatest grid point less than at least one sampling/coalescent time in current tree
// int maxGridIndex;
//
// int numLineages;
//
// int currentGridIndex;
// int currentTimeIndex;
//
// double currentTime;
// double nextTime;
//
// //time of last coalescent event in tree
// double lastCoalescentTime;
//
// for (int i = 0; i < numTrees; i++) {
// currentTimeIndex = 0;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
// minGridIndex = 0;
// while (gridPoints[minGridIndex] <= currentTime) {
// minGridIndex++;
// }
// currentGridIndex = minGridIndex;
//
// lastCoalescentTime = currentTime + intervalsList.get(i).getTotalDuration();
//
// theLastTime = lastCoalescentTime;
//
// maxGridIndex = numGridPoints - 1;
// while ((maxGridIndex >= 0) && (gridPoints[maxGridIndex] >= lastCoalescentTime)) {
// maxGridIndex = maxGridIndex - 1;
// }
//
// if (maxGridIndex >= 0) {
//
// //from likelihood of interval between first sampling time and gridPoints[minGridIndex]
//
// while (nextTime < gridPoints[currentGridIndex]) {
//
// //check to see if interval ends with coalescent event
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
//
// numCoalEvents[currentGridIndex]++;
// }
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5;
// currentTime = nextTime;
// currentTimeIndex++;
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
//
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//
// }
//
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - currentTime) * numLineages * (numLineages - 1) * 0.5;
//
// currentGridIndex++;
//
//
// //from likelihood of intervals between gridPoints[minGridIndex] and gridPoints[maxGridIndex]
//
// while (currentGridIndex <= maxGridIndex) {
// if (nextTime >= gridPoints[currentGridIndex]) {
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5;
//
// currentGridIndex++;
// } else {
//
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5;
//
// //check to see if interval ends with coalescent event
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
// numCoalEvents[currentGridIndex]++;
// }
// currentTime = nextTime;
// currentTimeIndex++;
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//
// while (nextTime < gridPoints[currentGridIndex]) {
// //check to see if interval is coalescent interval or sampling interval
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
// numCoalEvents[currentGridIndex]++;
// }
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5;
//
// currentTime = nextTime;
// currentTimeIndex++;
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//
// }
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - currentTime) * numLineages * (numLineages - 1) * 0.5;
//
// currentGridIndex++;
// }
// }
//
// //from likelihood of interval between gridPoints[maxGridIndex] and lastCoalescentTime
//
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5;
//
// //check to see if interval ends with coalescent event
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
// numCoalEvents[currentGridIndex]++;
// }
//
// currentTime = nextTime;
// currentTimeIndex++;
//
// while ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
// // currentTime = nextTime;
// // currentTimeIndex++;
//
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//
// //check to see if interval is coalescent interval or sampling interval
//
//
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
// numCoalEvents[currentGridIndex]++;
// }
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5;
// currentTime = nextTime;
// currentTimeIndex++;
//
// }
// } else {
// while ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
// //check to see if interval is coalescent interval or sampling interval
// if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
// numCoalEvents[currentGridIndex]++;
// }
// sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5;
//
// currentTime = nextTime;
// currentTimeIndex++;
// if ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
//
// while (nextTime <= currentTime) {
// currentTimeIndex++;
// currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
// nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
// }
//
// numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//
// }
//
// }
//
//
// }
// }
}
//
// public double[] getNumCoalEvents() {
// return numCoalEvents;
// }
protected double calculateLogCoalescentLikelihood() {
// if (!intervalsKnown) {
// // intervalsKnown -> false when handleModelChanged event occurs in super.
// wrapSetupIntervals();
// setupSufficientStatistics();
// intervalsKnown = true;
// }
// Matrix operations taken from block update sampler to calculate data likelihood and field prior
double currentLike = 0;
// double[] currentGamma = popSizeParameter.getParameterValues();
//
// for (int i = 0; i < fieldLength; i++) {
// currentLike += -numCoalEvents[i] * currentGamma[i] - sufficientStatistics[i] * Math.exp(-currentGamma[i]);
// }
return currentLike;
}
protected double calculateLogFieldLikelihood() {
// if (!intervalsKnown) {
// //intervalsKnown -> false when handleModelChanged event occurs in super.
// wrapSetupIntervals();
// setupSufficientStatistics();
// intervalsKnown = true;
// }
double currentLike = 0;
// DenseVector diagonal1 = new DenseVector(fieldLength);
// DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues());
//
// SymmTridiagMatrix currentQ = getScaledWeightMatrix(precisionParameter.getParameterValue(0), lambdaParameter.getParameterValue(0));
// currentQ.mult(currentGamma, diagonal1);
//
// // currentLike += 0.5 * logGeneralizedDeterminant(currentQ) - 0.5 * currentGamma.dot(diagonal1);
//
// currentLike += 0.5 * (fieldLength - 1) * Math.log(precisionParameter.getParameterValue(0)) - 0.5 * currentGamma.dot(diagonal1);
// if (lambdaParameter.getParameterValue(0) == 1) {
// currentLike -= (fieldLength - 1) / 2.0 * LOG_TWO_TIMES_PI;
// } else {
// currentLike -= fieldLength / 2.0 * LOG_TWO_TIMES_PI;
// }
return currentLike;
}
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogCoalescentLikelihood();
// logFieldLikelihood = calculateLogFieldLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
// return logLikelihood + logFieldLikelihood;
}
// protected void setupGMRFWeights() {
//
// //setupSufficientStatistics();
//
// //Set up the weight Matrix
// double[] offdiag = new double[fieldLength - 1];
// double[] diag = new double[fieldLength];
//
// diagonalValue = 2;
// //First set up the offdiagonal entries;
//
// for (int i = 0; i < fieldLength - 1; i++) {
// offdiag[i] = -1;
// }
//
// //Then set up the diagonal entries;
// for (int i = 1; i < fieldLength - 1; i++) {
// // diag[i] = -(offdiag[i] + offdiag[i - 1]);
// diag[i] = diagonalValue;
// }
// //Take care of the endpoints
// //diag[0] = -offdiag[0];
// //diag[fieldLength - 1] = -offdiag[fieldLength - 2];
// diag[0] = diagonalValue - 1.0;
// diag[fieldLength - 1] = diagonalValue - 1.0;
//
//
// weightMatrix = new SymmTridiagMatrix(diag, offdiag);
//
// }
protected double getFieldScalar() {
return 1.0;
}
private List<Tree> treeList;
private List<TreeIntervals> intervalsList;
public int nLoci() {
return treeList.size();
}
public Tree getTree(int nt) {
return treeList.get(nt);
}
public TreeIntervals getTreeIntervals(int nt) {
return intervalsList.get(nt);
}
public double getPopulationFactor(int nt) {
return 1.0;
}
public void storeTheState() {
for (TreeIntervals intervals : intervalsList) {
intervals.storeState();
}
}
public void restoreTheState() {
for (TreeIntervals intervals : intervalsList) {
intervals.restoreState();
}
}
// protected void storeState() {
// // System.arraycopy(numCoalEvents, 0, storedNumCoalEvents, 0, numCoalEvents.length);
// super.storeState();
// System.arraycopy(numCoalEvents, 0, storedNumCoalEvents, 0, numCoalEvents.length);
// // storedPrecMatrix = precMatrix.copy();
// }
// protected void restoreState() {
// super.restoreState();
//
// // Swap pointers
// double[] tmp = numCoalEvents;
// numCoalEvents = storedNumCoalEvents;
// storedNumCoalEvents = tmp;
// }
// public int getCoalescentIntervalLineageCount(int i) {
// return 0; //To change body of implemented methods use File | Settings | File Templates.
// }
//
// public IntervalType getCoalescentIntervalType(int i) {
// return null; //To change body of implemented methods use File | Settings | File Templates.
// }
}