/*
* GMRFMultilocusSkyrideLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.coalescent;
//import com.sun.xml.internal.messaging.saaj.packaging.mime.internet.ParameterList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.SymmTridiagMatrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
* @author Mandev Gill
* @author Marc A. Suchard
*/
public class GMRFMultilocusSkyrideLikelihood extends GMRFSkyrideLikelihood
implements MultiLociTreeSet, CoalescentIntervalProvider, Citable {
public static final boolean DEBUG = false;
private double cutOff;
private int numGridPoints;
protected int oldFieldLength;
// number of coalescent events which occur in an interval with constant population size
protected double[] numCoalEvents;
protected double[] storedNumCoalEvents;
protected double[] gridPoints;
protected double theLastTime;
protected double diagonalValue;
// sortedPoints[i][0] is the time of the i-th grid point or sampling or coalescent event
// sortedPoints[i][1] is 0 if the i-th point is a grid point, 1 if it's a sampling point, and 2 if it's a coalescent point
// sortedPoints[i][2] is the number of lineages present in the interval starting at time sortedPoints[i][0]
protected Parameter phiParameter;
protected Parameter ploidyFactors;
protected double[] ploidySums;
protected double[] storedPloidySums;
protected SymmTridiagMatrix precMatrix;
protected SymmTridiagMatrix storedPrecMatrix;
private SkygridHelper skygridHelper;
protected List<Parameter> missingCov;
protected List<MatrixParameter> covariates;
protected List<Parameter> beta;
protected List<Parameter> covPrecParameters;
protected List<SymmTridiagMatrix> weightMatricesForMissingCov;
protected int[] lastObservedIndex;
private double[] coalescentEventStatisticValues;
public GMRFMultilocusSkyrideLikelihood(List<Tree> treeList,
Parameter popParameter,
Parameter groupParameter,
Parameter precParameter,
Parameter lambda,
Parameter beta,
MatrixParameter dMatrix,
boolean timeAwareSmoothing,
double cutOff,
int numGridPoints,
Parameter phi,
Parameter ploidyFactorsParameter) {
super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD);
this.popSizeParameter = popParameter;
this.groupSizeParameter = groupParameter;
this.precisionParameter = precParameter;
this.lambdaParameter = lambda;
this.betaParameter = beta;
this.dMatrix = dMatrix;
if (dMatrix != null) {
addVariable(dMatrix);
}
this.timeAwareSmoothing = timeAwareSmoothing;
this.cutOff = cutOff;
this.numGridPoints = numGridPoints;
this.phiParameter = phi;
this.ploidyFactors = ploidyFactorsParameter;
setupGridPoints();
addVariable(popSizeParameter);
addVariable(precisionParameter);
addVariable(lambdaParameter);
if (betaParameter != null) {
addVariable(betaParameter);
skygridHelper = new SkygridCovariateHelper();
} else {
skygridHelper = new SkygridHelper();
}
if (phiParameter != null) {
addVariable(phiParameter);
}
addVariable(ploidyFactors);
setTree(treeList);
int correctFieldLength = getCorrectFieldLength();
if (popSizeParameter.getDimension() <= 1) {
// popSize dimension hasn't been set yet, set it here:
popSizeParameter.setDimension(correctFieldLength);
}
fieldLength = popSizeParameter.getDimension();
if (correctFieldLength != fieldLength) {
throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength);
}
oldFieldLength = getCorrectOldFieldLength();
if (ploidyFactors.getDimension() != treeList.size()) {
throw new IllegalArgumentException("Ploidy factors parameter should have length " + treeList.size());
}
// Field length must be set by this point
wrapSetupIntervals();
coalescentIntervals = new double[oldFieldLength];
storedCoalescentIntervals = new double[oldFieldLength];
sufficientStatistics = new double[fieldLength];
storedSufficientStatistics = new double[fieldLength];
numCoalEvents = new double[fieldLength];
storedNumCoalEvents = new double[fieldLength];
ploidySums = new double[fieldLength];
storedPloidySums = new double[fieldLength];
setupGMRFWeights();
setupSufficientStatistics();
addStatistic(new DeltaStatistic());
initializationReport();
/* Force all entries in groupSizeParameter = 1 for compatibility with Tracer */
if (groupSizeParameter != null) {
for (int i = 0; i < groupSizeParameter.getDimension(); i++)
groupSizeParameter.setParameterValue(i, 1.0);
}
this.coalescentEventStatisticValues = new double[getNumberOfCoalescentEvents()];
}
//rewrite this constructor without duplicating so much code
public GMRFMultilocusSkyrideLikelihood(List<Tree> treeList,
Parameter popParameter,
Parameter groupParameter,
Parameter precParameter,
Parameter lambda,
Parameter betaParameter,
MatrixParameter dMatrix,
boolean timeAwareSmoothing,
Parameter specGridPoints,
List<MatrixParameter> covariates,
Parameter ploidyFactorsParameter,
List<Parameter> lastObservedIndexParameter,
List<Parameter> covPrecParameters,
List<Parameter> betaList) {
super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD);
this.gridPoints = specGridPoints.getParameterValues();
this.numGridPoints = gridPoints.length;
this.cutOff = gridPoints[numGridPoints - 1];
if (lastObservedIndexParameter != null) {
lastObservedIndex = new int[lastObservedIndexParameter.size()];
for (int i = 0; i < lastObservedIndexParameter.size(); i++) {
this.lastObservedIndex[i] = (int) lastObservedIndexParameter.get(i).getParameterValue(0);
}
}
/*else{
for(int i=0; i < beta.getDimension(); i++) {
this.lastObservedIndex[i] = popParameter.getDimension();
}
}*/
this.betaParameter = betaParameter;
if (betaParameter != null) {
addVariable(betaParameter);
}
this.popSizeParameter = popParameter;
this.groupSizeParameter = groupParameter;
this.precisionParameter = precParameter;
this.lambdaParameter = lambda;
this.beta = betaList;
this.dMatrix = dMatrix;
if (dMatrix != null) {
addVariable(dMatrix);
}
this.timeAwareSmoothing = timeAwareSmoothing;
this.ploidyFactors = ploidyFactorsParameter;
this.covariates = covariates;
if (covariates != null) {
for (MatrixParameter cov : covariates) {
addVariable(cov);
}
}
this.covPrecParameters = covPrecParameters;
if (covPrecParameters != null) {
for (Parameter covPrec : covPrecParameters) {
addVariable(covPrec);
}
}
addVariable(popSizeParameter);
addVariable(precisionParameter);
addVariable(lambdaParameter);
addVariable(ploidyFactors);
setTree(treeList);
int correctFieldLength = getCorrectFieldLength();
if (popSizeParameter.getDimension() <= 1) {
// popSize dimension hasn't been set yet, set it here:
popSizeParameter.setDimension(correctFieldLength);
}
fieldLength = popSizeParameter.getDimension();
if (correctFieldLength != fieldLength) {
throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength);
}
oldFieldLength = getCorrectOldFieldLength();
if (ploidyFactors.getDimension() != treeList.size()) {
throw new IllegalArgumentException("Ploidy factor parameter should have length " + treeList.size());
}
// Field length must be set by this point
if (betaList != null || betaParameter != null) {
if (betaList != null) {
for (Parameter betaParam : betaList) {
addVariable(betaParam);
}
}
if (lastObservedIndexParameter != null) {
setupGMRFWeightsForMissingCov();
skygridHelper = new SkygridMissingCovariateHelper();
} else {
skygridHelper = new SkygridCovariateHelper();
}
} else {
skygridHelper = new SkygridHelper();
}
wrapSetupIntervals();
coalescentIntervals = new double[oldFieldLength];
storedCoalescentIntervals = new double[oldFieldLength];
sufficientStatistics = new double[fieldLength];
storedSufficientStatistics = new double[fieldLength];
numCoalEvents = new double[fieldLength];
storedNumCoalEvents = new double[fieldLength];
ploidySums = new double[fieldLength];
storedPloidySums = new double[fieldLength];
setupGMRFWeights();
addStatistic(new DeltaStatistic());
initializationReport();
}
protected void setTree(List<Tree> treeList) {
treesSet = this;
this.treeList = treeList;
makeTreeIntervalList(treeList, true);
numTrees = treeList.size();
}
private void makeTreeIntervalList(List<Tree> treeList, boolean add) {
if (intervalsList == null) {
intervalsList = new ArrayList<TreeIntervals>();
} else {
intervalsList.clear();
}
for (Tree tree : treeList) {
intervalsList.add(new TreeIntervals(tree));
if (add && tree instanceof TreeModel) {
addModel((TreeModel) tree);
}
}
}
protected int getCorrectFieldLength() {
return numGridPoints + 1;
}
protected int getCorrectOldFieldLength() {
int tips = 0;
for (Tree tree : treeList) {
tips += tree.getExternalNodeCount();
}
return tips - treeList.size();
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model instanceof TreeModel) {
TreeModel treeModel = (TreeModel) model;
int tn = treeList.indexOf(treeModel);
if (tn >= 0) {
// intervalsList.get(tn).setIntervalsUnknown(); // TODO Why is this slower (?) than remaking whole list?
makeTreeIntervalList(treeList, false);
intervalsKnown = false;
likelihoodKnown = false;
} else {
throw new RuntimeException("Unknown tree modified in GMRFMultilocusSkyrideLikelihood");
}
} else {
throw new RuntimeException("Unknown object modified in GMRFMultilocusSkyrideLikelihood");
}
}
public void initializationReport() {
System.out.println("Creating a GMRF smoothed skyride model for multiple loci (SkyGrid)");
System.out.println("\tPopulation sizes: " + popSizeParameter.getDimension());
}
public void wrapSetupIntervals() {
// Do nothing
}
int numTrees;
protected void setupGridPoints() {
if (gridPoints == null) {
gridPoints = new double[numGridPoints];
} else {
Arrays.fill(gridPoints, 0);
}
for (int pt = 0; pt < numGridPoints; pt++) {
gridPoints[pt] = (pt + 1) * (cutOff / numGridPoints);
}
}
protected void setupSufficientStatistics() {
//numCoalEvents = new double[fieldLength];
//sufficientStatistics = new double[fieldLength];
Arrays.fill(numCoalEvents, 0);
Arrays.fill(sufficientStatistics, 0);
Arrays.fill(ploidySums, 0);
//index of smallest grid point greater than at least one sampling/coalescent time in current tree
int minGridIndex;
//index of greatest grid point less than at least one sampling/coalescent time in current tree
int maxGridIndex;
int numLineages;
int currentGridIndex;
int currentTimeIndex;
double currentTime;
double nextTime;
double ploidyFactor;
//time of last coalescent event in tree
double lastCoalescentTime;
for (int i = 0; i < numTrees; i++) {
ploidyFactor = 1 / getPopulationFactor(i);
currentTimeIndex = 0;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
minGridIndex = 0;
while (minGridIndex < numGridPoints - 1 && gridPoints[minGridIndex] <= currentTime) { // MAS: Unclear about need for -1
minGridIndex++;
}
currentGridIndex = minGridIndex;
lastCoalescentTime = currentTime + intervalsList.get(i).getTotalDuration();
theLastTime = lastCoalescentTime;
maxGridIndex = numGridPoints - 1;
while ((maxGridIndex >= 0) && (gridPoints[maxGridIndex] >= lastCoalescentTime)) {
maxGridIndex = maxGridIndex - 1;
}
if (maxGridIndex >= 0 && minGridIndex < numGridPoints) {
//from likelihood of interval between first sampling time and gridPoints[minGridIndex]
while (nextTime < gridPoints[currentGridIndex]) {
//check to see if interval ends with coalescent event
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
currentTime = nextTime;
currentTimeIndex++;
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
ploidySums[currentGridIndex] = ploidySums[currentGridIndex] + Math.log(ploidyFactor) * numCoalEvents[currentGridIndex];
currentGridIndex++;
//from likelihood of intervals between gridPoints[minGridIndex] and gridPoints[maxGridIndex]
while (currentGridIndex <= maxGridIndex) {
if (nextTime >= gridPoints[currentGridIndex]) {
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
ploidySums[currentGridIndex] = ploidySums[currentGridIndex] + Math.log(ploidyFactor) * numCoalEvents[currentGridIndex];
currentGridIndex++;
} else {
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
//check to see if interval ends with coalescent event
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
currentTime = nextTime;
currentTimeIndex++;
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
while (nextTime < gridPoints[currentGridIndex]) {
//check to see if interval is coalescent interval or sampling interval
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
currentTime = nextTime;
currentTimeIndex++;
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (gridPoints[currentGridIndex] - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
ploidySums[currentGridIndex] = ploidySums[currentGridIndex] + Math.log(ploidyFactor) * numCoalEvents[currentGridIndex];
currentGridIndex++;
}
}
//from likelihood of interval between gridPoints[maxGridIndex] and lastCoalescentTime
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - gridPoints[currentGridIndex - 1]) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
//check to see if interval ends with coalescent event
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
currentTime = nextTime;
currentTimeIndex++;
while ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
// currentTime = nextTime;
// currentTimeIndex++;
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
//check to see if interval is coalescent interval or sampling interval
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
currentTime = nextTime;
currentTimeIndex++;
}
// if tree does not overlap with any gridpoints/change-points, in which case logpopsize is constant
} else {
while ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
//check to see if interval is coalescent interval or sampling interval
if (intervalsList.get(i).getCoalescentEvents(currentTimeIndex + 1) > 0) {
numCoalEvents[currentGridIndex]++;
}
sufficientStatistics[currentGridIndex] = sufficientStatistics[currentGridIndex] + (nextTime - currentTime) * numLineages * (numLineages - 1) * 0.5 * ploidyFactor;
currentTime = nextTime;
currentTimeIndex++;
if ((currentTimeIndex + 1) < intervalsList.get(i).getIntervalCount()) {
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
while (nextTime <= currentTime) {
currentTimeIndex++;
currentTime = intervalsList.get(i).getIntervalTime(currentTimeIndex);
nextTime = intervalsList.get(i).getIntervalTime(currentTimeIndex + 1);
}
numLineages = intervalsList.get(i).getLineageCount(currentTimeIndex + 1);
}
}
ploidySums[currentGridIndex] = ploidySums[currentGridIndex] + Math.log(ploidyFactor) * numCoalEvents[currentGridIndex];
}
}
}
public double[] getNumCoalEvents() {
return numCoalEvents;
}
public int getNumberOfCoalescentEvents() {
return getCorrectOldFieldLength();
}
public double getCoalescentEventsStatisticValue(int i) {
if (i == 0) {
if (DEBUG) {
System.err.println("numTrees: " + numTrees);
System.err.println("getCoalescentIntervalDimension(): " + super.getCoalescentIntervalDimension());
System.err.println("getNumberOfCoalescentEvents(): " + getNumberOfCoalescentEvents());
System.err.println("getIntervalCount(): " + getIntervalCount());
System.err.println("intervalsList.size(): " + intervalsList.size());
System.err.println("intervalsList.get(0).getIntervalCount(): " + intervalsList.get(0).getIntervalCount());
}
if (numTrees > 1) {
throw new RuntimeException("Generalized stepping-stone sampling for the Skygrid not implemented for #trees > 1");
}
for (int j = 0; j < coalescentEventStatisticValues.length; j++) {
coalescentEventStatisticValues[j] = 0.0;
}
int counter = 0;
for (int j = 0; j < intervalsList.get(0).getIntervalCount(); j++) {
if (intervalsList.get(0).getIntervalType(j) == IntervalType.COALESCENT) {
//this.coalescentEventStatisticValues[counter] += getCoalescentInterval(j) * (getLineageCount(j) * (getLineageCount(j) - 1.0)) / 2.0;
this.coalescentEventStatisticValues[counter] += intervalsList.get(0).getInterval(j) * (intervalsList.get(0).getLineageCount(j) * (intervalsList.get(0).getLineageCount(j) - 1.0)) / 2.0;
counter++;
} else {
//this.coalescentEventStatisticValues[counter] += getCoalescentInterval(j) * (getLineageCount(j) * (getLineageCount(j) - 1.0)) / 2.0;
this.coalescentEventStatisticValues[counter] += intervalsList.get(0).getInterval(j) * (intervalsList.get(0).getLineageCount(j) * (intervalsList.get(0).getLineageCount(j) - 1.0)) / 2.0;
}
}
}
return coalescentEventStatisticValues[i];
//throw new RuntimeException("getCoalescentEventsStatisticValue(int i) not implemented for Bayesian Skygrid");
//return sufficientStatistics[i];
}
protected double calculateLogCoalescentLikelihood() {
if (!intervalsKnown) {
// intervalsKnown -> false when handleModelChanged event occurs in super.
wrapSetupIntervals();
setupSufficientStatistics();
intervalsKnown = true;
}
// Matrix operations taken from block update sampler to calculate data likelihood and field prior
double currentLike = 0;
double[] currentGamma = popSizeParameter.getParameterValues();
for (int i = 0; i < fieldLength; i++) {
currentLike += -numCoalEvents[i] * currentGamma[i] + ploidySums[i] - sufficientStatistics[i] * Math.exp(-currentGamma[i]);
}
return currentLike;
}
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogCoalescentLikelihood();
logFieldLikelihood = skygridHelper.getLogFieldLikelihood();
likelihoodKnown = true;
}
return logLikelihood + logFieldLikelihood;
}
protected void setupGMRFWeights() {
//setupSufficientStatistics();
//Set up the weight Matrix
double[] offdiag = new double[fieldLength - 1];
double[] diag = new double[fieldLength];
diagonalValue = 2;
//First set up the offdiagonal entries;
for (int i = 0; i < fieldLength - 1; i++) {
offdiag[i] = -1;
}
//Then set up the diagonal entries;
for (int i = 1; i < fieldLength - 1; i++) {
// diag[i] = -(offdiag[i] + offdiag[i - 1]);
diag[i] = diagonalValue;
}
//Take care of the endpoints
//diag[0] = -offdiag[0];
//diag[fieldLength - 1] = -offdiag[fieldLength - 2];
diag[0] = diagonalValue - 1.0;
diag[fieldLength - 1] = diagonalValue - 1.0;
weightMatrix = new SymmTridiagMatrix(diag, offdiag);
}
protected double getFieldScalar() {
return 1.0;
}
protected void setupGMRFWeightsForMissingCov() {
//System.err.println("fieldLength: " + fieldLength);
// System.err.println("lastObservedIndex: " + lastObservedIndex);
//Set up the weight Matrix
weightMatricesForMissingCov = new ArrayList<SymmTridiagMatrix>();
for (int i = 0; i < covPrecParameters.size(); i++) {
double[] offdiag = new double[fieldLength - lastObservedIndex[i] - 1];
double[] diag = new double[fieldLength - lastObservedIndex[i]];
//First set up the offdiagonal entries;
for (int k = 0; k < fieldLength - lastObservedIndex[i] - 1; k++) {
offdiag[k] = -1;
}
//Then set up the diagonal entries;
for (int k = 0; k < fieldLength - lastObservedIndex[i] - 1; k++) {
// diag[i] = -(offdiag[i] + offdiag[i - 1]);
diag[k] = 2.0;
}
//Take care of the endpoint
diag[fieldLength - lastObservedIndex[i] - 1] = 1.0;
weightMatricesForMissingCov.add(i, new SymmTridiagMatrix(diag, offdiag));
}
}
public SymmTridiagMatrix getScaledWeightMatrixForMissingCov(double precision, int covIndex, int lastObs) {
SymmTridiagMatrix a = weightMatricesForMissingCov.get(covIndex).copy();
for (int i = 0; i < a.numRows() - 1; i++) {
a.set(i, i, a.get(i, i) * precision);
a.set(i + 1, i, a.get(i + 1, i) * precision);
}
a.set(fieldLength - lastObs - 1, fieldLength - lastObs - 1,
a.get(fieldLength - lastObs - 1, fieldLength - lastObs - 1) * precision);
return a;
}
private List<Tree> treeList;
private List<TreeIntervals> intervalsList;
public int nLoci() {
return treeList.size();
}
public Tree getTree(int nt) {
return treeList.get(nt);
}
public TreeIntervals getTreeIntervals(int nt) {
return intervalsList.get(nt);
}
public double getPopulationFactor(int nt) {
return ploidyFactors.getParameterValue(nt);
}
public List<Parameter> getBetaListParameter() {
return beta;
}
public List<MatrixParameter> getCovariates() {
return covariates;
}
public void storeTheState() {
for (TreeIntervals intervals : intervalsList) {
intervals.storeState();
}
}
public void restoreTheState() {
for (TreeIntervals intervals : intervalsList) {
intervals.restoreState();
}
}
protected void storeState() {
// System.arraycopy(numCoalEvents, 0, storedNumCoalEvents, 0, numCoalEvents.length);
super.storeState();
System.arraycopy(numCoalEvents, 0, storedNumCoalEvents, 0, numCoalEvents.length);
// storedPrecMatrix = precMatrix.copy();
System.arraycopy(ploidySums, 0, storedPloidySums, 0, ploidySums.length);
}
protected void restoreState() {
super.restoreState();
// Swap pointers
double[] tmp = numCoalEvents;
numCoalEvents = storedNumCoalEvents;
storedNumCoalEvents = tmp;
double[] tmp2 = ploidySums;
ploidySums = storedPloidySums;
storedPloidySums = tmp2;
}
/*public int getCoalescentIntervalLineageCount(int i) {
return 0;
}
public IntervalType getCoalescentIntervalType(int i) {
return null;
}*/
class SkygridHelper {
public SkygridHelper() {
}
protected void updateGammaWithCovariates(DenseVector currentGamma) {
// Do nothing
}
protected double handleMissingValues() {
return 0.0;
}
public double getLogFieldLikelihood() {
if (!intervalsKnown) {
//intervalsKnown -> false when handleModelChanged event occurs in super.
wrapSetupIntervals();
setupSufficientStatistics();
intervalsKnown = true;
}
DenseVector diagonal1 = new DenseVector(fieldLength);
DenseVector currentGamma = new DenseVector(popSizeParameter.getParameterValues());
updateGammaWithCovariates(currentGamma);
double currentLike = handleMissingValues();
SymmTridiagMatrix currentQ = getScaledWeightMatrix(precisionParameter.getParameterValue(0), lambdaParameter.getParameterValue(0));
currentQ.mult(currentGamma, diagonal1);
currentLike += 0.5 * (fieldLength - 1) * Math.log(precisionParameter.getParameterValue(0)) - 0.5 * currentGamma.dot(diagonal1);
if (lambdaParameter.getParameterValue(0) == 1) {
currentLike -= (fieldLength - 1) / 2.0 * LOG_TWO_TIMES_PI;
} else {
currentLike -= fieldLength / 2.0 * LOG_TWO_TIMES_PI;
}
return currentLike;
}
}
class SkygridCovariateHelper extends SkygridHelper {
public SkygridCovariateHelper() {
}
@Override
protected void updateGammaWithCovariates(DenseVector currentGamma) {
// Handle betaParameter / designMatrix
if (NEW_APPROACH) {
final int N = currentGamma.size();
double[] update = new double[N];
if (dMatrix != null) {
final int K = dMatrix.getColumnDimension();
if (N != dMatrix.getRowDimension()) {
throw new RuntimeException("Incorrect covariate dimensions (" + N + " != "
+ dMatrix.getRowDimension() + ")");
}
for (int i = 0; i < N; ++i) {
for (int j = 0; j < K; ++j) {
update[i] += dMatrix.getParameterValue(i, j) * betaParameter.getParameterValue(j);
}
}
}
if (covariates != null) {
if (beta.size() != covariates.size()) {
throw new RuntimeException("beta.size() != covariates.size()");
}
for (int k = 0; k < beta.size(); ++k) {
Parameter b = beta.get(k);
final int J = b.getDimension();
MatrixParameter covariate = covariates.get(k);
if ((J != covariate.getRowDimension()) ||
(N != covariate.getColumnDimension())) { // Note: XML current has covariates transposed
throw new RuntimeException("Incorrect dimensions in " + covariate.getId());
}
for (int i = 0; i < N; ++i) {
for (int j = 0; j < J; ++j) {
update[i] += covariate.getParameterValue(j, i) * b.getParameterValue(j);
}
}
}
}
for (int i = 0; i < N; ++i) {
currentGamma.set(i, currentGamma.get(i) - update[i]);
}
} else {
DenseVector currentBeta = new DenseVector(beta.size());
for (int i = 0; i < beta.size(); i++) {
currentBeta.set(i, beta.get(i).getParameterValue(0));
}
//int numMissing = fieldLength - lastObservedIndex;
//DenseVector tempVectCov = new DenseVector(numMissing);
//System.err.println("covariates.size(): " + covariates.size());
//System.err.println("covariates.get(0).getColumnDimension: " + covariates.get(0).getColumnDimension());
//System.err.println("covariates.get(0).getRowDimension: " + covariates.get(0).getRowDimension());
for (int i = 0; i < covariates.size(); i++) {
for (int j = 0; j < covariates.get(i).getColumnDimension(); j++) {
// System.err.println("j: " + j);
// System.err.println("covariates.get(i).getParameterValue(0,j): " + covariates.get(i).getParameterValue(0,j));
currentGamma.set(j, currentGamma.get(j) - covariates.get(i).getParameterValue(0, j) * currentBeta.get(i));
}
}
}
}
}
private static final boolean NEW_APPROACH = true;
class SkygridMissingCovariateHelper extends SkygridCovariateHelper {
public SkygridMissingCovariateHelper() {
}
@Override
protected double handleMissingValues() {
int numMissing;
DenseVector tempVectMissingCov;
SymmTridiagMatrix missingCovQ;
DenseVector tempVectMissingCov2;
double currentLike = 0.0;
for (int i = 0; i < covPrecParameters.size(); i++) {
numMissing = fieldLength - lastObservedIndex[i];
tempVectMissingCov = new DenseVector(numMissing);
tempVectMissingCov2 = new DenseVector(numMissing);
missingCovQ = getScaledWeightMatrixForMissingCov(covPrecParameters.get(i).getParameterValue(0), i,
lastObservedIndex[i]);
for (int j = 0; j < numMissing; j++) {
// System.err.println("covariate.get(i).getSize(): " + covariates.get(i).getSize());
// System.err.println("lastObservedIndex: " + lastObservedIndex);
// System.err.println("j: " + j);
// System.err.println("getParameterValue(0, lastObservedIndex-1): " + covariates.get(i).getParameterValue(0,lastObservedIndex-1));
tempVectMissingCov.set(j, covariates.get(i).getParameterValue(0, lastObservedIndex[i] + j) -
covariates.get(i).getParameterValue(0, lastObservedIndex[i] - 1));
}
missingCovQ.mult(tempVectMissingCov, tempVectMissingCov2);
// System.err.println("missingCovQ: " + missingCovQ.get(0,0));
currentLike += 0.5 * (numMissing) * Math.log(covPrecParameters.get(i).getParameterValue(0))
- 0.5 * tempVectMissingCov.dot(tempVectMissingCov2);
}
return currentLike;
}
}
@Override
public Citation.Category getCategory() {
return Citation.Category.TREE_PRIORS;
}
@Override
public String getDescription() {
return "Skyride coalescent";
}
@Override
public List<Citation> getCitations() {
return Arrays.asList(new Citation(
new Author[]{
new Author("MS", "Gill"),
new Author("P", "Lemey"),
new Author("NR", "Faria"),
new Author("A", "Rambaut"),
new Author("B", "Shapiro"),
new Author("MA", "Suchard")
},
"Improving Bayesian population dynamics inference: a coalescent-based model for multiple loci",
2013,
"Mol Biol Evol",
30, 713, 724
),
new Citation(
new Author[]{
new Author("VN", "Minin"),
new Author("EW", "Bloomquist"),
new Author("MA", "Suchard")
},
"Smooth skyride through a rough skyline: Bayesian coalescent-based inference of population dynamics",
2008,
"Mol Biol Evol",
25, 1459, 1471,
"10.1093/molbev/msn090"
)
);
}
}