/* * LoadingsIndependenceOperator.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.operators; import dr.inference.distribution.DistributionLikelihood; import dr.inference.model.LatentFactorModel; import dr.inference.model.MatrixParameterInterface; import dr.inference.model.Parameter; import dr.math.MathUtils; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.distributions.NormalDistribution; import dr.math.matrixAlgebra.CholeskyDecomposition; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.SymmetricMatrix; import java.util.ArrayList; import java.util.ListIterator; /** * Created with IntelliJ IDEA. * User: max * Date: 5/23/14 * Time: 2:23 PM * To change this template use File | Settings | File Templates. */ public class LoadingsIndependenceOperator extends AbstractCoercableOperator { NormalDistribution prior; LatentFactorModel LFM; ArrayList<double[][]> precisionArray; ArrayList<double[]> meanMidArray; ArrayList<double[]> meanArray; boolean randomScan; double scaleFactor; double priorPrecision; double priorMeanPrecision; public LoadingsIndependenceOperator(LatentFactorModel LFM, DistributionLikelihood prior, double weight, boolean randomScan, double scaleFactor, CoercionMode mode) { super(mode); setWeight(weight); this.scaleFactor = scaleFactor; this.prior = (NormalDistribution) prior.getDistribution(); this.LFM = LFM; precisionArray = new ArrayList<double[][]>(); double[][] temp; this.randomScan = randomScan; meanArray = new ArrayList<double[]>(); meanMidArray = new ArrayList<double[]>(); double[] tempMean; if (!randomScan) { for (int i = 0; i < LFM.getFactorDimension(); i++) { temp = new double[i + 1][i + 1]; precisionArray.add(temp); } for (int i = 0; i < LFM.getFactorDimension(); i++) { tempMean = new double[i + 1]; meanArray.add(tempMean); } for (int i = 0; i < LFM.getFactorDimension(); i++) { tempMean = new double[i + 1]; meanMidArray.add(tempMean); } } else { for (int i = 0; i < LFM.getFactorDimension(); i++) { temp = new double[LFM.getFactorDimension() - i][LFM.getFactorDimension() - i]; precisionArray.add(temp); } for (int i = 0; i < LFM.getFactorDimension(); i++) { tempMean = new double[LFM.getFactorDimension() - i]; meanArray.add(tempMean); } for (int i = 0; i < LFM.getFactorDimension(); i++) { tempMean = new double[LFM.getFactorDimension() - i]; meanMidArray.add(tempMean); } } // vectorProductAnswer=new MatrixParameter[LFM.getLoadings().getRowDimension()]; // for (int i = 0; i <vectorProductAnswer.length ; i++) { // vectorProductAnswer[i]=new MatrixParameter(null); // vectorProductAnswer[i].setDimensions(i+1, 1); // } // priorMeanVector=new MatrixParameter[LFM.getLoadings().getRowDimension()]; // for (int i = 0; i <priorMeanVector.length ; i++) { // priorMeanVector[i]=new MatrixParameter(null, i+1, 1, this.prior.getMean()/(this.prior.getSD()*this.prior.getSD())); // // // } priorPrecision = 1 / (this.prior.getSD() * this.prior.getSD()); priorMeanPrecision = this.prior.getMean() * priorPrecision; } private void getPrecisionOfTruncated(MatrixParameterInterface full, int newRowDimension, int row, double[][] answer) { // MatrixParameter answer=new MatrixParameter(null); // answer.setDimensions(this.getRowDimension(), Right.getRowDimension()); // System.out.println(answer.getRowDimension()); // System.out.println(answer.getColumnDimension()); int p = full.getColumnDimension(); for (int i = 0; i < newRowDimension; i++) { for (int j = i; j < newRowDimension; j++) { double sum = 0; for (int k = 0; k < p; k++) sum += full.getParameterValue(i, k) * full.getParameterValue(j, k); answer[i][j] = sum * LFM.getColumnPrecision().getParameterValue(row, row); if (i == j) { answer[i][j] += priorPrecision; } else { answer[j][i] = answer[i][j]; } } } } private void getTruncatedMean(int newRowDimension, int dataColumn, double[][] variance, double[] midMean, double[] mean) { // MatrixParameter answer=new MatrixParameter(null); // answer.setDimensions(this.getRowDimension(), Right.getRowDimension()); // System.out.println(answer.getRowDimension()); // System.out.println(answer.getColumnDimension()); MatrixParameterInterface data = LFM.getScaledData(); MatrixParameterInterface Left = LFM.getFactors(); int p = data.getColumnDimension(); for (int i = 0; i < newRowDimension; i++) { double sum = 0; for (int k = 0; k < p; k++) sum += Left.getParameterValue(i, k) * data.getParameterValue(dataColumn, k); sum = sum * LFM.getColumnPrecision().getParameterValue(dataColumn, dataColumn); sum += priorMeanPrecision; midMean[i] = sum; } for (int i = 0; i < newRowDimension; i++) { double sum = 0; for (int k = 0; k < newRowDimension; k++) sum += variance[i][k] * midMean[k]; mean[i] = sum; } } private void getPrecision(int i, double[][] answer) { int size = LFM.getFactorDimension(); if (i < size) { getPrecisionOfTruncated(LFM.getFactors(), i + 1, i, answer); } else { getPrecisionOfTruncated(LFM.getFactors(), size, i, answer); } } private void getMean(int i, double[][] variance, double[] midMean, double[] mean) { // Matrix factors=null; int size = LFM.getFactorDimension(); // double[] scaledDataColumn=LFM.getScaledData().getRowValues(i); // Vector dataColumn=null; // Vector priorVector=null; // Vector temp=null; // Matrix data=new Matrix(LFM.getScaledData().getParameterAsMatrix()); if (i < size) { getTruncatedMean(i + 1, i, variance, midMean, mean); // dataColumn=new Vector(data.toComponents()[i]); // try { // answer=precision.inverse().product(new Matrix(priorMeanVector[i].add(vectorProductAnswer[i]).getParameterAsMatrix())); // } catch (IllegalDimension illegalDimension) { // illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. // } } else { getTruncatedMean(size, i, variance, midMean, mean); // dataColumn=new Vector(data.toComponents()[i]); // try { // answer=precision.inverse().product(new Matrix(priorMeanVector[size-1].add(vectorProductAnswer[size-1]).getParameterAsMatrix())); // } catch (IllegalDimension illegalDimension) { // illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. // } } } private void copy(int i, double[] random) { Parameter changing = LFM.getLoadings().getParameter(i); for (int j = 0; j < random.length; j++) { changing.setParameterValueQuietly(j, random[j]); } } private void drawI(int i, ListIterator<double[][]> currentPrecision, ListIterator<double[]> currentMidMean, ListIterator<double[]> currentMean) { double[] draws = null; double[][] precision = null; double[][] variance; double[] midMean = null; double[] mean = null; double[][] cholesky = null; if (currentPrecision.hasNext()) { precision = currentPrecision.next(); } if (currentMidMean.hasNext()) { midMean = currentMidMean.next(); } if (currentMean.hasNext()) { mean = currentMean.next(); } getPrecision(i, precision); variance = (new SymmetricMatrix(precision)).inverse().toComponents(); try { cholesky = new CholeskyDecomposition(variance).getL(); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } getMean(i, variance, midMean, mean); draws = MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky, scaleFactor); // if(i<draws.length) // // { // while (draws[i] < 0) { // draws = MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky); // } // } if (i < draws.length) { if (draws[i] > 0) { copy(i, draws); // LFM.computeResiduals(); } } else { copy(i, draws); // LFM.computeResiduals(); } // copy(i, draws); } // @Override // public int getStepCount() { // return 0; //To change body of implemented methods use File | Settings | File Templates. // } @Override public String getPerformanceSuggestion() { return null; //To change body of implemented methods use File | Settings | File Templates. } @Override public String getOperatorName() { return "loadingsGibbsOperator"; //To change body of implemented methods use File | Settings | File Templates. } @Override public double doOperation() { int size = LFM.getLoadings().getColumnDimension(); if (!randomScan) { ListIterator<double[][]> currentPrecision = precisionArray.listIterator(); ListIterator<double[]> currentMidMean = meanMidArray.listIterator(); ListIterator<double[]> currentMean = meanArray.listIterator(); for (int i = 0; i < size; i++) { drawI(i, currentPrecision, currentMidMean, currentMean); } LFM.getLoadings().fireParameterChangedEvent(); } else { int i = MathUtils.nextInt(LFM.getLoadings().getColumnDimension()); ListIterator<double[][]> currentPrecision; ListIterator<double[]> currentMidMean; ListIterator<double[]> currentMean; if (i < LFM.getFactorDimension()) { currentPrecision = precisionArray.listIterator(LFM.getFactorDimension() - i - 1); currentMidMean = meanMidArray.listIterator(LFM.getFactorDimension() - i - 1); currentMean = meanArray.listIterator(LFM.getFactorDimension() - i - 1); } else { currentPrecision = precisionArray.listIterator(); currentMidMean = meanMidArray.listIterator(); currentMean = meanArray.listIterator(); } drawI(i, currentPrecision, currentMidMean, currentMean); LFM.getLoadings().fireParameterChangedEvent(); } return 0; } public double getCoercableParameter() { return Math.log(scaleFactor); } @Override public void setCoercableParameter(double value) { scaleFactor = Math.exp(value); } @Override public double getRawParameter() { return scaleFactor; } }