/*
* LoadingsGibbsOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.operators;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.model.LatentFactorModel;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.NormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import java.util.ArrayList;
import java.util.List;
import java.util.ListIterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Created with IntelliJ IDEA.
* User: max
* Date: 5/23/14
* Time: 2:23 PM
* To change this template use File | Settings | File Templates.
*/
public class LoadingsGibbsOperator extends SimpleMCMCOperator implements GibbsOperator {
NormalDistribution prior;
NormalDistribution workingPrior;
LatentFactorModel LFM;
ArrayList<double[][]> precisionArray;
ArrayList<double[]> meanMidArray;
ArrayList<double[]> meanArray;
boolean randomScan;
double pathParameter=1.0;
final Parameter missingIndicator;
double priorPrecision;
double priorMeanPrecision;
double priorPrecisionWorking;
double priorMeanPrecisionWorking;
private double a;
public LoadingsGibbsOperator(LatentFactorModel LFM, DistributionLikelihood prior, double weight, boolean randomScan, DistributionLikelihood workingPrior, boolean multiThreaded, int numThreads) {
setWeight(weight);
this.prior = (NormalDistribution) prior.getDistribution();
if (workingPrior != null) {
this.workingPrior = (NormalDistribution) workingPrior.getDistribution();
}
this.LFM = LFM;
precisionArray = new ArrayList<double[][]>();
double[][] temp;
this.randomScan = randomScan;
meanArray = new ArrayList<double[]>();
meanMidArray = new ArrayList<double[]>();
double[] tempMean;
if (!randomScan) {
for (int i = 0; i < LFM.getFactorDimension(); i++) {
temp = new double[i + 1][i + 1];
precisionArray.add(temp);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[i + 1];
meanArray.add(tempMean);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[i + 1];
meanMidArray.add(tempMean);
}
} else {
for (int i = 0; i < LFM.getFactorDimension(); i++) {
temp = new double[LFM.getFactorDimension() - i][LFM.getFactorDimension() - i];
precisionArray.add(temp);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[LFM.getFactorDimension() - i];
meanArray.add(tempMean);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[LFM.getFactorDimension() - i];
meanMidArray.add(tempMean);
}
}
// vectorProductAnswer=new MatrixParameter[LFM.getLoadings().getRowDimension()];
// for (int i = 0; i <vectorProductAnswer.length ; i++) {
// vectorProductAnswer[i]=new MatrixParameter(null);
// vectorProductAnswer[i].setDimensions(i+1, 1);
// }
// priorMeanVector=new MatrixParameter[LFM.getLoadings().getRowDimension()];
// for (int i = 0; i <priorMeanVector.length ; i++) {
// priorMeanVector[i]=new MatrixParameter(null, i+1, 1, this.prior.getMean()/(this.prior.getSD()*this.prior.getSD()));
//
//
// }
priorPrecision = 1 / (this.prior.getSD() * this.prior.getSD());
priorMeanPrecision = this.prior.getMean() * priorPrecision;
if (workingPrior == null) {
priorMeanPrecisionWorking = priorMeanPrecision;
priorPrecisionWorking = priorPrecision;
} else {
priorPrecisionWorking = 1 / (this.workingPrior.getSD() * this.workingPrior.getSD());
priorMeanPrecisionWorking = this.workingPrior.getMean() * priorPrecisionWorking;
}
if (multiThreaded) {
for (int i = 0; i < LFM.getLoadings().getRowDimension(); i++) {
if (i < LFM.getFactorDimension())
drawCallers.add(new DrawCaller(i, new double[i + 1][i + 1], new double[i + 1], new double[i + 1]));
else
drawCallers.add(new DrawCaller(i, new double[LFM.getFactorDimension()][LFM.getFactorDimension()], new double[LFM.getFactorDimension()], new double[LFM.getFactorDimension()]));
}
int threads = numThreads;
// Integer.parseInt(System.getProperty("thread.count"));
pool = Executors.newFixedThreadPool(threads);
}
else{
pool = null;
}
missingIndicator = LFM.getMissingIndicator();
}
private void getPrecisionOfTruncated(MatrixParameterInterface full, int newRowDimension, int row, double[][] answer) {
// MatrixParameter answer=new MatrixParameter(null);
// answer.setDimensions(this.getRowDimension(), Right.getRowDimension());
// System.out.println(answer.getRowDimension());
// System.out.println(answer.getColumnDimension());
int p = full.getColumnDimension();
for (int i = 0; i < newRowDimension; i++) {
for (int j = i; j < newRowDimension; j++) {
double sum = 0;
for (int k = 0; k < p; k++)
sum += full.getParameterValue(i, k) * full.getParameterValue(j, k);
answer[i][j] = sum * LFM.getColumnPrecision().getParameterValue(row, row);
if (i == j) {
answer[i][j] = answer[i][j] * pathParameter + getAdjustedPriorPrecision();
} else {
answer[i][j] *= pathParameter;
answer[j][i] = answer[i][j];
}
}
}
}
private void getTruncatedMean(int newRowDimension, int dataColumn, double[][] variance, double[] midMean, double[] mean) {
// MatrixParameter answer=new MatrixParameter(null);
// answer.setDimensions(this.getRowDimension(), Right.getRowDimension());
// System.out.println(answer.getRowDimension());
// System.out.println(answer.getColumnDimension());
MatrixParameterInterface data = LFM.getScaledData();
MatrixParameterInterface Left = LFM.getFactors();
int p = data.getColumnDimension();
for (int i = 0; i < newRowDimension; i++) {
double sum = 0;
for (int k = 0; k < p; k++)
{
if(missingIndicator == null || missingIndicator.getParameterValue(k * LFM.getScaledData().getRowDimension() + dataColumn) != 1)
sum += Left.getParameterValue(i, k) * data.getParameterValue(dataColumn, k);
}
sum = sum * LFM.getColumnPrecision().getParameterValue(dataColumn, dataColumn);
sum += priorMeanPrecision;
midMean[i] = sum;
}
for (int i = 0; i < newRowDimension; i++) {
double sum = 0;
for (int k = 0; k < newRowDimension; k++)
sum += variance[i][k] * midMean[k];
mean[i] = sum;
}
}
private void getPrecision(int i, double[][] answer) {
int size = LFM.getFactorDimension();
if (i < size) {
getPrecisionOfTruncated(LFM.getFactors(), i + 1, i, answer);
} else {
getPrecisionOfTruncated(LFM.getFactors(), size, i, answer);
}
}
private void getMean(int i, double[][] variance, double[] midMean, double[] mean) {
// Matrix factors=null;
int size = LFM.getFactorDimension();
// double[] scaledDataColumn=LFM.getScaledData().getRowValues(i);
// Vector dataColumn=null;
// Vector priorVector=null;
// Vector temp=null;
// Matrix data=new Matrix(LFM.getScaledData().getParameterAsMatrix());
if (i < size) {
getTruncatedMean(i + 1, i, variance, midMean, mean);
// dataColumn=new Vector(data.toComponents()[i]);
// try {
// answer=precision.inverse().product(new Matrix(priorMeanVector[i].add(vectorProductAnswer[i]).getParameterAsMatrix()));
// } catch (IllegalDimension illegalDimension) {
// illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates.
// }
} else {
getTruncatedMean(size, i, variance, midMean, mean);
// dataColumn=new Vector(data.toComponents()[i]);
// try {
// answer=precision.inverse().product(new Matrix(priorMeanVector[size-1].add(vectorProductAnswer[size-1]).getParameterAsMatrix()));
// } catch (IllegalDimension illegalDimension) {
// illegalDimension.printStackTrace(); //To change body of catch statement use File | Settings | File Templates.
// }
}
for (int j = 0; j < mean.length ; j++) {//TODO implement for generic prior
mean[j] *= pathParameter;
}
}
private void copy(int i, double[] random) {
MatrixParameterInterface changing = LFM.getLoadings();
for (int j = 0; j < random.length; j++) {
changing.setParameterValueQuietly(i, j, random[j]);
}
}
private void drawI(int i, double[][] precision, double[] midMean, double[] mean) {
double[] draws = null;
double[][] variance;
double[][] cholesky = null;
getPrecision(i, precision);
variance = (new SymmetricMatrix(precision)).inverse().toComponents();
try {
cholesky = new CholeskyDecomposition(variance).getL();
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
getMean(i, variance, midMean, mean);
draws = MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky);
// if(i<draws.length)
//
// {
// while (draws[i] < 0) {
// draws = MultivariateNormalDistribution.nextMultivariateNormalCholesky(mean, cholesky);
// }
// }
if (i < draws.length) {
//if (draws[i] > 0) { TODO implement as option
copy(i, draws);
//}
} else {
copy(i, draws);
}
// copy(i, draws);
}
@Override
public int getStepCount() {
return 0; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public String getPerformanceSuggestion() {
return null; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public String getOperatorName() {
return "loadingsGibbsOperator"; //To change body of implemented methods use File | Settings | File Templates.
}
@Override
public double doOperation() {
int size = LFM.getLoadings().getRowDimension();
if(LFM.getFactorDimension() != precisionArray.listIterator().next().length){
precisionArray.clear();
meanArray.clear();
meanMidArray.clear();
double[] tempMean;
double[][] temp;
if (!randomScan) {
for (int i = 0; i < LFM.getFactorDimension(); i++) {
temp = new double[i + 1][i + 1];
precisionArray.add(temp);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[i + 1];
meanArray.add(tempMean);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[i + 1];
meanMidArray.add(tempMean);
}
} else {
for (int i = 0; i < LFM.getFactorDimension(); i++) {
temp = new double[LFM.getFactorDimension() - i][LFM.getFactorDimension() - i];
precisionArray.add(temp);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[LFM.getFactorDimension() - i];
meanArray.add(tempMean);
}
for (int i = 0; i < LFM.getFactorDimension(); i++) {
tempMean = new double[LFM.getFactorDimension() - i];
meanMidArray.add(tempMean);
}
}
}
if(pool != null){
try {
pool.invokeAll(drawCallers);
LFM.getLoadings().fireParameterChangedEvent();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
else {
if (!randomScan) {
ListIterator<double[][]> currentPrecision = precisionArray.listIterator();
ListIterator<double[]> currentMidMean = meanMidArray.listIterator();
ListIterator<double[]> currentMean = meanArray.listIterator();
double[][] precision = null;
double[] midMean = null;
double[] mean = null;
for (int i = 0; i < size; i++) {
if(i < LFM.getFactorDimension())
{precision = currentPrecision.next();
midMean = currentMidMean.next();
mean = currentMean.next();
}
drawI(i, precision, midMean, mean);
}
LFM.getLoadings().fireParameterChangedEvent();
} else {
int i = MathUtils.nextInt(LFM.getLoadings().getRowDimension());
ListIterator<double[][]> currentPrecision;
ListIterator<double[]> currentMidMean;
ListIterator<double[]> currentMean;
if (i < LFM.getFactorDimension()) {
currentPrecision = precisionArray.listIterator(LFM.getFactorDimension() - i - 1);
currentMidMean = meanMidArray.listIterator(LFM.getFactorDimension() - i - 1);
currentMean = meanArray.listIterator(LFM.getFactorDimension() - i - 1);
} else {
currentPrecision = precisionArray.listIterator();
currentMidMean = meanMidArray.listIterator();
currentMean = meanArray.listIterator();
}
drawI(i, currentPrecision.next(), currentMidMean.next(), currentMean.next());
LFM.getLoadings().fireParameterChangedEvent(i, null);
// LFM.getLoadings().fireParameterChangedEvent();
}
}
return 0;
}
public void setPathParameter(double beta){
pathParameter=beta;
}
public double getAdjustedPriorPrecision() {
return priorPrecision * pathParameter + (1 - pathParameter) * priorPrecisionWorking;
}
class DrawCaller implements Callable<Double> {
int i;
double[][] precision;
double[] midMean;
double[] mean;
public DrawCaller(int i, double[][] precision, double[] midMean, double [] mean) {
this.i = i;
this.precision = precision;
this.midMean = midMean;
this.mean = mean;
}
private final boolean DEBUG_PARALLEL_EVALUATION = false;
public Double call() throws Exception {
if (DEBUG_PARALLEL_EVALUATION) {
System.err.print("Invoking thread #" + i + " for " + ": ");
}
drawI(i, precision, midMean, mean);
return null;
}
}
private final List<Callable<Double>> drawCallers = new ArrayList<Callable<Double>>();
private final ExecutorService pool;
}