package edu.stanford.rsl.conrad.optimization;
import edu.stanford.rsl.conrad.numerics.DecompositionSVD;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix;
import edu.stanford.rsl.conrad.numerics.SimpleOperators;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
/**
* This class implements a solver for a linear system of equations with additional regularization of the solution vector favoring smaller norms.
* The regularization is controlled using a Lagrange multiplier lambda. If lambda is set to 0, this class will produce a non-regularized least squares
* solution.
* Solves: min_x |A * x - b|^2 + lambda * |x|^2,
* which be reformulated into |A' * x - b'| using
* A' = (A , sqrt(lambda) * 1_n)^T, b' + (b , 0_n)^T
* Uses a singular value decomposition to solve the final system of equations.
* @author Mathias Unberath
*
*/
public class LSqMinNorm {
/**
* The coefficients multiplied to the columns of the matrix in the linear system of equations.
*/
private SimpleVector x;
/**
* The right hand side of the system of equations.
*/
private SimpleVector b;
/**
* The matrix containing the parameters.
*/
private SimpleMatrix a;
/**
* Lagrange Multiplier for the minimal norm regularization.
*/
private double lambda = 0.005;
/**
* Number of columns in the SimpleMatrix.
*/
private int nCol;
/**
* Number of rows in the SimpleMatrix.
*/
private int nRow;
//==========================================================================================
// METHODS
//==========================================================================================
/**
* Constructs the solver object and calls the solving method on the lineaer system of equations A * x = b subject to |x|^2 -> min.
* @param a The matrix containing all coefficients.
* @param b The right hand side of the system of equations.
*/
public LSqMinNorm(SimpleMatrix a, SimpleVector b){
assert(a.getRows() == b.getLen());
this.a = a;
this.b = b;
this.nCol = a.getCols();
this.nRow = a.getRows();
this.x = new SimpleVector(nCol);
solve();
}
/**
* Solves the system of equations under the constraint of minimal norm of the solution vector.
*/
private void solve(){
// construct new matrix containing the original matrix and then a nCol identity matrix
SimpleMatrix aPrime = new SimpleMatrix(nRow + nCol, nCol);
aPrime.setSubMatrixValue(0, 0, a);
for(int i = 0; i < nCol; i++){
aPrime.setElementValue(nRow + i, i, Math.sqrt(lambda));
}
// construct new right hand side using original right hand side and nCol entries being 0
SimpleVector bPrime = new SimpleVector(nRow + nCol);
bPrime.setSubVecValue(0, b);
DecompositionSVD svd = new DecompositionSVD(aPrime);
SimpleMatrix aPrInv = SimpleOperators.multiplyMatrixProd(svd.getV(), SimpleOperators.multiplyMatrixProd(svd.getreciprocalS(), svd.getU().transposed()));
SimpleVector sol = SimpleOperators.multiply(aPrInv, bPrime);
this.x = sol;
}
/**
* Calculates the root mean square error of the fit using the L2 norm.
* @return The error.
*/
public double getRmsError(){
SimpleVector opt = new SimpleVector(b.getLen());
for(int i = 0; i < nCol; i++){
opt.add(a.getCol(i).multipliedBy(x.getElement(i)));
}
opt.subtract(b);
return opt.normL2() / opt.getLen();
}
/**
* Getter for the coefficients solving the system of equations.
* @return The coefficients solving the system.
*/
public double[] getSolution(){
return this.x.copyAsDoubleArray();
}
/**
* Setter for the Lagrange multiplier used for regularization.
* @param lambda The multiplier.
*/
public void setLambda(double lambda){
this.lambda = lambda;
}
/**
* Getter for the Lagrange Multiplier used for regularization.
* @return The multiplier.
*/
public double getLambda(){
return this.lambda;
}
}
/*
* Copyright (C) 2010-2014 Mathias Unberath
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/