/*
* Copyright (C) 2014 Andreas Maier
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/
package edu.stanford.rsl.tutorial.weka;
import java.util.ArrayList;
import weka.classifiers.functions.LinearRegression;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
/**
* This is a short example on how to use the Weka API. We will generate a training set and a test set.
* Then we Train a Linear Regression Classifier and test it on the test feature vector set.
* <br><br>
* This is a general introduction to Weka API. Do not use this for image feature extraction.<br>
*
* @author akmaier
*
*/
public class RegressionExample {
/**
* This is the main routine
* @param args
*/
public static void main(String[] args) {
// a set of feature vectors is stored in an Instances object in weka.
// in oder to create such an object, we first have to create a list of features called attributes in weks.
// in this example we generate 10 random real valued features plus a real valued class attribute.
int numFeatures = 10;
ArrayList<Attribute> attribs = new ArrayList<Attribute>(numFeatures+1);
// generate 10 features and add them to the list of features.
for (int i =0;i<numFeatures;i++){
String nameString = "Feature " + i;
attribs.add(new weka.core.Attribute(nameString));
}
// generate a real valued class attribute.
Attribute classAttribute = new Attribute("Class");
// add to the list of features.
attribs.add(classAttribute);
// create 10000 random training vectors
int vectors = 10000;
Instances trainingSet = new Instances("Training Set", attribs, vectors);
for (int j = 0; j < vectors; j++){
double [] vector = new double [numFeatures+1];
for (int i =0;i<numFeatures;i++){
vector[i]=Math.random();
}
vector [numFeatures] = (int) (Math.random() *1.99);
trainingSet.add(new DenseInstance(1.0, vector));
}
trainingSet.setClass(classAttribute);
// create 10000 random test vectors
Instances testSet = new Instances("Test Set", attribs, vectors);
for (int j = 0; j < vectors; j++){
double [] vector = new double [numFeatures+1];
for (int i =0;i<numFeatures;i++){
vector[i]=Math.random();
}
vector [numFeatures] = (int) (Math.random() *1.99);
testSet.add(new DenseInstance(1.0, vector));
}
testSet.setClass(classAttribute);
try {
// Train Classifier
LinearRegression frf = new LinearRegression();
frf.buildClassifier(trainingSet);
// Evaluate Classifier:
double mse = 0;
for (int j = 0; j < vectors; j++){
double prediction = frf.classifyInstance(testSet.instance(j));
mse += Math.pow(prediction - testSet.instance(j).value(numFeatures),2);
}
// Report recognition rate
System.out.println("Mean square error (Should be about 0.5): " + (double)Math.sqrt(mse/vectors));
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}