/*
* IndianBuffetProcessPrior.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
import dr.math.Poisson;
import dr.math.distributions.PoissonDistribution;
import org.apache.commons.math.special.Beta;
/**
* @author Max Tolkoff
*/
public class IndianBuffetProcessPrior extends AbstractModelLikelihood implements MatrixSizePrior {
public IndianBuffetProcessPrior(Parameter alpha, Parameter beta, AdaptableSizeFastMatrixParameter data) {
super(null);
this.alpha=alpha;
alpha.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1));
addVariable(alpha);
this.beta=beta;
beta.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0, 1));
addVariable(beta);
this.data=data;
addVariable(data);
for (int i = 0; i < data.getRowDimension(); i++) {
if (data.getParameterValue(i, 0) != 0)
containsNonZeroElements[0] = true;
}
for (int i = 0; i <data.getColumnDimension() ; i++) {
for (int j = 0; j < data.getRowDimension(); j++) {
rowCount[i] += Math.abs(data.getParameterValue(j, i));
}
}
ncols = data.getColumnDimension();
}
private int factorial(int num){
if(num<0){
throw new RuntimeException("Cannot take a negative factorial");
}
else if(num==0){
return 1;
}
else
{
int fac=1;
for (int i = 0; i <num ; i++) {
fac*=(i+1);
}
return fac;
}
}
private double H(){
if(!betaKnown) {
H = 0;
for (int i = 0; i < data.getRowDimension(); i++) {
H += beta.getParameterValue(0) / (beta.getParameterValue(0) + i);
}
}
return H;
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if(ncols != data.getColumnDimension()){
int sum = 0;
for (int i = 0; i < data.getRowDimension(); i++) {
sum += data.getParameterValue(i, data.getColumnDimension()-1);
}
rowCount[data.getColumnDimension() - 1] = sum;
ncols = data.getColumnDimension();
}
else{
double value = data.getParameterValue(index);
int col = index / data.getRowDimension();
if(value == 0.0){
rowCount[col] -= 1;
if(rowCount[col] == 0){
containsNonZeroElements[col] = false;
}
}
else{
rowCount[col] += 1;
containsNonZeroElements[col] = true;
}
}
likelihoodKnown=false;
if(variable==beta)
betaKnown=false;
if(variable==data)
dataKnown=false;
}
@Override
protected void storeState() {
storedBetaKnown=betaKnown;
storedContainsNonZeroElements=containsNonZeroElements;
storedDataKnown=dataKnown;
storedLikelihoodKnown=likelihoodKnown;
storedLogLikelihood=logLikelihood;
storedRowCount=rowCount;
storedKPlus=KPlus;
storedH=H;
storedBottom=bottom;
storedSum2=sum2;
storedncols=ncols;
}
@Override
protected void restoreState() {
betaKnown=storedBetaKnown;
containsNonZeroElements=storedContainsNonZeroElements;
dataKnown=storedDataKnown;
likelihoodKnown=storedLikelihoodKnown;
logLikelihood=storedLogLikelihood;
rowCount=storedRowCount;
KPlus=storedKPlus;
H=storedH;
bottom=storedBottom;
sum2=storedSum2;
ncols=storedncols;
}
@Override
protected void acceptState() {
}
@Override
public Model getModel() {
return this;
}
@Override
public double getLogLikelihood() {
if(!likelihoodKnown){
logLikelihood=calculateLogLikelihood();
likelihoodKnown=true;
}
return logLikelihood;
}
private double calculateLogLikelihood(){
int sum;
if(!dataKnown) {
bottom=1;
boolean[] isExplored= new boolean[data.getColumnDimension()];
containsNonZeroElements = new boolean[data.getColumnDimension()];
rowCount = new int[data.getColumnDimension()];
boolean same;
for (int i = 0; i < data.getColumnDimension(); i++) {
sum = 1;
if (!isExplored[i]) {
for (int j = i + 1; j < data.getColumnDimension(); j++) {
same = true;
if (!isExplored[j]) {
for (int k = 0; k < data.getRowDimension(); k++) {
if (Math.abs(data.getParameterValue(k, i)) != Math.abs(data.getParameterValue(k, j)))
same = false;
// if (data.getParameterValue(k, j) != 0) {
// containsNonZeroElements[j] = true;
// }
// rowCount[j]+=data.getParameterValue(k,j);
}
}
if (same && containsNonZeroElements[j]) {
isExplored[j] = true;
sum += 1;
} else if (!containsNonZeroElements[j]) {
isExplored[j] = true;
}
}
}
bottom *= factorial(sum);
}
}
if(!dataKnown || !betaKnown){
sum2=0;
KPlus=0;
for (int i = 0; i <data.getColumnDimension() ; i++) {
if(containsNonZeroElements[i]) {
KPlus++;
sum2+=Beta.logBeta(rowCount[i], data.getRowDimension() + beta.getParameterValue(0) - rowCount[i]);
}
}
}
double p1=KPlus*Math.log(alpha.getParameterValue(0)*beta.getParameterValue(0)/bottom);
double p2=-alpha.getParameterValue(0)*H();
double p3=sum2;
betaKnown=true;
dataKnown=true;
return p1+p2+p3;
}
@Override
public double getSizeLogLikelihood() {
PoissonDistribution poisson = new PoissonDistribution(alpha.getParameterValue(0) * H());
calculateLogLikelihood();
return poisson.logPdf(KPlus) - Math.log(1 - Math.exp(-poisson.mean()));
}
public int[] getRowCount() {
return rowCount;
}
public AdaptableSizeFastMatrixParameter getData() {
return data;
}
@Override
public void makeDirty() {
betaKnown=false;
dataKnown=false;
}
boolean likelihoodKnown;
boolean storedLikelihoodKnown;
double logLikelihood;
double storedLogLikelihood;
boolean betaKnown=false;
boolean dataKnown=false;
boolean storedDataKnown;
boolean storedBetaKnown;
int[] rowCount;
int[] storedRowCount;
int KPlus;
int storedKPlus;
boolean[] containsNonZeroElements;
boolean[] storedContainsNonZeroElements;
double H;
double storedH;
int bottom;
int storedBottom;
double sum2;
double storedSum2;
int ncols;
int storedncols;
AdaptableSizeFastMatrixParameter data;
Parameter alpha;
Parameter beta;
}