/* * DistanceDependentCRPGibbsOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.antigenic; import java.util.logging.Logger; import dr.inference.model.Likelihood; import dr.inference.model.Parameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.SimpleMCMCOperator; import dr.math.MathUtils; import dr.xml.*; /** * A Gibbs operator for allocation of items to clusters under a distance dependent Chinese restaurant process. * * @author Gabriela Cybis * @author Marc Suchard */ public class DistanceDependentCRPGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { public final static String DDCRP_GIBBS_OPERATOR = "distanceDependentCRPGibbsOperator"; private final Parameter chiParameter; private final double[][] depMatrix; public NPAntigenicLikelihood modelLikelihood; private Parameter links = null; private Parameter assignments = null; // double[][] x; double k0; double v0; double[] m; double[][] T0Inv; double logDetT0; public DistanceDependentCRPGibbsOperator(Parameter links, Parameter assignments, Parameter chiParameter, NPAntigenicLikelihood Likelihood, double weight) { this.links = links; this.assignments = assignments; this.modelLikelihood = Likelihood; this.chiParameter = chiParameter; this.depMatrix = Likelihood.getLogDepMatrix(); for (int i=0;i<links.getDimension();i++){ links.setParameterValue(i, i); } setWeight(weight); // double[][] x=modelLikelihood.getData(); // modelLikelihood.printInformtion(x[0][0]); this.m = new double[2]; m[0]= modelLikelihood.priorMean.getParameterValue(0); m[1]= modelLikelihood.priorMean.getParameterValue(1); this.v0 = 2; this.k0= modelLikelihood.priorPrec.getParameterValue(0)/modelLikelihood.clusterPrec.getParameterValue(0); this.T0Inv= new double[2][2]; T0Inv[0][0]= v0/modelLikelihood.clusterPrec.getParameterValue(0); T0Inv[1][1]= v0/modelLikelihood.clusterPrec.getParameterValue(0); T0Inv[1][0]= 0.0; T0Inv[0][1]= 0.0; this.logDetT0= -Math.log(T0Inv[0][0]*T0Inv[1][1]); } /** * @return the parameter this operator acts on. */ public Parameter getParameter() { return (Parameter) links; } /** * @return the Variable this operator acts on. public Variable getVariable() { return clusteringParameter; } */ /** * change the parameter and return the hastings ratio. */ public final double doOperation() { int index = MathUtils.nextInt(links.getDimension()); int oldGroup = (int)assignments.getParameterValue(index); /* * Set index customer link to index and all connected to it to a new assignment (min value empty) */ int minEmp = minEmpty(modelLikelihood.getLogLikelihoodsVector()); links.setParameterValue(index, index); int[] visited = connected(index, links); int ii = 0; while (visited[ii]!=0){ assignments.setParameterValue(visited[ii]-1, minEmp); ii++; } /* * Adjust likvector for group separated */ modelLikelihood.setLogLikelihoodsVector(oldGroup,modelLikelihood.getLogLikGroup(oldGroup) ); modelLikelihood.setLogLikelihoodsVector(minEmp,modelLikelihood.getLogLikGroup(minEmp) ); int maxFull = maxFull( modelLikelihood.getLogLikelihoodsVector()); double[] liks = modelLikelihood.getLogLikelihoodsVector(); /* * computing likelihoods of joint groups */ double[] crossedLiks = new double[maxFull+1]; for (int ll=0;ll<maxFull+1;ll++ ){ if (ll!=minEmp){ crossedLiks[ll]=getLogLik2Group(ll,minEmp); } } /* * Add logPrior */ double[] logP = new double[links.getDimension()]; for (int jj=0; jj<links.getDimension(); jj++){ logP[jj] += depMatrix[index][jj]; int n = (int)assignments.getParameterValue(jj); if (n!= minEmp){ logP[jj]+=crossedLiks[n] - liks[n] - liks[minEmp]; } } logP[index]= Math.log(chiParameter.getParameterValue(0)); /* * possibilidade de mandar p zero as probs muito pequenas */ /* * Gibbs sampling */ this.rescale(logP); // Improve numerical stability this.exp(logP); // Transform back to probability-scale int k = MathUtils.randomChoicePDF(logP); links.setParameterValue(index, k); int newGroup = (int)assignments.getParameterValue(k); ii = 0; while (visited[ii]!=0){ assignments.setParameterValue(visited[ii]-1, newGroup); ii++; } /* * updating conditional likelihood vector */ modelLikelihood.setLogLikelihoodsVector(newGroup, modelLikelihood.getLogLikGroup(newGroup)); if (newGroup!=minEmp){ modelLikelihood.setLogLikelihoodsVector(minEmp, 0); } sampleMeans(maxFull); return 0.0; } /* * find min Empty */ public int minEmpty(double[] logLikVector){ int isEmpty=0; int i =0; while (isEmpty==0){ if(logLikVector[i]==0){ isEmpty=1;} else { if(i==logLikVector.length-1){isEmpty=1;} i++;} } return i; } /* * find max Full */ public int maxFull(double[] logLikVector){ int isEmpty=1; int i =logLikVector.length-1; while (isEmpty==1){ if(logLikVector[i]!=0){ isEmpty=0;} else {i--;} } return i; } /* * find customers connected to i */ public int[] connected(int i, Parameter clusteringParameter){ int n = clusteringParameter.getDimension(); int[] visited = new int[n+1]; visited[0]=i+1; int tv=1; for(int j=0;j<n;j++){ if(visited[j]!=0){ int curr = visited[j]-1; /*look forward */ int forward = (int) clusteringParameter.getParameterValue(curr); visited[tv] = forward+1; tv++; // Check to see if is isn't already on the list for(int ii=0; ii<tv-1; ii++){ if(visited[ii]==forward+1){ tv--; visited[tv]=0; } } /*look back */ for (int jj=0; jj<n;jj++){ if((int)clusteringParameter.getParameterValue(jj)==curr){ visited[tv]= jj+1; tv++; for(int ii=0; ii<tv-1; ii++){ if(visited[ii]==jj+1){ tv--; visited[tv]=0; } } } } }} return visited; } private void printInformtion(Parameter par) { StringBuffer sb = new StringBuffer("parameter \n"); for(int j=0; j<par.getDimension(); j++){ sb.append(par.getParameterValue(j)); } Logger.getLogger("dr.evomodel").info(sb.toString()); }; /* OLD public double getLogLikGroup(int groupNumber){ double L =0.0; int ngroup=0; for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == groupNumber){ ngroup++;}} if (ngroup != 0){ double[] group = new double[2*ngroup]; int count = 0; for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == groupNumber){ group[count] = modelLikelihood.getData()[i][0]; group[ngroup+count] = modelLikelihood.getData()[i][1]; count+=1;}} double[][] var = new double[2*ngroup][2*ngroup]; double[] mean = new double[2*ngroup]; double m0 = modelLikelihood.getPriorMean().getParameterValue(0); double m1 = modelLikelihood.getPriorMean().getParameterValue(1); double vp = modelLikelihood.getPriorPrec().getParameterValue(0); double vc = modelLikelihood.getClusterPrec().getParameterValue(0); for (int i=0; i<ngroup; i++){ mean[i]=m0; mean[ngroup+i]=m1; for (int l=0;l<ngroup;l++){ var[i][ngroup+l]=0; var[ngroup+i][l]=0; if (l==i){var[i][l]= vp+ vc; var[ngroup+i][ngroup+l]= vp+ vc;} else { var[i][l] = vp; var[ngroup+i][ngroup+l]= vp;} } } double[][] precision = new SymmetricMatrix(var).inverse().toComponents(); L = new MultivariateNormalDistribution(mean, precision).logPdf(group); } return L; } */ public double getLogLik2Group(int group1, int group2){ throw new UnsupportedOperationException("This method has been commented out because of build errors"); // double L =0.0; // // // int ngroup1=0; // for (int i=0;i<assignments.getDimension(); i++){ // if((int) assignments.getParameterValue(i) == group1 ){ // ngroup1++;}} // // int ngroup2=0; // for (int i=0;i<assignments.getDimension(); i++){ // if((int) assignments.getParameterValue(i) == group2 ){ // ngroup2++;}} // // int ngroup = (ngroup1+ngroup2); // // if (ngroup != 0){ // double[][] group = new double[ngroup][2]; // // double mean[]=new double[2]; // // // int count = 0; // for (int i=0;i<assignments.getDimension(); i++){ // if((int) assignments.getParameterValue(i) == group1 ){ // group[count][0] = modelLikelihood.getData()[i][0]; // group[count][1] = modelLikelihood.getData()[i][1]; // mean[0]+=group[count][0]; // mean[1]+=group[count][1]; // count+=1;}} // // for (int i=0;i<assignments.getDimension(); i++){ // if((int) assignments.getParameterValue(i) == group2 ){ // group[count][0] = modelLikelihood.getData()[i][0]; // group[count][1] = modelLikelihood.getData()[i][1]; // mean[0]+=group[count][0]; // mean[1]+=group[count][1]; // count+=1;}} // // // // mean[0]/=ngroup; // mean[1]/=ngroup; // // // // double kn= k0+ngroup; // double vn= v0+ngroup; // // // double[][] sumdif=new double[2][2]; // // for(int i=0;i<ngroup;i++){ // sumdif[0][0]+= (group[i][0]-mean[0])*(group[i][0]-mean[0]); // sumdif[0][1]+= (group[i][0]-mean[0])*(group[i][1]-mean[1]); // sumdif[1][0]+= (group[i][0]-mean[0])*(group[i][1]-mean[1]); // sumdif[1][1]+= (group[i][1]-mean[1])*(group[i][1]-mean[1]); // } // // // // double[][] TnInv = new double[2][2]; // TnInv[0][0]=T0Inv[0][0]+ngroup*(k0/kn)*(mean[0]-m[0])*(mean[0]-m[0])+sumdif[0][0]; // TnInv[0][1]=T0Inv[0][1]+ngroup*(k0/kn)*(mean[1]-m[1])*(mean[0]-m[0])+sumdif[0][1]; // TnInv[1][0]=T0Inv[1][0]+ngroup*(k0/kn)* (mean[0]-m[0])*(mean[1]-m[1])+sumdif[1][0]; // TnInv[1][1]=T0Inv[1][1]+ngroup*(k0/kn)* (mean[1]-m[1])*(mean[1]-m[1])+sumdif[1][1]; // // // double logDetTn=-Math.log(TnInv[0][0]*TnInv[1][1]-TnInv[0][1]*TnInv[1][0]); // // // L+= -(ngroup)*Math.log(Math.PI); // L+= Math.log(k0) - Math.log(kn); // L+= (vn/2)*logDetTn - (v0/2)*logDetT0; // L+= GammaFunction.lnGamma(vn/2)+ GammaFunction.lnGamma((vn/2)-0.5); // L+=-GammaFunction.lnGamma(v0/2)- GammaFunction.lnGamma((v0/2)-0.5); // // // // // // // // // } // // return L; } /*public double getLogLik2Group(int group1, int group2){ double L =0.0; int ngroup1=0; for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == group1 ){ ngroup1++;}} int ngroup2=0; for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == group2 ){ ngroup2++;}} int ngroup = (ngroup1+ngroup2); if (ngroup != 0){ double[] group = new double[2*ngroup]; int count = 0; for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == group1 ){ group[count] = modelLikelihood.getData()[i][0]; group[count+ngroup] = modelLikelihood.getData()[i][1]; count+=1;}} for (int i=0;i<assignments.getDimension(); i++){ if((int) assignments.getParameterValue(i) == group2 ){ group[count] = modelLikelihood.getData()[i][0]; group[count+ngroup] = modelLikelihood.getData()[i][1]; count+=1;}} double[][] var = new double[2*ngroup][2*ngroup]; double[] mean = new double[2*ngroup]; double m0 = modelLikelihood.getPriorMean().getParameterValue(0); double m1 = modelLikelihood.getPriorMean().getParameterValue(1); double vp = modelLikelihood.getPriorPrec().getParameterValue(0); double vc = modelLikelihood.getClusterPrec().getParameterValue(0); for (int i=0; i<ngroup; i++){ mean[i]=m0; mean[i+ngroup]=m1; for (int l=0;l<ngroup;l++){ var[i][ngroup+l]=0; var[ngroup+i][l]=0; if (l==i){var[i][l]= vp+ vc; var[ngroup+i][ngroup+l]= vp+ vc;} else { var[i][l] = vp; var[ngroup+i][ngroup+l]= vp;} } } double[][] precision = new SymmetricMatrix(var).inverse().toComponents(); L = new MultivariateNormalDistribution(mean, precision).logPdf(group); } return L; } */ public void sampleMeans(int maxFull){ double[][] means=new double[maxFull+2][2]; throw new UnsupportedOperationException("This method has been commented out because of build errors"); //sample mean vector for each cluster // for (int i=0; i<maxFull+1; i++){ // // // Find all elements in cluster // // int ngroup=0; // for (int ii=0;ii<assignments.getDimension(); ii++){ // if((int) assignments.getParameterValue(ii) == i ){ // ngroup++;}} // // // if (ngroup != 0){ // double[][] group = new double[ngroup][2]; // double[] groupMean=new double[2]; // // int count = 0; // for (int ii=0;ii<assignments.getDimension(); ii++){ // if((int) assignments.getParameterValue(ii) == i ){ // group[count][0] = modelLikelihood.getData()[ii][0]; // group[count][1] = modelLikelihood.getData()[ii][1]; // groupMean[0]+=group[count][0]; // groupMean[1]+=group[count][1]; // count+=1;}} // // groupMean[0]/=ngroup; // groupMean[1]/=ngroup; // // // // double kn= k0+ngroup; // double vn= v0+ngroup; // // // double[][] sumdif=new double[2][2]; // // for(int jj=0;jj<ngroup;jj++){ // sumdif[0][0]+= (group[jj][0]-groupMean[0])*(group[jj][0]-groupMean[0]); // sumdif[0][1]+= (group[jj][0]-groupMean[0])*(group[jj][1]-groupMean[1]); // sumdif[1][0]+= (group[jj][0]-groupMean[0])*(group[jj][1]-groupMean[1]); // sumdif[1][1]+= (group[jj][1]-groupMean[1])*(group[jj][1]-groupMean[1]); // } // // // // double[][] TnInv = new double[2][2]; // TnInv[0][0]=T0Inv[0][0]+ngroup*(k0/kn)*(groupMean[0]-m[0])*(groupMean[0]-m[0])+sumdif[0][0]; // TnInv[0][1]=T0Inv[0][1]+ngroup*(k0/kn)*(groupMean[1]-m[1])*(groupMean[0]-m[0])+sumdif[0][1]; // TnInv[1][0]=T0Inv[1][0]+ngroup*(k0/kn)* (groupMean[0]-m[0])*(groupMean[1]-m[1])+sumdif[1][0]; // TnInv[1][1]=T0Inv[1][1]+ngroup*(k0/kn)* (groupMean[1]-m[1])*(groupMean[1]-m[1])+sumdif[1][1]; // // Matrix Tn = new SymmetricMatrix(TnInv).inverse(); // // // double[] posteriorMean=new double[2]; // // compute posterior mean // // posteriorMean[0]= (k0*m[0] +ngroup*groupMean[0])/(k0+ngroup); // posteriorMean[1]= (k0*m[1] +ngroup*groupMean[1])/(k0+ngroup); // // // // //compute posterior Precision // double[][] posteriorPrecision=new WishartDistribution(vn,Tn.toComponents()).nextWishart(); // posteriorPrecision[0][0]*= kn; // posteriorPrecision[1][0]*= kn; // posteriorPrecision[0][1]*= kn; // posteriorPrecision[1][1]*= kn; // // // // // double[] sample= new MultivariateNormalDistribution(posteriorMean,posteriorPrecision).nextMultivariateNormal(); // means[i][0]=sample[0]; // means[i][1]=sample[1]; // } // } // // //Fill in cluster means for each observation // // for (int j=0; j<assignments.getDimension();j++){ // double[] group=new double[2]; // group[0]=means[(int)assignments.getParameterValue(j)][0]; // group[1]=means[(int)assignments.getParameterValue(j)][1]; // // modelLikelihood.setMeans(j, group); // } } private void exp(double[] logX) { for (int i = 0; i < logX.length; ++i) { logX[i] = Math.exp(logX[i]); // if(logX[i]<1E-8){logX[i]=0;} } } private void rescale(double[] logX) { throw new UnsupportedOperationException("This method has been commented out because of build errors"); // double max = this.max(logX); // for (int i = 0; i < logX.length; ++i) { // if(logX[i] == Double.NEGATIVE_INFINITY){ // modelLikelihood.printInformtion(logX[i]); // logX[i]=-1E16; // } // if(logX[i]==Double.POSITIVE_INFINITY){ // modelLikelihood.printInformtion(logX[i]); // logX[i]=1E16; // } // logX[i] -= max; // // } } private double max(double[] x) { double max = x[0]; for (double xi : x) { if (xi > max) { max = xi; } } return max; } //MCMCOperator INTERFACE public final String getOperatorName() { return DDCRP_GIBBS_OPERATOR; } public final void optimize(double targetProb) { throw new RuntimeException("This operator cannot be optimized!"); } public boolean isOptimizing() { return false; } public void setOptimizing(boolean opt) { throw new RuntimeException("This operator cannot be optimized!"); } public double getMinimumAcceptanceLevel() { return 0.1; } public double getMaximumAcceptanceLevel() { return 0.4; } public double getMinimumGoodAcceptanceLevel() { return 0.20; } public double getMaximumGoodAcceptanceLevel() { return 0.30; } public String getPerformanceSuggestion() { if (Utils.getAcceptanceProbability(this) < getMinimumAcceptanceLevel()) { return ""; } else if (Utils.getAcceptanceProbability(this) > getMaximumAcceptanceLevel()) { return ""; } else { return ""; } } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public final static String CHI = "chi"; public final static String LIKELIHOOD = "likelihood"; public final static String ASSIGNMENTS = "assignments"; public final static String LINKS = "links"; public final static String DEP_MATRIX = "depMatrix"; public String getParserName() { return DDCRP_GIBBS_OPERATOR; } /* (non-Javadoc) * @see dr.xml.AbstractXMLObjectParser#parseXMLObject(dr.xml.XMLObject) */ public Object parseXMLObject(XMLObject xo) throws XMLParseException { double weight = xo.getDoubleAttribute(MCMCOperator.WEIGHT); XMLObject cxo = xo.getChild(ASSIGNMENTS); Parameter assignments = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(LINKS); Parameter links = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(CHI); Parameter chiParameter = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(LIKELIHOOD); NPAntigenicLikelihood likelihood = (NPAntigenicLikelihood)cxo.getChild(NPAntigenicLikelihood.class); return new DistanceDependentCRPGibbsOperator( links, assignments, chiParameter,likelihood, weight); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "An operator that picks a new allocation of an item to a cluster under the Dirichlet process."; } public Class getReturnType() { return DistanceDependentCRPGibbsOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(MCMCOperator.WEIGHT), new ElementRule(CHI, new XMLSyntaxRule[] { new ElementRule(Parameter.class), }), new ElementRule(LIKELIHOOD, new XMLSyntaxRule[] { new ElementRule(Likelihood.class), }, true), new ElementRule(ASSIGNMENTS, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(LINKS, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), }; }; public int getStepCount() { return 1; } }