package dr.evomodel.branchmodel;
import dr.evomodelxml.branchmodel.BranchAssignmentModelParser;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.inference.model.Variable.ChangeType;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
@SuppressWarnings("serial")
public class BranchAssignmentModel extends AbstractModel implements BranchModel {
public static final String BRANCH_ASSIGNMENT_MODEL = "branchAssignmentModel";
private TreeModel treeModel;
private final String annotation;
private final LinkedHashMap<Integer, SubstitutionModel> modelIndexMap;
private final SubstitutionModel baseModel;
private Integer baseModelIndex;
private LinkedHashMap<NodeRef, Integer> branchAssignmentMap;
private LinkedList<SubstitutionModel> substitutionModels;
public BranchAssignmentModel(
TreeModel treeModel, //
String annotation, //
LinkedHashMap<Integer, SubstitutionModel> modelIndexMap, //
SubstitutionModel baseModel//
) {
super(BRANCH_ASSIGNMENT_MODEL);
this.treeModel = treeModel;
this.annotation = annotation;
this.modelIndexMap = modelIndexMap;
this.baseModel = baseModel;
this.substitutionModels = new LinkedList<SubstitutionModel>();
this.branchAssignmentMap = new LinkedHashMap<NodeRef, Integer>();
// base model comes last
this.baseModelIndex = modelIndexMap.size();
setup();
}// END: Constructor
private void setup() {
// for (int i = 0; i < modelIndexMap.size() + 1; i++) {
// substitutionModels.add(null);
// }
// try {
//
// File file = new File("/home/filip/Dropbox/BeagleSequenceSimulator/branchSpecificSimulations/annotated_tree.nexus");
// BufferedReader reader;
//
// reader = new BufferedReader(new FileReader(file));
// NexusImporter importer = new NexusImporter(reader);
// Tree tree = importer.importTree(null);
// this.treeModel = new TreeModel(tree);
//
// } catch ( Exception e) {
// e.printStackTrace();
// }
for (NodeRef node : this.treeModel.getNodes()) {
if (!treeModel.isRoot(node)) {
Integer modelIndex = Integer.MAX_VALUE;
SubstitutionModel model = null;
Object nodeAttribute = treeModel.getNodeAttribute(node,
annotation);
if (nodeAttribute == null) {
System.out
.println("Attribute "
+ annotation
+ " missing from node. Using base model as branch model.");
modelIndex = this.baseModelIndex;
model = this.baseModel;
} else {
modelIndex = (Integer) nodeAttribute;
model = this.modelIndexMap.get(modelIndex);
}
branchAssignmentMap.put(node, modelIndex);
// if (substitutionModels.get(modelIndex) == null) {
// substitutionModels.set(modelIndex, model);
// }
substitutionModels.add(model);
}// END: root check
}// END: nodes loop
}//END: setup
@Override
public Mapping getBranchModelMapping(NodeRef branch) {
final int modelIndex = branchAssignmentMap.get(branch);
return new Mapping() {
public int[] getOrder() {
return new int[] { modelIndex };
}
public double[] getWeights() {
return new double[] { 1.0 };
}
};
}// END: getBranchModelMapping
@Override
public List<SubstitutionModel> getSubstitutionModels() {
return substitutionModels;
}// END: getSubstitutionModels
@Override
public SubstitutionModel getRootSubstitutionModel() {
Object nodeAttribute = treeModel.getNodeAttribute(treeModel.getRoot(),
BranchAssignmentModelParser.ANNOTATION_VALUE);
SubstitutionModel model = null;
if (nodeAttribute == null) {
model = this.baseModel;
} else {
Integer modelIndex = (Integer) nodeAttribute;
model = this.modelIndexMap.get(modelIndex);
}
return model;
}// END: getRootSubstitutionModel
@Override
public FrequencyModel getRootFrequencyModel() {
return getRootSubstitutionModel().getFrequencyModel();
}
@Override
public boolean requiresMatrixConvolution() {
return false;
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
fireModelChanged();
}
@Override
protected void handleVariableChangedEvent(@SuppressWarnings("rawtypes") Variable variable, int index,
ChangeType type) {
}
@Override
protected void storeState() {
}
@Override
protected void restoreState() {
}
@Override
protected void acceptState() {
}
}// END: class