/* * HiddenMarkovRatesTest.java * * Copyright (c) 2002-2014 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package test.dr.app.beagle; import dr.evomodel.substmodel.nucleotide.HKY; import dr.evolution.datatype.Nucleotides; import dr.evolution.datatype.TwoStates; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.GeneralSubstitutionModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.substmodel.UniformizedSubstitutionModel; import dr.inference.markovjumps.MarkovJumpsType; import dr.inference.model.Parameter; import dr.math.matrixAlgebra.Vector; import test.dr.inference.trace.TraceCorrelationAssert; import java.util.List; /** * @author Marc Suchard */ public class HiddenMarkovRatesTest extends TraceCorrelationAssert { public HiddenMarkovRatesTest(String name) { super(name); } private double[] getBinaryFreqs(int index) { return new double[]{0.7, 0.3}; } private double[] getHKYFreqs(int index) { return new double[]{0.25, 0.25, 0.25, 0.25}; } private double getKappa(int index) { return 10.0; } private int getNumberReplicates(int index) { return 1000000; } private double getLength(int index) { return 1.5; } public void testHiddenRates() { final int index = 0; double[] freqs = getBinaryFreqs(index); FrequencyModel binaryFreqModel = new FrequencyModel(TwoStates.INSTANCE, freqs); int relativeTo = 0; Parameter ratesParameter = new Parameter.Default(0); GeneralSubstitutionModel binaryModel = new GeneralSubstitutionModel("binary", TwoStates.INSTANCE, binaryFreqModel, ratesParameter, relativeTo); UniformizedSubstitutionModel uSM = new UniformizedSubstitutionModel(binaryModel, MarkovJumpsType.REWARDS); uSM.setSaveCompleteHistory(true); double[] rewardRegister = new double[]{0.0, 1.0}; uSM.setRegistration(rewardRegister); final double[] hkyFreqs = getHKYFreqs(index); FrequencyModel hkyFreqModel = new FrequencyModel(Nucleotides.INSTANCE, hkyFreqs); final double kappa = getKappa(index); final HKY hky = new HKY(kappa, hkyFreqModel); final double length = getLength(index); double[] resultCompleteHistory = new double[16]; final int replicates = getNumberReplicates(index); double result = 0.0; for (int r = 0; r < replicates; ++r) { result += oneCompleteHistoryReplicate(resultCompleteHistory, hky, uSM, length); } result /= replicates; normalize(resultCompleteHistory, replicates); System.out.println("Averaged probabilities"); System.out.println(result); System.out.println(new Vector(resultCompleteHistory)); System.out.println(); double[] intermediate = new double[16]; hky.getTransitionProbabilities(result, intermediate); System.out.println("Intermediate using above average reward"); System.out.println(result); System.out.println(new Vector(intermediate)); System.out.println(); double[] resultExpected = new double[16]; UniformizedSubstitutionModel expectedUSM = new UniformizedSubstitutionModel(binaryModel, MarkovJumpsType.REWARDS, replicates); expectedUSM.setRegistration(rewardRegister); result = oneCompleteHistoryReplicate(resultExpected, hky, expectedUSM, length); System.out.println("Averaged reward"); System.out.println(result); System.out.println(new Vector(resultExpected)); System.out.println(); double[] originalProbs = new double[16]; hky.getTransitionProbabilities(length, originalProbs); System.out.println("Original probabilities"); System.out.println(new Vector(originalProbs)); System.out.println(); } double oneCompleteHistoryReplicate(double[] result, final SubstitutionModel hky, final UniformizedSubstitutionModel uSM, final double length) { double reward = uSM.computeCondStatMarkovJumps(1, 1, length); if (DEBUG) reward = DEBUG_REWARD; double[] tmp = new double[hky.getDataType().getStateCount() * hky.getDataType().getStateCount()]; hky.getTransitionProbabilities(reward, tmp); increment(result, tmp); return reward; } private static final boolean DEBUG = false; private static final double DEBUG_REWARD = 2.0; private void normalize(double[] a, final int count) { for (int i = 0; i < a.length; ++i) { a[i] /= count; } } private void increment(double[] out, final double[] in) { for (int i = 0; i < out.length; ++i) { out[i] += in[i]; } } private double[] getTransitionProbabilities(SubstitutionModel substModel, List<Double> times) { final int stateCount = substModel.getDataType().getStateCount(); double[] result = getIdentityMatrix(stateCount); double[] tmp = new double[stateCount * stateCount]; for (Double time : times) { substModel.getTransitionProbabilities(time, tmp); result = matrixMultiplication(result, tmp, stateCount); } return result; } private double[] matrixMultiplication(double[] a, double[] b, int dim) { double[] result = new double[dim * dim]; for (int i = 0; i < dim; ++i) { for (int j = 0; j < dim; ++j) { for (int k = 0; k < dim; ++k) { result[i * dim + j] += a[i * dim + k] * b[k * dim + j]; } } } return result; } private double[] getIdentityMatrix(int stateCount) { double[] result = new double[stateCount * stateCount]; for (int i = 0; i < stateCount; ++i) { result[i * stateCount + i] = 1.0; } return result; } }