/* * DiscretizedLocationOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.operators; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood; import dr.inference.operators.AbstractCoercableOperator; import dr.inference.operators.CoercionMode; import dr.inference.operators.MCMCOperator; import dr.math.MathUtils; import dr.xml.*; import java.awt.geom.Point2D; import java.util.*; import java.util.logging.Logger; /** * @author Marc Suchard */ public class DiscretizedLocationOperator extends AbstractCoercableOperator { public static final String GIBBS_OPERATOR = "discretizedLocationOperator"; public static final String INTERNAL_ONLY = "onlyInternalNodes"; public static final String DISK = "neighborhoodSize"; public static final String RANDOMIZE = "randomize"; public DiscretizedLocationOperator(AbstractMultivariateTraitLikelihood traitModel, boolean onlyInternalNodes, int disk, CoercionMode mode) { super(mode); this.treeModel = traitModel.getTreeModel(); this.traitName = traitModel.getTraitName(); this.onlyInternalNodes = onlyInternalNodes; allLocations = makeLocationList(); nearestNeighborMap = makeNearestNeighborMap(); this.disk = disk; this.autoOptimize = convertToAutoOptimizeValue(disk); if (disk > allLocations.size() - 2) throw new RuntimeException("Neighborhood size is too large"); printInfo(); } private Map<Point2D, List<WeightedPoint2D>> makeNearestNeighborMap() { Map<Point2D, List<WeightedPoint2D>> map = new HashMap<Point2D, List<WeightedPoint2D>>(); for (Point2D location : allLocations) { List<WeightedPoint2D> weightedNeighbors = new ArrayList<WeightedPoint2D>(); for (Point2D neighbor : allLocations) { double distance = location.distance(neighbor); if (distance > 0) weightedNeighbors.add( new WeightedPoint2D(neighbor.getX(), neighbor.getY(), distance) ); } Collections.sort(weightedNeighbors); map.put(location, weightedNeighbors); } // for (Point2D location : map.keySet()) { // List<WeightedPoint2D> neighbors = map.get(location); // System.err.println("Location: "+location+"\n"); // System.err.println("\t"); // int count = 0; // for (WeightedPoint2D neighbor : neighbors) { // count++; // if (count < 3) // System.err.println(" "+neighbor); // } // System.err.println("\n"); // } return map; } private void recursivelySetTrait(NodeRef node, double[] trait, NodeRef fromNode) { treeModel.setMultivariateTrait(node, traitName, trait); for (int i = 0; i < treeModel.getChildCount(node); i++) { NodeRef child = treeModel.getChild(node, i); if (child != fromNode && treeModel.getBranchLength(child) == 0) { // System.err.println("recursing down"); recursivelySetTrait(child, trait, node); } } if (!treeModel.isRoot(node) && treeModel.getBranchLength(node) == 0) { // System.err.println("recursing up"); recursivelySetTrait(treeModel.getParent(node), trait, node); } } public void randomizeNodes() { List<Point2D> listLocations = new ArrayList<Point2D>(); listLocations.addAll(allLocations); for (int i = 0; i < treeModel.getInternalNodeCount(); i++) { NodeRef node = treeModel.getInternalNode(i); double[] trait = treeModel.getMultivariateNodeTrait(node, traitName); Point2D newPt = listLocations.get(MathUtils.nextInt(listLocations.size())); trait[0] = newPt.getX(); trait[1] = newPt.getY(); recursivelySetTrait(node, trait, null); // treeModel.setMultivariateTrait(node, traitName, trait); } System.err.println("Done with randomization"); // System.exit(-1); } private void printInfo() { StringBuffer sb = new StringBuffer(); sb.append("\nCreating a discretized location sampler:\n"); sb.append("\tTip count: " + treeModel.getExternalNodeCount() + "\n"); sb.append("\tUnique locations: " + allLocations.size() + "\n"); sb.append("\tNeighborhood size: " + disk + "\n"); Logger.getLogger("dr.evomodel.operators").info(sb.toString()); } private Set<Point2D> makeLocationList() { Set<Point2D> uniquePoints = new HashSet<Point2D>(); for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { NodeRef node = treeModel.getExternalNode(i); double[] leafTrait = treeModel.getMultivariateNodeTrait(node, traitName); Point2D.Double point = new Point2D.Double(leafTrait[0], leafTrait[1]); if (!uniquePoints.contains(point)) { uniquePoints.add(point); savedPt = point; } } return uniquePoints; } public double doOperation() { NodeRef node; if (onlyInternalNodes) node = treeModel.getInternalNode(MathUtils.nextInt( treeModel.getInternalNodeCount())); else node = treeModel.getNode(MathUtils.nextInt( treeModel.getNodeCount())); double[] trait = treeModel.getMultivariateNodeTrait(node, traitName); Point2D currentPt = new Point2D.Double(trait[0], trait[1]); List<WeightedPoint2D> neighbors = nearestNeighborMap.get(currentPt); if (neighbors == null) throw new RuntimeException("Node location outside allowable values: " + currentPt); // Point2D newPt = neighbors.get(MathUtils.nextInt(disk)); Point2D newPt = neighbors.get(MathUtils.nextInt(convertFromAutoOptimizeValue(autoOptimize))); trait[0] = newPt.getX(); trait[1] = newPt.getY(); // treeModel.setMultivariateTrait(node, traitName, trait); recursivelySetTrait(node, trait, null); return 0; } private int convertFromAutoOptimizeValue(double value) { return 1 + (int) Math.exp(autoOptimize); } private double convertToAutoOptimizeValue(int value) { return Math.log(value - 1); } public double getCoercableParameter() { return autoOptimize; } public void setCoercableParameter(double value) { autoOptimize = value; } public double getRawParameter() { return convertFromAutoOptimizeValue(autoOptimize); } // public double getScaleFactor() { // return scaleFactor; // } public double getTargetAcceptanceProbability() { return 0.50; } public final String getPerformanceSuggestion() { // double prob = MCMCOperator.Utils.getAcceptanceProbability(this); // double targetProb = getTargetAcceptanceProbability(); // dr.util.NumberFormatter formatter = new dr.util.NumberFormatter(5); // double sf = OperatorUtils.optimizeScaleFactor(autoOptimize, prob, targetProb); // if (prob < getMinimumGoodAcceptanceLevel()) { // return "Try setting scaleFactor to about " + formatter.format(sf); // } else if (prob > getMaximumGoodAcceptanceLevel()) { // return "Try setting scaleFactor to about " + formatter.format(sf); // } else return ""; return "I have no idea."; } public String getOperatorName() { return GIBBS_OPERATOR; } public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return GIBBS_OPERATOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { CoercionMode mode = CoercionMode.parseMode(xo); double weight = xo.getDoubleAttribute(WEIGHT); boolean onlyInternalNodes = xo.getAttribute(INTERNAL_ONLY, true); int disk = xo.getAttribute(DISK, 4); AbstractMultivariateTraitLikelihood traitModel = (AbstractMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); DiscretizedLocationOperator operator = new DiscretizedLocationOperator(traitModel, onlyInternalNodes, disk, mode); operator.setWeight(weight); // Get node priors // for (int i = 0; i < xo.getChildCount(); i++) { // if (xo.getChild(i) instanceof MultivariateDistributionLikelihood) { // MultivariateDistribution dist = ((MultivariateDistributionLikelihood) xo.getChild(i)).getDistribution(); // if (dist instanceof GeoSpatialDistribution) { // GeoSpatialDistribution prior = (GeoSpatialDistribution) dist; // String nodeLabel = prior.getLabel(); // TreeModel treeModel = traitModel.getTreeModel(); // // // Get taxon node from tree // int index = treeModel.getTaxonIndex(nodeLabel); // if (index == -1) { // throw new XMLParseException("taxon '" + nodeLabel + "' not found for geoSpatialDistribution element in traitGibbsOperator element"); // } // NodeRef node = treeModel.getExternalNode(index); //// operator.setTaxonPrior(node, prior); // System.err.println("Adding truncated prior for " + node); // } // } // } boolean randomize = xo.getAttribute(RANDOMIZE, false); if (randomize) operator.randomizeNodes(); return operator; } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element returns a multivariate Gibbs operator on traits for possible all nodes."; } public Class getReturnType() { return MCMCOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ AttributeRule.newDoubleRule(WEIGHT), AttributeRule.newBooleanRule(AUTO_OPTIMIZE, true), AttributeRule.newBooleanRule(INTERNAL_ONLY, true), new ElementRule(AbstractMultivariateTraitLikelihood.class), AttributeRule.newIntegerRule(DISK, true), AttributeRule.newBooleanRule(RANDOMIZE, true), }; }; public class WeightedPoint2D extends Point2D.Double implements Comparable { public double weight; public WeightedPoint2D(double x, double y, double weight) { super(x, y); this.weight = weight; } public int compareTo(Object o) { WeightedPoint2D pt = (WeightedPoint2D) o; if (weight > pt.weight) return 1; if (weight < pt.weight) return -1; return 0; } public String toString() { return super.toString() + "(" + weight + ")"; } } private Map<Point2D, List<WeightedPoint2D>> nearestNeighborMap; private Set<Point2D> allLocations; private final MultivariateTraitTree treeModel; private String traitName; private double autoOptimize; private boolean onlyInternalNodes = true; private int disk = 4; private Point2D savedPt; }