package org.seqcode.math.diff; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import org.seqcode.deepseq.events.BindingEvent; import org.seqcode.deepseq.events.BindingManager; import org.seqcode.deepseq.experiments.ControlledExperiment; import org.seqcode.deepseq.experiments.ExperimentCondition; import org.seqcode.deepseq.experiments.ExperimentManager; import org.seqcode.gseutils.Pair; import org.seqcode.viz.scatter.ScatterPlotMaker; import Jama.Matrix; /** * CountsDataset is a class used to hold a matrix of counts associated with a set of experiments. * The class is used as the basis for pairwise statistical testing. As such, the class also * maintains scaling factors (for each pair of experiments), and p-value matrices (for each pair * of conditions). * * @author Shaun Mahony * @version %I%, %G% */ public class CountsDataset { protected Matrix counts; //Counts matrix. Indexed by unit and sample protected double [] totals; //Count totals. Indexed by sample protected double [] scaling; //Scaling factors. Indexed by sample protected String [] unitNames; //Unit names (gene names or peak coords) protected int numSamples=1; protected int numUnits = 1; protected int numConds = 1; protected int focalCondition=0; protected int [] design; //Design array. Sample index --> Condition index protected Matrix DEpval; //P-values of differential expression. Indexed by unit and condition protected Matrix condFold;//Condition Fold (logFold from EdgeR). Indexed by unit and condition protected Matrix condMean;//Condition means (logCPM from EdgeR). Indexed by unit and condition protected Matrix condRawVar; //Condition variances. Indexed by unit and condition protected HashMap<Integer, Pair<String,String>> sampleToExptName; protected HashMap<Integer, String> condToName; protected HashMap<String, Integer> unitToIndex; protected final double LOG_2 = Math.log(2.0); /** * Constructor: preformed data * @param counts : count matrix * @param units : Names of each unit (i.e. gene name or peak coordinate) * @param design : design array (which index in counts corresponds to which condition, indexed by integer) * @param exptNameMap: translation between sample index and condition/replicate names * @param condNameMap: translation between condition index and name * @param focalCond: focal condition (sample index of) */ public CountsDataset(Matrix counts, String [] units, int [] design, HashMap<Integer, Pair<String,String>> exptNameMap, HashMap<Integer, String> condNameMap){this(counts, units, design, exptNameMap, condNameMap, 0);} public CountsDataset(Matrix counts, String [] units, int [] design, HashMap<Integer, Pair<String,String>> exptNameMap, HashMap<Integer, String> condNameMap, int focalCond){ this.counts = counts; this.unitNames = units; this.design = design; this.focalCondition = focalCond; this.sampleToExptName = exptNameMap; this.numSamples = counts.getColumnDimension(); this.numUnits = counts.getRowDimension(); this.condToName = condNameMap; this.numConds = condToName.size(); this.totals = new double[numSamples]; this.scaling = new double[numSamples]; this.DEpval = new Matrix(numUnits,numConds); this.condFold = new Matrix(numUnits,numConds); this.condMean = new Matrix(numUnits,numConds); this.condRawVar = new Matrix(numUnits,numConds); //Initialize sample totals for(int j=0; j<numSamples; j++){ double total =0; for(int i=0; i<numUnits; i++) total+= counts.get(i,j); totals[j]=total; scaling[j]=1; } //Initialize the unit -> index map unitToIndex = new HashMap<String, Integer>(); for(int i=0; i<numUnits; i++){ String un = unitNames[i]; unitToIndex.put(un, i); } } /** * Constructor: initialize data from events * @param expts */ public CountsDataset(ExperimentManager manager, List<BindingEvent> events, int focalCond){ this.focalCondition = focalCond; this.numConds = manager.getNumConditions(); //Count points & samples int sampleCount = manager.getReplicates().size(); int numPoints=events.size(); //Set up counts array this.counts = new Matrix(numPoints,sampleCount); this.unitNames = new String[numPoints]; int d=0; for(BindingEvent be : events){ for(ExperimentCondition c : manager.getConditions()){ for(ControlledExperiment rep : c.getReplicates()){ counts.set(d, rep.getIndex(), be.getRepSigHits(rep)); } } unitNames[d] = be.getPoint().getLocationString(); d++; } //Name translators & design this.condToName = new HashMap<Integer, String>(); this.sampleToExptName = new HashMap<Integer, Pair<String,String>>(); this.design = new int[sampleCount]; for(ExperimentCondition c : manager.getConditions()){ for(ControlledExperiment rep : c.getReplicates()){ design[rep.getIndex()] = c.getIndex(); sampleToExptName.put(rep.getIndex(), new Pair<String,String>(c.getName(), rep.getRepName())); }condToName.put(c.getIndex(), c.getName()); } this.numSamples = counts.getColumnDimension(); this.numUnits = counts.getRowDimension(); this.totals = new double[numSamples]; this.scaling = new double[numSamples]; this.DEpval = new Matrix(numUnits,numConds); this.condFold = new Matrix(numUnits,numConds); this.condMean = new Matrix(numUnits,numConds); this.condRawVar = new Matrix(numUnits,numConds); //Initialize sample totals for(int j=0; j<numSamples; j++){ double total =0; for(int i=0; i<numUnits; i++) total+= counts.get(i,j); totals[j]=total; scaling[j]=1; } //Initialize the unit -> index map unitToIndex = new HashMap<String, Integer>(); for(int i=0; i<numUnits; i++){ String un = unitNames[i]; unitToIndex.put(un, i); } } /** * Calculcate the scaled mean and fold for all points in reference to the focal condition */ public void calcScMeanAndFold(){ for(int i=0; i<numUnits; i++){ //Average for focal condition double focalScTotal=0; double focalSamps=0; for(int s=0; s<numSamples; s++){ if(design[s]==focalCondition){ focalScTotal+= counts.get(i,s)/scaling[s]; focalSamps++; } } double focalScMean = focalScTotal/focalSamps; condMean.set(i, focalCondition, focalScMean); condFold.set(i, focalCondition, 1.0); //Averages for other conditions for(int c=0; c<numConds; c++){ if(c!=focalCondition){ double otherScTotal=0; double otherSamps=0; for(int s=0; s<numSamples; s++){ if(design[s]==c){ otherScTotal+= counts.get(i,s)/scaling[s]; otherSamps++; } } double otherScMean = otherScTotal/otherSamps; condMean.set(i, c, (otherScMean+focalScMean)/2); double fold = focalScMean>0 ? (otherScMean>0 ? (otherScMean/focalScMean) : 1/focalScMean) : otherScMean; condFold.set(i, c, fold); } } } } /** * Update the provided events with differential expression values in the current dataset. * @param events * @return */ public List<BindingEvent> updateEvents( List<BindingEvent> events, ExperimentManager manager){ for(int e=0; e<events.size(); e++){ BindingEvent ev = events.get(e); ExperimentCondition ref = manager.getIndexedCondition(focalCondition); for(ExperimentCondition c : manager.getConditions()){ if(c!=ref){ //Only update the p-value since this should only be called after EdgeR as EdgeR's fold and mean are weird. ev.setInterCondP(c, ref, DEpval.get(e, c.getIndex())); ev.setInterCondFold(c, ref, condFold.get(e, c.getIndex())); ev.setInterCondScMean(c, ref, condMean.get(e, c.getIndex())); } } } return events; } /** * Get the counts matrix * @return */ public Matrix getCounts(){return counts;} /** * Get an individual count * @param u Unit index (i.e. gene or peak) * @param s Sample index * @return */ public double getCount(int u, int s){return counts.get(u, s);} /** * Get the sum of counts for all samples * @return */ public double[] getTotals(){return totals;} /** * Get the sum of counts for sample s * @param s * @return */ public double getTotal(int s){return totals[s];} /** * Get the name of unit u * @param u * @return */ public String getUnitName(int u){return unitNames[u];} /** * Get the name of an experiment (Condition:Replicate) * @param i * @return */ public Pair<String,String> getExptName(int i){ return sampleToExptName.get(i);} /** * Get the name of a condition (Condition:Replicate) * @param i * @return */ public String getCondName(int i){ return condToName.get(i);} /** * Set the scaling factors array * @param s double[] */ public void setScalingFactors(double[] s){ if(s.length==scaling.length) scaling = s; else{ System.err.println("Error: scaling array has incompatible length."); System.exit(1); } } /** * Get the scaling array for all samples * @return */ public double[] getScalingFactors(){return scaling;} /** * Get the scaling factor for sample s * @param s * @return */ public double getScalingFactor(int s){return scaling[s];} /** * Set the index of the condition that all comparisons are in reference to. * @param x */ public void setFocalCondition(int x){focalCondition=x;} /** * Get the focal condition * @param x * @return */ public int getFocalCondition(){return focalCondition;} /** * Get the number of samples * @return */ public int getNumSamples(){return numSamples;} /** * Get the number of units (i.e. genes or peaks) * @return */ public int getNumUnits(){return numUnits;} /** * Get the number of conditions * @return */ public int getNumConditions(){return numConds;} /** * Get the design array (all samples) * @return */ public int[] getDesignArray(){return design;} /** * Get the condition corresponding to sample s * @param i * @return */ public int getSampleCondition(int s){return design[s];} /** * Get the differential expression p-value matrix * @return */ public Matrix getDEpvals(){return DEpval;} /** * Get an individual p-value * @param u Unit index (i.e. gene or peak) * @param c Condition index * @return */ public double getDEpval(int u, int c){return DEpval.get(u, c);} /** * Set the DE p-values matrix * @param p */ public void setDEpvals(Matrix p){DEpval=p;} /** * Set an individual p-value * @param u * @param c * @param p */ public void setDEpval(int u, int c, double p){DEpval.set(u, c,p);} /** * Get the differential expression p-value matrix * @return */ public Matrix getCondMeans(){return condMean;} /** * Get an individual p-value * @param u Unit index (i.e. gene or peak) * @param c Condition index * @return */ public double getCondMean(int u, int c){return condMean.get(u, c);} /** * Get the differential expression fold matrix * @return */ public Matrix getCondFolds(){return condFold;} /** * Get an individual fold value * @param u Unit index (i.e. gene or peak) * @param c Condition index * @return */ public double getCondFold(int u, int c){return condFold.get(u, c);} /** * Set the conditions mean matrix * @param c */ public void setCondMeans(Matrix c){condMean=c;} /** * Set an individual p-value * @param u * @param c * @param p */ public void setCondMean(int u, int c, double p){condMean.set(u, c,p);} /** * Set the conditions mean matrix * @param c */ public void setCondFolds(Matrix c){condFold=c;} /** * Set an individual p-value * @param u * @param c * @param p */ public void setCondFold(int u, int c, double f){condFold.set(u, c, f);} /** * Get the differential expression p-value matrix * @return */ public Matrix getCondRawVar(){return condRawVar;} /** * Get an individual p-value * @param u Unit index (i.e. gene or peak) * @param c Condition index * @return */ public double getCondRawVar(int u, int c){return condRawVar.get(u, c);} /** * Set the conditions vars matrix * @param c */ public void setCondRawVars(Matrix c){condRawVar=c;} /** * Set an individual p-value * @param u * @param c * @param p */ public void setCondRawVar(int u, int c, double p){condRawVar.set(u, c,p);} /** * Return the scaling factors * If tableFormat is true, print on a single line * @param tableFormat * @return */ public String scalingFactorString(boolean tableFormat){ String str=""; if(tableFormat){ Pair<String,String> name = sampleToExptName.get(focalCondition); str = name.car()+":"+name.cdr(); for(int s=0; s<numSamples; s++) str = str+"\t"+scaling[s]; }else{ for(int s=0; s<numSamples; s++){ Pair<String,String> name = sampleToExptName.get(s); str = str+name.car()+":"+name.cdr()+"\t"+scaling[s]+"\n"; } } return str; } /** * Return the experiment names (condition:replicate) * If tableFormat is true, print on a single line * @param tableFormat * @return */ public String experimentNamesString(boolean tableFormat){ String str=""; if(tableFormat){ for(int s=0; s<numSamples; s++){ Pair<String,String> name = sampleToExptName.get(s); str = str+"\t"+name.car()+":"+name.cdr(); } }else{ for(int s=0; s<numSamples; s++){ Pair<String,String> name = sampleToExptName.get(s); str = str+name.car()+":"+name.cdr()+"\n"; } } return str; } /** * Return the index number for a given unit name * @param s * @return */ public Integer getUnitID(String s){ if(unitToIndex.containsKey(s)) return unitToIndex.get(s); else return -Integer.MAX_VALUE; } /** * Scatters of each sample in the focal condition against one another * @param rasterImage */ public void savePairwiseFocalSampleMAPlots(String directory, boolean rasterImage){ double A_min=1; int ref = 0; //Set one sample as the reference (the deepest sequenced sample in the focal condition) double maxTotal=0; for(int s=0; s<numSamples; s++) if(design[s] == focalCondition) if(totals[s]>maxTotal){ ref=s; maxTotal=totals[s]; } //Scale all focal condition samples against the reference for(int x=0; x<numSamples; x++){ if(design[x]==focalCondition && x!=ref){ List<Pair<Double,Double>> highlightMA = new ArrayList<Pair<Double,Double>>(); List<Pair<Double,Double>> otherMA = new ArrayList<Pair<Double,Double>>(); for(int d=0; d<numUnits; d++){ double fold=0, avg=0; if(counts.get(d,x)>0 || counts.get(d,ref)>0) avg = ((counts.get(d,x)/scaling[x])+(counts.get(d,ref)/scaling[ref]))/2; if(counts.get(d,x)>0 && counts.get(d,ref)>0) fold = Math.log((counts.get(d,x)/scaling[x])/(counts.get(d,ref)/scaling[ref]))/LOG_2; else if(counts.get(d,x)>0) fold = Math.log((counts.get(d,x)/scaling[x]))/LOG_2; else if(counts.get(d,ref)>0) fold = Math.log((counts.get(d,ref)/scaling[ref]))/LOG_2; if(avg<A_min) avg = A_min; otherMA.add(new Pair<Double,Double>(fold,avg)); } //Make the MA matrices Matrix maMatrixHighlight = new Matrix(highlightMA.size(),2); Matrix maMatrixOther = new Matrix(otherMA.size(),2); int count=0; for(Pair<Double,Double> v : highlightMA){ maMatrixHighlight.set(count, 0, v.cdr()); maMatrixHighlight.set(count, 1, v.car()); count++; } count=0; for(Pair<Double,Double> v : otherMA){ maMatrixOther.set(count, 0, v.cdr()); maMatrixOther.set(count, 1, v.car()); count++; } //Image name Pair<String,String> refName = getExptName(ref); Pair<String,String> currName = getExptName(x); String fileName = currName.car()+"-"+currName.cdr()+"_vs_"+refName.car()+"-"+refName.cdr()+".MA"; if(rasterImage) fileName = directory+fileName+".png"; else fileName = directory+fileName+".svg"; //Generate image ScatterPlotMaker plotter = new ScatterPlotMaker(currName.car()+":"+currName.cdr()+" vs "+refName.car()+":"+refName.cdr()+" MA plot"); plotter.saveMAplot(maMatrixOther, maMatrixHighlight, 0.0, fileName, rasterImage); } } } /** * Scatters of each condition against the focal condition * @param rasterImage */ public void savePairwiseConditionMAPlots(double pValThreshold, String directory, boolean rasterImage){ double A_min=1; //Scale all focal condition samples against the reference for(int c=0; c<numConds; c++){ if(c!=focalCondition){ List<Pair<Double,Double>> highlightMA = new ArrayList<Pair<Double,Double>>(); List<Pair<Double,Double>> otherMA = new ArrayList<Pair<Double,Double>>(); for(int d=0; d<numUnits; d++){ double fold=condFold.get(d,c), avg = condMean.get(d,c); if(avg<A_min) avg = A_min; if(getDEpval(d, c)<pValThreshold) highlightMA.add(new Pair<Double,Double>(fold,avg)); else otherMA.add(new Pair<Double,Double>(fold,avg)); } //Make the MA matrices Matrix maMatrixHighlight = new Matrix(highlightMA.size(),2); Matrix maMatrixOther = new Matrix(otherMA.size(),2); int count=0; for(Pair<Double,Double> v : highlightMA){ maMatrixHighlight.set(count, 0, v.cdr()); maMatrixHighlight.set(count, 1, v.car()); count++; } count=0; for(Pair<Double,Double> v : otherMA){ maMatrixOther.set(count, 0, v.cdr()); maMatrixOther.set(count, 1, v.car()); count++; } //Image name String refName = getCondName(focalCondition); String currName = getCondName(c); String fileName = currName+"_vs_"+refName+".MA"; if(rasterImage) fileName = directory+fileName+".png"; else fileName = directory+fileName+".svg"; //Generate image ScatterPlotMaker plotter = new ScatterPlotMaker(currName+" vs "+refName+" MA plot"); plotter.saveMAplot(maMatrixOther, maMatrixHighlight, 0.0, fileName, rasterImage); } } } /** * XY scatters of each sample in the focal condition against one another * @param rasterImage */ public void savePairwiseFocalSampleXYPlots(String directory, boolean rasterImage){ int ref = 0; //Set one sample as the reference (the deepest sequenced sample in the focal condition) double maxTotal=0; for(int s=0; s<numSamples; s++) if(design[s] == focalCondition) if(totals[s]>maxTotal){ ref=s; maxTotal=totals[s]; } //Scale all focal condition samples against the reference for(int x=0; x<numSamples; x++){ if(design[x]==focalCondition && x!=ref){ List<Pair<Double,Double>> highlightMA = new ArrayList<Pair<Double,Double>>(); List<Pair<Double,Double>> otherMA = new ArrayList<Pair<Double,Double>>(); for(int d=0; d<numUnits; d++){ otherMA.add(new Pair<Double,Double>(counts.get(d,x),counts.get(d,ref))); } //Make the MA matrices Matrix maMatrixHighlight = new Matrix(highlightMA.size(),2); Matrix maMatrixOther = new Matrix(otherMA.size(),2); int count=0; for(Pair<Double,Double> v : highlightMA){ maMatrixHighlight.set(count, 0, v.car()); maMatrixHighlight.set(count, 1, v.cdr()); count++; } count=0; for(Pair<Double,Double> v : otherMA){ maMatrixOther.set(count, 0, v.car()); maMatrixOther.set(count, 1, v.cdr()); count++; } //Image name Pair<String,String> refName = getExptName(ref); Pair<String,String> currName = getExptName(x); String fileName = currName.car()+"-"+currName.cdr()+"_vs_"+refName.car()+"-"+refName.cdr()+".XY"; if(rasterImage) fileName = directory+fileName+".png"; else fileName = directory+fileName+".svg"; //Generate image ScatterPlotMaker plotter = new ScatterPlotMaker(currName.car()+":"+currName.cdr()+" vs "+refName.car()+":"+refName.cdr()+" XY plot"); plotter.saveXYplot(maMatrixOther, maMatrixHighlight, currName.car()+"-"+currName.cdr(), refName.car()+"-"+refName.cdr(), fileName, rasterImage); } } } /** * XY scatters of each condition against the focal condition * @param rasterImage */ public void savePairwiseConditionXYPlots(ExperimentManager man, BindingManager bindingMan, double pValThreshold, String directory, boolean rasterImage){ List<BindingEvent> events = bindingMan.getBindingEvents(); ExperimentCondition ref = man.getIndexedCondition(focalCondition); for(ExperimentCondition c : man.getConditions()){ if(c.getIndex()!=focalCondition){ List<Pair<Double,Double>> highlightMA = new ArrayList<Pair<Double,Double>>(); List<Pair<Double,Double>> otherMA = new ArrayList<Pair<Double,Double>>(); for(int e=0; e<events.size(); e++){ BindingEvent ev = events.get(e); if(ev.getInterCondP(c, ref) < pValThreshold) highlightMA.add(new Pair<Double,Double>(ev.getCondSigHits(c), ev.getCondSigHits(ref))); else otherMA.add(new Pair<Double,Double>(ev.getCondSigHits(c), ev.getCondSigHits(ref))); } //Make the MA matrices Matrix maMatrixHighlight = new Matrix(highlightMA.size(),2); Matrix maMatrixOther = new Matrix(otherMA.size(),2); int count=0; for(Pair<Double,Double> v : highlightMA){ maMatrixHighlight.set(count, 0, v.car()); maMatrixHighlight.set(count, 1, v.cdr()); count++; } count=0; for(Pair<Double,Double> v : otherMA){ maMatrixOther.set(count, 0, v.car()); maMatrixOther.set(count, 1, v.cdr()); count++; } //Image name String refName = ref.getName(); String currName = c.getName(); String fileName = currName+"_vs_"+refName+".XY"; if(rasterImage) fileName = directory+fileName+".png"; else fileName = directory+fileName+".svg"; //Generate image ScatterPlotMaker plotter = new ScatterPlotMaker(currName+" vs "+refName+" XY plot"); plotter.saveXYplot(maMatrixOther, maMatrixHighlight, currName, refName, fileName, rasterImage); } } } }