/* * HiddenLinkageModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.tree; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import dr.evolution.MetagenomeData; import dr.evolution.alignment.Alignment; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.oldevomodel.substmodel.SubstitutionModel; import dr.oldevomodel.treelikelihood.GeneralLikelihoodCore; import dr.oldevomodel.treelikelihood.LikelihoodCore; import dr.oldevomodel.treelikelihood.NativeAminoAcidLikelihoodCore; import dr.oldevomodel.treelikelihood.NativeNucleotideLikelihoodCore; import dr.evomodel.tipstatesmodel.TipStatesModel; import dr.inference.model.Model; import dr.inference.model.Variable; import dr.inference.model.Variable.ChangeType; import dr.math.MathUtils; /** * A model of linkage among sets of metagenome sequence reads * Reads are assumed to derive from a fixed number of organisms, where * the identity of the organisms is unknown, and the assignment of reads * to those organisms is also unknown. * @author Aaron Darling (koadman) * */ public class HiddenLinkageModel extends TipStatesModel implements PatternList { int linkageGroupCount = 0; ArrayList< HashSet<Taxon> > groups = null; MetagenomeData data = null; ArrayList<Taxon> alignmentTaxa; double[][] tipPartials; double[][] storedTipPartials; boolean[] dirtyTipPartials; LikelihoodCore core; double blen = 0.001; SubstitutionModel substitutionModel; public HiddenLinkageModel(int linkageGroupCount, MetagenomeData data) { super("HiddenLinkageModel", data.getReferenceTaxa(), data.getReadsTaxa()); this.linkageGroupCount = linkageGroupCount; this.data = data; // initial state: randomly assign reads to groups groups = new ArrayList< HashSet<Taxon> >(linkageGroupCount); for(int i=0; i<linkageGroupCount; i++) groups.add(new HashSet<Taxon>()); TaxonList reads = data.getReadsTaxa(); for(int i=0; i<reads.getTaxonCount(); i++){ int g = MathUtils.nextInt(linkageGroupCount); groups.get(g).add(reads.getTaxon(i)); } // create an alignment taxa list with reference + hidden groups alignmentTaxa = new ArrayList<Taxon>(data.getReferenceTaxa().asList()); for(int i=0; i<linkageGroupCount; i++) alignmentTaxa.add(new Taxon("LinkageGroup_" + i)); int plen = data.getAlignment().getSiteCount() * data.getAlignment().getStateCount(); tipPartials = new double[alignmentTaxa.size()][plen]; storedTipPartials = new double[alignmentTaxa.size()][plen]; dirtyTipPartials = new boolean[alignmentTaxa.size()]; initCore(); setupMatrices(); // compute initial partials for(int i=0; i<tipPartials.length; i++) computeTipPartials(i); } double[] tipMatrix; double[] internalMatrix; @Override public boolean areUnique() { return false; } /* * Initializes a likelihoodCore to calculate likelihoods for * the tips */ private void initCore(){ if(data.getAlignment().getDataType() instanceof dr.evolution.datatype.Nucleotides) core = new NativeNucleotideLikelihoodCore(); if(data.getAlignment().getDataType() instanceof dr.evolution.datatype.AminoAcids) core = new NativeAminoAcidLikelihoodCore(); if(data.getAlignment().getDataType() instanceof dr.evolution.datatype.Codons) core = new GeneralLikelihoodCore(data.getAlignment().getStateCount()); // initialize the likelihood core core.initialize(data.getReadsTaxa().getTaxonCount()*2, data.getAlignment().getSiteCount(), 1, false); for(int i=0; i<data.getReadsTaxa().getTaxonCount(); i++){ int index = data.getAlignment().getTaxonIndex(data.getReadsTaxa().getTaxon(i)); int[] states = new int[data.getAlignment().getSiteCount()]; for(int j=0; j<states.length; j++) states[j]=data.getAlignment().getState(index,j); core.setNodeStates(i, states); } for(int i=0; i<data.getReadsTaxa().getTaxonCount(); i++) core.createNodePartials(i+data.getReadsTaxa().getTaxonCount()); } /* * set up transition matrices for internal and external nodes */ private void setupMatrices(){ tipMatrix=new double[data.getAlignment().getStateCount()*data.getAlignment().getStateCount()]; internalMatrix=new double[data.getAlignment().getStateCount()*data.getAlignment().getStateCount()]; double diag = 1.0-blen; double offdiag = blen / (data.getAlignment().getStateCount() - 1); double internalDiag = 0.99999999999999; double internalOffDiag = (1.0-internalDiag) / (data.getAlignment().getStateCount() - 1); for(int i=0; i<tipMatrix.length; i++) { tipMatrix[i]=offdiag; internalMatrix[i]=internalOffDiag; } for(int i=0; i<data.getAlignment().getStateCount(); i++) { tipMatrix[i*data.getAlignment().getStateCount() + i] = diag; internalMatrix[i*data.getAlignment().getStateCount() + i] = internalDiag; } for(int i=0; i<data.getReadsTaxa().getTaxonCount(); i++) core.setNodeMatrix(i, 0, tipMatrix); for(int i=0; i<data.getReadsTaxa().getTaxonCount(); i++) core.setNodeMatrix(i+data.getReadsTaxa().getTaxonCount(), 0, internalMatrix); } public int getLinkageGroupCount() { return linkageGroupCount; } public MetagenomeData getData() { return data; } /** * Returns the group ID to which a particular metagenomic read belongs * @param t * @return */ public int getLinkageGroupId(Taxon t) { int i=0; for(HashSet<Taxon> h : groups){ if(h.contains(t)) break; i++; } return i; } private class Move { public Move(Taxon read, int fromGroup, int toGroup){ this.read = read; this.fromGroup = fromGroup; this.toGroup = toGroup; } Taxon read; int fromGroup; int toGroup; } ArrayList<Move> movesMade = new ArrayList<Move>(); @Override protected void acceptState() { movesMade.clear(); for(int i=0; i<dirtyTipPartials.length; i++) dirtyTipPartials[i]=false; } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { // TODO Auto-generated method stub } protected void handleVariableChangedEvent(Variable variable, int index, ChangeType type) { // TODO Auto-generated method stub } protected void restoreState() { // make all moves in reverse for(int i=movesMade.size(); i>0; i--) { Move m = movesMade.get(i-1); groups.get(m.toGroup).remove(m.read); groups.get(m.fromGroup).add(m.read); } movesMade.clear(); for(int i=0; i<dirtyTipPartials.length; i++) { if(dirtyTipPartials[i]){ swapTipPartials(i); dirtyTipPartials[i]=false; } } } protected void storeState() { movesMade.clear(); for(int i=0; i<dirtyTipPartials.length; i++) dirtyTipPartials[i]=false; } /** * Get the members of the i'th linkage group * @param i * @return Members of the i'th group */ public Set<Taxon> getGroup(int i){ return groups.get(i); } /** * Moves a read from one linkage group to another linkage group * @param read * @param oldGroup * @param newGroup */ public void moveReadGroup(Taxon read, int oldGroup, int newGroup) { boolean found = groups.get(oldGroup).remove(read); if(!found) throw new RuntimeException("Error, could not find read " + read + " in linkage group " + oldGroup); groups.get(newGroup).add(read); movesMade.add(new Move(read, oldGroup, newGroup)); computeTipPartials(data.getReferenceTaxa().getTaxonCount() + oldGroup); computeTipPartials(data.getReferenceTaxa().getTaxonCount() + newGroup); // this.fireModelChanged(); this.fireModelChanged(alignmentTaxa.get(alignmentTaxa.size() - groups.size() + oldGroup)); this.fireModelChanged(alignmentTaxa.get(alignmentTaxa.size() - groups.size() + newGroup)); } private void swapTipPartials(int nodeIndex){ double[] tmp = storedTipPartials[nodeIndex]; storedTipPartials[nodeIndex] = tipPartials[nodeIndex]; tipPartials[nodeIndex] = tmp; } private void computeTipPartials(int nodeIndex){ if(!dirtyTipPartials[nodeIndex]){ swapTipPartials(nodeIndex); dirtyTipPartials[nodeIndex]=true; } double[] tipPartials = this.tipPartials[nodeIndex]; // if this is one of the reference organisms, then return the resolved partials // if this is a linkage group, return partials that correspond to probabilities of each nucleotide. Alignment aln = data.getAlignment(); int sc = aln.getStateCount(); for(int i=0; i<tipPartials.length; i++){ tipPartials[i]=0.0; } if(nodeIndex < data.getReferenceTaxa().getTaxonCount()){ int j=0; for(int i=0; i<aln.getSiteCount(); i++){ int s = aln.getState(nodeIndex, i); if(s>=sc){ for(int k=0; k<sc; k++) tipPartials[j + k] = 1.0; }else tipPartials[j + s] = 1.0; j += sc; } }else{ int gI = nodeIndex - data.getReferenceTaxa().getTaxonCount(); HashSet<Taxon> group = groups.get(gI); int internalNum = data.getReadsTaxa().getTaxonCount(); Taxon firstTax=null; boolean peeled = false; for(Taxon tax : group){ if(firstTax==null) { firstTax = tax; continue; } int c2 = data.getReadsTaxa().getTaxonIndex(tax); if(!peeled){ int c1 = data.getReadsTaxa().getTaxonIndex(firstTax); core.setNodePartialsForUpdate(internalNum); core.calculatePartials(c1, c2, internalNum); }else{ core.setNodePartialsForUpdate(internalNum); core.calculatePartials(internalNum-1, c2, internalNum); } internalNum++; peeled = true; } if(group.size()==0) { for(int i=0; i<tipPartials.length; i++) tipPartials[i]=1.0; }else if(!peeled) getPartialsForGroupSizeOne(firstTax, tipPartials); else core.getPartials(internalNum-1, tipPartials); } } private void getPartialsForGroupSizeOne(Taxon tax, double[] tipPartials) { Alignment aln = data.getAlignment(); int sc = aln.getStateCount(); int index = aln.getTaxonIndex(tax); int j=0; for(int i=0; i<aln.getSiteCount(); i++){ int s = aln.getState(index, i); if(s>=sc){ for(int k=0; k<sc; k++) tipPartials[j + k] = 1.0; }else System.arraycopy(internalMatrix, s*sc, tipPartials, j, sc); j += sc; } } /* private void computeTipPartials(int nodeIndex){ if(!dirtyTipPartials[nodeIndex]){ swapTipPartials(nodeIndex); dirtyTipPartials[nodeIndex]=true; } double[] tipPartials = this.tipPartials[nodeIndex]; // if this is one of the reference organisms, then return the resolved partials // if this is a linkage group, return partials that correspond to probabilities of each nucleotide. Alignment aln = data.getAlignment(); int sc = aln.getStateCount(); for(int i=0; i<tipPartials.length; i++){ tipPartials[i]=0.0; } if(nodeIndex < data.getReferenceTaxa().getTaxonCount()){ int j=0; for(int i=0; i<aln.getSiteCount(); i++){ int s = aln.getState(nodeIndex, i); if(s>=sc){ for(int k=0; k<sc; k++) tipPartials[j + k] = 1.0; }else tipPartials[j + s] = 1.0; j += sc; } }else{ // average the information for each sequence in the linkage group int gI = nodeIndex - data.getReferenceTaxa().getTaxonCount(); HashSet<Taxon> group = groups.get(gI); for( Taxon tax : group){ int sI = data.getAlignment().getTaxonIndex(tax); int j=0; for(int i=0; i<aln.getSiteCount(); i++){ int s = aln.getState(sI, i); if(s<sc) tipPartials[j + s] += 1.0; j += sc; } } // now normalize back to probability distributions int j=0; int l=tipPartials.length / sc; for(int i=0; i<l; i++){ double max=0; for(int k=0; k<sc; k++){ max = max > tipPartials[j+k] ? max : tipPartials[j+k]; } if(max>0){ for(int k=0; k<sc; k++){ tipPartials[j+k]/=max; } }else{ // if no sequence had info, set it to unknown (1.0 for all values) for(int k=0; k<sc; k++){ tipPartials[j+k]=1.0; } } j += sc; } } } */ /** * NOT YET IMPLEMENTED. Creates a new read linkage group. * @return the new group id */ public int newGroup(){ throw new RuntimeException("Not implemented!"); } /** * NOT YET IMPLEMENTED. Deletes a read linkage group. Must be empty. */ public void deleteGroup(){ throw new RuntimeException("Not implemented!"); } @Override public Type getModelType() { return Type.PARTIALS; } @Override public void getTipStates(int nodeIndex, int[] tipStates) { throw new IllegalArgumentException("This model emits only tip partials"); } @Override public void getTipPartials(int nodeIndex, double[] tipPartials) { int n = nodeIdToMyTaxaMap[tree.getNode(nodeIndex).getNumber()]; System.arraycopy(this.tipPartials[n], 0, tipPartials, 0, tipPartials.length); } int[] nodeIdToMyTaxaMap; protected void taxaChanged() { nodeIdToMyTaxaMap = new int[tree.getNodeCount()]; for(int i=0; i<nodeIdToMyTaxaMap.length; i++){ for(int j=0; j<alignmentTaxa.size(); j++){ if(tree.getTaxon(i)==null) continue; if(tree.getTaxon(i)==null || alignmentTaxa.get(j) == null) System.err.print("asdgasdg\n"); else if(tree.getTaxon(i).getId()==null || alignmentTaxa.get(j).getId() == null) System.err.print("asdgasdg\n"); if(tree.getTaxon(i).getId().equalsIgnoreCase(alignmentTaxa.get(j).getId())){ nodeIdToMyTaxaMap[tree.getExternalNode(i).getNumber()] = j; break; } } } } // // BEGIN PatternList Implementation // This merely delegates to Alignment for most methods // public DataType getDataType() { return data.getAlignment().getDataType(); } public int[] getPattern(int patternIndex) { return data.getAlignment().getPattern(patternIndex); } public int getPatternCount() { return data.getAlignment().getPatternCount(); } public int getPatternLength() { return data.getAlignment().getPatternLength(); } public int getPatternState(int taxonIndex, int patternIndex) { if(taxonIndex<data.getReferenceTaxa().getTaxonCount()) return data.getAlignment().getPatternState(taxonIndex, patternIndex); return 0; } public double getPatternWeight(int patternIndex) { return data.getAlignment().getPatternWeight(patternIndex); } public double[] getPatternWeights() { return data.getAlignment().getPatternWeights(); } public int getStateCount() { return data.getAlignment().getStateCount(); } public double[] getStateFrequencies() { return data.getAlignment().getStateFrequencies(); } public List<Taxon> asList() { return alignmentTaxa; } public Taxon getTaxon(int taxonIndex) { return alignmentTaxa.get(taxonIndex); } public Object getTaxonAttribute(int taxonIndex, String name) { return alignmentTaxa.get(taxonIndex).getAttribute(name); } public int getTaxonCount() { return alignmentTaxa.size(); } public String getTaxonId(int taxonIndex) { return alignmentTaxa.get(taxonIndex).getId(); } public int getTaxonIndex(String id) { for(int i=0; i<alignmentTaxa.size(); i++){ if(alignmentTaxa.get(i).getId().equals(id)) return i; } return -1; } public int getTaxonIndex(Taxon taxon) { for(int i=0; i<alignmentTaxa.size(); i++){ if(alignmentTaxa.get(i).compareTo(taxon)==0) return i; } return -1; } public Iterator<Taxon> iterator() { return alignmentTaxa.iterator(); } }