package edu.stanford.rsl.tutorial.dmip;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix.InversionType;
import edu.stanford.rsl.conrad.numerics.SimpleOperators;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
import ij.gui.Plot;
import java.util.ArrayList;
/**
* Exercise 4 of Diagnostic Medical Image Processing (DMIP)
* @author Bastian Bier
*
*/
public class RANSAC {
/**
* Function calculation the ransac model estimate
*
* @param points The point cloud where the line should be fit through
* @param minParam The number of parameters required for the model
* @param p_opt The probability of picking the right points for the models during the iterations
* @param p_out The probability of an outliner
*
* @return solution the parameters of the resulting RANSAC line in a SimpleVector: [m,c]
*/
public SimpleVector commonRansac(SimpleMatrix points, int mn, double p_opt, double p_out){
// Calculating the amount of required iterations
int it = (int) (Math.log(1 - p_opt) / Math.log(1 - Math.pow(1 - p_out, mn)) + 1 );
// Error of the best fitted line
double error = Double.POSITIVE_INFINITY;
// Solution vector
SimpleVector solution = new SimpleVector(mn);
for(int i = 0; i < it; i++)
{
// Select mn random points
// Calculate the indexes of the points
ArrayList<Integer> indexes = new ArrayList<Integer>();
int randIdx = (int) (Math.random() / (1.f / points.getRows()));
indexes.add(randIdx);
for(int n = 1; n < mn; n++){
randIdx = (int) (Math.random() / (1.f / points.getRows()));
while(indexes.contains(randIdx))
{
randIdx = (int) (Math.random() / (1.f / points.getRows()));
}
indexes.add(randIdx);
}
// Calculate the parameters
SimpleMatrix a = new SimpleMatrix(mn,mn);
for(int n = 0; n < mn; n++){
a.setRowValue(n, points.getRow(indexes.get(n)));
}
SimpleVector lineParams= new SimpleVector(mn);
// TODO: estimate the line parameters for the selected points
// Calculate the error of the estimated line
// update the error and the parameters, if the current line has a smaller error
double cur_err = 0.0;
// TODO: calculate the error of the current line
if(cur_err < error)
{
error = cur_err;
solution = lineParams;
}
}
return solution;
}
/**
* Function calculating a line through a point cloud using the SVD
*
* @param points The point cloud where the line should be fit through
* 2 points result in a exact line
* >2 points result in a regression line
*
* @return x_result the parameters of the line in a SimpleVector: [m,c]
*/
public SimpleVector fitline(SimpleMatrix points){
// Build up the measurement matrix
SimpleMatrix m = new SimpleMatrix(points.getRows(),2);
SimpleVector b = new SimpleVector(points.getRows());
m.fill(1);
for(int i = 0; i < points.getRows(); i++){
m.setElementValue(i, 0, points.getElement(i, 0));
b.setElementValue(i, points.getElement(i, 1));
}
// Solution vector containing the estimated parameters m and c
SimpleVector x_result = new SimpleVector(2);
// Calculate the parameters using the Pseudo-Inverse
// TODO: calculate the line parameters, write them in x_result
return x_result;
}
/**
* Calculate the error of a line
*
* @param line_params Parameters of the line
* @param points The point cloud where the line should be fit through.
*
* @return error the calculated error
*/
public double lineError(SimpleVector line_params, SimpleMatrix points){
// Threshold defining the allowed distance of a point to the line
double thresh = 0.2;
// TODO: line parameters
// TODO: get some point on the line
// TODO: calculate normal vector of the line
// TODO: calculate distance line to origin
// TODO: calculate the distance for each point to the line
// TODO: check if the distance is higher than the threshold
// TODO: return the error
return 0;
}
public static void main(String[] args) {
//
RANSAC ransac = new RANSAC();
//
// The point cloud is defined
//
SimpleMatrix pts = new SimpleMatrix(7,2);
pts.setRowValue(0, new SimpleVector(0,0));
pts.setRowValue(1, new SimpleVector(1,1));
pts.setRowValue(2, new SimpleVector(2,2));
pts.setRowValue(3, new SimpleVector(3,3));
pts.setRowValue(4, new SimpleVector(3.2,1.9));
pts.setRowValue(5, new SimpleVector(4,4));
pts.setRowValue(6, new SimpleVector(10,1.8));
//
// Regression Line
//
// Create a scatter plot of the point cloud and fit a regression line
Plot scatterPlot = new Plot("Regression Line", "X", "Y", Plot.DEFAULT_FLAGS);
scatterPlot.setLimits(0, 11, 0, 5);
scatterPlot.addPoints(pts.getCol(0).copyAsDoubleArray(), pts.getCol(1).copyAsDoubleArray(), Plot.BOX);
scatterPlot.show();
// Calculate the regression line through the given point cloud
SimpleVector regressionLine = ransac.fitline(pts);
// Add the regression line
double y11 = regressionLine.getElement(0) * 11 + regressionLine.getElement(1);
double y0 = regressionLine.getElement(0) * 0 + regressionLine.getElement(1);
scatterPlot.drawLine(0, y0, 11, y11);
//
// RANSAC
//
// Parameters for RANSAC
double p_opt = 0.9999; // probability how likely it is to pick the right mn points
double p_out = 0.2; // probability of an outlier
int min_number = 2; // minimum number of datapoints required to build the model
// Create a scatter plot of the point cloud and fit a RANSAC line
Plot ransac_plot = new Plot("Ransac Line", "X", "Y", Plot.DEFAULT_FLAGS);
ransac_plot.setLimits(0, 11, 0, 5);
ransac_plot.addPoints(pts.getCol(0).copyAsDoubleArray(), pts.getCol(1).copyAsDoubleArray(), Plot.BOX);
// Compute a line using the RANSAC algorithm and plot it
SimpleVector ransacLine = ransac.commonRansac(pts, min_number, p_opt, p_out);
double y1 = ransacLine.getElement(0) * 0 + ransacLine.getElement(1);
double y2 = ransacLine.getElement(0) * 11 + ransacLine.getElement(1);
ransac_plot.drawLine(0, y1, 11, y2);
ransac_plot.show();
}
}