package dr.inference.operators; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.MomentDistributionModel; import dr.inference.model.*; import dr.math.MathUtils; import dr.math.distributions.NormalDistribution; import dr.math.matrixAlgebra.SymmetricMatrix; /** * Created by max on 2/4/16. */ public class LoadingsGibbsTruncatedOperator extends SimpleMCMCOperator implements GibbsOperator{ Likelihood prior; LatentFactorModel LFM; double[][] precisionArray; double[] meanMidArray; double[] meanArray; boolean randomScan; double pathParameter=1.0; final Parameter missingIndicator; double priorPrecision; double priorMeanPrecision; MatrixParameterInterface loadings; DistributionLikelihood cutoffPrior; public LoadingsGibbsTruncatedOperator(LatentFactorModel LFM, Likelihood prior, double weight, boolean randomScan, MatrixParameterInterface loadings, DistributionLikelihood cutoffPrior) { setWeight(weight); this.loadings=loadings; this.prior = prior; this.LFM = LFM; if(prior instanceof MomentDistributionModel){ priorPrecision = (((MomentDistributionModel) this.prior).getScaleMatrix()[0][0]); priorMeanPrecision = ((MomentDistributionModel) this.prior).getMean()[0] * priorPrecision;} else if (prior instanceof DistributionLikelihood){ priorPrecision = 1 / ((DistributionLikelihood) this.prior).getDistribution().variance(); priorMeanPrecision = ((DistributionLikelihood) this.prior).getDistribution().mean() * priorPrecision; } this.cutoffPrior = cutoffPrior; missingIndicator = LFM.getMissingIndicator(); } private void getPrecisionOfTruncated(MatrixParameterInterface full, int newRowDimension, int row, double[][] answer) { int p = full.getColumnDimension(); for (int i = 0; i < newRowDimension; i++) { for (int j = i; j < newRowDimension; j++) { double sum = 0; for (int k = 0; k < p; k++) sum += full.getParameterValue(i, k) * full.getParameterValue(j, k); answer[i][j] = sum * LFM.getColumnPrecision().getParameterValue(row, row); if (i == j) { answer[i][j] =answer[i][j] * pathParameter + priorPrecision; } else { answer[i][j] *= pathParameter; answer[j][i] = answer[i][j]; } } } } private void getTruncatedMean(int newRowDimension, int dataColumn, double[][] variance, double[] midMean, double[] mean) { MatrixParameterInterface data = LFM.getScaledData(); MatrixParameterInterface Left = LFM.getFactors(); int p = data.getColumnDimension(); for (int i = 0; i < newRowDimension; i++) { double sum = 0; for (int k = 0; k < p; k++) { if(missingIndicator == null || missingIndicator.getParameterValue(k * LFM.getScaledData().getRowDimension() + dataColumn) != 1) sum += Left.getParameterValue(i, k) * data.getParameterValue(dataColumn, k); } sum = sum * LFM.getColumnPrecision().getParameterValue(dataColumn, dataColumn); sum += priorMeanPrecision; midMean[i] = sum; } for (int i = 0; i < newRowDimension; i++) { double sum = 0; for (int k = 0; k < newRowDimension; k++) sum += variance[i][k] * midMean[k]; mean[i] = sum; } } private void getPrecision(int i, double[][] answer) { int size = LFM.getFactorDimension(); getPrecisionOfTruncated(LFM.getFactors(), size, i, answer); } private void getMean(int i, double[][] variance, double[] midMean, double[] mean) { int size = LFM.getFactorDimension(); getTruncatedMean(size, i, variance, midMean, mean); for (int j = 0; j <mean.length ; j++) {//TODO implement for generic prior mean[j]*=pathParameter; } } private void copy(int i, double[] random) { MatrixParameterInterface changing = loadings; for (int j = 0; j < random.length; j++) { changing.setParameterValueQuietly(i, j, random[j]); } } private void getTruncatedDraw(int row, int column, NormalDistribution truncated){ double lowCutoff; double highCutoff; MatrixParameterInterface cutoff = (MatrixParameterInterface) ((MomentDistributionModel) prior).getCutoff(); lowCutoff = - Math.sqrt(cutoff.getParameterValue(row, column)); highCutoff = - lowCutoff; double low = truncated.cdf(lowCutoff); double high = truncated.cdf(highCutoff); double split = low / (low + (1-high)); double draw = 0; int count = 0; while((draw < highCutoff && draw > lowCutoff || Double.isNaN(draw)) && count < 10000) { double rand = MathUtils.nextDouble(); if (rand < split) { draw = MathUtils.nextDouble() * low; draw = truncated.quantile(draw); } else { draw = MathUtils.nextDouble() * (1 - high) + high; draw = truncated.quantile(draw); } count++; } if(count < 10000){ loadings.setParameterValue(row, column, draw); } } private void drawI(int i, int column) { double[] draws = null; precisionArray = new double[loadings.getColumnDimension()][loadings.getColumnDimension()]; double[][] variance; meanMidArray = new double[loadings.getColumnDimension()]; meanArray = new double[loadings.getColumnDimension()]; double[][] cholesky = null; NormalDistribution conditioned; getPrecision(i, precisionArray); if(LFM.getLoadings().getParameterValue(i, column) != 0){ variance = (new SymmetricMatrix(precisionArray)).inverse().toComponents(); // try { // cholesky = new CholeskyDecomposition(variance).getL(); // } catch (IllegalDimension illegalDimension) { // illegalDimension.printStackTrace(); // } getMean(i, variance, meanMidArray, meanArray); if(LFM.getFactorDimension() != 1) conditioned = getConditionalDistribution(meanArray, variance, column, i); else conditioned = new NormalDistribution(meanArray[0], Math.sqrt(variance[0][0])); } else conditioned = new NormalDistribution(0, Math.sqrt(1 / priorPrecision)); //TODO generify if(prior instanceof MomentDistributionModel){ if(MathUtils.nextDouble() < .5) { getTruncatedDraw(i, column, conditioned); getCutoffDraw(i, column, conditioned); } else{ getCutoffDraw(i, column, conditioned); getTruncatedDraw(i, column, conditioned); } } else{ loadings.setParameterValue(i, column, conditioned.quantile(MathUtils.nextDouble())); } } private NormalDistribution getConditionalDistribution(double[] meanArray, double[][] variance, int column, int row) { double[][] newVariance = new double[meanArray.length - 1][meanArray.length - 1]; for (int i = 0; i < meanArray.length; i++) { for (int j = 0; j < meanArray.length; j++) { if(i < column && j < column){ newVariance[i][j] = variance[i][j]; } else if(i < column && j > column){ newVariance[i][j - 1] = variance[i][j]; } else if(i > column && j < column){ newVariance[i - 1][j] = variance[i][j]; } else if(i > column && j > column){ newVariance[i - 1][j - 1] = variance[i][j]; } else{} } } double[][] smallPrecision = (new SymmetricMatrix(newVariance)).inverse().toComponents(); double[] meanStore1 = new double[meanArray.length - 1]; double[] meanStore2 = new double[meanArray.length - 1]; double[] precStore = new double[meanArray.length - 1]; for (int i = 0; i < meanArray.length; i++) { if(i < column){ meanStore1[i] = LFM.getLoadings().getParameterValue(row, i) - meanArray[i]; } else if (i > column){ meanStore1[i - 1] = LFM.getLoadings().getParameterValue(row, i) - meanArray[i]; } else{} } for (int i = 0; i < meanArray.length - 1; i++) { for (int j = 0; j < meanArray.length - 1; j++) { meanStore2[i] += smallPrecision[i][j] * meanStore1[j]; } } double mean = meanArray[column]; for (int i = 0; i < meanArray.length - 1; i++) { if(i < column){ mean += meanStore2[i] * variance[i][column]; } else{ mean += meanStore2[i] * variance[i + 1][column]; } } for (int i = 0; i < meanArray.length - 1; i++) { for (int j = 0; j < meanArray.length - 1; j++) { if(i < column) precStore[i] += smallPrecision[i][j] * variance[j][column]; else precStore[i] += smallPrecision[i][j] * variance[j+1][column]; } } double varianceElement = variance[column][column]; for (int i = 0; i < meanArray.length - 1; i++) { if(i < column) varianceElement -= precStore[i] * variance[i][column]; else varianceElement -= precStore[i] * variance[i+1][column]; } return new NormalDistribution(mean, Math.sqrt(varianceElement)); } void getCutoffDraw(int row, int column, NormalDistribution posteriorLoadings){ double loadingsCutoff = Math.abs(loadings.getParameterValue(row, column)); double draw = MathUtils.nextDouble() * loadingsCutoff; double cutoffVal = Math.sqrt(((MatrixParameterInterface) ((MomentDistributionModel) prior).getCutoff()).getParameterValue(row, column)); double top = cutoffPrior.getDistribution().pdf(Math.pow(draw,2)) / (1 - (posteriorLoadings.cdf(draw) - posteriorLoadings.cdf(-draw))); double bottom = cutoffPrior.getDistribution().pdf(Math.pow(cutoffVal, 2)) / (1 - (posteriorLoadings.cdf(cutoffVal) - posteriorLoadings.cdf(-cutoffVal))); // double stopperCDF = Math.pow(cutoffPrior.getDistribution().cdf(loadingsCutoff), 2); // double randQuant = MathUtils.nextDouble() * stopperCDF; if(MathUtils.nextDouble() < top / bottom){ ((MatrixParameterInterface) ((MomentDistributionModel) prior).getCutoff()).setParameterValue(row, column, Math.pow(draw, 2));} } @Override public int getStepCount() { return 0; //To change body of implemented methods use File | Settings | File Templates. } @Override public String getPerformanceSuggestion() { return null; //To change body of implemented methods use File | Settings | File Templates. } @Override public String getOperatorName() { return "loadingsGibbsTruncatedOperator"; //To change body of implemented methods use File | Settings | File Templates. } @Override public double doOperation() { int size = LFM.getLoadings().getRowDimension(); int column = MathUtils.nextInt(LFM.getLoadings().getColumnDimension()); for (int i = 0; i < size; i++) { drawI(i, column); } ((Parameter) loadings).fireParameterChangedEvent(); return 0; } public void setPathParameter(double beta){ pathParameter=beta; } }