/*
* Procrustes.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.math;
import org.apache.commons.math.linear.*;
/**
* Procrustination function based on procrustes.r
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
*/
public class Procrustes {
public Procrustes(RealMatrix X, RealMatrix Xstar, boolean allowTranslation, boolean allowDilation) {
rowDimension = X.getRowDimension();
columnDimension = X.getColumnDimension();
if (Xstar.getRowDimension() != rowDimension) {
throw new IllegalArgumentException("X and Xstar do not have the same number of rows");
}
if (Xstar.getColumnDimension() != columnDimension) {
throw new IllegalArgumentException("X and Xstar do not have the same number of columns");
}
RealMatrix J = new Array2DRowRealMatrix(rowDimension, rowDimension);
if (allowTranslation) {
// J <- diag(n) - 1/n * matrix(1, n, n)
// for n = 3, J = {{1, -2/3, -2/3}, {-2/3, 1, -2/3}, {-2/3, -2/3, 1}}
for (int i = 0; i < rowDimension; i++) {
J.setEntry(i, i, 1.0 - (1.0 / rowDimension));
for (int j = i + 1; j < rowDimension; j++) {
J.setEntry(i, j, -1.0 / rowDimension);
J.setEntry(j, i, -1.0 / rowDimension);
}
}
} else {
// J <- diag(n)
for (int i = 0; i < rowDimension; i++) {
J.setEntry(i, i, 1);
}
}
// C <- t(Xstar) %*% J %*% X
RealMatrix C = Xstar.transpose().multiply(J.multiply(X));
// svd.out <- svd(C)
// R <- svd.out$v %*% t(svd.out$u)
// NB: Apache math does a different SVD from R. TODO Should use Colt library
SingularValueDecomposition SVD = new SingularValueDecompositionImpl(C);
R = SVD.getV().multiply(SVD.getUT());
// s <- 1
double s = 1.0; // scale = 1 unless dilation is being used
if (allowDilation) {
// mat1 <- t(Xstar) %*% J %*% X %*% R
RealMatrix mat1 = Xstar.transpose().multiply(J.multiply(X.multiply(R)));
// mat2 <- t(X) %*% J %*% X
RealMatrix mat2 = X.transpose().multiply(J.multiply(X));
// s.numer <- 0
// s.denom <- 0
double numer = 0.0;
double denom = 0.0;
// for (i in 1:m) {
// s.numer <- s.numer + mat1[i, i]
// s.denom <- s.denom + mat2[i, i]
// }
for (int i = 0; i < columnDimension; i++) {
numer = numer + mat1.getEntry(i, i);
denom = denom + mat2.getEntry(i, i);
}
// s <- s.numer/s.denom
s = numer / denom;
}
this.s = s;
// tt <- matrix(0, m, 1)
RealMatrix tmpT = new Array2DRowRealMatrix(columnDimension, 1); // a translation vector of zero unless translation is being used
if (allowTranslation) {
// tt <- 1/n * t(Xstar - s * X %*% R) %*% matrix(1, n, 1)
RealMatrix tmp = new Array2DRowRealMatrix(rowDimension, 1);
for (int i = 0; i < rowDimension; i++) {
tmp.setEntry(i, 0, 1);
}
tmpT = Xstar.subtract(X.multiply(R).scalarMultiply(s)).transpose().scalarMultiply(1.0 / rowDimension).multiply(tmp);
}
T = tmpT;
}
public final RealMatrix getTranslation() {
return T.copy(); // NB Different from R
}
public final double getDilation() {
return s;
}
public final RealMatrix getR() {
return R.copy(); // NB Different from R
}
/**
* procrustinate the complete matrix of coordinates
* @param X the matrix containing coordinates (same dimensions as X in the constructor)
* @return the transformed matrix
*/
public final RealMatrix procrustinate(RealMatrix X) {
if (X.getRowDimension() != rowDimension) {
throw new IllegalArgumentException("X does not have the expected number of rows");
}
if (X.getColumnDimension() != columnDimension) {
throw new IllegalArgumentException("X does not have the expected number of columns");
}
// X.new <- s * X %*% R + matrix(tt, nrow(X), ncol(X), byrow = TRUE)
RealMatrix tt = new Array2DRowRealMatrix(rowDimension, columnDimension);
for (int i = 0; i < rowDimension; i++) {
tt.setRowMatrix(i, T.transpose());
}
// rotate, scale and translate
return X.multiply(R).scalarMultiply(s).add(tt); // Was a bug here
}
/**
* procrustinate a single set of coordinates
* @param X
*/
public double[] procrustinate(double[] X) {
if (X.length != columnDimension) {
throw new IllegalArgumentException("X does not have the expected number of elements");
}
RealMatrix tmp = new Array2DRowRealMatrix(X);
// rotate, scale and translate
RealMatrix Xnew = tmp.multiply(R).scalarMultiply(s).add(T);
return Xnew.getRow(0);
}
/**
* procrustinate the complete matrix of coordinates
*/
public final static RealMatrix procrustinate(RealMatrix X, RealMatrix Xstar, boolean allowTranslation, boolean allowDilation) {
return new Procrustes(X, Xstar, allowTranslation, allowDilation).procrustinate(X);
}
private final RealMatrix R;
private final RealMatrix T;
private final double s;
private final int rowDimension;
private final int columnDimension;
}