/*
* MatrixMatrixProduct.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.model;
/*
@author Max Tolkoff
*/
//Designed to return a data matrix post computation if asked. Designed for latent liabilities
public class MatrixMatrixProduct extends MatrixParameter implements VariableListener {
MatrixParameter left;
MatrixParameter right;
MatrixParameter inPlace;
private final int leftDim;
private final int rightDim;
private final int midDim;
Parameter columnMask;
boolean[][] oldStoredValues;
double[][] storedValues;
boolean[][] areValuesStored;
private Bounds bounds=null;
public MatrixMatrixProduct(MatrixParameter[] params, Parameter columnMask) {
super(null, params);
this.columnMask=columnMask;
this.left=params[0];
this.right=params[1];
if(params.length==3){
inPlace=params[2];
inPlace.addVariableListener(this);
}
storedValues=new double[left.getRowDimension()][right.getColumnDimension()];
areValuesStored=new boolean[left.getRowDimension()][right.getColumnDimension()];
oldStoredValues=new boolean[left.getRowDimension()][right.getColumnDimension()];
for (int i = 0; i <left.getRowDimension() ; i++) {
for (int j = 0; j <right.getColumnDimension() ; j++) {
areValuesStored[i][j]=false;
}
}
leftDim=left.getRowDimension();
midDim=left.getColumnDimension();
rightDim=right.getColumnDimension();
left.addVariableListener(this);
right.addVariableListener(this);
inPlace.addVariableListener(this);
}
public void variableChangedEvent(Variable variable, int index, ChangeType type) {
if(variable==right) {
// System.out.println("RightChanged");
// System.out.println(index/getRowDimension());
// System.out.println(index);
for (int i = 0; i <getRowDimension() ; i++) {
areValuesStored[i][index/right.getRowDimension()]=false;
} }
if(variable==left) {
// System.out.println("LeftChanged");
// System.out.println(index%left.getRowDimension());
// System.out.println(index);
for (int i = 0; i <getColumnDimension(); i++) {
areValuesStored[index%left.getRowDimension()][i]=false;
} }
fireParameterChangedEvent(index, type);
}
@Override
public int getDimension(){
return leftDim*rightDim;
}
public void addBounds(Bounds bounds) {
this.bounds = bounds;
}
protected void storeValues() {
System.arraycopy(areValuesStored, 0, oldStoredValues, 0, areValuesStored.length);
left.storeParameterValues();
right.storeParameterValues();
inPlace.storeParameterValues();
}
protected void restoreValues() {
left.restoreParameterValues();
right.restoreVariableValues();
inPlace.restoreParameterValues();
areValuesStored=oldStoredValues;
}
// protected void acceptValues() {
// left.acceptParameterValues();
// right.acceptParameterValues();
// }
public double getParameterValue(int i, int j) {
double sum = 0;
if (columnMask.getParameterValue(j)!=0 && !areValuesStored[i][j]) {
for (int k = 0; k < midDim; k++) {
{
sum += left.getParameterValue(i, k) * right.getParameterValue(k, j);
}
}
inPlace.setParameterValue(i,j, sum);
areValuesStored[i][j]=true;
}
else{
sum=inPlace.getParameterValue(i,j);
}
return sum;
}
@Override
public double[][] getParameterAsMatrix() {
return super.getParameterAsMatrix();
}
public Parameter getParameter(int PID) {
for (int i = 0; i <leftDim ; i++) {
getParameterValue(i,PID);
}
return inPlace.getParameter(PID) ;
}
public int getRowDimension(){
return leftDim;
}
public int getColumnDimension(){
return rightDim;
}
@Override
public double getParameterValue(int dim) {
return getParameterValue(dim/getRowDimension(),dim%rightDim);
}
private void throwError(String functionName) throws RuntimeException {
throw new RuntimeException("Object " + getId() + " is a deterministic function. Calling "
+ functionName + " is not allowed");
}
public void setParameterValue(int dim, double value) {
throwError("setParameterValue()");
}
public void setParameterValueQuietly(int dim, double value) {
throwError("setParameterValueQuietly()");
}
@Override
public void setParameterValueNotifyChangedAll(int dim, double value) {
throwError("setParameterValueNotifyChangedAll()");
}
// @Override
// public String getParameterName() {
// if (getId() == null) {
// StringBuilder sb = new StringBuilder("product");
// sb.append(".").append(left.getId());
// sb.append(".").append(right.getId());
// setId(sb.toString());
// }
// return getId();
// }
@Override
public Bounds<Double> getBounds() {
return bounds;
}
public void addDimension(int index, double value) {
throw new RuntimeException("Not yet implemented.");
}
public double removeDimension(int index) {
throw new RuntimeException("Not yet implemented.");
}
};