/* * Transform.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.util; import dr.inference.model.Parameter; import java.util.ArrayList; import java.util.List; /** * interface for the one-to-one transform of a continuous variable. * A static member Transform.LOG provides an instance of LogTransform * * @author Andrew Rambaut * @author Guy Baele * @author Marc Suchard * @version $Id: Transform.java,v 1.5 2005/05/24 20:26:01 rambaut Exp $ */ public interface Transform { /** * @param value evaluation point * @return the transformed value */ double transform(double value); /** * overloaded transformation that takes and returns an array of doubles * @param values evaluation points * @param from start transformation at this index * @param to end transformation at this index * @return the transformed values */ double[] transform(double[] values, int from, int to); /** * @param value evaluation point * @return the inverse transformed value */ double inverse(double value); /** * overloaded transformation that takes and returns an array of doubles * @param values evaluation points * @param from start transformation at this index * @param to end transformation at this index * @return the transformed values */ double[] inverse(double[] values, int from, int to); double updateGradientLogDensity(double gradient, double value); double[] updateGradientLogDensity(double[] gradient, double[] value, int from, int to); double gradientInverse(double value); double[] gradientInverse(double[] values, int from, int to); /** * @return the transform's name */ String getTransformName(); /** * @param value evaluation point * @return the log of the transform's jacobian */ double getLogJacobian(double value); /** * @param values evaluation points * @param from start calculation at this index * @param to end calculation at this index * @return the log of the transform's jacobian */ double getLogJacobian(double[] values, int from, int to); abstract class UnivariableTransform implements Transform { public abstract double transform(double value); public double[] transform(double[] values, int from, int to) { double[] result = values.clone(); for (int i = from; i < to; ++i) { result[i] = transform(values[i]); } return result; } public abstract double inverse(double value); public double[] inverse(double[] values, int from, int to) { double[] result = values.clone(); for (int i = from; i < to; ++i) { result[i] = inverse(values[i]); } return result; } public abstract double gradientInverse(double value); public double[] gradientInverse(double[] values, int from, int to) { double[] result = values.clone(); for (int i = from; i < to; ++i) { result[i] = gradientInverse(values[i]); } return result; } public abstract double updateGradientLogDensity(double gradient, double value); public double[] updateGradientLogDensity(double[] gradient, double[] value , int from, int to) { double[] result = value.clone(); for (int i = from; i < to; ++i) { result[i] = updateGradientLogDensity(gradient[i], value[i]); } return result; } public abstract double getLogJacobian(double value); public double getLogJacobian(double[] values, int from, int to) { double sum = 0.0; for (int i = from; i < to; ++i) { sum += getLogJacobian(values[i]); } return sum; } } abstract class MultivariableTransform implements Transform { public double transform(double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } public double inverse(double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } public double updateGradientLogDensity(double gradient, double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } public double gradientInverse(double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } public double getLogJacobian(double value) { throw new RuntimeException("Transformation not permitted for this type of parameter, exiting ..."); } } class LogTransform extends UnivariableTransform { public double transform(double value) { return Math.log(value); } public double inverse(double value) { return Math.exp(value); } public double gradientInverse(double value) { return Math.exp(value); } public double updateGradientLogDensity(double gradient, double value) { // value == gradient of inverse() // 1.0 == gradient of log Jacobian of inverse() return gradient * value + 1.0; } public String getTransformName() { return "log"; } public double getLogJacobian(double value) { return -Math.log(value); } } class LogConstrainedSumTransform extends MultivariableTransform { public LogConstrainedSumTransform() { } public double[] transform(double[] values, int from, int to) { double[] transformedValues = new double[to - from + 1]; int counter = 0; for (int i = from; i <= to; i++) { transformedValues[counter] = Math.log(values[i]); counter++; } return transformedValues; } //inverse transformation assumes a sum of elements equal to the number of elements public double[] inverse(double[] values, int from, int to) { double sum = (double)(to - from + 1); double[] transformedValues = new double[to - from + 1]; int counter = 0; double newSum = 0.0; for (int i = from; i <= to; i++) { transformedValues[counter] = Math.exp(values[i]); newSum += transformedValues[counter]; counter++; } for (int i = 0; i < sum; i++) { transformedValues[i] = (transformedValues[i] / newSum) * sum; } return transformedValues; } public String getTransformName() { return "logConstrainedSum"; } public double[] updateGradientLogDensity(double[] gradient, double[] value, int from, int to) { throw new RuntimeException("Not yet implemented"); } public double[] gradientInverse(double[] values, int from, int to) { throw new RuntimeException("Not yet implemented"); } public double getLogJacobian(double[] values, int from, int to) { double sum = 0.0; for (int i = from; i <= to; i++) { sum -= Math.log(values[i]); } return sum; } public static void main(String[] args) { //specify starting values double[] startValues = {1.5, 0.6, 0.9}; System.err.print("Starting values: "); double startSum = 0.0; for (int i = 0; i < startValues.length; i++) { System.err.print(startValues[i] + " "); startSum += startValues[i]; } System.err.println("\nSum = " + startSum); //perform transformation double[] transformedValues = LOG_CONSTRAINED_SUM.transform(startValues, 0, startValues.length-1); System.err.print("Transformed values: "); for (int i = 0; i < transformedValues.length; i++) { System.err.print(transformedValues[i] + " "); } System.err.println(); //add draw for normal distribution to transformed elements for (int i = 0; i < transformedValues.length; i++) { transformedValues[i] += 0.20 * Math.random(); } //perform inverse transformation transformedValues = LOG_CONSTRAINED_SUM.inverse(transformedValues, 0, transformedValues.length-1); System.err.print("New values: "); double endSum = 0.0; for (int i = 0; i < transformedValues.length; i++) { System.err.print(transformedValues[i] + " "); endSum += transformedValues[i]; } System.err.println("\nSum = " + endSum); if (startSum != endSum) { System.err.println("Starting and ending constraints differ!"); } } } class LogitTransform extends UnivariableTransform { public LogitTransform() { range = 1.0; lower = 0.0; } public double transform(double value) { return Math.log(value / (1.0 - value)); } public double inverse(double value) { return 1.0 / (1.0 + Math.exp(-value)); } public double gradientInverse(double value) { throw new RuntimeException("Not yet implemented"); } public double updateGradientLogDensity(double gradient, double value) { throw new RuntimeException("Not yet implemented"); } public String getTransformName() { return "logit"; } public double getLogJacobian(double value) { return -Math.log(1.0 - value) - Math.log(value); } private final double range; private final double lower; } class FisherZTransform extends UnivariableTransform { public double transform(double value) { return 0.5 * (Math.log(1.0 + value) - Math.log(1.0 - value)); } public double inverse(double value) { return (Math.exp(2 * value) - 1) / (Math.exp(2 * value) + 1); } public double gradientInverse(double value) { throw new RuntimeException("Not yet implemented"); } public double updateGradientLogDensity(double gradient, double value) { throw new RuntimeException("Not yet implemented"); } public String getTransformName() { return "fisherz"; } public double getLogJacobian(double value) { return -Math.log(1 - value) - Math.log(1 + value); } } class NegateTranform extends UnivariableTransform { public double transform(double value) { return -value; } public double inverse(double value) { return -value; } public double updateGradientLogDensity(double gradient, double value) { // -1 == gradient of inverse() // 0.0 == gradient of log Jacobian of inverse() return -gradient; } public double gradientInverse(double value) { return -1.0; } public String getTransformName() { return "negate"; } public double getLogJacobian(double value) { return 0.0; } } class NoTransform extends UnivariableTransform { public double transform(double value) { return value; } public double inverse(double value) { return value; } public double updateGradientLogDensity(double gradient, double value) { return gradient; } public double gradientInverse(double value) { return 1.0; } public String getTransformName() { return "none"; } public double getLogJacobian(double value) { return 0.0; } } class Compose extends UnivariableTransform { public Compose(UnivariableTransform outer, UnivariableTransform inner) { this.outer = outer; this.inner = inner; } @Override public String getTransformName() { return "compose." + outer.getTransformName() + "." + inner.getTransformName(); } @Override public double transform(double value) { final double outerValue = inner.transform(value); final double outerTransform = outer.transform(outerValue); // System.err.println(value + " " + outerValue + " " + outerTransform); // System.exit(-1); return outerTransform; // return outer.transform(inner.transform(value)); } @Override public double inverse(double value) { return inner.inverse(outer.inverse(value)); } @Override public double gradientInverse(double value) { return inner.gradientInverse(value) * outer.gradientInverse(inner.transform(value)); } @Override public double updateGradientLogDensity(double gradient, double value) { // final double innerGradient = inner.updateGradientLogDensity(gradient, value); // final double outerValue = inner.transform(value); // final double outerGradient = outer.updateGradientLogDensity(innerGradient, outerValue); // return outerGradient; return outer.updateGradientLogDensity(inner.updateGradientLogDensity(gradient, value), inner.transform(value)); } @Override public double getLogJacobian(double value) { return inner.getLogJacobian(value) + outer.getLogJacobian(inner.transform(value)); } private final UnivariableTransform outer; private final UnivariableTransform inner; } class Inverse extends UnivariableTransform { public Inverse(UnivariableTransform inner) { this.inner = inner; } @Override public String getTransformName() { return "inverse." + inner.getTransformName(); } @Override public double transform(double value) { return inner.inverse(value); // Purposefully switched } @Override public double updateGradientLogDensity(double gradient, double value) { throw new RuntimeException("Not yet implemented"); } @Override public double inverse(double value) { return inner.transform(value); // Purposefully switched } @Override public double gradientInverse(double value) { throw new RuntimeException("Not yet implemented"); } @Override public double getLogJacobian(double value) { return -inner.getLogJacobian(value); } private final UnivariableTransform inner; } class Collection extends MultivariableTransform { private final List<ParsedTransform> segments; private final Parameter parameter; public Collection(List<ParsedTransform> segments, Parameter parameter) { this.parameter = parameter; this.segments = ensureContiguous(segments); } public Parameter getParameter() { return parameter; } private List<ParsedTransform> ensureContiguous(List<ParsedTransform> segments) { final List<ParsedTransform> contiguous = new ArrayList<ParsedTransform>(); int current = 0; for (ParsedTransform segment : segments) { if (current < segment.start) { contiguous.add(new ParsedTransform(NONE, current, segment.start)); } contiguous.add(segment); current = segment.end; } if (current < parameter.getDimension()) { contiguous.add(new ParsedTransform(NONE, current, parameter.getDimension())); } System.err.println("Segments:"); for (ParsedTransform transform : contiguous) { System.err.println(transform.transform.getTransformName() + " " + transform.start + " " + transform.end); } // System.exit(-1); return contiguous; } @Override public double[] transform(double[] values, int from, int to) { final double[] result = values.clone(); for (ParsedTransform segment : segments) { if (from < segment.end && to >= segment.start) { final int begin = Math.max(segment.start, from); final int end = Math.min(segment.end, to); for (int i = begin; i < end; ++i) { result[i] = segment.transform.transform(values[i]); } } } return result; } @Override public double[] inverse(double[] values, int from, int to) { final double[] result = values.clone(); for (ParsedTransform segment : segments) { if (from < segment.end && to >= segment.start) { final int begin = Math.max(segment.start, from); final int end = Math.min(segment.end, to); for (int i = begin; i < end; ++i) { result[i] = segment.transform.inverse(values[i]); } } } return result; } @Override public double[] gradientInverse(double[] values, int from, int to) { final double[] result = values.clone(); for (ParsedTransform segment : segments) { if (from < segment.end && to >= segment.start) { final int begin = Math.max(segment.start, from); final int end = Math.min(segment.end, to); for (int i = begin; i < end; ++i) { result[i] = segment.transform.gradientInverse(values[i]); } } } return result; } @Override public double[] updateGradientLogDensity(double[] gradient, double[] values, int from, int to) { final double[] result = values.clone(); for (ParsedTransform segment : segments) { if (from < segment.end && to >= segment.start) { final int begin = Math.max(segment.start, from); final int end = Math.min(segment.end, to); for (int i = begin; i < end; ++i) { result[i] = segment.transform.updateGradientLogDensity(gradient[i], values[i]); } } } return result; } @Override public String getTransformName() { return "collection"; } @Override public double getLogJacobian(double[] values, int from, int to) { double sum = 0.0; for (ParsedTransform segment : segments) { if (from < segment.end && to >= segment.start) { final int begin = Math.max(segment.start, from); final int end = Math.min(segment.end, to); for (int i = begin; i < end; ++i) { sum += segment.transform.getLogJacobian(values[i]); } } } // System.err.println("Log: " + sum + " " + segments.size()); return sum; } // class Segment { // // public Segment(Transform transform, int start, int end) { // this.transform = transform; // this.start = start; // this.end = end; // } // public Transform transform; // public int start; // public int end; // } } class ParsedTransform { public Transform transform; public int start; // zero-indexed public int end; // zero-indexed, i.e, i = start; i < end; ++i public int every = 1; public List<Parameter> parameters = null; public ParsedTransform() { } public ParsedTransform(Transform transform, int start, int end) { this.transform = transform; this.start = start; this.end = end; } public ParsedTransform clone() { ParsedTransform clone = new ParsedTransform(); clone.transform = transform; clone.start = start; clone.end = end; clone.every = every; clone.parameters = parameters; return clone; } public boolean equivalent(ParsedTransform other) { if (start == other.start && end == other.end && every == other.every && parameters == parameters) { return true; } else { return false; } } } class Util { public static Transform[] getListOfNoTransforms(int size) { Transform[] transforms = new Transform[size]; for (int i = 0; i < size; ++i) { transforms[i] = NONE; } return transforms; } } NoTransform NONE = new NoTransform(); LogTransform LOG = new LogTransform(); NegateTranform NEGATE = new NegateTranform(); LogConstrainedSumTransform LOG_CONSTRAINED_SUM = new LogConstrainedSumTransform(); LogitTransform LOGIT = new LogitTransform(); FisherZTransform FISHER_Z = new FisherZTransform(); enum Type { NONE("none", new NoTransform()), LOG("log", new LogTransform()), NEGATE("negate", new NegateTranform()), LOG_CONSTRAINED_SUM("logConstrainedSum", new LogConstrainedSumTransform()), LOGIT("logit", new LogitTransform()), FISHER_Z("fisherZ",new FisherZTransform()); Type(String name, Transform transform) { this.name = name; this.transform = transform; } public Transform getTransform() { return transform; } public String getName() { return name; } private Transform transform; private String name; } // String TRANSFORM = "transform"; // String TYPE = "type"; // String START = "start"; // String END = "end"; // String EVERY = "every"; // String INVERSE = "inverse"; }