/*
* AvgPosteriorIBDReporter.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.ibd;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.oldevomodel.substmodel.AbstractSubstitutionModel;
import dr.oldevomodel.substmodel.HKY;
import dr.evomodel.tree.TreeModel;
import dr.oldevomodel.treelikelihood.NodePosteriorTreeLikelihood;
import dr.inference.model.*;
import dr.xml.*;
/**
* Package: dr.evomodel.ibd
* Description: Computes for each tip the expected number of other tips IBD to it given the tip labels,
* averaged over the full length of the alignment
* <p/>
* <p/>
* Created by
* avaleks (alexander.alekseyenko@gmail.com)
* Date: 04-Aug-2008
* Time: 13:46:33
*/
public class AvgPosteriorIBDReporter extends AbstractModel implements TreeTraitProvider {
protected double[] ibdweights;
protected double[][] ibdForward;
protected double[][] ibdBackward;
protected double[] diag;
protected boolean weightsKnown;
protected HKY substitutionModel;
protected TreeModel treeModel;
protected BranchRateModel branchRateModel;
protected Parameter mutationParameter;
protected NodePosteriorTreeLikelihood likelihoodReporter;
protected double[] probabilities;
AvgPosteriorIBDReporter(NodePosteriorTreeLikelihood likelihoodReporter, Parameter mutationParameter, TreeModel treeModel, BranchRateModel branchRateModel, AbstractSubstitutionModel substitutionModel) {
super("AvgPosteriorIBDReporter");
this.substitutionModel = (HKY) substitutionModel;
addModel(this.substitutionModel);
this.treeModel = treeModel;
addModel(this.treeModel);
this.branchRateModel = branchRateModel;
addModel(this.branchRateModel);
this.mutationParameter = mutationParameter;
addVariable(this.mutationParameter);
this.likelihoodReporter = likelihoodReporter;
this.probabilities = new double[substitutionModel.getStateCount() * substitutionModel.getStateCount()];
}
public void forwardIBD() {
int numNodes = treeModel.getNodeCount();
int stateCount = substitutionModel.getStateCount();
getDiagonalRates(diag);
int patternCount = likelihoodReporter.getPatternCount();
for (int nodeId = 0; nodeId < numNodes; ++nodeId) {
NodeRef node = treeModel.getNode(nodeId);
NodeRef parent = treeModel.getParent(node);
likelihoodReporter.getNodeMatrix(nodeId, probabilities);
double[] posteriors = likelihoodReporter.getPosteriors(nodeId);
if (parent == null) { // handle the root
} else if (treeModel.isExternal(node)) { // Handle the tip
double branchTime = branchRateModel.getBranchRate(treeModel, node) * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(node));
for (int state = 0; state < stateCount; ++state) {
double enorm = Math.exp(-diag[state] * branchTime) / probabilities[state + state * stateCount];
for (int pattern = 0; pattern < patternCount; ++pattern) {
ibdForward[nodeId][pattern * stateCount + state] = posteriors[pattern * stateCount + state] * enorm;
}
}
} else { // Handle internal node
double branchTime = branchRateModel.getBranchRate(treeModel, node) * (treeModel.getNodeHeight(parent) - treeModel.getNodeHeight(node));
int childCount = treeModel.getChildCount(node);
for (int state = 0; state < stateCount; ++state) {
double enorm = Math.exp(-diag[state] * branchTime) / probabilities[state + state * stateCount];
for (int pattern = 0; pattern < patternCount; ++pattern) {
ibdForward[nodeId][pattern * stateCount + state] = 0;
for (int child = 0; child < childCount; ++child) {
int childNodeId = treeModel.getChild(node, child).getNumber();
ibdForward[nodeId][pattern * stateCount + state] += ibdForward[childNodeId][pattern * stateCount + state];
}
ibdForward[nodeId][pattern * stateCount + state] *= posteriors[pattern * stateCount + state] * enorm;
}
}
}
}
}
public void backwardIBD(NodeRef node) {
int stateCount = substitutionModel.getStateCount();
int patternCount = likelihoodReporter.getPatternCount();
if (node == null) {
node = treeModel.getRoot();
int nodeId = node.getNumber();
for (int i = 0; i < patternCount * stateCount; ++i) {
ibdBackward[nodeId][i] = 0;
}
}
getDiagonalRates(diag);
int childCount = treeModel.getChildCount(node);
int nodeId = node.getNumber();
double[] posteriors = likelihoodReporter.getPosteriors(nodeId);
for (int child = 0; child < childCount; ++child) {
NodeRef childNode = treeModel.getChild(node, child);
int childNodeId = childNode.getNumber();
likelihoodReporter.getNodeMatrix(childNodeId, probabilities);
double branchTime = branchRateModel.getBranchRate(treeModel, childNode) * (treeModel.getNodeHeight(node) - treeModel.getNodeHeight(childNode));
for (int pattern = 0; pattern < patternCount; ++pattern) {
for (int state = 0; state < stateCount; ++state) {
ibdBackward[childNodeId][pattern * stateCount + state] = ibdBackward[nodeId][pattern * stateCount + state];
for (int sibling = 0; sibling < childCount; ++sibling) {
if (sibling != child) {
int siblingId = treeModel.getChild(node, sibling).getNumber();
ibdBackward[childNodeId][pattern * stateCount + state] += ibdForward[siblingId][pattern * stateCount + state];
}
}
ibdBackward[childNodeId][pattern * stateCount + state] *= posteriors[pattern * stateCount + state] * Math.exp(-diag[state] * branchTime) / probabilities[state + state * stateCount];
}
}
}
for (int child = 0; child < childCount; ++child) {
NodeRef childNode = treeModel.getChild(node, child);
backwardIBD(childNode);
}
}
public void expectedIBD() {
int stateCount = substitutionModel.getStateCount();
int nodeCount = treeModel.getNodeCount();
int patternCount = likelihoodReporter.getPatternCount();
if (ibdweights == null) {
ibdweights = new double[treeModel.getExternalNodeCount()];
ibdForward = new double[nodeCount][stateCount * patternCount];
ibdBackward = new double[nodeCount][stateCount * patternCount];
diag = new double[stateCount];
}
forwardIBD();
backwardIBD(null);
int numTips = treeModel.getExternalNodeCount();
double[] patternWeights = likelihoodReporter.getPatternWeights();
double total = 0.0;
for (int i = 0; i < patternCount; ++i) {
total += patternWeights[i];
}
for (int i = 0; i < numTips; ++i) {
double[] posteriors = likelihoodReporter.getPosteriors(i);
ibdweights[i] = 0;
for (int pattern = 0; pattern < patternCount; ++pattern) {
for (int j = 0; j < stateCount; ++j) {
ibdweights[i] += ibdBackward[i][pattern * stateCount + j] * posteriors[pattern * stateCount + j] * patternWeights[pattern] / total;
}
}
}
}
protected void getDiagonalRates(double[] diagonalRates) {
double kappa = substitutionModel.getKappa();
double[] freq = substitutionModel.getFrequencyModel().getFrequencies();
double mutationRate = mutationParameter.getParameterValue(0);
double beta = 0.5 / ((freq[0] + freq[2]) * (freq[1] + freq[3]) + kappa * (freq[0] * freq[2] + freq[1] * freq[3]));
diagonalRates[0] = ((freq[1] + freq[3]) + freq[2] * kappa) * mutationRate * beta;
diagonalRates[1] = ((freq[0] + freq[2]) + freq[3] * kappa) * mutationRate * beta;
diagonalRates[2] = ((freq[1] + freq[3]) + freq[0] * kappa) * mutationRate * beta;
diagonalRates[3] = ((freq[0] + freq[2]) + freq[1] * kappa) * mutationRate * beta;
}
TreeTrait avgPosteriorIBDWeight = new TreeTrait.D() {
public String getTraitName() {
return "AvgPosteriorIBDWeight";
}
public Intent getIntent() {
return Intent.NODE;
}
public Double getTrait(Tree tree, NodeRef node) {
if (!weightsKnown) {
expectedIBD();
weightsKnown = true;
}
if (tree.isExternal(node)) {
int nodeNum = node.getNumber();
return ibdweights[nodeNum] + 1;
}
return null;
}
};
public TreeTrait[] getTreeTraits() {
return new TreeTrait[] { avgPosteriorIBDWeight };
}
public TreeTrait getTreeTrait(String key) {
// ignore the key - it must be the one they wanted, no?
return avgPosteriorIBDWeight;
}
/**
* The XML parser
*/
public static final String IBD_REPORTER_LIKELIHOOD = "avgPosteriorIBDReporter";
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return IBD_REPORTER_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
Parameter mutationParameter = (Parameter) xo.getChild(Parameter.class);
AbstractSubstitutionModel substitutionModel =
(AbstractSubstitutionModel) xo.getChild(AbstractSubstitutionModel.class);
BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
if (branchRateModel == null) {
branchRateModel = new DefaultBranchRateModel();
}
NodePosteriorTreeLikelihood likelihoodReporter = (NodePosteriorTreeLikelihood) xo.getChild(NodePosteriorTreeLikelihood.class);
return new AvgPosteriorIBDReporter(likelihoodReporter, mutationParameter, treeModel, branchRateModel, substitutionModel);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element represents a reporter for average expected number of tips ibd conditional on observed patterns.";
}
public Class getReturnType() {
return Likelihood.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{
new ElementRule(TreeModel.class),
new ElementRule(BranchRateModel.class, true),
new ElementRule(AbstractSubstitutionModel.class),
new ElementRule(Parameter.class),
new ElementRule(NodePosteriorTreeLikelihood.class)
};
};
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == branchRateModel || model == treeModel || model == substitutionModel || model == likelihoodReporter) {
weightsKnown = false;
} else {
System.err.println("Weird call back to IBDReporter from " + model.getModelName());
}
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == mutationParameter) {
weightsKnown = false;
} else {
System.err.println("Weird call back to IBDReporter from " + variable.getVariableName());
}
}
protected void storeState() {
}
protected void restoreState() {
}
protected void acceptState() {
}
}