/*
* DiscreteAntigenicTraitLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.antigenic;
import dr.inference.model.*;
import dr.math.MathUtils;
import dr.util.*;
import dr.xml.*;
import java.io.*;
import java.util.*;
import java.util.logging.Logger;
/**
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
*/
@Deprecated
public class DiscreteAntigenicTraitLikelihood extends AntigenicTraitLikelihood implements Citable {
public final static String DISCRETE_ANTIGENIC_TRAIT_LIKELIHOOD = "discreteAntigenicTraitLikelihood";
private static final int CLUSTER_COUNT = -1;
/**
* Constructor
* @param mdsDimension dimension of the mds space
* @param mdsPrecision parameter which gives the precision of the bmds
* @param locationsParameter a parameter of locations of viruses/sera
* @param dataTable the assay table (virus in rows, serum assays in columns)
* @param virusAntiserumMap a map of viruses to corresponding sera
* @param assayAntiserumMap a map of repeated assays for a given sera
* @param log2Transform transform the data into log 2 space
*/
public DiscreteAntigenicTraitLikelihood(
int mdsDimension,
Parameter mdsPrecision,
Parameter clusterIndexParameter,
MatrixParameter locationsParameter,
CompoundParameter tipTraitParameter,
DataTable<String[]> dataTable,
Map<String, String> virusAntiserumMap,
Map<String, String> assayAntiserumMap,
List<String> virusLocationStatisticList,
final boolean log2Transform) {
super(DISCRETE_ANTIGENIC_TRAIT_LIKELIHOOD);
String[] virusNames = dataTable.getRowLabels();
String[] assayNames = dataTable.getColumnLabels();
// the total number of viruses is the number of rows in the table
int virusCount = dataTable.getRowCount();
int assayCount = dataTable.getColumnCount();
int[] assayToSerumIndices = new int[assayNames.length];
double[][] observationValueTable = new double[virusCount][assayCount];
ObservationType[][] observationTypeTable = new ObservationType[virusCount][assayCount];
initalizeTable(dataTable, observationValueTable, observationTypeTable, log2Transform);
// This removes viruses that are not in the tree
List<String> tipLabels = null;
if (tipTraitParameter != null) {
tipLabels = new ArrayList<String>();
int tipCount = tipTraitParameter.getParameterCount();
for (int i = 0; i < tipCount; i++) {
String label = tipTraitParameter.getParameter(i).getParameterName();
if (label.endsWith(".antigenic")) {
label = label.substring(0, label.indexOf(".antigenic"));
}
tipLabels.add(label);
}
}
// locations are either viruses or sera (or both)
List<String> locationLabelsList = new ArrayList<String>();
int[] virusToLocationIndices = new int[virusCount];
int count = 0;
for (String virusName : virusNames) {
String name = null;
if (virusAntiserumMap != null) {
name = virusAntiserumMap.get(virusName);
}
if (name == null) {
name = virusName;
}
virusToLocationIndices[count] = locationLabelsList.size();
locationLabelsList.add(name);
count++;
}
List<String> serumNamesList = new ArrayList<String>();
count = 0;
for (String assayName : assayNames) {
String name = null;
if (assayAntiserumMap != null) {
name = assayAntiserumMap.get(assayName);
}
if (name == null) {
name = assayName;
}
int index = serumNamesList.indexOf(name);
if (index == -1) {
index = serumNamesList.size();
serumNamesList.add(name);
}
assayToSerumIndices[count] = index;
count++;
}
String[] serumNames = new String[serumNamesList.size()];
serumNamesList.toArray(serumNames);
int serumCount = serumNames.length;
int[] serumToLocationIndices = new int[serumCount];
count = 0;
for (String serumName : serumNames) {
int index = locationLabelsList.indexOf(serumName);
if (index == -1) {
index = locationLabelsList.size();
locationLabelsList.add(serumName);
}
serumToLocationIndices[count] = index;
count++;
}
String[] locationLabels = new String[locationLabelsList.size()];
locationLabelsList.toArray(locationLabels);
int locationCount = locationLabels.length;
List<Double> observationList = new ArrayList<Double>();
List<ObservationType> observationTypeList = new ArrayList<ObservationType>();
int[] virusObservationCounts = new int[virusCount];
int[] serumObservationCounts = new int[serumCount];
List<Pair> locationPairs = new ArrayList<Pair>();
// System.out.print("virus");
// for (int j = 0; j < assayCount; j++) {
// int k = assayToSerumIndices[j];
// System.out.print("\t" + assayNames[k] + "[" + locationLabels[serumToLocationIndices[k]] + "]");
// }
// System.out.println();
// Build a sparse matrix of non-missing assay values
for (int i = 0; i < virusCount; i++) {
if (virusToLocationIndices[i] != -1) {
// viruses with location indices of minus one have been excluded
// System.out.print(virusNames[i] + "[" + locationLabels[virusToLocationIndices[i]] + "]");
for (int j = 0; j < assayCount; j++) {
int k = assayToSerumIndices[j];
Double value = observationValueTable[i][j];
ObservationType type = observationTypeTable[i][j];
// System.out.print("\t" + value);
if (type != ObservationType.MISSING) {
observationList.add(value);
observationTypeList.add(type);
locationPairs.add(new Pair(virusToLocationIndices[i], serumToLocationIndices[k]));
virusObservationCounts[i]++;
serumObservationCounts[k]++;
}
}
// System.out.println();
}
}
// check that all the viruses and sera have observations
for (int i = 0; i < virusCount; i++) {
if (virusToLocationIndices[i] != -1 && virusObservationCounts[i] == 0) {
System.err.println("WARNING: Virus " + virusNames[i] + " has 0 observations");
}
}
for (int j = 0; j < serumCount; j++) {
if (serumObservationCounts[j] == 0) {
System.err.println("WARNING: Antisera " + serumNames[j] + " has 0 observations");
}
}
// Convert into arrays
double[] observations = new double[observationList.size()];
for (int i = 0; i < observationList.size(); i++) {
observations[i] = observationList.get(i);
}
int[] rowLocationIndices = new int[locationPairs.size()];
for (int i = 0; i < rowLocationIndices.length; i++) {
rowLocationIndices[i] = locationPairs.get(i).location1;
}
int[] columnLocationIndices = new int[locationPairs.size()];
for (int i = 0; i < columnLocationIndices.length; i++) {
columnLocationIndices[i] = locationPairs.get(i).location2;
}
ObservationType[] observationTypes = new ObservationType[observationTypeList.size()];
observationTypeList.toArray(observationTypes);
int thresholdCount = 0;
for (int i = 0; i < observations.length; i++) {
thresholdCount += (observationTypes[i] != ObservationType.POINT ? 1 : 0);
}
if (tipTraitParameter != null) {
// the location -> tip map
tipIndices = new int[locationCount];
for (int i = 0; i < locationCount; i++) {
tipIndices[i] = tipLabels.indexOf(locationLabels[i]);
}
// for (int i = 0; i < locationCount; i++) {
// if (tipIndices[i] == -1) {
// System.err.println("Location, " + locationLabels[i] + ", not found in tree");
// }
// }
for (String tipLabel : tipLabels) {
if (!locationLabelsList.contains(tipLabel)) {
System.err.println("Tip, " + tipLabel + ", not found in location list");
}
}
} else {
tipIndices = null;
}
this.tipTraitParameter = tipTraitParameter;
StringBuilder sb = new StringBuilder();
sb.append("\tDiscreteAntigenicTraitLikelihood:\n");
sb.append("\t\t" + virusNames.length + " viruses\n");
sb.append("\t\t" + assayNames.length + " assays\n");
sb.append("\t\t" + serumNames.length + " antisera\n");
sb.append("\t\t" + locationLabels.length + " locations\n");
sb.append("\t\t" + locationPairs.size() + " distances\n");
sb.append("\t\t" + observations.length + " observations\n");
sb.append("\t\t" + thresholdCount + " threshold observations\n");
Logger.getLogger("dr.evomodel").info(sb.toString());
initialize(
mdsDimension,
false,
mdsPrecision,
locationsParameter,
locationLabels,
observations,
observationTypes,
rowLocationIndices,
columnLocationIndices);
// some random initial locations
for (int i = 0; i < locationsParameter.getParameterCount(); i++) {
for (int j = 0; j < mdsDimension; j++) {
// double r = MathUtils.nextGaussian();
double r = 0.0;
if (j == 0) {
r = (double) i * 0.05;
}
else {
r = MathUtils.nextGaussian();
}
locationsParameter.getParameter(i).setParameterValueQuietly(j, r);
}
}
// Start off with a 1-to-1 correspondence between location and cluster
if (CLUSTER_COUNT > 0) {
maxClusterCount = CLUSTER_COUNT;
} else {
maxClusterCount = getLocationCount();
}
this.clusterIndexParameter = clusterIndexParameter;
clusterIndexParameter.setDimension(getLocationCount());
clusterSizes = new int[maxClusterCount];
//Force the boundaries of rateCategoryParameter to match the category count
Parameter.DefaultBounds bound = new Parameter.DefaultBounds(maxClusterCount - 1, 0, getLocationCount());
clusterIndexParameter.addBounds(bound);
for (int i = 0; i < getLocationCount(); i++) {
int r = i;
clusterIndexParameter.setParameterValue(i, r);
for (int dim = 0; dim < mdsDimension; dim++) {
tipTraitParameter.setParameterValue((tipIndices[r] * mdsDimension) + dim,
locationsParameter.getParameterValue((r * mdsDimension) + dim));
}
}
updateClusterSizes();
addVariable(clusterIndexParameter);
addStatistic(new ClusterMask());
addStatistic(new ClusterIndices());
addStatistic(new ClusterCount());
addStatistic(new ClusterSizes());
addStatistic(new ClusteredLocations());
int i = 0;
for (String virusName : virusNames) {
if (virusLocationStatisticList.contains(virusName)) {
addStatistic(new VirusLocation(virusName + "." + "location", i));
}
i++;
}
}
@Override
protected void setupLocationsParameter(MatrixParameter locationsParameter) {
locationsParameter.setColumnDimension(getMDSDimension());
int n = CLUSTER_COUNT;
if (n < 1) {
n = getLocationCount();
}
locationsParameter.setRowDimension(n);
for (int i = 0; i < n; i++) {
locationsParameter.getParameter(i).setId("cluster_" + (i+1));
}
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
if (variable == clusterIndexParameter) {
for (int i = 0; i < distanceUpdated.length; i++) {
distanceUpdated[i] = true;
}
residualsKnown = false;
thresholdsKnown = false;
clusterMaskKnown = false;
}
if (tipTraitParameter != null) {
Parameter locations = getLocationsParameter();
int mdsDimension = getMDSDimension();
if (variable == locations) {
int location = index / mdsDimension;
int dim = index % mdsDimension;
for (int i = 0; i < clusterIndexParameter.getDimension(); i++) {
// the location of a cluster has moved so set the new location of every virus in that cluster
if (((int)clusterIndexParameter.getParameterValue(i)) == location) {
if (tipIndices[i] != -1) {
tipTraitParameter.setParameterValue((tipIndices[i] * mdsDimension) + dim, locations.getParameterValue(index));
}
}
}
} else if (variable == clusterIndexParameter) {
// a virus has moved cluster so set its location to that of a new cluster
if (tipIndices[index] != -1) {
int location = (int)clusterIndexParameter.getParameterValue(index);
for (int dim = 0; dim < mdsDimension; dim++) {
tipTraitParameter.setParameterValue((tipIndices[index] * mdsDimension) + dim, locations.getParameterValue((location * mdsDimension) + dim));
}
}
}
}
super.handleVariableChangedEvent(variable, index, type);
}
public CompoundParameter getTipTraitParameter() {
return tipTraitParameter;
}
public int[] getTipIndices() {
return tipIndices;
}
private CompoundParameter tipTraitParameter;
private int[] tipIndices;
@Override
public void makeDirty() {
super.makeDirty();
clusterMaskKnown = false;
}
@Override
protected void storeState() {
super.storeState();
}
@Override
protected void restoreState() {
super.restoreState();
clusterMaskKnown = false;
}
@Override
protected int getLocationIndex(final int index) {
return (int)clusterIndexParameter.getParameterValue(index);
}
private void updateClusterSizes() {
for (int i = 0; i < maxClusterCount; i++) {
clusterSizes[i] = 0;
}
for (int i = 0; i < getLocationCount(); i++) {
int j = (int)clusterIndexParameter.getParameterValue(i);
clusterSizes[j] ++;
}
clusterCount = 0;
for (int i = 0; i < maxClusterCount; i++) {
if (clusterSizes[i] > 0) {
clusterCount++;
}
}
clusterMaskKnown = true;
}
private int maxClusterCount;
private final Parameter clusterIndexParameter;
private final int[] clusterSizes;
private int clusterCount;
private boolean clusterMaskKnown;
public class ClusterMask extends Statistic.Abstract {
public ClusterMask() {
super("clusterMask");
}
public int getDimension() {
return maxClusterCount;
}
public double getStatisticValue(int i) {
if (!clusterMaskKnown) {
updateClusterSizes();
}
return clusterSizes[i] > 0 ? 1.0 : 0.0;
}
}
public class ClusterIndices extends Statistic.Abstract {
public ClusterIndices() {
super("clusterIndices");
}
public int getDimension() {
return clusterIndexParameter.getDimension();
}
@Override
public String getDimensionName(final int i) {
return getLocationLabels()[i];
}
public double getStatisticValue(int i) {
return clusterIndexParameter.getParameterValue(i);
}
}
public class ClusterCount extends Statistic.Abstract {
public ClusterCount() {
super("clusterCount");
}
public int getDimension() {
return 1;
}
public double getStatisticValue(int i) {
if (!clusterMaskKnown) {
updateClusterSizes();
}
return clusterCount;
}
}
public class ClusterSizes extends Statistic.Abstract {
public ClusterSizes() {
super("clusterSizes");
}
public int getDimension() {
return maxClusterCount;
}
public double getStatisticValue(int i) {
if (!clusterMaskKnown) {
updateClusterSizes();
}
return clusterSizes[i];
}
}
public class ClusteredLocations extends Statistic.Abstract {
public ClusteredLocations() {
super("clusteredLocations");
}
@Override
public String getDimensionName(final int i) {
int location = i / getMDSDimension();
int dim = i % getMDSDimension();
String label = getLocationLabels()[location];
if (getMDSDimension() == 2) {
return label + "_" + (dim == 0 ? "X" : "Y");
} else {
return label + "_" + (dim + 1);
}
}
public int getDimension() {
return getLocationCount() * getMDSDimension();
}
public double getStatisticValue(final int i) {
int location = i / getMDSDimension();
int dim = i % getMDSDimension();
int j = (int)clusterIndexParameter.getParameterValue(location);
Parameter loc = getLocationsParameter().getParameter(j);
return loc.getParameterValue(dim);
}
}
public class VirusLocation extends Statistic.Abstract {
public VirusLocation(String statisticName, int virusIndex) {
super(statisticName);
this.virusIndex = virusIndex;
}
@Override
public String getDimensionName(final int dim) {
if (getMDSDimension() == 2) {
return getStatisticName() + "_" + (dim == 0 ? "X" : "Y");
} else {
return getStatisticName() + "_" + (dim + 1);
}
}
public int getDimension() {
return getMDSDimension();
}
public double getStatisticValue(final int dim) {
int cluster = (int)clusterIndexParameter.getParameterValue(virusIndex);
Parameter loc = getLocationsParameter().getParameter(cluster);
return loc.getParameterValue(dim);
}
private final int virusIndex;
}
private class Pair {
Pair(final int location1, final int location2) {
if (location1 < location2) {
this.location1 = location1;
this.location2 = location2;
} else {
this.location1 = location2;
this.location2 = location1;
}
}
int location1;
int location2;
@Override
public boolean equals(final Object o) {
return ((Pair)o).location1 == location1 && ((Pair)o).location2 == location2;
}
@Override
public String toString() {
return "" + location1 + ", " + location2;
}
};
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public final static String FILE_NAME = "fileName";
public final static String VIRUS_MAP_FILE_NAME = "virusMapFile";
public final static String ASSAY_MAP_FILE_NAME = "assayMapFile";
public static final String CLUSTER_INDICES = "clusterIndices";
public final static String TIP_TRAIT = "tipTrait";
public final static String LOCATIONS = "locations";
public static final String MDS_DIMENSION = "mdsDimension";
public static final String MDS_PRECISION = "mdsPrecision";
public static final String VIRUS_LOCATIONS = "virusLocations";
public static final String LOG_2_TRANSFORM = "log2Transform";
public static final String TITRATION_THRESHOLD = "titrationThreshold";
public String getParserName() {
return DISCRETE_ANTIGENIC_TRAIT_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String fileName = xo.getStringAttribute(FILE_NAME);
DataTable<String[]> assayTable;
try {
assayTable = DataTable.Text.parse(new FileReader(fileName));
} catch (IOException e) {
throw new XMLParseException("Unable to read assay data from file: " + e.getMessage());
}
Map<String, String> virusAntiserumMap = null;
if (xo.hasAttribute(VIRUS_MAP_FILE_NAME)) {
try {
virusAntiserumMap = readMap(xo.getStringAttribute(VIRUS_MAP_FILE_NAME));
} catch (IOException e) {
throw new XMLParseException("Virus map file not found: " + xo.getStringAttribute(VIRUS_MAP_FILE_NAME));
}
}
Map<String, String> assayAntiserumMap = null;
if (xo.hasAttribute(ASSAY_MAP_FILE_NAME)) {
try {
assayAntiserumMap = readMap(xo.getStringAttribute(ASSAY_MAP_FILE_NAME));
} catch (IOException e) {
throw new XMLParseException("Assay map file not found: " + xo.getStringAttribute(ASSAY_MAP_FILE_NAME));
}
}
int mdsDimension = xo.getIntegerAttribute(MDS_DIMENSION);
boolean log2Transform = false;
if (xo.hasAttribute(LOG_2_TRANSFORM)) {
log2Transform = xo.getBooleanAttribute(LOG_2_TRANSFORM);
}
Parameter clusterIndicesParameter = (Parameter) xo.getElementFirstChild(CLUSTER_INDICES);
List<String> virusLocationStatisticList = null;
String[] virusLocations = xo.getStringArrayAttribute(VIRUS_LOCATIONS);
if (virusLocations != null) {
virusLocationStatisticList = Arrays.asList(virusLocations);
}
// This parameter needs to be linked to the one in the IntegratedMultivariateTreeLikelihood (I suggest that the parameter is created
// here and then a reference passed to IMTL - which optionally takes the parameter of tip trait values, in which case it listens and
// updates accordingly.
CompoundParameter tipTraitParameter = null;
if (xo.hasChildNamed(TIP_TRAIT)) {
tipTraitParameter = (CompoundParameter) xo.getElementFirstChild(TIP_TRAIT);
}
MatrixParameter locationsParameter = (MatrixParameter) xo.getElementFirstChild(LOCATIONS);
Parameter mdsPrecision = (Parameter) xo.getElementFirstChild(MDS_PRECISION);
AntigenicTraitLikelihood AGTL = new DiscreteAntigenicTraitLikelihood(mdsDimension, mdsPrecision, clusterIndicesParameter, locationsParameter, tipTraitParameter, assayTable, virusAntiserumMap, assayAntiserumMap, virusLocationStatisticList, log2Transform);
Logger.getLogger("dr.evomodel").info("Using Discrete Evolutionary Cartography model. Please cite:\n" + Utils.getCitationString(AGTL));
return AGTL;
}
private Map<String, String> readMap(String fileName) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(fileName));
Map<String, String> map = new HashMap<String, String>();
String line = reader.readLine();
while (line != null) {
if (line.trim().length() > 0) {
String[] parts = line.split("\t");
if (parts.length > 1) {
map.put(parts[0], parts[1]);
}
}
line = reader.readLine();
}
reader.close();
return map;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of immunological assay data such as Hemagglutinin inhibition (HI) given vectors of coordinates" +
"for viruses and sera/antisera in some multidimensional 'antigenic' space. This is a discrete classifier form of the model" +
"which assigns viruses to discrete antigenic classes.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
AttributeRule.newStringRule(FILE_NAME, false, "The name of the file containing the assay table"),
AttributeRule.newStringRule(VIRUS_MAP_FILE_NAME, true, "The name of the file containing the virus to serum map"),
AttributeRule.newStringRule(ASSAY_MAP_FILE_NAME, true, "The name of the file containing the assay to serum map"),
AttributeRule.newIntegerRule(MDS_DIMENSION, false, "The dimension of the space for MDS"),
AttributeRule.newBooleanRule(LOG_2_TRANSFORM, true, "Whether to log2 transform the data"),
AttributeRule.newStringArrayRule(VIRUS_LOCATIONS, true, "A list of virus names to create location statistics for"),
new ElementRule(CLUSTER_INDICES, Parameter.class, "The parameter of cluster indices for each virus/serum"),
new ElementRule(TIP_TRAIT, CompoundParameter.class, "The parameter of tip locations from the tree", true),
new ElementRule(LOCATIONS, MatrixParameter.class),
new ElementRule(MDS_PRECISION, Parameter.class)
};
public Class getReturnType() {
return DiscreteAntigenicTraitLikelihood.class;
}
};
@Override
public Citation.Category getCategory() {
return Citation.Category.TRAIT_MODELS;
}
@Override
public String getDescription() {
return "Bayesian Antigenic Cartography framework";
}
public List<Citation> getCitations() {
return Arrays.asList(CommonCitations.BEDFORD_2015_INTEGRATING);
}
}