/* * MultidimensionalScalingLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.antigenic; import dr.inference.model.*; import dr.math.distributions.NormalDistribution; import dr.util.DataTable; import dr.xml.*; import java.io.FileReader; import java.io.IOException; /** * @author Andrew Rambaut * @author Marc Suchard * @version $Id$ */ public class MultidimensionalScalingLikelihood extends AbstractModelLikelihood { public enum ObservationType { POINT, UPPER_BOUND, LOWER_BOUND, MISSING } public final static String MULTIDIMENSIONAL_SCALING_LIKELIHOOD = "multidimensionalScalingLikelihood"; public MultidimensionalScalingLikelihood(String name) { super(name); } /** * A simple constructor for a fully specified symmetrical data matrix * @param mdsDimension * @param mdsPrecision * @param locationsParameter * @param dataTable */ public MultidimensionalScalingLikelihood( int mdsDimension, boolean isLeftTruncated, Parameter mdsPrecision, MatrixParameter locationsParameter, DataTable<double[]> dataTable) { super(MULTIDIMENSIONAL_SCALING_LIKELIHOOD); // construct a compact data table String[] rowLabels = dataTable.getRowLabels(); String[] columnLabels = dataTable.getRowLabels(); int rowCount = dataTable.getRowCount(); int observationCount = ((rowCount - 1) * rowCount) / 2; double[] observations = new double[observationCount]; ObservationType[] observationTypes = new ObservationType[observationCount]; // int[] distanceIndices = new int[observationCount]; int[] rowLocationIndices = new int[observationCount]; int[] columnLocationIndices = new int[observationCount]; int u = 0; for (int i = 0; i < rowCount; i++) { double[] dataRow = dataTable.getRow(i); for (int j = i + 1; j < rowCount; j++) { observations[u] = dataRow[j]; observationTypes[u] = ObservationType.POINT; // distanceIndices[u] = u; rowLocationIndices[u] = i; columnLocationIndices[u] = j; u++; } } initialize(mdsDimension, isLeftTruncated, mdsPrecision, locationsParameter, rowLabels, observations, observationTypes, rowLocationIndices, columnLocationIndices); } protected void initialize( final int mdsDimension, final boolean isLeftTruncated, final Parameter mdsPrecision, final MatrixParameter locationsParameter, final String[] locationLabels, final double[] observations, final ObservationType[] observationTypes, // final int[] distanceIndices, final int[] rowLocationIndices, final int[] columnLocationIndices) { this.mdsDimension = mdsDimension; locationCount = locationLabels.length; // upper triangular matrix this.distanceCount = (locationCount * (locationCount - 1)) / 2; this.locationLabels = locationLabels; this.observations = observations; this.observationTypes = observationTypes; // this.distanceIndices = distanceIndices; this.rowLocationIndices = rowLocationIndices; this.columnLocationIndices = columnLocationIndices; this.observationCount = observations.length; this.upperThresholdCount = 0; this.lowerThresholdCount = 0; for (ObservationType type : observationTypes) { upperThresholdCount += (type == ObservationType.UPPER_BOUND ? 1 : 0); lowerThresholdCount += (type == ObservationType.LOWER_BOUND ? 1 : 0); } thresholdCount = upperThresholdCount + lowerThresholdCount; pointObservationCount = observationCount - thresholdCount; upperThresholdIndices = new int[upperThresholdCount]; lowerThresholdIndices = new int[lowerThresholdCount]; pointObservationIndices = new int[pointObservationCount]; int ut = 0; int lt = 0; int po = 0; for (int i = 0; i < observationCount; i++) { switch (observationTypes[i]) { case POINT: pointObservationIndices[po] = i; po++; break; case UPPER_BOUND: upperThresholdIndices[ut] = i; ut++; break; case LOWER_BOUND: lowerThresholdIndices[lt] = i; lt++; break; } } this.locationsParameter = locationsParameter; setupLocationsParameter(this.locationsParameter); addVariable(locationsParameter); locationUpdated = new boolean[locationsParameter.getParameterCount()]; // a cache of row to column distances (column indices given by array above). distances = new double[distanceCount]; storedDistances = new double[distanceCount]; distanceUpdated = new boolean[distanceCount]; // a cache of individual truncations truncations = new double[distanceCount]; storedTruncations = new double[distanceCount]; // a cache of threshold calcs thresholds = new double[thresholdCount]; storedThresholds = new double[thresholdCount]; this.mdsPrecisionParameter = mdsPrecision; addVariable(mdsPrecision); this.isLeftTruncated = isLeftTruncated; // Re-normalize likelihood for strictly positive distances // make sure everything is calculated on first evaluation makeDirty(); addStatistic(new Distances()); } protected void setupLocationsParameter(MatrixParameter locationsParameter) { if (locationsParameter.getColumnDimension() > 0){ boolean proceed = true; //TODO: allow for a generic tipTrait/location specification with mismatches: tipTraits which are not represented by the pairwise distances (tipTraits > locations) and more locations based on pairwise distances than represented by the tipTraits if (locationsParameter.getColumnDimension() != locationCount){ System.err.println("locationsParameter column dimension ("+locationsParameter.getColumnDimension()+") is not equal to the locationCount ("+locationCount+")"); proceed = false; } if (locationsParameter.getRowDimension() != mdsDimension){ System.err.println("locationsParameter row dimension ("+locationsParameter.getRowDimension()+") is not equal to the mdsDimension ("+mdsDimension+")"); proceed = false; } if (!proceed) { System.exit(-1); } } else{ locationsParameter.setColumnDimension(mdsDimension); locationsParameter.setRowDimension(locationCount); } for (int i = 0; i < locationLabels.length; i++) { locationsParameter.getParameter(i).setId(locationLabels[i]); } for (int i = 0; i < locationsParameter.getParameterCount(); ++i) { Parameter param = locationsParameter.getParameter(i); try { if (param.getBounds() != null) { // Do nothing } } catch (NullPointerException exception) { param.addBounds(new Parameter.DefaultBounds( Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, param.getDimension())); } } } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) { // TODO Flag which cachedDistances or mdsPrecision need to be updated if (variable == locationsParameter) { int location = index / mdsDimension; locationUpdated[location] = true; distancesKnown = false; residualsKnown = false; thresholdsKnown = false; truncationsKnown = false; } else if (variable == mdsPrecisionParameter) { for (int i = 0; i < distanceUpdated.length; i++) { distanceUpdated[i] = true; } residualsKnown = false; thresholdsKnown = false; truncationsKnown = false; } else { // could be a derived class's parameter // throw new IllegalArgumentException("Unknown parameter"); } likelihoodKnown = false; } @Override protected void storeState() { System.arraycopy(distances, 0, storedDistances, 0, distances.length); System.arraycopy(truncations, 0, storedTruncations, 0, truncations.length); System.arraycopy(thresholds, 0, storedThresholds, 0, thresholds.length); storedLogLikelihood = logLikelihood; storedTruncationSum = truncationSum; storedThresholdSum = thresholdSum; storedSumOfSquaredResiduals = sumOfSquaredResiduals; } @Override protected void restoreState() { double[] tmp = storedDistances; storedDistances = distances; distances = tmp; distancesKnown = true; tmp = storedTruncations; storedTruncations = truncations; truncations = tmp; tmp = storedThresholds; storedThresholds = thresholds; thresholds = tmp; logLikelihood = storedLogLikelihood; likelihoodKnown = true; truncationSum = storedTruncationSum; truncationsKnown = true; thresholdSum = storedThresholdSum; thresholdsKnown = true; sumOfSquaredResiduals = storedSumOfSquaredResiduals; residualsKnown = true; } @Override protected void acceptState() { // do nothing } public void makeDirty() { distancesKnown = false; likelihoodKnown = false; residualsKnown = false; truncationsKnown = false; thresholdsKnown = false; for (int i = 0; i < locationUpdated.length; i++) { locationUpdated[i] = true; } for (int i = 0; i < distanceUpdated.length; i++) { distanceUpdated[i] = true; } } public Model getModel() { return this; } public double getLogLikelihood() { if (!likelihoodKnown) { if (!distancesKnown) { calculateDistances(); residualsKnown = false; } logLikelihood = computeLogLikelihood(); for (int i = 0; i < locationUpdated.length; i++) { locationUpdated[i] = false; } for (int i = 0; i < distanceUpdated.length; i++) { distanceUpdated[i] = false; } } return logLikelihood; } // This function can be overwritten to implement other sampling densities, i.e. discrete ranks protected double computeLogLikelihood() { double precision = mdsPrecisionParameter.getParameterValue(0); if (!residualsKnown) { sumOfSquaredResiduals = calculateSumOfSquaredResiduals(); } // totalNonMissingCount should be totalObservedCount (not > or < threshold) double logLikelihood = (0.5 * Math.log(precision) * pointObservationCount) - (0.5 * precision * sumOfSquaredResiduals); if (thresholdCount > 0) { if (!thresholdsKnown) { thresholdSum = calculateThresholdObservations(precision); } logLikelihood += thresholdSum; } if (isLeftTruncated) { if (!truncationsKnown) { calculateTruncations(precision); } truncationSum = calculateTruncationSum(); logLikelihood -= truncationSum; } likelihoodKnown = true; return logLikelihood; } protected double calculateThresholdObservations(double precision) { double sum = 0.0; double sd = 1.0 / Math.sqrt(precision); int j = 0; for (int i = 0; i < upperThresholdCount; i++) { int observationIndex = upperThresholdIndices[i]; int dist = getDistanceIndexForObservation(observationIndex); if (dist != -1) { if (distanceUpdated[dist]) { // double cdf = NormalDistribution.cdf(observations[observationIndex], distances[dist], sd, false); // double tail = 1.0 - cdf; // using special tail function of NormalDistribution (see main() in NormalDistribution for test) double tail = NormalDistribution.tailCDF(observations[observationIndex], distances[dist], sd); thresholds[j] = Math.log(tail); } } else { // -1 denotes a distance to self (i.e., 0) double tail = NormalDistribution.tailCDF(observations[observationIndex], 0.0, sd); thresholds[j] = Math.log(tail); } if (Double.isInfinite(thresholds[j])) { System.out.println("Error calculation threshold probability"); } sum += thresholds[j]; j++; } for (int i = 0; i < lowerThresholdCount; i++) { int observationIndex = lowerThresholdIndices[i]; int dist = getDistanceIndexForObservation(observationIndex); if (dist != -1) { if (distanceUpdated[dist]) { thresholds[j] = NormalDistribution.cdf(observations[observationIndex], distances[dist], sd, true); } } else { // -1 denotes a distance to self (i.e., 0) thresholds[j] = NormalDistribution.cdf(observations[observationIndex], 0.0, sd, true); } if (Double.isInfinite(thresholds[j])) { System.out.println("Error calculation threshold probability"); } sum += thresholds[j]; j++; } thresholdsKnown = true; return sum; } protected void calculateTruncations(double precision) { double sd = 1.0 / Math.sqrt(precision); for (int i = 0; i < distanceCount; i++) { if (distanceUpdated[i]) { truncations[i] = NormalDistribution.cdf(distances[i], 0.0, sd, true); } } truncationsKnown = true; } protected double calculateTruncationSum() { double sum = 0.0; for (int i = 0; i < observationCount; i++) { int dist = getDistanceIndexForObservation(i); if (dist != -1) { sum += truncations[dist]; } else { sum += Math.log(0.5); } } return sum; } protected double calculateSumOfSquaredResiduals() { double sum = 0.0; for (int i = 0; i < observationCount; i++) { if (observationTypes[i] == ObservationType.POINT) { // Only increment sum if dataTable[i][j] is observed (not > or < threshold) double residual; int dist = getDistanceIndexForObservation(i); if (dist == -1) { // -1 denotes a distance to self (i.e., 0) residual = - observations[i]; } else { residual = distances[dist] - observations[i]; } sum += residual * residual; } } residualsKnown = true; return sum; } protected void calculateDistances() { int k = 0; for (int x = 0; x < locationCount; x++) { for (int y = x + 1; y < locationCount; y++) { // the diagonal (x=y) is always zero so don't update it if (locationUpdated[x] || locationUpdated[y]) { distances[k] = calculateDistance( locationsParameter.getParameter(x), locationsParameter.getParameter(y)); distanceUpdated[k] = true; } k++; } } distancesKnown = true; } private int getDistanceIndexForObservation(final int observationIndex) { int row = getLocationIndex(rowLocationIndices[observationIndex]); int col = getLocationIndex(columnLocationIndices[observationIndex]); if (row == col) { return -1; } // make sure row index is less than column index if (row > col) { int tmp = row; row = col; col = tmp; } // get the index of [row, col] in an unrolled upper triangular matrix int index = 0; for (int x = 0; x < row; x++) { index += locationCount - x - 1; } index += col - row - 1; return index; } /** * Returns the index of the given by index * Overriding this allows the clustering of locations * @param index * @return */ protected int getLocationIndex(int index) { return index; } public String[] getLocationLabels() { return locationLabels; } protected double calculateDistance(Parameter X, Parameter Y) { double sum = 0.0; for (int i = 0; i < mdsDimension; i++) { double difference = X.getParameterValue(i) - Y.getParameterValue(i); sum += difference * difference; } return Math.sqrt(sum); } public int getMDSDimension() { return mdsDimension; } public int getLocationCount() { return locationCount; } public MatrixParameter getLocationsParameter() { return locationsParameter; } public class Distances extends Statistic.Abstract { public Distances() { super("distances"); } public int getDimension() { return distanceCount; } public double getStatisticValue(int i) { if (!distancesKnown) { calculateDistances(); } return distances[i]; } } // ************************************************************** // XMLObjectParser // ************************************************************** public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public final static String FILE_NAME = "fileName"; public final static String TIP_TRAIT = "tipTrait"; public final static String LOCATIONS = "locations"; public static final String MDS_DIMENSION = "mdsDimension"; public static final String MDS_PRECISION = "mdsPrecision"; public String getParserName() { return MULTIDIMENSIONAL_SCALING_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { String fileName = xo.getStringAttribute(FILE_NAME); DataTable<double[]> distanceTable; try { distanceTable = DataTable.Double.parse(new FileReader(fileName)); } catch (IOException e) { throw new XMLParseException("Unable to read assay data from file: " + e.getMessage()); } if (distanceTable.getRowCount() != distanceTable.getColumnCount()) { throw new XMLParseException("Data table is not symmetrical."); } int mdsDimension = xo.getIntegerAttribute(MDS_DIMENSION); MatrixParameter locationsParameter = (MatrixParameter) xo.getElementFirstChild(LOCATIONS); Parameter mdsPrecision = (Parameter) xo.getElementFirstChild(MDS_PRECISION); boolean isLeftTruncated = false; return new MultidimensionalScalingLikelihood(mdsDimension, isLeftTruncated, mdsPrecision, locationsParameter, distanceTable); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "Provides the likelihood of pairwise distance given vectors of coordinates" + "for points according to the multidimensional scaling scheme of XXX & Rafferty (xxxx)."; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule(FILE_NAME, false, "The name of the file containing the assay table"), AttributeRule.newIntegerRule(MDS_DIMENSION, false, "The dimension of the space for MDS"), new ElementRule(LOCATIONS, MatrixParameter.class), new ElementRule(MDS_PRECISION, Parameter.class) }; public Class getReturnType() { return MultidimensionalScalingLikelihood.class; } }; private int distanceCount; private int observationCount; private int upperThresholdCount; private int lowerThresholdCount; private int pointObservationCount; private int thresholdCount; private String[] locationLabels; private int locationCount; private double[] observations; private ObservationType[] observationTypes; // protected int[] distanceIndices; private int[] rowLocationIndices; private int[] columnLocationIndices; private int[] upperThresholdIndices; private int[] lowerThresholdIndices; private int[] pointObservationIndices; private MatrixParameter locationsParameter; private Parameter mdsPrecisionParameter; private boolean likelihoodKnown = false; private double logLikelihood; private double storedLogLikelihood; protected boolean distancesKnown = false; private double sumOfSquaredResiduals; private double storedSumOfSquaredResiduals; private double[] distances; private double[] storedDistances; protected boolean[] locationUpdated; protected boolean[] distanceUpdated; protected boolean residualsKnown = false; protected boolean truncationsKnown = false; private double truncationSum; private double storedTruncationSum; private double[] truncations; private double[] storedTruncations; protected boolean thresholdsKnown = false; private double thresholdSum; private double storedThresholdSum; private double[] thresholds; private double[] storedThresholds; private boolean isLeftTruncated; private int mdsDimension; }