/*
* MixtureModelLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.math.LogTricks;
import dr.math.MathUtils;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.*;
import java.util.*;
import java.util.logging.Logger;
/**
* @author Marc A. Suchard
* @author Andrew Rambaut
* @author Alexander V. Alekseyenko
*/
public class MixtureModelLikelihood extends AbstractModelLikelihood implements Citable, Loggable {
public static final String MIXTURE_MODEL_ALIAS = "integratedMixtureModel";
public static final String MIXTURE_MODEL = "mixtureModelLikelihood";
// public static final String MIXTURE_WEIGHTS = "weights";
public static final String NORMALIZE = "normalize";
public MixtureModelLikelihood(List<Likelihood> likelihoodList, Parameter weights) {
super(MIXTURE_MODEL);
this.likelihoodList = likelihoodList;
this.mixtureWeights = weights;
for (Likelihood model : likelihoodList) {
if (model.getModel() != null) {
addModel(model.getModel());
}
}
addVariable(mixtureWeights);
StringBuilder sb = new StringBuilder();
sb.append("Constructing a finite mixture model\n");
sb.append("\tComponents:\n");
for (Likelihood like : likelihoodList) {
Model model = like.getModel();
sb.append("\t\t\t").append(
model != null ?
like.getModel().getId() : "anonymous"
).append("\n");
}
// sb.append("\tMixing parameter: ").append(mixtureWeights.getId()).append("\n");
sb.append("\tPlease cite:\n");
sb.append(Utils.getCitationString((this)));
Logger.getLogger("dr.inference.model").info(sb.toString());
}
protected void handleModelChangedEvent(Model model, Object object, int index) {
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
}
protected void storeState() {
}
protected void restoreState() {
}
protected void acceptState() {
}
public Model getModel() {
return this;
}
public double getLogLikelihood() {
return getLogLikelihoodSum();
}
private double getLogLikelihoodSum() {
double logSum = Double.NEGATIVE_INFINITY;
double bad = 0;
double[] weights;
if(useParameter)
weights = MathUtils.getNormalized(mixtureWeights.getParameterValues());
else{
weights = new double[likelihoodList.size()];
for(int i=0; i<likelihoodList.size(); ++i) weights[i]=1.0/likelihoodList.size();
}
for (int i = 0; i < likelihoodList.size(); ++i) {
double pi = weights[i];
if (pi > 0.0) {
logSum = LogTricks.logSum(logSum,
Math.log(pi) + likelihoodList.get(i).getLogLikelihood());
}
bad += likelihoodList.get(i).getLogLikelihood() * pi;
}
if (powerPrior) {
return bad;
} else {
return logSum;
}
}
private double getWeight(final int dim) {
if (useParameter) {
return mixtureWeights.getParameterValue(dim);
} else {
return 1.0 / likelihoodList.size();
}
}
public void makeDirty() {
// Do nothing
}
public LogColumn[] getColumns() {
LogColumn[] columns = new LogColumn[likelihoodList.size()];
for (int i = 0; i < likelihoodList.size(); ++i) {
columns[i] = new MixtureColumn(MIXTURE_MODEL, i);
}
return columns;
}
private class MixtureColumn extends NumberColumn {
public MixtureColumn(String label, int dim) {
super(label);
this.dim = dim;
}
@Override
public double getDoubleValue() {
double logSum = getLogLikelihoodSum();
double logLike = likelihoodList.get(dim).getLogLikelihood() + Math.log(getWeight(dim));
double x = logLike - logSum;
if (inProbSpace) {
x = Math.exp(x);
}
// System.err.println(logLike + " : " + logSum + " " + dim + " " + x);
// System.exit(-1);
return x;
}
private final int dim;
private final boolean inProbSpace = true;
}
private static final boolean useParameter = true;
private static final boolean powerPrior = false;
public static XMLObjectParser PARSER_ALIAS = new AbstractXMLObjectParser() {
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
try{
return ((AbstractXMLObjectParser)PARSER).parseXMLObject(xo);
}
catch(XMLParseException e){
throw(e);
}
}
public XMLSyntaxRule[] getSyntaxRules() {
return PARSER.getSyntaxRules();
}
public String getParserDescription() {
return PARSER.getParserDescription();
}
public Class getReturnType() {
return PARSER.getReturnType();
}
public String getParserName() {
return MIXTURE_MODEL_ALIAS;
}
};
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return MIXTURE_MODEL;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
Parameter weights = (Parameter) xo.getChild(Parameter.class);
List<Likelihood> likelihoodList = new ArrayList<Likelihood>();
for (int i = 0; i < xo.getChildCount(); i++) {
if (xo.getChild(i) instanceof Likelihood)
likelihoodList.add((Likelihood) xo.getChild(i));
}
if (weights.getDimension() != likelihoodList.size()) {
throw new XMLParseException("Dim of " + weights.getId() + " does not match the number of likelihoods");
}
if (xo.hasAttribute(NORMALIZE)) {
if (xo.getBooleanAttribute(NORMALIZE)) {
double sum = 0;
for (int i = 0; i < weights.getDimension(); i++)
sum += weights.getParameterValue(i);
for (int i = 0; i < weights.getDimension(); i++)
weights.setParameterValue(i, weights.getParameterValue(i) / sum);
}
}
return new MixtureModelLikelihood(likelihoodList, weights);
}
private boolean normalized(Parameter p) {
double sum = 0;
for (int i = 0; i < p.getDimension(); i++)
sum += p.getParameterValue(i);
return (sum == 1.0);
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "This element represents a finite mixture of likelihood models.";
}
public Class getReturnType() {
return CompoundModel.class;
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newBooleanRule(NORMALIZE, true),
new ElementRule(Likelihood.class, 2, Integer.MAX_VALUE),
new ElementRule(Parameter.class)
};
};
private final Parameter mixtureWeights;
List<Likelihood> likelihoodList;
@Override
public Citation.Category getCategory() {
return Citation.Category.MISC;
}
@Override
public String getDescription() {
return "Mixture model";
}
@Override
public List<Citation> getCitations() {
return Collections.singletonList(CommonCitations.LEMEY_MIXTURE_2012);
}
}