/* * WorkingPriorParsers.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inferencexml.distribution; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.MultivariateDistributionLikelihood; import dr.inference.model.Likelihood; import dr.inference.model.Statistic; import dr.inference.trace.LogFileTraces; import dr.inference.trace.TraceException; import dr.math.distributions.*; import dr.util.FileHelpers; import dr.xml.*; import java.io.File; import java.io.FileNotFoundException; /** * @author Guy Baele * @author Marc Suchard */ public class WorkingPriorParsers { public final static boolean DEBUG = true; public static final String NORMAL_REFERENCE_PRIOR = "normalReferencePrior"; public static final String NORMAL_WORKING_PRIOR = "normalWorkingPrior"; public static final String LOG_TRANSFORMED_NORMAL_REFERENCE_PRIOR = "logTransformedNormalReferencePrior"; public static final String LOG_TRANSFORMED_NORMAL_WORKING_PRIOR = "logTransformedNormalWorkingPrior"; public static final String LOGIT_TRANSFORMED_NORMAL_REFERENCE_PRIOR = "logitTransformedNormalReferencePrior"; public static final String LOGIT_TRANSFORMED_NORMAL_WORKING_PRIOR = "logitTransformedNormalWorkingPrior"; public static final String GAMMA_REFERENCE_PRIOR = "gammaReferencePrior"; public static final String GAMMA_WORKING_PRIOR = "gammaWorkingPrior"; public static final String PARAMETER_COLUMN = "parameterColumn"; public static final String DIMENSION = "dimension"; public static final String UPPERLIMIT = "upperLimit"; /** * A special parser that reads a convenient short form of reference priors on parameters. */ public static XMLObjectParser GAMMA_REFERENCE_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return GAMMA_REFERENCE_PRIOR; } public String[] getParserNames() { return new String[]{getParserName(), GAMMA_WORKING_PRIOR}; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FileHelpers.FILE_NAME); try { File file = new File(fileName); String parent = file.getParent(); if (!file.isAbsolute()) { parent = System.getProperty("user.dir"); } file = new File(parent, fileName); fileName = file.getAbsolutePath(); String parameterName = xo.getStringAttribute(PARAMETER_COLUMN); int dimension = 1; if (xo.hasAttribute(DIMENSION)) { dimension = xo.getIntegerAttribute(DIMENSION); } if (dimension <= 0) { throw new XMLParseException("Column '" + parameterName + "' has dimension smaller than 1."); } LogFileTraces traces = new LogFileTraces(fileName, file); traces.loadTraces(); long maxState = traces.getMaxState(); // leaving the burnin attribute off will result in 10% being used long burnin = xo.getAttribute("burnin", maxState / 10); if (burnin < 0 || burnin >= maxState) { burnin = maxState / 10; System.out.println("WARNING: Burn-in larger than total number of states - using 10%"); } traces.setBurnIn(burnin); if (dimension == 1) { int traceIndexParameter = -1; for (int i = 0; i < traces.getTraceCount(); i++) { String traceName = traces.getTraceName(i); if (traceName.trim().equals(parameterName)) { traceIndexParameter = i; } } if (traceIndexParameter == -1) { throw new XMLParseException("GammaKDEDistribution: Column '" + parameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; DistributionLikelihood likelihood = new DistributionLikelihood(new GammaKDEDistribution((Double[]) traces.getValues(traceIndexParameter).toArray(parameterSamples))); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } else { //dimension > 1 GammaKDEDistribution[] arrayKDE = new GammaKDEDistribution[dimension]; for (int i = 0; i < dimension; i++) { //look for parameterName1, parameterName2, ... if necessary String newParameterName = parameterName + (i+1); int traceIndexParameter = -1; for (int j = 0; j < traces.getTraceCount(); j++) { String traceName = traces.getTraceName(j); if (traceName.trim().equals(newParameterName)) { traceIndexParameter = j; } } if (traceIndexParameter == -1) { throw new XMLParseException("GammaKDEDistribution: Column '" + newParameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); arrayKDE[i] = new GammaKDEDistribution(parameterSamples); } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(new MultivariateKDEDistribution(arrayKDE)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } } catch (FileNotFoundException fnfe) { throw new XMLParseException("File '" + fileName + "' can not be opened for " + getParserName() + " element."); } catch (java.io.IOException ioe) { throw new XMLParseException(ioe.getMessage()); } catch (TraceException e) { throw new XMLParseException(e.getMessage()); } } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule("fileName"), AttributeRule.newStringRule("parameterColumn"), AttributeRule.newIntegerRule("burnin"), new ElementRule(Statistic.class, 1, Integer.MAX_VALUE) }; public String getParserDescription() { return "Calculates the reference prior probability of some data under a given normal distribution."; } public Class getReturnType() { return Likelihood.class; } }; /** * A special parser that reads a convenient short form of reference priors on parameters. */ public static XMLObjectParser LOG_TRANSFORMED_NORMAL_REFERENCE_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return LOG_TRANSFORMED_NORMAL_REFERENCE_PRIOR; } public String[] getParserNames() { return new String[]{getParserName(), LOG_TRANSFORMED_NORMAL_WORKING_PRIOR}; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FileHelpers.FILE_NAME); try { File file = new File(fileName); String parent = file.getParent(); if (!file.isAbsolute()) { parent = System.getProperty("user.dir"); } file = new File(parent, fileName); fileName = file.getAbsolutePath(); String parameterName = xo.getStringAttribute(PARAMETER_COLUMN); int dimension = 1; if (xo.hasAttribute(DIMENSION)) { dimension = xo.getIntegerAttribute(DIMENSION); } if (dimension <= 0) { throw new XMLParseException("Column '" + parameterName + "' has dimension smaller than 1."); } LogFileTraces traces = new LogFileTraces(fileName, file); traces.loadTraces(); long maxState = traces.getMaxState(); // leaving the burnin attribute off will result in 10% being used long burnin = xo.getAttribute("burnin", maxState / 10); if (burnin < 0 || burnin >= maxState) { burnin = maxState / 10; System.out.println("WARNING: Burn-in larger than total number of states - using 10%"); } traces.setBurnIn(burnin); if (dimension == 1) { int traceIndexParameter = -1; for (int i = 0; i < traces.getTraceCount(); i++) { String traceName = traces.getTraceName(i); if (traceName.trim().equals(parameterName)) { traceIndexParameter = i; } } if (traceIndexParameter == -1) { throw new XMLParseException("LogTransformedNormalKDEDistribution: Column '" + parameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); DistributionLikelihood likelihood = new DistributionLikelihood(new LogTransformedNormalKDEDistribution(parameterSamples)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } else { //dimension > 1 LogTransformedNormalKDEDistribution[] arrayKDE = new LogTransformedNormalKDEDistribution[dimension]; for (int i = 0; i < dimension; i++) { //look for parameterName1, parameterName2, ... if necessary String newParameterName = parameterName + (i+1); int traceIndexParameter = -1; for (int j = 0; j < traces.getTraceCount(); j++) { String traceName = traces.getTraceName(j); if (traceName.trim().equals(newParameterName)) { traceIndexParameter = j; } } if (traceIndexParameter == -1) { throw new XMLParseException("LogTransformedNormalKDEDistribution: Column '" + newParameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); arrayKDE[i] = new LogTransformedNormalKDEDistribution(parameterSamples); } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(new MultivariateKDEDistribution(arrayKDE)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } } catch (FileNotFoundException fnfe) { throw new XMLParseException("File '" + fileName + "' can not be opened for " + getParserName() + " element."); } catch (java.io.IOException ioe) { throw new XMLParseException(ioe.getMessage()); } catch (TraceException e) { throw new XMLParseException(e.getMessage()); } } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule("fileName"), AttributeRule.newStringRule("parameterColumn"), AttributeRule.newIntegerRule("burnin"), new ElementRule(Statistic.class, 1, Integer.MAX_VALUE) }; public String getParserDescription() { return "Calculates the reference prior probability of some data under log transformed normal distribution."; } public Class getReturnType() { return Likelihood.class; } }; /** * A special parser that reads a convenient short form of reference priors on parameters. */ public static XMLObjectParser LOGIT_TRANSFORMED_NORMAL_REFERENCE_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return LOGIT_TRANSFORMED_NORMAL_REFERENCE_PRIOR; } public String[] getParserNames() { return new String[]{getParserName(), LOGIT_TRANSFORMED_NORMAL_WORKING_PRIOR}; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FileHelpers.FILE_NAME); try { File file = new File(fileName); String parent = file.getParent(); if (!file.isAbsolute()) { parent = System.getProperty("user.dir"); } file = new File(parent, fileName); fileName = file.getAbsolutePath(); //keep using String and not an array of Strings, append integers later on String parameterName = xo.getStringAttribute(PARAMETER_COLUMN); int dimension = 1; if (xo.hasAttribute(DIMENSION)) { dimension = xo.getIntegerAttribute(DIMENSION); } if (dimension <= 0) { throw new XMLParseException("Column '" + parameterName + "' has dimension smaller than 1."); } double upperlimit = 1.0; if (xo.hasAttribute(UPPERLIMIT)) { upperlimit = xo.getDoubleAttribute(UPPERLIMIT); } if (upperlimit <= 0.0) { throw new XMLParseException("Positive upper bound expected for logit transformed normal KDE distribution."); } LogFileTraces traces = new LogFileTraces(fileName, file); traces.loadTraces(); long maxState = traces.getMaxState(); // leaving the burnin attribute off will result in 10% being used long burnin = xo.getAttribute("burnin", maxState / 10); if (burnin < 0 || burnin >= maxState) { burnin = maxState / 10; System.out.println("WARNING: Burn-in larger than total number of states - using 10%"); } traces.setBurnIn(burnin); if (dimension == 1) { int traceIndexParameter = -1; for (int i = 0; i < traces.getTraceCount(); i++) { String traceName = traces.getTraceName(i); if (traceName.trim().equals(parameterName)) { traceIndexParameter = i; } } if (traceIndexParameter == -1) { throw new XMLParseException("LogitTransformedNormalKDEDistribution: Column '" + parameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); DistributionLikelihood likelihood = new DistributionLikelihood(new LogitTransformedNormalKDEDistribution(parameterSamples, upperlimit)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } else { //dimension > 1 LogitTransformedNormalKDEDistribution[] arrayKDE = new LogitTransformedNormalKDEDistribution[dimension]; for (int i = 0; i < dimension; i++) { //look for parameterName1, parameterName2, ... if necessary String newParameterName = parameterName + (i+1); int traceIndexParameter = -1; for (int j = 0; j < traces.getTraceCount(); j++) { String traceName = traces.getTraceName(j); if (traceName.trim().equals(newParameterName)) { traceIndexParameter = j; } } if (traceIndexParameter == -1) { throw new XMLParseException("LogitTransformedNormalKDEDistribution: Column '" + newParameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); arrayKDE[i] = new LogitTransformedNormalKDEDistribution(parameterSamples, upperlimit); } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(new MultivariateKDEDistribution(arrayKDE)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } } catch (FileNotFoundException fnfe) { throw new XMLParseException("File '" + fileName + "' can not be opened for " + getParserName() + " element."); } catch (java.io.IOException ioe) { throw new XMLParseException(ioe.getMessage()); } catch (TraceException e) { throw new XMLParseException(e.getMessage()); } } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule("fileName"), AttributeRule.newStringRule("parameterColumn"), AttributeRule.newIntegerRule("burnin"), //optional to provide a dimension attribute AttributeRule.newIntegerRule("dimension", true), //optional to provide an upperLimit attribute AttributeRule.newDoubleRule("upperLimit", true), new ElementRule(Statistic.class, 1, Integer.MAX_VALUE) }; public String getParserDescription() { return "Calculates the reference prior probability of some data under logit transformed normal distribution."; } public Class getReturnType() { return Likelihood.class; } }; /** * A special parser that reads a convenient short form of reference priors on parameters. */ public static XMLObjectParser NORMAL_REFERENCE_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return NORMAL_REFERENCE_PRIOR; } public String[] getParserNames() { return new String[]{getParserName(), NORMAL_WORKING_PRIOR}; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FileHelpers.FILE_NAME); try { File file = new File(fileName); String parent = file.getParent(); if (!file.isAbsolute()) { parent = System.getProperty("user.dir"); } file = new File(parent, fileName); fileName = file.getAbsolutePath(); String parameterName = xo.getStringAttribute(PARAMETER_COLUMN); int dimension = 1; if (xo.hasAttribute(DIMENSION)) { dimension = xo.getIntegerAttribute(DIMENSION); } if (dimension <= 0) { throw new XMLParseException("Column '" + parameterName + "' has dimension smaller than 1."); } LogFileTraces traces = new LogFileTraces(fileName, file); traces.loadTraces(); long maxState = traces.getMaxState(); // leaving the burnin attribute off will result in 10% being used long burnin = xo.getAttribute("burnin", maxState / 10); if (burnin < 0 || burnin >= maxState) { burnin = maxState / 10; System.out.println("WARNING: Burn-in larger than total number of states - using 10%"); } traces.setBurnIn(burnin); if (dimension == 1) { int traceIndexParameter = -1; for (int i = 0; i < traces.getTraceCount(); i++) { String traceName = traces.getTraceName(i); if (traceName.trim().equals(parameterName)) { traceIndexParameter = i; } } if (traceIndexParameter == -1) { throw new XMLParseException("NormalKDEDistribution: Column '" + parameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; DistributionLikelihood likelihood = new DistributionLikelihood(new NormalKDEDistribution((Double[]) traces.getValues(traceIndexParameter).toArray(parameterSamples))); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } else { //dimension > 1 NormalKDEDistribution[] arrayKDE = new NormalKDEDistribution[dimension]; for (int i = 0; i < dimension; i++) { //look for parameterName1, parameterName2, ... if necessary String newParameterName = parameterName + (i+1); int traceIndexParameter = -1; for (int j = 0; j < traces.getTraceCount(); j++) { String traceName = traces.getTraceName(j); if (traceName.trim().equals(newParameterName)) { traceIndexParameter = j; } } if (traceIndexParameter == -1) { throw new XMLParseException("NormalKDEDistribution: Column '" + newParameterName + "' can not be found for " + getParserName() + " element."); } Double[] parameterSamples = new Double[traces.getStateCount()]; traces.getValues(traceIndexParameter).toArray(parameterSamples); arrayKDE[i] = new NormalKDEDistribution(parameterSamples); } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(new MultivariateKDEDistribution(arrayKDE)); for (int j = 0; j < xo.getChildCount(); j++) { if (xo.getChild(j) instanceof Statistic) { if (DEBUG) { System.out.println(((Statistic) xo.getChild(j)).toString()); System.out.println(((Statistic) xo.getChild(j)).getDimension()); } likelihood.addData((Statistic) xo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } } catch (FileNotFoundException fnfe) { throw new XMLParseException("File '" + fileName + "' can not be opened for " + getParserName() + " element."); } catch (java.io.IOException ioe) { throw new XMLParseException(ioe.getMessage()); } catch (TraceException e) { throw new XMLParseException(e.getMessage()); } } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule("fileName"), AttributeRule.newStringRule("parameterColumn"), AttributeRule.newIntegerRule("burnin"), new ElementRule(Statistic.class, 1, Integer.MAX_VALUE) }; public String getParserDescription() { return "Calculates the reference prior probability of some data under a given normal distribution."; } public Class getReturnType() { return Likelihood.class; } }; }