/*
* RootProcessDelegate.java
*
* Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.treedatalikelihood.continuous;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
/**
* @author Marc A. Suchard
* @version $Id$
*/
public interface RootProcessDelegate {
int getExtraPartialBufferCount();
int getExtraMatrixBufferCount();
void calculateRootLogLikelihood(ContinuousDiffusionIntegrator cdi, int rootIndex, final double[] logLike,
boolean incrementOuterProducts);
void setRootPartial(ContinuousDiffusionIntegrator cdi);
double getPseudoObservations();
int getPriorBufferIndex();
// int getDegreesOfFreedom();
abstract class Abstract implements RootProcessDelegate {
protected final ConjugateRootTraitPrior prior;
private final PrecisionType precisionType;
private final int priorBufferIndex;
private final int numTraits;
public abstract double getPseudoObservations();
public Abstract(final ConjugateRootTraitPrior prior,
final PrecisionType precisionType, int numTraits,
int partialBufferCount, int matrixBufferCount) {
this.prior = prior;
this.precisionType = precisionType;
this.numTraits = numTraits;
this.priorBufferIndex = partialBufferCount;
}
@Override
public int getExtraPartialBufferCount() {
return 2; // TODO Why does 1 not work?
}
@Override
public int getExtraMatrixBufferCount() {
return 0;
}
@Override
public int getPriorBufferIndex() { return priorBufferIndex; }
@Override
public void calculateRootLogLikelihood(ContinuousDiffusionIntegrator cdi, int rootBufferIndex,
final double[] logLike, boolean incrementOuterProducts) {
cdi.calculateRootLogLikelihood(rootBufferIndex, priorBufferIndex, logLike, incrementOuterProducts);
}
@Override
public void setRootPartial(ContinuousDiffusionIntegrator cdi) {
double[] mean = prior.getMean();
final int dimTrait = mean.length;
final int length = dimTrait + precisionType.getMatrixLength(dimTrait);
double[] partial = new double[length * numTraits];
int offset = 0;
for (int trait = 0; trait < numTraits; ++trait) {
System.arraycopy(mean, 0, partial, offset, dimTrait);
final double precision = getPseudoObservations();
for (int i = 0; i < dimTrait; ++i) {
precisionType.fillPrecisionInPartials(partial, offset, i, precision, dimTrait);
}
// partial[offset + dimTrait] = getPseudoObservations();
offset += length;
}
cdi.setPostOrderPartial(priorBufferIndex, partial);
}
}
class Fixed extends Abstract {
public Fixed(ConjugateRootTraitPrior prior, PrecisionType precisionType,
int numTraits, int partialBufferCount, int matrixBufferCount) {
super(prior, precisionType, numTraits, partialBufferCount, matrixBufferCount);
}
@Override
public double getPseudoObservations() {
return Double.POSITIVE_INFINITY;
}
// @Override
// public int getDegreesOfFreedom() { return 0; }
}
class FullyConjugate extends Abstract {
public FullyConjugate(ConjugateRootTraitPrior prior, PrecisionType precisionType,
int numTraits, int partialBufferCount, int matrixBufferCount) {
super(prior, precisionType, numTraits, partialBufferCount, matrixBufferCount);
}
@Override
public double getPseudoObservations() {
return prior.getPseudoObservations();
}
// @Override
// public int getDegreesOfFreedom() { return 1; }
}
}