/*
* NPAntigenicLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.antigenic;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.GammaFunction;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.StringAttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
/**
* @author Gabriela Cybis
* @author Marc Suchard
*/
public class NPAntigenicLikelihood extends AbstractModelLikelihood {
public static final String NP_ANTIGENIC_LIKELIHOOD = "NPAntigenicLikelihood";
public NPAntigenicLikelihood (TreeModel treeModel,CompoundParameter traitParameter , Parameter assignments, Parameter links,
Parameter chi, Parameter clusterPrec, Parameter priorMean, Parameter priorPrec,
Parameter transformFactor,Parameter means1, Parameter means2, Parameter locationDrift,
Parameter offsets, Boolean hasDrift){
super(NP_ANTIGENIC_LIKELIHOOD);
this.assignments = assignments;
this.links = links;
this.clusterPrec = clusterPrec;
this.priorPrec = priorPrec;
this.priorMean = priorMean;
this.treeModel= treeModel;
this.traitParameter= traitParameter;
this.transformFactor=transformFactor;
this.means1=means1;
this.means2=means2;
this.alpha= chi;
this.locationDrift=locationDrift;
this.offsets=offsets;
// this.hasDrift=hasDrift;
this.hasDrift=false;
addVariable(traitParameter);
addVariable(assignments);
addVariable(links);
addModel(treeModel);
addVariable(chi);
addVariable(transformFactor);
addVariable(alpha);
// addVariable(locationDrift);
addVariable(offsets);
numdata = traitParameter.getParameterCount();
this.allTips= TreeUtils.getExternalNodes(treeModel,treeModel.getRoot());
setDepMatrix();
for (int i=0; i<numdata; i++){
assignments.setParameterValue(i, i);
links.setParameterValue(i,i);
}
this.logLikelihoodsVector = new double[links.getDimension()+1];
this.logLikelihoodsVectorKnown = new boolean[links.getDimension()+1];
this.storedLogLikelihoodsVector = new double[links.getDimension()+1];
this.m = new double[2];
m[0]= priorMean.getParameterValue(0);
m[1]= priorMean.getParameterValue(1);
this.v0 = 2;
// double v1 = 3;
this.k0= priorPrec.getParameterValue(0)/clusterPrec.getParameterValue(0);
// double k1= k0+1;
this.T0Inv= new double[2][2];
T0Inv[0][0]= v0/clusterPrec.getParameterValue(0);
T0Inv[1][1]= v0/clusterPrec.getParameterValue(0);
T0Inv[1][0]= 0.0;
T0Inv[0][1]= 0.0;
this.logDetT0= -Math.log(T0Inv[0][0]*T0Inv[1][1]);
/* for(int i=0;i<logLikelihoodsVector.length-1;i++){
double[][] T1Inv = new double[2][2];
T1Inv[0][0]=T0Inv[0][0]+(k0/k1)* data[i][0]*data[i][0];
T1Inv[0][1]=T0Inv[0][1]+(k0/k1)* data[i][0]*data[i][1];
T1Inv[1][0]=T0Inv[1][0]+(k0/k1)* data[i][1]*data[i][0];
T1Inv[1][1]=T0Inv[1][1]+(k0/k1)* data[i][1]*data[i][1];
double logDetT1=-Math.log(T1Inv[0][0]*T1Inv[1][1]-T1Inv[0][1]*T1Inv[1][0]);
logLikelihoodsVector[i]= -(1*2/2)*Math.log(Math.PI);
logLikelihoodsVector[i]+= Math.log(k0) - Math.log(k1);
logLikelihoodsVector[i]+= (v1/2)*logDetT1 - (v0/2)*logDetT0;
logLikelihoodsVector[i]+= GammaFunction.lnGamma(v1/2)+ GammaFunction.lnGamma((v1/2)-0.5);
logLikelihoodsVector[i]+=-GammaFunction.lnGamma(v0/2)- GammaFunction.lnGamma((v0/2)-0.5);
}
*/
}
/*
public void setInitialAssignmentsToDates(){
double[] offsetValues = new double[numdata];
for (int i=0; i<numdata; i++){
offsetValues[i] =offsets.getParameterValue(findOffsetIndex(i));
}
boolean[] assigned = new boolean[numdata];
for (int i=0; i<numdata;i++ ){
assigned[i]=false;
}
int group = 0;
for (int i=0; i<numdata; i++){
if (!assigned[i]){
int last = i;
for (int j=0;j<numdata;j++){
if (offsetValues[j]==offsetValues[i]||offsetValues[j]==offsetValues[i]+1||offsetValues[j]==offsetValues[i]+2||offsetValues[j]==offsetValues[i]+3||offsetValues[j]==offsetValues[i]+4){
links.setParameterValue(j,last);
assignments.setParameterValue(j, group);
assigned[j]=true;
last=j;
}
}
links.setParameterValue(i,last);
group ++;
}
}
printInformation(links);
printInformation(assignments);
}
private int findOffsetIndex(int traitParameterIndex){
String NAME = traitParameter.getParameter(traitParameterIndex).getParameterName();
boolean notFound =true;
int i=0;
while (notFound){
if(offsets.getDimensionName(i).compareTo(NAME)==0){
notFound=false;
}else{
i++;
}
}
// printInformation((double) i);
// printInformation(offsets.getDimensionName(i),NAME);
return i;
}
private void setData(){
dataMatrixKnown=true;
int dim = traitParameter.getParameter(0).getSize();
for (int i=0; i<numdata; i++){
for (int j=1; j<dim; j++){
data[i][j]= traitParameter.getParameter(i).getParameterValue(j);
}
//if (hasDrift){
// int offsetIndex = findOffsetIndex(i);
// Data[i][0] += locationDrift.getParameterValue(0)*offsets.getParameterValue(offsetIndex);
//} else{
data[i][0]= traitParameter.getParameter(i).getParameterValue(0);
//}
}
}
private void setDatum(int virus){
int dim = traitParameter.getParameter(0).getSize();
for (int j=1; j<dim; j++){
data[virus][j]= traitParameter.getParameter(virus).getParameterValue(j);
}
//if (hasDrift){
// int offsetIndex = findOffsetIndex(i);
// Data[i][0] += locationDrift.getParameterValue(0)*offsets.getParameterValue(offsetIndex);
//} else{
data[virus][0]= traitParameter.getParameter(virus).getParameterValue(0);
//}
}
*/
private void setDepMatrix(){
depMatrixKnown = true;
depMatrix=new double[numdata][numdata];
List<NodeRef> childList = new ArrayList<NodeRef>();
recursion(treeModel.getRoot(),childList);
logCorrectMatrix(transformFactor.getParameterValue(0));
logDepMatrix = new double[numdata][numdata];
for(int i=0;i<numdata;i++){
for(int j=0;j<i;j++){
logDepMatrix[i][j]=Math.log(depMatrix[i][j]);
logDepMatrix[j][i]=logDepMatrix[j][i];
}
}
}
public double getLogLikGroup(int groupNumber){
double L =0.0;
int ngroup=0;
for (int i=0;i<assignments.getDimension(); i++){
if((int) assignments.getParameterValue(i) == groupNumber){
ngroup++;}}
if (ngroup != 0){
double[][] group = new double[ngroup][2];
double mean[]=new double[2];
int count = 0;
for (int i=0;i<assignments.getDimension(); i++){
if((int) assignments.getParameterValue(i) == groupNumber){
group[count][0] = getData(i,0);
group[count][1] = getData(i,0);
mean[0]+=group[count][0];
mean[1]+=group[count][1];
count++;}}
mean[0]/=ngroup;
mean[1]/=ngroup;
double kn= k0+ngroup;
double vn= v0+ngroup;
double[][] sumdif=new double[2][2];
for(int i=0;i<ngroup;i++){
sumdif[0][0]+= (group[i][0]-mean[0])*(group[i][0]-mean[0]);
sumdif[0][1]+= (group[i][0]-mean[0])*(group[i][1]-mean[1]);
sumdif[1][0]+= (group[i][0]-mean[0])*(group[i][1]-mean[1]);
sumdif[1][1]+= (group[i][1]-mean[1])*(group[i][1]-mean[1]);
}
double[][] TnInv = new double[2][2];
TnInv[0][0]=T0Inv[0][0]+ngroup*(k0/kn)*(mean[0]-m[0])*(mean[0]-m[0])+sumdif[0][0];
TnInv[0][1]=T0Inv[0][1]+ngroup*(k0/kn)*(mean[1]-m[1])*(mean[0]-m[0])+sumdif[0][1];
TnInv[1][0]=T0Inv[1][0]+ngroup*(k0/kn)* (mean[0]-m[0])*(mean[1]-m[1])+sumdif[1][0];
TnInv[1][1]=T0Inv[1][1]+ngroup*(k0/kn)* (mean[1]-m[1])*(mean[1]-m[1])+sumdif[1][1];
double logDetTn=-Math.log(TnInv[0][0]*TnInv[1][1]-TnInv[0][1]*TnInv[1][0]);
L+= -(ngroup)*Math.log(Math.PI);
L+= Math.log(k0) - Math.log(kn);
L+= (vn/2)*logDetTn - (v0/2)*logDetT0;
L+= GammaFunction.lnGamma(vn/2)+ GammaFunction.lnGamma((vn/2)-0.5);
L+=-GammaFunction.lnGamma(v0/2)- GammaFunction.lnGamma((v0/2)-0.5);
}
logLikelihoodsVectorKnown[groupNumber]=true;
return L;
}
public Model getModel() {
return this;
}
public double[] getLogLikelihoodsVector(){
return logLikelihoodsVector;
}
public Parameter getLinks(){
return links;
}
public Parameter getAssignments(){
return assignments;
}
public double getData(int virus, int dim){
return traitParameter.getParameter(virus).getParameterValue(dim);
}
public double[][] getDepMatrix(){
return depMatrix;
}
public double[][] getLogDepMatrix(){
return logDepMatrix;
}
public Parameter getPriorMean(){
return priorMean;
}
public Parameter getPriorPrec(){
return priorPrec;
}
public Parameter getClusterPrec(){
return clusterPrec;
}
public void setLogLikelihoodsVector(int pos, double value){
logLikelihoodsVector[pos]=value;
}
public void setAssingments(int pos, double value){
assignments.setParameterValue(pos,value);
}
public void setLinks(int pos, double value){
links.setParameterValue(pos,value);
}
public void setMeans(int pos, double[] value){
means1.setParameterValue(pos,value[0]);
means2.setParameterValue(pos,value[1]);
}
public double getLogLikelihood() {
if (!logLikelihoodKnown) {
logLikelihood = computeLogLikelihood();
}
return logLikelihood;
}
public double computeLogLikelihood() {
if (!depMatrixKnown ){
setDepMatrix();
}
double logL = 0.0;
for (int j=0 ; j<logLikelihoodsVector.length;j++){
if(!logLikelihoodsVectorKnown[j]){
logLikelihoodsVector[j]=getLogLikGroup(j);
}
logL +=logLikelihoodsVector[j];
}
for (int j=0 ; j<links.getDimension();j++){
if(links.getParameterValue(j)==j){
logL += Math.log(alpha.getParameterValue(0));
}
else{logL += Math.log(depMatrix[j][(int) links.getParameterValue(j)]);
}
double sumDist=0.0;
for (int i=0;i<numdata;i++){
if(i!=j){sumDist += depMatrix[i][j];
}
}
logL-= Math.log(alpha.getParameterValue(0)+sumDist);
}
logLikelihoodKnown=true;
// printInformation(logL);
return logL;
}
/* Getting matrix from tree*/
void recursion( NodeRef node, List childList){
List<NodeRef> leftChildTipList = new ArrayList<NodeRef>();
List<NodeRef> rightChildTipList = new ArrayList<NodeRef>();
if(!treeModel.isExternal(node)){
recursion(treeModel.getChild(node, 0),leftChildTipList);
recursion(treeModel.getChild(node, 1),rightChildTipList);
double lBranch = treeModel.getBranchLength(treeModel.getChild(node, 0));
double rBranch = treeModel.getBranchLength(treeModel.getChild(node, 1));
Set<NodeRef> notLeftChildList = new HashSet<NodeRef>();
notLeftChildList.addAll(allTips);
for (NodeRef i :leftChildTipList){
notLeftChildList.remove(i);
}
Set<NodeRef> notRightChildList = new HashSet<NodeRef>();
notRightChildList.addAll(allTips);
for (NodeRef i :rightChildTipList){
notRightChildList.remove(i);
}
for (NodeRef lChild : leftChildTipList){
for (NodeRef Child : notLeftChildList){
depMatrix[Child.getNumber()][lChild.getNumber()] += lBranch;
depMatrix[lChild.getNumber()][Child.getNumber()] += lBranch;
}
}
for (NodeRef rChild : rightChildTipList){
for (NodeRef Child : notRightChildList){
depMatrix[Child.getNumber()][rChild.getNumber()] += rBranch;
depMatrix[rChild.getNumber()][Child.getNumber()] += rBranch;
}
}
childList.addAll(leftChildTipList);
childList.addAll(rightChildTipList);
}
else{
childList.add(node);
}
}
void logCorrectMatrix(double p){
for (int i=0; i<numdata; i++){
for (int j=0; j<i; j++){
depMatrix[i][j]=1/Math.pow(depMatrix[i][j],p);
depMatrix[j][i]=depMatrix[i][j];
}}
}
public double getTreeDist(int i, int j){
double dist=0;
NodeRef MRCA = findMRCA(i,j);
NodeRef Parent = treeModel.getExternalNode(i);
while (Parent!=MRCA){
dist+=treeModel.getBranchLength(Parent);
Parent = treeModel.getParent(Parent);
}
Parent = treeModel.getExternalNode(j);
while (Parent!=MRCA){
dist+=treeModel.getBranchLength(Parent);
Parent = treeModel.getParent(Parent);
}
return dist;
}
private NodeRef findMRCA(int iTip, int jTip) {
Set<String> leafNames = new HashSet<String>();
leafNames.add(treeModel.getTaxonId(iTip));
leafNames.add(treeModel.getTaxonId(jTip));
return TreeUtils.getCommonAncestorNode(treeModel, leafNames);
}
public void printInformtion(double[][] Mat) {
StringBuffer sb = new StringBuffer("matrix \n");
for(int i=0;i <numdata; i++){
sb.append(" \n");
for(int j=0; j<numdata; j++){
sb.append(Mat[i][j]+" \t");
}
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformation(Parameter Vec) {
StringBuffer sb = new StringBuffer("Vector \n");
for (int i = 0; i<numdata; i++){
sb.append(Vec.getParameterValue(i)+" \t");
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformation(int[] Vec) {
StringBuffer sb = new StringBuffer("Vector \n");
for(int i=0;i <numdata; i++){
sb.append(Vec[i]+" \t");
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printOrder() {
StringBuffer sb = new StringBuffer("taxa \n");
for(int i=0;i <numdata; i++){
sb.append(" \n");
sb.append(treeModel.getTaxonId(i));
}
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformation(double x) {
StringBuffer sb = new StringBuffer("Info \n");
sb.append(x);
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformation(String x) {
StringBuffer sb = new StringBuffer("Info \n");
sb.append(x);
Logger.getLogger("dr.evomodel").info(sb.toString()); };
public void printInformation(String x, String y) {
StringBuffer sb = new StringBuffer("Info \n");
sb.append(x + " and " + y);
Logger.getLogger("dr.evomodel").info(sb.toString()); };
@Override
protected void storeState() {
System.arraycopy(logLikelihoodsVector, 0, storedLogLikelihoodsVector, 0, logLikelihoodsVector.length);
}
@Override
protected void restoreState() {
double[] tmp = logLikelihoodsVector;
logLikelihoodsVector = storedLogLikelihoodsVector;
storedLogLikelihoodsVector = tmp;
depMatrixKnown = !proposedChangeDepMatrix;
proposedChangeDepMatrix =false;
// dataMatrixKnown = !proposedChangeDataMatrix;
// proposedChangeDataMatrix =false;
logLikelihoodKnown = false;
}
public void makeDirty() {
}
public void acceptState() {
// DO NOTHING
proposedChangeDepMatrix =false;
proposedChangeDataMatrix =false;
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == treeModel)
depMatrixKnown=false;
logLikelihoodKnown=false;
return;
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
logLikelihoodKnown = false;
if (variable == transformFactor) {
depMatrixKnown=false;
proposedChangeDepMatrix=true;
}
// if (variable == traitParameter || variable == locationDrift ){
if (variable == traitParameter){
// dataMatrixKnown=false;
// proposedChangeDataMatrix=true;
int loc= index / 2;
int changedGroup=(int)assignments.getParameterValue(loc);
logLikelihoodsVectorKnown[changedGroup]=false;
}
}
Set<NodeRef> allTips;
CompoundParameter traitParameter;
Parameter alpha;
Parameter clusterPrec ;
Parameter priorPrec ;
Parameter priorMean ;
Parameter assignments;
Parameter links;
Parameter means2;
Parameter means1;
Parameter locationDrift;
Parameter offsets;
boolean hasDrift;
private boolean depMatrixKnown= false;
private boolean[] dataMatrixKnown;
private boolean logLikelihoodKnown=false;
private double logLikelihood =0.0;
private boolean[] logLikelihoodsVectorKnown;
boolean proposedChangeDepMatrix=false;
boolean proposedChangeDataMatrix=false;
TreeModel treeModel;
String traitName;
//double[][] data;
double[][] depMatrix;
double[][] logDepMatrix;
double[] logLikelihoodsVector;
double[] storedLogLikelihoodsVector;
int numdata;
Parameter transformFactor;
double k0;
double v0;
double[][] T0Inv;
double[] m;
double logDetT0;
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String CLUSTER_PREC = "clusterPrec";
public final static String PRIOR_PREC = "priorPrec";
public final static String PRIOR_MEAN = "priorMean";
public final static String ASSIGNMENTS = "assignments";
public final static String LINKS = "links";
public final static String MEANS_1 = "clusterMeans1";
public final static String MEANS_2 = "clusterMeans2";
public final static String TRANSFORM_FACTOR = "transformFactor";
public final static String CHI = "chi";
public final static String OFFSETS = "offsets";
public final static String LOCATION_DRIFT = "locationDrift";
boolean integrate = false;
public String getParserName() {
return NP_ANTIGENIC_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
//String traitName = (String) xo.getAttribute(TRAIT_NAME);
XMLObject cxo = xo.getChild(CLUSTER_PREC);
Parameter clusterPrec = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(PRIOR_PREC);
Parameter priorPrec = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(PRIOR_MEAN);
Parameter priorMean = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(ASSIGNMENTS);
Parameter assignments = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(LINKS);
Parameter links = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(MEANS_2);
Parameter means2 = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(MEANS_1);
Parameter means1 = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(CHI);
Parameter chi = (Parameter) cxo.getChild(Parameter.class);
cxo = xo.getChild(TRANSFORM_FACTOR);
Parameter transformFactor = (Parameter) cxo.getChild(Parameter.class);
cxo=xo.getChild(LOCATION_DRIFT) ;
Parameter locationDrift= (Parameter) cxo.getChild(Parameter.class);
cxo=xo.getChild(OFFSETS);
Parameter offsets =(Parameter) cxo.getChild(Parameter.class);
boolean hasDrift = false;
if (offsets.getDimension()>1){
hasDrift=true;
}
TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
TreeTraitParserUtilities.TraitsAndMissingIndices returnValue =
utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, integrate);
// traitName = returnValue.traitName;
CompoundParameter traitParameter = returnValue.traitParameter;
return new NPAntigenicLikelihood(treeModel,traitParameter, assignments, links, chi,clusterPrec, priorMean,priorPrec,
transformFactor, means1,means2,locationDrift,offsets,hasDrift);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "conditional likelihood ddCRP";
}
public Class getReturnType() {
return NPAntigenicLikelihood.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new StringAttributeRule(TreeTraitParserUtilities.TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"),
new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}),
new ElementRule(PRIOR_PREC,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(CLUSTER_PREC,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(PRIOR_MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(ASSIGNMENTS,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(LINKS,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(TRANSFORM_FACTOR,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(MEANS_1,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(MEANS_2,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(CHI, new XMLSyntaxRule[] {
new ElementRule(Parameter.class),
}),
new ElementRule(OFFSETS, Parameter.class),
new ElementRule(LOCATION_DRIFT, Parameter.class),
new ElementRule(TreeModel.class),
};
};
}