package dr.inference.distribution; import dr.inference.model.*; import dr.math.MathUtils; import dr.math.matrixAlgebra.CholeskyDecomposition; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.SymmetricMatrix; import java.util.List; import java.util.Vector; /** * Created by max on 4/6/16. */ public class DeterminentalPointProcessPrior extends AbstractModelLikelihood{ double theta; MatrixParameterInterface data; boolean likelihoodKnown = false; boolean storedLikelihoodKnown; double logLikelihood; double storedLogLikelihood; double[][] relationshipList; double[][] storedRelationshipList; int size; int storedSize; Vector<Integer> changedList; Vector<Integer> storedChangedList; public DeterminentalPointProcessPrior(String name, double theta, MatrixParameterInterface data) { super(name); this.theta = theta; this.data = data; addVariable(data); relationshipList = new double[data.getColumnDimension()][data.getColumnDimension()]; storedRelationshipList = new double[data.getColumnDimension()][data.getColumnDimension()]; size = data.getColumnDimension(); for (int i = 0; i < data.getRowDimension(); i++) { for (int j = 0; j < data.getColumnDimension(); j++) { if(i % (j+1) == 0) { data.setParameterValueQuietly(i, j, 0); } } } reset(); changedList = new Vector<Integer>(); } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void storeState() { storedLikelihoodKnown = likelihoodKnown; storedLogLikelihood = logLikelihood; storedSize = size; // System.out.println("first"); // for (int i = 0; i < relationshipList.length; i++) { // for (int j = 0; j < relationshipList.length ; j++) { // System.out.println(i + " + " + j + ": " + relationshipList[i][j]); // } // } // if(relationshipList.length != storedRelationshipList.length) // storedRelationshipList = new double[size][size]; // for(int i = 0; i < relationshipList.length; i++) // storedRelationshipList[i] = relationshipList[i].clone(); // System.out.println("stored"); // for (int i = 0; i < relationshipList.length; i++) { // for (int j = 0; j < relationshipList.length ; j++) { // System.out.println(i + " + " + j + ": " + storedRelationshipList[i][j]); // } // } // System.arraycopy(relationshipList, 0, storedRelationshipList, 0, relationshipList.length); storedChangedList = (Vector<Integer>) changedList.clone(); } @Override protected void restoreState() { likelihoodKnown = storedLikelihoodKnown; logLikelihood = storedLogLikelihood; double[][] relationshipListTemp = relationshipList; relationshipList = storedRelationshipList; storedRelationshipList = relationshipListTemp; size = storedSize; Vector<Integer> changedListTemp = changedList; changedList = storedChangedList; storedChangedList = changedListTemp; } @Override protected void acceptState() { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { likelihoodKnown = false; if(index == -1) reset(); else { changedList.add(index); } } @Override public Model getModel() { return this; } @Override public double getLogLikelihood() { makeDirty(); if(!likelihoodKnown) { logLikelihood = computeLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } public double computeLogLikelihood(){ int newSize = data.getColumnDimension(); while(!changedList.isEmpty()){ int index = changedList.remove(0); int row = index % data.getRowDimension(); int col = index / data.getRowDimension(); for (int i = 0; i < data.getColumnDimension(); i++) { if(col != i){ if(data.getParameterValue(row, col) == data.getParameterValue(row, i)){ relationshipList[col][i] *= Math.exp(1 / (theta * theta)); relationshipList[i][col]=relationshipList[col][i]; } else{ relationshipList[col][i] *= Math.exp(- 1 / (theta * theta)); relationshipList[i][col] = relationshipList[col][i]; } } } } // reset(); if (newSize != size){ size = newSize; relationshipList = new double[size][size]; reset(); } CholeskyDecomposition chol = null; try { chol = new CholeskyDecomposition(relationshipList); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } double product = 0; for (int i = 0; i <newSize ; i++) { product += Math.log(chol.getL()[i][i]); } product *= 2; return product; } @Override public void makeDirty() { likelihoodKnown = false; changedList = new Vector<Integer>(); // storedChangedList = new Vector<Integer>(); reset(); } public void reset(){ if(relationshipList.length != data.getColumnDimension()){ relationshipList = new double[data.getColumnDimension()][data.getColumnDimension()]; } for (int i = 0; i < data.getColumnDimension(); i++) { for (int j = 0; j < i; j++) { int count = 0; for (int k = 0; k < data.getRowDimension(); k++) { count += Math.abs(data.getParameterValue(k, i) - data.getParameterValue(k, j)); } relationshipList[i][j] = Math.exp(- count / (theta * theta)); relationshipList[j][i] = relationshipList[i][j]; } } for (int i = 0; i <data.getColumnDimension() ; i++) { relationshipList[i][i] = 1; } } }