/*
* Copyright (C) 2014 Andreas Maier
* CONRAD is developed as an Open Source project under the GNU General Public License (GPL).
*/
package edu.stanford.rsl.tutorial.weka;
import java.util.ArrayList;
import hr.irb.fastRandomForest.FastRandomForest;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;
/**
* This is a short example on how to use the Weka API. We will generate a training set and a test set.
* Then we Train a Fast Random Forest Classifier and test it on the test feature vector set.
* <br><br>
* This is a general introduction to Weka API. Do not use this for image feature extraction.<br>
*
* @author akmaier
*
*/
public class ClassificationExample {
/**
* This is the main routine
* @param args
*/
public static void main(String[] args) {
// a set of feature vectors is stored in an Instances object in weka.
// in oder to create such an object, we first have to create a list of features called attributes in weks.
// in this example we generate 10 random real valued features plus a nominal class attribute with the classe "Class One" and "Class Two".
int numFeatures = 10;
ArrayList<Attribute> attribs = new ArrayList<Attribute>(numFeatures+1);
// generate 10 features and add them to the list of features.
for (int i =0;i<numFeatures;i++){
String nameString = "Feature " + i;
attribs.add(new weka.core.Attribute(nameString));
}
// generate a list of nominal classes for the class attribute.
ArrayList<String> classValues = new ArrayList<String>(2);
classValues.add("Class one");
classValues.add("Class two");
// create the class attribute.
Attribute classAttribute = new Attribute("Class", classValues);
// add to the list of features.
attribs.add(classAttribute);
// create 10000 random training vectors
int vectors = 10000;
Instances trainingSet = new Instances("Training Set", attribs, vectors);
for (int j = 0; j < vectors; j++){
double [] vector = new double [numFeatures+1];
for (int i =0;i<numFeatures;i++){
vector[i]=Math.random();
}
vector [numFeatures] = (int) (Math.random() *1.99);
trainingSet.add(new DenseInstance(1.0, vector));
}
trainingSet.setClass(classAttribute);
// create 10000 random test vectors
Instances testSet = new Instances("Test Set", attribs, vectors);
for (int j = 0; j < vectors; j++){
double [] vector = new double [numFeatures+1];
for (int i =0;i<numFeatures;i++){
vector[i]=Math.random();
}
vector [numFeatures] = (int) (Math.random() *1.99);
testSet.add(new DenseInstance(1.0, vector));
}
testSet.setClass(classAttribute);
try {
// Train Classifier
FastRandomForest frf = new FastRandomForest();
frf.buildClassifier(trainingSet);
// Evaluate Classifier:
int hit = 0;
for (int j = 0; j < vectors; j++){
double prediction = frf.classifyInstance(testSet.instance(j));
if (prediction == testSet.instance(j).value(numFeatures)) hit++;
}
// Report recognition rate
System.out.println("Correct classifications (Should be about 50 %): " + (double)hit*100/vectors + " %");
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}