/* * SplineBasis.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.model; import dr.math.IntegrableUnivariateFunction; import dr.xml.*; import no.uib.cipr.matrix.BandMatrix; import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.Matrix; import no.uib.cipr.matrix.Vector; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.logging.Logger; /** * @author Marc A. Suchard */ public class SplineBasis extends AbstractModel implements IntegrableUnivariateFunction { public static final String SPLINE_BASIS = "splineFunction"; public static final String KNOT_POINTS = "knotLocations"; public static final String KNOT_VALUES = "knotValues"; public static final String DEGREE = "degree"; public SplineBasis(String name, Variable<Double> knotLocations, Variable<Double> knotValues, int degree) { super(name); this.knotLocations = knotLocations; this.knotValues = knotValues; addVariable(knotLocations); addVariable(knotValues); this.degree = degree; updateBasis = true; n = knotValues.getSize(); h = new double[n - 1]; deltaY = new double[n - 1]; hMatrix = new BandMatrix(n, 1, 1); yByH = new DenseVector(n); z = new DenseVector(n); calculateBasis(); StringBuilder buffer = new StringBuilder(); buffer.append("Constructing spline basis:\n"); buffer.append("\tDegree: ").append(degree).append("\n"); buffer.append("\tRange: [").append(getLowerBound()).append(", ").append(getUpperBound()).append("\n"); Logger.getLogger("dr.math").info(buffer.toString()); } public int getDegree() { return degree; } public double evaluate(double location) { calculateBasis(); int i = 0; double xi = knotLocations.getValue(i); while( xi < location) { i++; xi = knotLocations.getValue(i); } // TODO Keep a sorted list of knotLocations for a O(log N) tree search. double xip1 = knotLocations.getValue(i+1); double yi = knotValues.getValue(i); double yip1 = knotValues.getValue(i+1); double zi = z.get(i); double zip1 = z.get(i+1); double hi = xip1 - xi; return zip1 * Math.pow(location - xi, 3) + zi * Math.pow(zip1 - location, 3) + (yip1/hi - hi/6*zip1)*(location - xi) + (yi/hi - hi/6*zi)*(xip1 - location); } public double getLowerBound() { return rangeMin; } public double getUpperBound() { return rangeMax; } public double evaluateIntegral(double startLocation, double endLocation) { calculateBasis(); // TODO return 0; } private void calculateBasis() { if (updateBasis) { Double[] x = knotLocations.getValues(); Double[] y = knotValues.getValues(); rangeMin = x[0]; rangeMax = x[x.length-1]; for (int i = 0; i < n - 1; i++) { h[i] = x[i + 1] - x[i]; deltaY[i] = y[i + 1] - y[i]; } hMatrix.set(0, 0, 1.0); // TODO Do not need to update yByH.set(0, 0.0); for (int i = 1; i < n - 2; i++) { hMatrix.set(i, i - 1, h[i - 1]); hMatrix.set(i, i, 2 * (h[i] + h[i - 1])); hMatrix.set(i, i + 1, h[i]); yByH.set(i, 6 * (deltaY[i]/h[i] - deltaY[i - 1]/h[i - 1]) ); } hMatrix.set(n - 1, n - 1, 1.0); // TODO Do not need to update yByH.set(n - 1, 0.0); hMatrix.solve(yByH, z); updateBasis = false; } } public void addModelListener(ModelListener listener) { } protected void handleModelChangedEvent(Model model, Object object, int index) { } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { updateBasis = true; } protected void storeState() { } protected void restoreState() { } protected void acceptState() { } /** * The XML parser */ public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return SPLINE_BASIS; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { int degree = xo.getAttribute(DEGREE, 3); Parameter knotLocations = (Parameter) xo.getElementFirstChild(KNOT_POINTS); Parameter knotValues = (Parameter) xo.getElementFirstChild(KNOT_VALUES); if (knotLocations.getDimension() != knotValues.getDimension()) throw new XMLParseException("Spline basis knot locations and values must have the same dimension"); List<XY> xyList = new ArrayList<XY>(); for(int i=0; i<knotLocations.getDimension(); i++) xyList.add(new XY(knotLocations.getParameterValue(i),knotValues.getParameterValue(i))); Collections.sort(xyList); for(int i=0; i<knotLocations.getDimension(); i++) { XY xy = xyList.get(i); knotLocations.setParameterValue(i,xy.x); knotValues.setParameterValue(i,xy.y); } return new SplineBasis(xo.getId(), knotLocations, knotValues, degree); } class XY implements Comparable { private final double x; private final double y; public XY(double x, double y) { this.x = x; this.y = y; } public int compareTo(Object o) { double z = ((XY)o).x; if (z == x) throw new RuntimeException("No ties accepted in spline basis"); return Double.compare(x, z); } } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents the a spline interpolation of discrete data."; } public Class getReturnType() { return SplineBasis.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newIntegerRule(DEGREE, true), new ElementRule(KNOT_POINTS, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) }), new ElementRule(KNOT_VALUES, new XMLSyntaxRule[]{ new ElementRule(Parameter.class) }), }; }; private final int degree; private final int n; private final Variable<Double> knotLocations; private final Variable<Double> knotValues; // private double[] splineCoefficients; // private double[] storedSplineCoefficients; private boolean updateBasis; private double rangeMax; private double rangeMin; private final double[] h; private final double[] deltaY; private final Matrix hMatrix; private final Vector yByH; private final Vector z; }