/*
* TwoPhaseOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.operators;
import java.util.ArrayList;
import java.util.List;
import dr.inference.model.Parameter;
/**
* This class allows to use two different sets of operators.
*
* @author Guy Baele
*/
public class TwoPhaseOperator extends AbstractCoercableOperator {
public static final boolean DEBUG = false;
public static final boolean PROVIDE_SAMPLES = false;
/*private AbstractCoercableOperator[] phaseOneOperators;
private AbstractCoercableOperator[] phaseTwoOperators;
double phaseOneTotalWeight, phaseTwoTotalWeight;*/
private List<AbstractCoercableOperator> phaseOneOperators;
private List<AdaptableVarianceMultivariateNormalOperator> phaseTwoOperators;
//private List<AbstractCoercableOperator> phaseTwoOperators;
private List<AbstractCoercableOperator> currentOperators;
private SimpleOperatorSchedule phaseOneScheduler;
private SimpleOperatorSchedule phaseTwoScheduler;
private SimpleOperatorSchedule currentOperatorScheduler;
private List<Parameter> parameters;
private ArrayList<ArrayList<Double>> storedValues;
private int initial;
private int burnin;
private int numberOfCalls;
private int currentOperatorIndex;
private boolean switchOperators;
public TwoPhaseOperator(List<AbstractCoercableOperator> phaseOneOperators, List<AdaptableVarianceMultivariateNormalOperator> phaseTwoOperators, List<Parameter> parameters, int initial, int burnin, double weight, CoercionMode mode) {
super(mode);
/*this.phaseOneOperators = phaseOneOperators;
this.phaseTwoOperators = phaseTwoOperators;
this.phaseOneTotalWeight = 0.0;
for (int i = 0; i < phaseOneOperators.length; i++) {
this.phaseOneTotalWeight += phaseOneOperators[i].getWeight();
}
this.phaseTwoTotalWeight = 0.0;
for (int i = 0; i < phaseTwoOperators.length; i++) {
this.phaseTwoTotalWeight += phaseTwoOperators[i].getWeight();
}*/
if (DEBUG) {
System.err.println("\nConstructing TwoPhaseOperator");
}
setWeight(weight);
this.initial = initial;
this.burnin = burnin;
this.numberOfCalls = 0;
this.switchOperators = false;
this.phaseOneOperators = phaseOneOperators;
this.phaseTwoOperators = phaseTwoOperators;
this.parameters = parameters;
this.storedValues = new ArrayList<ArrayList<Double>>();
for (int i = 0; i < phaseOneOperators.size(); i++) {
this.storedValues.add(new ArrayList<Double>());
}
phaseOneScheduler = new SimpleOperatorSchedule();
for (MCMCOperator operator : phaseOneOperators) {
phaseOneScheduler.addOperator(operator);
}
if (DEBUG) {
System.err.println("Phase One Scheduler initiated with size: " + phaseOneScheduler.getOperatorCount());
}
phaseTwoScheduler = new SimpleOperatorSchedule();
for (MCMCOperator operator : phaseTwoOperators) {
phaseTwoScheduler.addOperator(operator);
}
if (DEBUG) {
System.err.println("Phase Two Scheduler initiated with size: " + phaseTwoScheduler.getOperatorCount());
}
currentOperatorScheduler = phaseOneScheduler;
currentOperators = phaseOneOperators;
}
public double doOperation() {
if (DEBUG) {
System.err.println("\nTwoPhaseOperator: doOperation() called");
}
numberOfCalls++;
if (DEBUG) {
System.err.println("Number of times called: " + numberOfCalls);
}
currentOperatorIndex = currentOperatorScheduler.getNextOperatorIndex();
if (DEBUG) {
System.err.println("current operator index: " + currentOperatorIndex);
}
//don't store anything in the first set of operators themselves
//store everything in this class to not clutter AVMVN operator with excessive code
if (numberOfCalls > burnin && !switchOperators) {
//assume a 1-on-1 relationship between the parameter list and the first phase of operators
//i.e. each parameter has 1 operator acting on it
//now we can use currentOperatorIndex to help with the bookkeeping
//first decide to which of the phase two operators the parameter value needs to be written to
int phaseTwoCounter = 0;
//at the same time decide where it actually came from in order to determine its actual value
//i.e. Parameter might be a CompoundParameter, which complicates things
int parameterIndex = currentOperatorIndex;
for (int i = 0; i < phaseTwoOperators.size(); i++) {
//TODO: this may rely on the AVMVN operator only having 1 CompoundParameter
if (currentOperatorIndex < phaseTwoOperators.get(i).getParameter().getSize()) {
break;
} else {
parameterIndex -= phaseTwoOperators.get(i).getParameter().getSize();
phaseTwoCounter++;
}
}
storedValues.get(currentOperatorIndex).add(parameters.get(phaseTwoCounter).getParameterValue(parameterIndex));
if (DEBUG) {
System.err.println("Storing values in TwoPhaseOperator");
System.err.println("currentOperatorIndex: " + currentOperatorIndex);
System.err.println("parameterIndex: " + parameterIndex);
System.err.print("storage dimensions: " + storedValues.size());
for (int i = 0; i < storedValues.size(); i++) {
System.err.print(" -> " + storedValues.get(i).size());
}
System.err.println();
}
/*if (DEBUG) {
System.err.println("Passing values to phase two operator(s)");
System.err.println("currentOperatorIndex: " + currentOperatorIndex);
System.err.println("AVMVN operator assigned: " + phaseTwoCounter);
System.err.println("parameterIndex: " + parameterIndex);
}
phaseTwoOperators.get(phaseTwoCounter).setSample(parameterIndex, parameters.get(phaseTwoCounter).getParameterValue(parameterIndex));
*/
}
if (numberOfCalls > initial && !switchOperators) {
if (DEBUG) {
System.err.println("Switch from phase one scheduler to phase two scheduler");
}
currentOperatorScheduler = phaseTwoScheduler;
//TODO: fix Java type safety problem below
currentOperators = (List<AbstractCoercableOperator>)(List<?>) phaseTwoOperators;
//an extra draw is needed here
currentOperatorIndex = currentOperatorScheduler.getNextOperatorIndex();
if (PROVIDE_SAMPLES) {
//call methods to calculate means and covariance matrix and pass them on to AVMVN operator(s)
//need to create the appropriate list of lists to pass on to AVMVN operator
for (int i = 0; i < phaseTwoOperators.size(); i++) {
int listSize = phaseTwoOperators.get(i).getParameter().getDimension();
ArrayList<ArrayList<Double>> temp = new ArrayList<ArrayList<Double>>();
for (int j = 0; j < listSize; j++) {
temp.add(new ArrayList<Double>());
temp.set(j, storedValues.get(i*phaseTwoOperators.get(i).getParameter().getSize()+j));
}
phaseTwoOperators.get(i).provideSamples(temp);
}
}
switchOperators = true;
}
double logJacobian = (currentOperators.get(currentOperatorIndex)).doOperation();
return logJacobian;
}
/*@Override
public void accept(double deviation) {
if (DEBUG) {
System.err.println("TwoPhaseOperator: accept(double deviation) called");
}
currentOperators.get(currentOperatorIndex).accept(deviation);
}
@Override
public void reject() {
if (DEBUG) {
System.err.println("TwoPhaseOperator: reject() called");
}
currentOperators.get(currentOperatorIndex).reject();
}
@Override
public void reset() {
if (DEBUG) {
System.err.println("TwoPhaseOperator: reset() called");
}
currentOperators.get(currentOperatorIndex).reset();
}*/
public double getCoercableParameter() {
return currentOperators.get(currentOperatorIndex).getCoercableParameter();
}
public void setCoercableParameter(double value) {
currentOperators.get(currentOperatorIndex).setCoercableParameter(value);
}
public double getRawParameter() {
return currentOperators.get(currentOperatorIndex).getRawParameter();
}
public String getPerformanceSuggestion() {
return currentOperators.get(currentOperatorIndex).getPerformanceSuggestion();
}
@Override
public double getTargetAcceptanceProbability() {
return currentOperators.get(currentOperatorIndex).getTargetAcceptanceProbability();
}
public String getOperatorName() {
return "twoPhaseOperator(use at own risk)";
}
}