/*
* RapidMiner
*
* Copyright (C) 2001-2014 by RapidMiner and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapidminer.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package liblinear;
import static liblinear.Linear.NL;
import static liblinear.Linear.atof;
import static liblinear.Linear.atoi;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
public class Train {
public static void main( String[] args ) throws IOException {
new Train().run(args);
}
private double bias = 1;
private boolean cross_validation = false;
private String inputFilename;
private String modelFilename;
private int nr_fold;
private Parameter param = null;
private Problem prob = null;
private void do_cross_validation() {
int[] target = new int[prob.l];
//long start, stop;
//start = System.currentTimeMillis();
Linear.crossValidation(prob, param, nr_fold, target);
//stop = System.currentTimeMillis();
//System.out.println("time: " + (stop - start) + " ms");
int total_correct = 0;
for ( int i = 0; i < prob.l; i++ )
if ( target[i] == prob.y[i] ) ++total_correct;
//System.out.printf("correct: %d" + NL, total_correct);
//System.out.printf("Cross Validation Accuracy = %g%%\n", 100.0 * total_correct / prob.l);
}
private void exit_with_help() {
System.out.println("Usage: train [options] training_set_file [model_file]" + NL //
+ "options:" + NL//
+ "-s type : set type of solver (default 1)" + NL//
+ " 0 -- L2-regularized logistic regression" + NL//
+ " 1 -- L2-loss support vector machines (dual)" + NL//
+ " 2 -- L2-loss support vector machines (primal)" + NL//
+ " 3 -- L1-loss support vector machines (dual)" + NL//
+ " 4 -- multi-class support vector machines by Crammer and Singer" + NL//
+ "-c cost : set the parameter C (default 1)" + NL//
+ "-e epsilon : set tolerance of termination criterion" + NL//
+ " -s 0 and 2" + NL//
+ " |f'(w)|_2 <= eps*min(pos,neg)/l*|f'(w0)|_2," + NL//
+ " where f is the primal function, (default 0.01)" + NL//
+ " -s 1, 3, and 4" + NL//
+ " Dual maximal violation <= eps; similar to libsvm (default 0.1)" + NL//
+ "-B bias : if bias >= 0, instance x becomes [x; bias]; if < 0, no bias term added (default 1)" + NL//
+ "-wi weight: weights adjust the parameter C of different classes (see README for details)" + NL//
+ "-v n: n-fold cross validation mode" + NL//
);
System.exit(1);
}
Problem getProblem() {
return prob;
}
double getBias() {
return bias;
}
Parameter getParameter() {
return param;
}
void parse_command_line( String argv[] ) {
int i;
// eps: see setting below
param = new Parameter(SolverType.L2LOSS_SVM_DUAL, 1, Double.POSITIVE_INFINITY);
// default values
bias = 1;
cross_validation = false;
int nr_weight = 0;
// parse options
for ( i = 0; i < argv.length; i++ ) {
if ( argv[i].charAt(0) != '-' ) break;
if ( ++i >= argv.length ) exit_with_help();
switch ( argv[i - 1].charAt(1) ) {
case 's':
param.solverType = SolverType.values()[atoi(argv[i])];
break;
case 'c':
param.setC(atof(argv[i]));
break;
case 'e':
param.setEps(atof(argv[i]));
break;
case 'B':
bias = atof(argv[i]);
break;
case 'w':
++nr_weight;
{
int[] old = param.weightLabel;
param.weightLabel = new int[nr_weight];
System.arraycopy(old, 0, param.weightLabel, 0, nr_weight - 1);
}
{
double[] old = param.weight;
param.weight = new double[nr_weight];
System.arraycopy(old, 0, param.weight, 0, nr_weight - 1);
}
param.weightLabel[nr_weight - 1] = atoi(argv[i - 1].substring(2));
param.weight[nr_weight - 1] = atof(argv[i]);
break;
case 'v':
cross_validation = true;
nr_fold = atoi(argv[i]);
if ( nr_fold < 2 ) {
System.err.print("n-fold cross validation: n must >= 2\n");
exit_with_help();
}
break;
default:
System.err.println("unknown option");
exit_with_help();
}
}
// determine filenames
if ( i >= argv.length ) exit_with_help();
inputFilename = argv[i];
if ( i < argv.length - 1 )
modelFilename = argv[i + 1];
else {
int p = argv[i].lastIndexOf('/');
++p; // whew...
modelFilename = argv[i].substring(p) + ".model";
}
if ( param.eps == Double.POSITIVE_INFINITY ) {
if ( param.solverType == SolverType.L2_LR || param.solverType == SolverType.L2LOSS_SVM ) {
param.setEps(0.01);
} else if ( param.solverType == SolverType.L2LOSS_SVM_DUAL || param.solverType == SolverType.L1LOSS_SVM_DUAL
|| param.solverType == SolverType.MCSVM_CS ) {
param.setEps(0.1);
}
}
}
// read in a problem (in libsvm format)
void readProblem( String filename ) throws IOException {
BufferedReader fp = new BufferedReader(new FileReader(filename));
List<Integer> vy = new ArrayList<Integer>();
List<FeatureNode[]> vx = new ArrayList<FeatureNode[]>();
int max_index = 0;
try {
while ( true ) {
String line = fp.readLine();
if ( line == null ) break;
StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
String token = st.nextToken();
vy.add(atoi(token));
int m = st.countTokens() / 2;
FeatureNode[] x;
if ( bias >= 0 ) {
x = new FeatureNode[m + 1];
} else {
x = new FeatureNode[m];
}
for ( int j = 0; j < m; j++ ) {
int index = atoi(st.nextToken());
double value = atof(st.nextToken());
x[j] = new FeatureNode(index, value);
}
if ( m > 0 ) {
max_index = Math.max(max_index, x[m - 1].index);
}
vx.add(x);
}
prob = new Problem();
prob.bias = bias;
prob.l = vy.size();
prob.n = max_index;
if ( bias >= 0 ) {
prob.n++;
}
prob.x = new FeatureNode[prob.l][];
for ( int i = 0; i < prob.l; i++ ) {
prob.x[i] = vx.get(i);
if ( bias >= 0 ) {
assert prob.x[i][prob.x[i].length - 1] == null;
prob.x[i][prob.x[i].length - 1] = new FeatureNode(max_index + 1, bias);
} else {
assert prob.x[i][prob.x[i].length - 1] != null;
}
}
prob.y = new int[prob.l];
for ( int i = 0; i < prob.l; i++ )
prob.y[i] = vy.get(i);
}
finally {
fp.close();
}
}
private void run( String[] args ) throws IOException {
parse_command_line(args);
readProblem(inputFilename);
if ( cross_validation )
do_cross_validation();
else {
Model model = Linear.train(prob, param);
Linear.saveModel(new File(modelFilename), model);
}
}
}