/* * GeneralizedLinearModel.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.glm; import cern.colt.matrix.impl.DenseDoubleMatrix2D; import cern.colt.matrix.linalg.SingularValueDecomposition; import dr.inference.distribution.DensityModel; import dr.inference.distribution.ParametricDistributionModel; import dr.inference.distribution.ParametricMultivariateDistributionModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.NumberColumn; import dr.inference.model.*; import dr.inferencexml.glm.GeneralizedLinearModelParser; import dr.math.MultivariateFunction; import dr.math.distributions.Distribution; import dr.util.Transform; import org.w3c.dom.Document; import org.w3c.dom.Element; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; /** * @author Marc Suchard * @author Andrew Rambaut */ public final class GeneralizedLinearModel extends AbstractModelLikelihood { public enum LinkFunction { IDENTITY(new Transform.NoTransform()), LOG(new Transform.LogTransform()), LOGIT(new Transform.LogitTransform()); LinkFunction(Transform transform) { this.transform = transform; } public Transform getTransform() { return transform; } private final Transform transform; } private final Transform linkFunction; private final DensityModel density; private final boolean isMultivariateDensity; private final Parameter dependentParameter; private final List<Parameter> independentParameter = new ArrayList<Parameter>(); private final List<Parameter> independentParameterDelta = new ArrayList<Parameter>(); private final List<DesignMatrix> designMatrix = new ArrayList<DesignMatrix>(); private int numIndependentVariables = 0; private int numRandomEffects = 0; private int N; protected List<Parameter> randomEffects = null; private double[] transformedXBeta; private double[] storedTransformedXBeta; private boolean transformedXBetaKnown = false; private double[] Y; private double storedLogLikelihood; private double logLikelihood; private boolean likelihoodKnown = false; public GeneralizedLinearModel(Parameter dependentParameter, DensityModel density, LinkFunction linkFunction) { super(GeneralizedLinearModelParser.GLM_LIKELIHOOD); this.dependentParameter = dependentParameter; this.linkFunction = linkFunction.getTransform(); this.density = density; isMultivariateDensity = density instanceof ParametricMultivariateDistributionModel; addModel(density); if (dependentParameter != null) { addVariable(dependentParameter); N = dependentParameter.getDimension(); } else { N = 0; } transformedXBeta = new double[N]; storedTransformedXBeta = new double[N]; Y = dependentParameter.getParameterValues(); transformedXBetaKnown = false; likelihoodKnown = false; } public void addRandomEffectsParameter(Parameter effect) { if (randomEffects == null) { randomEffects = new ArrayList<Parameter>(); } if (N != 0 && effect.getDimension() != N) { throw new RuntimeException("Random effects have the wrong dimension"); } addVariable(effect); randomEffects.add(effect); numRandomEffects++; } public void addIndependentParameter(Parameter effect, DesignMatrix matrix, Parameter delta) { if (N == 0) { N = matrix.getRowDimension(); } designMatrix.add(matrix); independentParameter.add(effect); independentParameterDelta.add(delta); if (designMatrix.size() != independentParameter.size()) { throw new RuntimeException("Independent variables and their design matrices are out of sync"); } addVariable(effect); addVariable(matrix); if (delta != null) { addVariable(delta); } numIndependentVariables++; Logger.getLogger("dr.inference").info("\tAdding independent predictors '" + effect.getStatisticName() + "' with design matrix '" + matrix.getStatisticName() + "'"); } public boolean getAllIndependentVariablesIdentifiable() { int totalColDim = 0; for (DesignMatrix mat : designMatrix) { totalColDim += mat.getColumnDimension(); } double[][] grandDesignMatrix = new double[N][totalColDim]; int offset = 0; for (DesignMatrix mat : designMatrix) { final int length = mat.getColumnDimension(); for (int i = 0; i < N; ++i) { for (int j = 0; j < length; ++j) { grandDesignMatrix[i][offset + j] = mat.getParameterValue(i, j); } } offset += length; } double[][] mat = grandDesignMatrix; if (grandDesignMatrix.length < grandDesignMatrix[0].length) { mat = new double[grandDesignMatrix[0].length][grandDesignMatrix.length]; for (int i = 0; i < grandDesignMatrix.length; ++i) { for (int j = 0; j < grandDesignMatrix[i].length; ++j) { mat[j][i] = grandDesignMatrix[i][j]; } } } SingularValueDecomposition svd = new SingularValueDecomposition(new DenseDoubleMatrix2D(mat)); int rank = svd.rank(); boolean isFullRank = (totalColDim == rank); Logger.getLogger("dr.inference").info("\tTotal # of predictors = " + totalColDim + " and rank = " + rank); return isFullRank; } public double[] getXBeta() { double[] xBeta = new double[N]; for (int j = 0; j < numIndependentVariables; j++) { Parameter beta = independentParameter.get(j); Parameter delta = independentParameterDelta.get(j); DesignMatrix X = designMatrix.get(j); final int K = beta.getDimension(); for (int k = 0; k < K; k++) { double betaK = beta.getParameterValue(k); if (delta != null) betaK *= delta.getParameterValue(k); for (int i = 0; i < N; i++) xBeta[i] += X.getParameterValue(i, k) * betaK; } } for (int j = 0; j < numRandomEffects; j++) { Parameter effect = randomEffects.get(j); for (int i = 0; i < N; i++) { xBeta[i] += effect.getParameterValue(i); } } return xBeta; } public double[] getXBeta(int j) { double[] xBeta = new double[N]; Parameter beta = independentParameter.get(j); Parameter delta = independentParameterDelta.get(j); DesignMatrix X = designMatrix.get(j); final int K = beta.getDimension(); for (int k = 0; k < K; k++) { double betaK = beta.getParameterValue(k); if (delta != null) { betaK *= delta.getParameterValue(k); } for (int i = 0; i < N; i++) { xBeta[i] += X.getParameterValue(i, k) * betaK; } } if (numRandomEffects != 0) { throw new RuntimeException("Attempting to retrieve fixed effects without controlling for random effects"); } return xBeta; } public int getNumberOfFixedEffects() { return numIndependentVariables; } public int getNumberOfRandomEffects() { return numRandomEffects; } public Parameter getFixedEffect(int j) { return independentParameter.get(j); } public Parameter getRandomEffect(int j) { return randomEffects.get(j); } public Parameter getDependentVariable() { return dependentParameter; } public int getEffectNumber(Parameter effect) { return independentParameter.indexOf(effect); } public double[][] getX(int j) { return designMatrix.get(j).getParameterAsMatrix(); } private void calculateTransformedXBeta() { double[] xBeta = getXBeta(); for (int i = 0; i < N; i++) { transformedXBeta[i] = linkFunction.inverse(xBeta[i]); } transformedXBetaKnown = true; } private double calculateLogLikelihood() { if (!transformedXBetaKnown) { calculateTransformedXBeta(); } double logL = 0.0; if (isMultivariateDensity) { // todo - implement } else { for (int i = 0; i < Y.length; i++) { density.getLocationVariable().setValue(0, transformedXBeta[i]); logL += density.logPdf(new double[] { Y[i] }); } } return logL; } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * todo - cache likelihood values * @return */ public double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = calculateLogLikelihood(); } return logLikelihood; } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** @Override protected void handleModelChangedEvent(Model model, Object object, int index) { // some aspect of the density function has changed likelihoodKnown = false; } @Override protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == dependentParameter) { Y = dependentParameter.getParameterValues(); } transformedXBetaKnown = false; likelihoodKnown = false; } @Override protected void storeState() { storedLogLikelihood = logLikelihood; System.arraycopy(transformedXBeta, 0, storedTransformedXBeta, 0, transformedXBeta.length); } @Override protected void restoreState() { logLikelihood = storedLogLikelihood; // could use double buffering to speed this up for very large matrices... System.arraycopy(storedTransformedXBeta, 0, transformedXBeta, 0, transformedXBeta.length); } @Override protected void acceptState() { // Nothing to do } @Override public Model getModel() { return this; } @Override public String toString() { return super.toString() + ": " + getLogLikelihood(); } @Override public void makeDirty() { } // ************************************************************** // Loggable IMPLEMENTATION // ************************************************************** // probably makes more consistent sense to log the likelihood and expose the xBeta values // through a statistic... // /** // * @return the log columns. // */ // public LogColumn[] getColumns() { // return new dr.inference.loggers.LogColumn[]{ // new LikelihoodColumn(getId()) // }; // } // // private class LikelihoodColumn extends dr.inference.loggers.NumberColumn { // public LikelihoodColumn(String label) { // super(label); // } // // public double getDoubleValue() { // return getLogLikelihood(); // } // } public LogColumn[] getColumns() { LogColumn[] output = new LogColumn[N]; for (int i = 0; i < N; i++) output[i] = new NumberArrayColumn(getId() + i, i); return output; } private class NumberArrayColumn extends NumberColumn { private final int index; public NumberArrayColumn(String label, int index) { super(label); this.index = index; } public double getDoubleValue() { return getXBeta()[index]; } } // ************************************************************** // XMLElement IMPLEMENTATION // ************************************************************** public Element createElement(Document d) { throw new RuntimeException("Not implemented yet!"); } }