/* * TreeTraitParserUtilities.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodelxml.treelikelihood; import dr.evolution.tree.MultivariateTraitTree; import dr.evomodel.continuous.StandardizeTraits; import dr.inference.model.*; import dr.math.MathUtils; import dr.xml.*; import java.util.*; import java.util.logging.Logger; /** * @author Marc A. Suchard */ public class TreeTraitParserUtilities { public static final String TRAIT_PARAMETER = "traitParameter"; public static final String TRAIT_NAME = "traitName"; public static final String MISSING = "missingIndicator"; public static final String RANDOM_SAMPLE = "randomSample"; public static final String DEFAULT_TRAIT_NAME = "trait"; public static final String RANDOMIZE = "randomize"; public static final String RANDOMIZE_LOWER = "lower"; public static final String RANDOMIZE_UPPER = "upper"; public static final String ALLOW_IDENTICAL = "allowIdentical"; public static final String JITTER = "jitter"; public static final String WINDOW = "window"; public static final String DUPLICATES = "duplicatesOnly"; public static final String STANDARDIZE = "standardize"; public static final String SAMPLE_MISSING_TRAITS = "sampleMissingTraits"; public void randomize(Parameter trait, double[] lower, double[] upper) { // Draws each dimension in each trait from U[lower, upper) for (int i = 0; i < trait.getDimension(); i++) { final int whichLower = i % lower.length; final int whichUpper = i % upper.length; final double newValue = MathUtils.uniform(lower[whichLower], upper[whichUpper]); trait.setParameterValue(i, newValue); } } // public void standardize(Parameter trait) { // for (int i = 0; i < trait.) // } public static ElementRule randomizeRules(boolean optional) { return new ElementRule(TreeTraitParserUtilities.RANDOMIZE, new XMLSyntaxRule[]{ AttributeRule.newDoubleRule(TreeTraitParserUtilities.RANDOMIZE_LOWER, true), AttributeRule.newDoubleRule(TreeTraitParserUtilities.RANDOMIZE_UPPER, true), new ElementRule(Parameter.class) }, optional); } public static ElementRule jitterRules(boolean optional) { return new ElementRule(JITTER, new XMLSyntaxRule[]{ AttributeRule.newDoubleArrayRule(WINDOW), AttributeRule.newBooleanRule(DUPLICATES, true), new ElementRule(Parameter.class), }, optional); } public void jitter(XMLObject xo, int length, List<Integer> missingIndices) throws XMLParseException { XMLObject cxo = xo.getChild(TreeTraitParserUtilities.JITTER); Parameter traits = (Parameter) cxo.getChild(Parameter.class); double[] window = cxo.getDoubleArrayAttribute(TreeTraitParserUtilities.WINDOW); // Must be included, no default value boolean duplicates = cxo.getAttribute(TreeTraitParserUtilities.DUPLICATES, true); // default = true jitter(traits, length, missingIndices, window, duplicates, true); } public void randomize(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(RANDOMIZE); Parameter traits = (Parameter) cxo.getChild(Parameter.class); double[] randomizeLower; double[] randomizeUpper; if (cxo.hasAttribute(RANDOMIZE_LOWER)) { randomizeLower = cxo.getDoubleArrayAttribute(RANDOMIZE_LOWER); } else { randomizeLower = new double[]{-90.0}; } if (cxo.hasAttribute(RANDOMIZE_UPPER)) { randomizeUpper = cxo.getDoubleArrayAttribute(RANDOMIZE_UPPER); } else { randomizeUpper = new double[]{+90.0}; } randomize(traits, randomizeLower, randomizeUpper); } private class DoubleArray implements Comparable { double[] value; int index; DoubleArray(double[] value, int index) { this.value = value; this.index = index; } public double[] getValues() { return value; } public int getIndex() { return index; } public int compareTo(Object o) { double[] x = ((DoubleArray) o).getValues(); for (int i = 0; i < value.length; i++) { if (value[i] > x[i]) { return 1; } else if (value[i] < x[i]) { return -1; } } return 0; } } public boolean hasIdenticalTraits(Parameter trait, List<Integer> missingIndices, int dim) { int numTraits = trait.getDimension() / dim; List<DoubleArray> traitArray = new ArrayList<DoubleArray>(); for (int i = 0; i < numTraits; i++) { if (!missingIndices.contains(i * dim)) { // TODO Assumes completely missing traits double[] x = new double[dim]; for (int j = 0; j < dim; j++) { x[j] = trait.getParameterValue(i * dim + j); } traitArray.add(new DoubleArray(x, i)); } } DoubleArray[] sortedTraits = traitArray.toArray(new DoubleArray[0]); Arrays.sort(sortedTraits); // Mark duplicates for (int i = 1; i < sortedTraits.length; i++) { if (sortedTraits[i].compareTo(sortedTraits[i - 1]) == 0) { return true; } } return false; } public void jitter(Parameter trait, int dim, List<Integer> missingIndices, double[] window, boolean duplicates, boolean verbose) { int numTraits = trait.getDimension() / dim; boolean[] update = new boolean[numTraits]; if (!duplicates) { Arrays.fill(update, true); } else { DoubleArray[] traitArray = new DoubleArray[numTraits]; for (int i = 0; i < numTraits; i++) { double[] x = new double[dim]; for (int j = 0; j < dim; j++) { x[j] = trait.getParameterValue(i * dim + j); } traitArray[i] = new DoubleArray(x, i); } Arrays.sort(traitArray); // Mark duplicates for (int i = 1; i < numTraits; i++) { if (traitArray[i].compareTo(traitArray[i - 1]) == 0) { update[traitArray[i - 1].getIndex()] = true; update[traitArray[i].getIndex()] = true; } } } for (int i = 0; i < numTraits; i++) { if (update[i]) { StringBuffer sb1 = null; StringBuffer sb2 = null; if (verbose) { sb1 = new StringBuffer(); sb2 = new StringBuffer(); } boolean hitAtLeastOneComponent = false; for (int j = 0; j < dim; j++) { final double oldValue = trait.getParameterValue(i * dim + j); final double newValue; if (!missingIndices.contains(i * dim + j)) { newValue = window[j % window.length] * (MathUtils.nextDouble() - 0.5) + oldValue; trait.setParameterValue(i * dim + j, newValue); hitAtLeastOneComponent = true; } else { newValue = oldValue; } if (verbose) { sb1.append(" ").append(oldValue); sb2.append(" ").append(newValue); } } if (verbose && hitAtLeastOneComponent) { Logger.getLogger("dr.evomodel.continuous").info( " Replacing trait #" + (i + 1) + " Old:" + sb1.toString() + " New: " + sb2.toString() ); } } } } public class TraitsAndMissingIndices { public CompoundParameter traitParameter; public List<Integer> missingIndices; public String traitName; public Parameter sampleMissingParameter; TraitsAndMissingIndices(CompoundParameter traitParameter, List<Integer> missingIndices, String traitName, Parameter sampleMissingParameter) { this.traitParameter = traitParameter; this.missingIndices = missingIndices; this.traitName = traitName; this.sampleMissingParameter = sampleMissingParameter; } } public TraitsAndMissingIndices parseTraitsFromTaxonAttributes( XMLObject xo, String inTraitName, MultivariateTraitTree treeModel, boolean integrateOutInternalStates) throws XMLParseException { XMLObject xoc = xo.getChild(TRAIT_PARAMETER); Parameter parameter = (Parameter) xoc.getChild(Parameter.class); boolean existingTraitParameter = false; int randomSampleSizeFlag = xo.getAttribute(RANDOM_SAMPLE, -1); String traitName = inTraitName; CompoundParameter traitParameter; List<Integer> missingIndices = null; Parameter sampleMissingParameter = null; boolean isMatrixParameter = false; if (parameter instanceof MatrixParameter || parameter instanceof FastMatrixParameter) { traitParameter = (CompoundParameter) parameter; isMatrixParameter = true; } else if (parameter instanceof CompoundParameter) { // if we have been passed a CompoundParameter, this will be a leaf trait // parameter from a tree model so use this to allow for individual sampling // of leaf parameters. traitParameter = (CompoundParameter) parameter; existingTraitParameter = true; } else { // create a compound parameter of appropriate dimensions traitParameter = new CompoundParameter(parameter.getId()); ParameterParser.replaceParameter(xoc, traitParameter); } if (xo.hasAttribute(TRAIT_NAME)) { Map<Integer, Integer> randomSample = null; traitName = xo.getStringAttribute(TRAIT_NAME); StringBuilder warnings = new StringBuilder(); int warningLength = 0; final int maxWarnings = 10; // Fill in attributeValues int taxonCount = treeModel.getTaxonCount(); for (int i = 0; i < taxonCount; i++) { String taxonName = treeModel.getTaxonId(i); // changed to just label the rows by the taxonName so it can be picked up elsewhere String paramName = taxonName; String altParamName = taxonName + "." + traitName; String object = (String) treeModel.getTaxonAttribute(i, traitName); if (object == null) { throw new RuntimeException("Trait \"" + traitName + "\" not found for taxa \"" + taxonName + "\""); } else { StringTokenizer st = new StringTokenizer(object); int count = st.countTokens(); Parameter traitParam; if (existingTraitParameter) { traitParam = getTraitParameterByName(traitParameter, paramName); if (traitParam == null) { // try the alternative param name traitParam = getTraitParameterByName(traitParameter, altParamName); if (traitParam == null) { throw new RuntimeException("Missing trait parameters for taxon, " + paramName); } } } else { if (isMatrixParameter) { traitParam = traitParameter.getParameter(i); traitParam.setId(paramName); } else { // Make multidimensional, in earlier revisions only first dimension was stored traitParam = new Parameter.Default(paramName, count); traitParameter.addParameter(traitParam); } } int sampleSize = count; if (randomSampleSizeFlag > 0) { if (randomSample == null) { randomSample = drawRandomSample(randomSampleSizeFlag, count); } sampleSize = randomSampleSizeFlag; } if (sampleSize != traitParam.getDimension()) { if (existingTraitParameter) { throw new RuntimeException("Trait length must match trait parameter dimension for taxon, " + taxonName + ": " + sampleSize + " != " + traitParam.getDimension()); } else { traitParam.setDimension(sampleSize); } } int index = 0; for (int j = 0; j < count; j++) { String oneValue = st.nextToken(); if (randomSampleSizeFlag == -1 || randomSample.containsKey(j)) { double value = Double.NaN; if (oneValue.equals("NA") || oneValue.equals("?") ) { if (warningLength < maxWarnings) { warnings.append( "Warning: Missing value in tip for taxon " + taxonName + " (filling with 0 as starting value when sampling only)\n" // See comment below ); ++warningLength; } } else { try { value = new Double(oneValue); if (Double.isNaN(value)) { if (warningLength < maxWarnings) { warnings.append( "Warning: Unrecognizable number " + oneValue + " for taxon " + taxonName + "\n" ); ++warningLength; } } } catch (NumberFormatException e) { throw new RuntimeException(e.getMessage()); } } int replicates = 1; if (randomSampleSizeFlag != -1) { // Count how many times to add this datum replicates = randomSample.get(j); } for (int k = 0; k < replicates; k++) { traitParam.setParameterValue(index, value); index++; } } } } } if (warningLength > 0) { Logger.getLogger("dr.evomodel.continuous").info(warnings.toString()); if (warningLength == maxWarnings) { Logger.getLogger("dr.evomodel.continuous").info("Warning: only first " + maxWarnings + " trait warnings were displayed\n"); } } // Standardize if (xo.getAttribute(STANDARDIZE, false) && traitParameter instanceof MatrixParameterInterface) { System.err.println(traitParameter.getClass().getCanonicalName()); System.err.println("Yes, standardize"); StandardizeTraits st = new StandardizeTraits((MatrixParameterInterface) traitParameter); String message = st.doStandardization(false); Logger.getLogger("dr.evomodel.continous").info(message); } // Find missing values double[] allValues = traitParameter.getParameterValues(); missingIndices = new ArrayList<Integer>(); for (int i = 0; i < allValues.length; i++) { if ((new Double(allValues[i])).isNaN()) { traitParameter.setParameterValue(i, 0); // Here, missings are set to zero missingIndices.add(i); } } if (xo.hasChildNamed(MISSING)) { XMLObject cxo = xo.getChild(MISSING); Parameter missingParameter = new Parameter.Default(allValues.length, 0.0); for (int i : missingIndices) { missingParameter.setParameterValue(i, 1.0); } missingParameter.addBounds(new Parameter.DefaultBounds(1.0, 0.0, allValues.length)); ParameterParser.replaceParameter(cxo, missingParameter); sampleMissingParameter = missingParameter; } // Give warnings if trait exist for internal and root nodes when integrating them out if (integrateOutInternalStates) { int numTraits = traitParameter.getParameterCount(); if (numTraits != treeModel.getExternalNodeCount()) { throw new XMLParseException( "Dimensionality of '" + traitParameter.getId() + "' (" + numTraits + ") is not equal to the number" + " of tree tips (" + treeModel.getExternalNodeCount() + ")"); } for (int j = 0; j < numTraits; j++) { String parameterName = traitParameter.getParameter(j).getId(); if (parameterName.startsWith("node") || parameterName.startsWith("root")) { throw new XMLParseException( "Internal/root node trait parameters are not allowed when " + "using the integrated observed data multivariateTraitLikelihoood"); } } } } if (xo.getAttribute(SAMPLE_MISSING_TRAITS, false) || xo.hasChildNamed(MISSING)) { missingIndices = new ArrayList<Integer>(); // return empty } return new TraitsAndMissingIndices(traitParameter, missingIndices, traitName, sampleMissingParameter); } private Parameter getTraitParameterByName(CompoundParameter traits, String name) { for (int i = 0; i < traits.getParameterCount(); i++) { Parameter found = traits.getParameter(i); if (found.getStatisticName().compareTo(name) == 0) return found; } return null; } private Map<Integer, Integer> drawRandomSample(int total, int length) { Map<Integer, Integer> thisMap = new HashMap<Integer, Integer>(total); for (int i = 0; i < total; i++) { int item = MathUtils.nextInt(length); if (thisMap.containsKey(item)) { thisMap.put(item, thisMap.get(item) + 1); } else { thisMap.put(item, 1); } } return thisMap; } }