/*
* Copyright (C) 2014 Mathias Unberath
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/
package edu.stanford.rsl.conrad.geometry.shapes.activeshapemodels;
import java.util.ArrayList;
import edu.stanford.rsl.conrad.geometry.shapes.simple.PointND;
import edu.stanford.rsl.conrad.numerics.DecompositionSVD;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix;
import edu.stanford.rsl.conrad.numerics.SimpleOperators;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
/**
* This class implements the Generalized Procrustes Analysis using the method proposed by J. C. Gower.
* Assumes that perfect point correspondence between the point-clouds is established.
* @see Generalized Procrustes Analysis, J.C. Gower (1975), Psychometrika Vol. 40 pp. 33-51
* @author Mathias Unberath
*
*/
public class GPA {
/**
* Dimension of the point-clouds' vertices.
*/
public int dimension = -1;
/**
* Value at which the Generalized Procrustes Analysis iteration scheme shall be stopped
* because convergence is supposed to be reached.
* This value is expressed in % of initial error.
*/
public float convergence = 1e-3f;
/**
* Number of points in one point-cloud.
*/
public int numPoints = -1;
/**
* Array to store the scaling factor for the single point-clouds stored in the point ArrayList.
* The factor is used to re-scale the meshes during Generalized Procrustes Alignment.
*/
public ArrayList<Float> scaling;
/**
* Scaling used during Generalized Procrustes Analysis to maintain the proper scaling of the point-clouds.
*/
public ArrayList<Float> rho;
/**
* Array to store the centers of mass of the single point-clouds stored in the point ArrayList.
* The centers are used to shift the point-clouds' mean value to the origin during
* Generalized Procrustes Alignment.
*/
public ArrayList<PointND> centers;
/**
* ArrayList to store the rotation matrices for each point-cloud.
*/
public ArrayList<SimpleMatrix> rotations;
/**
* An ArrayList containing the point-clouds' vertices.
*/
public ArrayList<SimpleMatrix> pointList;
/**
* The consensus pint-cloud calculated during Generalized Procrustes Analysis.
*/
public SimpleMatrix consensus;
/**
* The connectivity information for the case, where the point-clouds are meshes and hence
* connectivity information exists. Is not used within this class but can be passed on.
* One single copy of the connectivity is enough as perfect point correspondence between
* the point-clouds is assumed.
*/
public SimpleMatrix connectivity;
/**
* NUmber of point-clouds to be analyzed during this GPA run.
* Note: using the method addElement will increase this number, while using addElementAtIndex will assume
* that enough elements are initialized.
*/
public int numPc = 0;
/**
* Debug flag for console output.
*/
public boolean DEBUG = true;
public boolean PRINT = false;
//==========================================================================================
// METHODS
//==========================================================================================\
/**
* Creates the object and initializes the point-cloud list.
* @param numPointClouds Number of point-clouds used for this GPA run.
*/
public GPA(int numPointclouds){
assert(numPointclouds > 0) : new IllegalArgumentException("Number of point-clouds must be bigger than 0.");
this.numPc = numPointclouds;
this.pointList = new ArrayList<SimpleMatrix>(numPc);
}
/**
* Start the Generalized Procrustes Alignment on the point-cloud data.
*/
public void runGPA(){
System.out.println("Starting Generalized Procrustes Analysis on " + numPc + " data-sets.");
this.rotations = new ArrayList<SimpleMatrix>(numPc);
shiftMeanToOriginAndScale();
// first consensus object is the first element
this.consensus = pointList.get(0);
// initialize the intra-procedure scalings
initializeRho();
int nIter = 1;
if(DEBUG){
System.out.println("Iteration: " + nIter);
System.out.println("Calculating first consensus object.");
}
nIter++;
// first iteration is slightly different in residual and rho-scaling
for(int i = 0; i < numPc; i++){
getRotationMatrixAndRotate(i);
}
updateConsensus();
double residual = getInitialResidual();
double initialRes = residual;
if(DEBUG){
System.out.println("Initial residual: " + residual);
}
double oldRes = 0;
double change = 1;
while(change > initialRes/100*convergence){ // this should be changed according to convergence criteria
if(PRINT){
// do something
}
oldRes = residual;
if(DEBUG){
System.out.println("Iteration: " + nIter);
}
// first step: rotate every cloud, then update consensus and residual
for(int i = 0; i < numPc; i++){
getRotationMatrixAndRotate(i);
}
SimpleMatrix oldConsensus = consensus;
updateConsensus();
residual -= getResidual(oldConsensus);
if(DEBUG){
System.out.println("Residual at rotation " + nIter + " is: " + residual);
}
// second step: re-scale every cloud, then update consensus and residual
updateRho();
rescalePointClouds();
oldConsensus = consensus;
updateConsensus();
residual -= getResidual(oldConsensus);
if(DEBUG){
System.out.println("Residual at rescaling " + nIter + " is: " + residual);
}
// check if we have converged:
change = oldRes - residual;
if(DEBUG){
System.out.println("Residual change at iteration " + nIter + " is: " + change);
}
nIter++;
}
}
/**
* Initializes the object and initializes the point-cloud list.
* @param numPointClouds Number of point-clouds used for this GPA run.
*/
public void init(int numPointclouds){
assert(numPointclouds > 0) : new IllegalArgumentException("Number of point-clouds must be bigger than 0.");
this.numPc = numPointclouds;
this.pointList = new ArrayList<SimpleMatrix>(numPc);
}
/**
* Adds a point-cloud to the list and increments the number of point-clouds.
* @param pointclod Point-cloud to add.
*/
public void addElement(SimpleMatrix pointcloud){
if(numPoints == -1){
this.numPoints = pointcloud.getRows();
this.dimension = pointcloud.getCols();
}else{
assert(numPoints == pointcloud.getRows()) : new IllegalArgumentException("Number of points in point-cloud does not match.");
assert(dimension == pointcloud.getCols()) : new IllegalArgumentException("Point dimension in point-cloud does not match.");
}
this.pointList.add(pointcloud);
this.numPc++;
}
/**
* Adds a point-cloud to the list at a certain index.
* @param idx Position where point-cloud will be added.
* @param pointclod Point-cloud to add.
*/
public void addElement(int idx, SimpleMatrix pointcloud){
if(numPoints == -1){
this.numPoints = pointcloud.getRows();
this.dimension = pointcloud.getCols();
}else{
assert(numPoints == pointcloud.getRows()) : new IllegalArgumentException("Number of points in point-cloud does not match.");
assert(dimension == pointcloud.getCols()) : new IllegalArgumentException("Point dimension in point-cloud does not match.");
}
assert(idx < numPc) : new IllegalArgumentException("Index out of bounds. Use addElement() instead.");
this.pointList.add(idx, pointcloud);
}
/**
* Adds the center of mass of a mesh object at a certain index position. Needed to restore original mesh positions after
* a shift to zero-mean.
* @param colIdx List-index to be written.
* @param centOfMass Value to be written at list-index.
*/
private void addCenterOfMassAtIndex(int colIdx, PointND centOfMass){
assert(colIdx < numPc) : new IllegalArgumentException("Index out of bounds. Initialize list first.");
this.centers.add(colIdx, centOfMass);
}
/**
* Adds the center of mass of a mesh object at a certain index position. Needed to restore original mesh positions after
* a shift to zero-mean.
* @param colIdx List-index to be written.
* @param centOfMass Value to be written at list-index.
*/
private void addScalingAtIndex(int colIdx, float scaling){
assert(colIdx < this.numPc) : new IllegalArgumentException("Index out of bounds. Initialize list first.");
this.scaling.add(colIdx, scaling);
}
/**
* Sets the center of mass of a mesh object at a certain index position. Needed to restore original mesh positions after
* a shift to zero-mean.
* @param colIdx List-index to be written.
* @param centOfMass Value to be written at list-index.
*/
public void setCenterOfMassAtIndex(int colIdx, PointND centOfMass){
assert(colIdx < this.numPc) : new IllegalArgumentException("Index out of bounds. Add point-cloud first.");
this.centers.set(colIdx, centOfMass);
}
/**
* Sets the center of mass of a mesh object at a certain index position. Needed to restore original mesh positions after
* a shift to zero-mean.
* @param colIdx List-index to be written.
* @param centOfMass Value to be written at list-index.
*/
public void setScalingAtIndex(int colIdx, float scaling){
assert(colIdx < this.numPc) : new IllegalArgumentException("Index out of bounds. Add point-cloud first.");
this.scaling.set(colIdx, scaling);
}
/**
* This method calculates the centroid and scaling-factor for all point-clouds stored in the list. The centroid and scaling-factor
* are stored in the corresponding class members. The point-clouds centroid is then shifted to the origin and scaling is applied.
*/
private void shiftMeanToOriginAndScale(){
assert(numPoints > 0 || dimension > 0) : new Exception("No data for GPA found. Did you add the datasets?");
checkPointList();
this.centers = new ArrayList<PointND>(numPc);
this.scaling = new ArrayList<Float>(numPc);
double val;
System.out.println("Calculating centroid and scaling for each point-cloud.");
for(int k = 0; k < numPc; k++){
SimpleMatrix cloud = pointList.get(k);
SimpleVector mean = new SimpleVector(dimension);
// calculate mean value
float[] mAcc = new float[3];
for(int i = 0; i < numPoints; i++){
for(int j = 0; j < dimension; j++){
val = cloud.getElement(i, j);
mAcc[j] += val;
}
}
for(int j = 0; j < dimension; j++){
mean.setElementValue(j, mAcc[j]/numPoints);
}
addCenterOfMassAtIndex(k, new PointND(mean));
// calculate scale value using mean value
float sAcc = 0;
for(int i = 0; i < numPoints; i++){
for(int j = 0; j < dimension; j++){
val = cloud.getElement(i, j);
sAcc += Math.pow((val - mean.getElement(j)), 2);
}
}
float scaling = (float)Math.sqrt(sAcc);
addScalingAtIndex(k, scaling);
// updated point-cloud
for(int i = 0; i < numPoints; i++){
for(int j = 0; j < dimension; j++){
val = cloud.getElement(i, j);
cloud.setElementValue(i, j, (val - mean.getElement(j)) / scaling);
}
}
pointList.set(k,cloud);
}
}
/**
* Checks if all initialized list entries have been filled with point-clouds.
*/
private void checkPointList(){
assert (pointList.size() == numPc) : new Exception("Not all initialized point-clouds set.");
}
/**
* Computes the rotation matrix for the RMS-norm-minimal rotational registration of consensus and point-cloud at index idx.
* Then calls the method to store the rotation matrix in the class member list and rotate the point-cloud.
* Uses a singular value decomposition.
* @param idx The index of the point-cloud.
*/
private void getRotationMatrixAndRotate(int idx){
SimpleMatrix m1 = consensus;
SimpleMatrix m2 = pointList.get(idx);
// create matrix containing information about both point-clouds m1^T * m2
SimpleMatrix m1Tm2 = SimpleOperators.multiplyMatrixProd(m1.transposed(), m2);
// perform SVD such that:
// m1^T * m2 = U sigma V^T
DecompositionSVD svd = new DecompositionSVD(m1Tm2, true, true, true);
// exchange sigma with new matrix s having only +/- 1 as singular values
// this allows only for rotations but no scaling, e.g. sheer
// signum is the same as in sigma, hence reflections are still taken into account
int nColsS = svd.getS().getCols();
SimpleMatrix s = new SimpleMatrix(nColsS,nColsS);
for(int i = 0; i < nColsS; i++){
s.setElementValue(i, i, Math.signum(svd.getSingularValues()[i]));
}
// calculate rotation matrix such that:
// H = V s U^T
SimpleMatrix h = SimpleOperators.multiplyMatrixProd(svd.getV(), SimpleOperators.multiplyMatrixProd(s, svd.getU().transposed()));
rotatePointCloudAtIndex(idx, h);
}
/**
* This method rotates the point-cloud at a certain index. The rotation has to be pre-computed and expressed as a matrix.
* The rotation matrix will be stored in the corresponding class member in order to be able to restore the original point-cloud.
* @param idx Point-cloud index in list.
* @param rotation Rotation matrix to be applied.
*/
private void rotatePointCloudAtIndex(int idx, SimpleMatrix rotation){
// decide whether it's the first rotation or a follow up
// first entries have to be added, follow ups need to be set
if(rotations.size() != numPc){
this.rotations.add(rotation);
}else{
this.rotations.set(idx, rotation);
}
// rotate the point-cloud
this.pointList.set(idx, SimpleOperators.multiplyMatrixProd(pointList.get(idx), rotation));
}
/**
* Calculates the initial residual of the conensus point-cloud in the sense of J. C. Gower.
* @return
*/
private float getInitialResidual(){
SimpleMatrix res = SimpleOperators.multiplyMatrixProd(consensus.transposed(), consensus);
float residual = numPc * (1 - trace(res));
return residual;
}
/**
* Calculates the residual update for all iterations except the initial one.
* @param oldConsensus The old consensus point-cloud.
* @return The change in residual.
*/
private float getResidual(SimpleMatrix oldConsensus){
SimpleMatrix M1 = SimpleOperators.multiplyMatrixProd(consensus.transposed(), consensus);
SimpleMatrix M2 = SimpleOperators.multiplyMatrixProd(oldConsensus.transposed(), oldConsensus);
float residual = numPc * (trace(M1) - trace(M2));
return residual;
}
/**
* Updates the consensus point-cloud. The new consensus is the mean of all point-clouds stored in the list.
*/
private void updateConsensus(){
SimpleMatrix cons = new SimpleMatrix(numPoints,dimension);
for(int k = 0; k < numPc; k++){
SimpleMatrix pc = pointList.get(k);
for(int i = 0; i < numPoints; i++){
for(int j = 0; j < dimension; j++){
cons.addToElement(i, j, pc.getElement(i, j)/numPc);
}
}
}
this.consensus = cons;
}
/**
* Rescales the point-clouds with the new rho.
*/
private void rescalePointClouds(){
for(int k = 0; k < numPc; k++){
SimpleMatrix pc = pointList.get(k);
float rho = this.rho.get(k);
for(int i = 0; i < numPoints; i++){
for(int j = 0; j < dimension; j++){
pc.multiplyElementBy(i, j, rho);
}
}
this.pointList.set(k, pc);
}
}
/**
* Initializes the GPA scaling rho with all ones.
*/
private void initializeRho(){
this.rho = new ArrayList<Float>(numPc);
for(int i = 0; i < numPc; i++){
this.rho.add(1f);
}
}
/**
* Calculates the new scaling rho update using the new consensus object. The updates rho is given as
* ratio of the new rho divided by the old rho value to compensate for the old scaling already being
* applied to the point-cloud.
*/
private void updateRho(){
float traceConsensus = trace(SimpleOperators.multiplyMatrixProd(consensus.transposed(), consensus));
for(int k = 0; k < numPc; k++){
SimpleMatrix cloud = pointList.get(k);
float traceCloud = trace(SimpleOperators.multiplyMatrixProd(cloud.transposed(), cloud));
float traceCloudCons = trace(SimpleOperators.multiplyMatrixProd(cloud.transposed(), consensus));
float rhoUpdate = (float)Math.sqrt(traceCloudCons / (traceConsensus * traceCloud));
this.rho.set(k, rhoUpdate);
//this.scaling.set(k, this.scaling.get(k)/rhoUpdate);
//this.rho.set(k, 1f);
}
}
/**
* Calculates the trace of a square matrix.
* @param m The square Matrix.
* @return The trace of the matrix.
*/
private float trace(SimpleMatrix m){
assert(m.isSquare()) : new IllegalArgumentException("Trace not defined for non-square matrices.");
float tr = 0;
for(int i = 0; i < m.getRows(); i++){
tr += m.getElement(i, i);
}
return tr;
}
/**
* Calculates a scaled and centroid-shifted version of the consensus object.
* The scaling is the mean scaling of the input point-clouds.
* The centroid is the mean of the centroids.
* @return A SimpleMatrix containing the consensus point-cloud.
*/
public SimpleMatrix getScaledAndShiftedConsensus(){
SimpleMatrix m = new SimpleMatrix(numPoints,dimension);
// calculate mean-scale and mean of centroids
float upScale = 0;
SimpleVector meanCentroid = new SimpleVector(dimension);
for(int i = 0; i < numPc; i++){
upScale += scaling.get(i) / numPc;
for(int j = 0; j < dimension; j++){
meanCentroid.addToElement(j, centers.get(i).get(j) / numPc);
}
}
// apply scaling
m = consensus.multipliedBy(upScale);
// shift by mean centroid
for(int j = 0; j < dimension; j++){
float val = (float) meanCentroid.getElement(j);
for(int i = 0; i < numPoints; i++){
m.addToElement(i, j, val);
}
}
return m;
}
/**
* Calculats the up-scaled and centroid shifted version of the point-cloud at index <idx> after
* Generalized Procrustes Alignment.
* @param idx Index of point-cloud to be processed.
* @return The up-scaled and centroid shifted point-cloud.
*/
public SimpleMatrix getScaledAndShiftedPointCloud(int idx){
SimpleMatrix m = new SimpleMatrix(numPoints,dimension);
// apply scaling
m = pointList.get(idx).multipliedBy(scaling.get(idx));
// shift by mean centroid
for(int j = 0; j < dimension; j++){
float val = (float) centers.get(idx).get(j);
for(int i = 0; i < numPoints; i++){
m.addToElement(i, j, val);
}
}
return m;
}
/**
* Sets the connectivity information in case it exists for the point-clouds.
* @param con The connectivity information.
*/
public void setConnectivity(SimpleMatrix con){
this.connectivity = con;
}
/**
* Calculates and returns the mean center of mass of all point-clouds.
* @return The mean center of mass.
*/
public SimpleVector getMeanCenter(){
SimpleVector mC = new SimpleVector(dimension);
for(int i = 0; i < centers.size(); i++){
mC.add(centers.get(i).getAbstractVector());
}
return mC.dividedBy(centers.size());
}
}
/*
* Copyright (C) 2010-2014 Mathias Unberath
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/