package com.spbsu.region;
import com.spbsu.commons.math.vectors.Vec;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec;
import com.spbsu.commons.random.FastRandom;
import com.spbsu.commons.util.ThreadTools;
import com.spbsu.commons.math.Func;
import com.spbsu.ml.GridTools;
import com.spbsu.ml.ProgressHandler;
import com.spbsu.commons.math.Trans;
import com.spbsu.ml.data.set.VecDataSet;
import com.spbsu.ml.data.tools.DataTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.func.Ensemble;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.L2GreedyTDRegion;
import com.spbsu.ml.methods.GradientBoosting;
import com.spbsu.ml.methods.greedyRegion.RegionForest;
import gnu.trove.list.array.TDoubleArrayList;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;
/**
* Created by noxoomo on 11/07/14.
*/
class ScoresCalcer {
static FastRandom random = new FastRandom(10);
public int iterations = 12000;
public double step = 0.004;
private Pool<?> learn;
private Pool<?> validate;
static ThreadPoolExecutor pool = ThreadTools.createBGExecutor("Boosting thread", -1);
public ScoresCalcer(Pool<?> learn, Pool<?> validate) {
this.learn = learn;
this.validate = validate;
}
public double[] run(int tries) {
//alpha 0.075652 and beta 1.164409 s
return run(tries, 0.02, 0.5);
}
public double[] run(int tries, final double alpha, final double beta) {
final double[] scores = new double[tries];
final CountDownLatch latch = new CountDownLatch(tries);
for (int i = 0; i < tries; ++i) {
final long seed = random.nextLong();
final int index = i;
pool.execute(new Runnable() {
@Override
public void run() {
final GradientBoosting<L2> boosting = new GradientBoosting
(new RegionForest<>(GridTools.medianGrid(learn.vecData(), 32), new FastRandom(seed), 5, alpha, beta), L2GreedyTDRegion.class, iterations, step);
// new GradientBoosting
// (new BootstrapOptimization(
// new GreedyTDWeakRegion2<>(GridTools.medianGrid(learn.vecData(), 32), alpha, beta), new FastRandom(seed)), L2GreedyRegion.class, iterations, step);
// new GreedyTDWeakRegion2<>(GridTools.medianGrid(learn.vecData(), 32)), new FastRandom(seed),alpha,beta), L2GreedyRegion.class, iterations, step);
final ScoreCalcer validateListener = new ScoreCalcer(/*"\ttest:\t"*/"\t", validate.vecData(), validate.target(L2.class));
boosting.addListener(validateListener);
final Ensemble ensemble = boosting.fit(learn.vecData(), learn.target(L2.class));
scores[index] = validateListener.min;
System.out.println("Score for run " + validateListener.min);
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
return scores;
}
protected static class ScoreCalcer implements ProgressHandler {
final String message;
final Vec current;
private final VecDataSet ds;
private final L2 target;
public ScoreCalcer(String message, VecDataSet ds, L2 target) {
this.message = message;
this.ds = ds;
this.target = target;
current = new ArrayVec(ds.length());
}
double min = 1e10;
@Override
public void invoke(Trans partial) {
if (partial instanceof Ensemble) {
final Ensemble linear = (Ensemble) partial;
final Trans increment = linear.last();
for (int i = 0; i < ds.length(); i++) {
if (increment instanceof Ensemble) {
current.adjust(i, linear.wlast() * (increment.trans(ds.data().row(i)).get(0)));
} else {
current.adjust(i, linear.wlast() * ((Func) increment).value(ds.data().row(i)));
}
}
} else {
for (int i = 0; i < ds.length(); i++) {
current.set(i, ((Func) partial).value(ds.data().row(i)));
}
}
double curLoss = VecTools.distance(current, target.target) / Math.sqrt(ds.length());
// System.out.print(message + curLoss);
min = Math.min(curLoss, min);
// System.out.print(" minimum = " + min);
}
}
}
public class ExperimentsRunner {
static FastRandom random = new FastRandom(0);
private static Executor pool = ThreadTools.createBGExecutor("Boosting thread", -1);
public static void main(String[] args) {
try {
String learnPath = "features.txt.gz";
String testPath = "featuresTest.txt.gz";
final Pool<?> learn = DataTools.loadFromFeaturesTxt(learnPath);
final Pool<?> validate = DataTools.loadFromFeaturesTxt(testPath);
System.err.println(String.format("Learn size %d\nValidation size: %d", learn.data().length(), validate.data().length()));
switch (args[0]) {
case "Region": {
ScoresCalcer calcer = new ScoresCalcer(learn, validate);
final double[] scores = calcer.run(16);
System.out.println(String.format("Min scores for runs with parametrs iter = %d, step = %f", calcer.iterations, calcer.step));
System.out.println(Arrays.toString(scores));
break;
}
case "RandomSearch": {
int tries = 200;
final int bootstrapRuns = 5;
final double[] meanScores = new double[tries];
final double[] minScores = new double[tries];
final double[] maxScores = new double[tries];
final double[] alphas = new double[tries];
final double[] betas = new double[tries];
final CountDownLatch latch = new CountDownLatch(tries);
for (int i = 0; i < tries; ++i) {
final int index = i;
alphas[i] = random.nextDouble() * 2;
betas[i] = random.nextDouble() * 2;
pool.execute(new Runnable() {
@Override
public void run() {
final ScoresCalcer calcer = new ScoresCalcer(learn, validate);
final double[] scores = calcer.run(bootstrapRuns, alphas[index], betas[index]);
final double[] stat = stats(scores);
String msg = String.format("For alpha %f and beta %f scores are :\n", +alphas[index], betas[index]);
System.out.println(msg + Arrays.toString(scores));
meanScores[index] = stat[0];
minScores[index] = stat[1];
maxScores[index] = stat[2];
latch.countDown();
}
});
}
try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(String.format("alphas: "));
System.out.println(Arrays.toString(alphas));
System.out.println(String.format("betas: "));
System.out.println(Arrays.toString(betas));
System.out.println(String.format("meanScores: "));
System.out.println(Arrays.toString(meanScores));
System.out.println(String.format("minScores: "));
System.out.println(Arrays.toString(minScores));
System.out.println(String.format("maxcores: "));
System.out.println(Arrays.toString(maxScores));
break;
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
public static double[] stats(double[] sample) {
if (sample.length == 0) {
return new double[3];
}
double min;
double max;
min = max = sample[0];
double mean = 0;
double result[] = new double[3];
for (double score : sample) {
min = min < score ? min : score;
max = max > score ? max : score;
mean += score;
}
mean /= sample.length;
result[0] = mean;
result[1] = min;
result[2] = max;
return result;
}
public static double mean(TDoubleArrayList sample) {
double mean = 0;
for (int i = 0; i < sample.size(); ++i)
mean += sample.get(i);
return mean /= sample.size();
}
public static double var(TDoubleArrayList sample) {
double secondMoment = 0;
for (int i = 0; i < sample.size(); ++i) {
double d = sample.get(i);
secondMoment += d * d;
}
secondMoment /= sample.size();
double m = mean(sample);
return secondMoment - m * m;
}
public static double[] stats(TDoubleArrayList sample) {
if (sample.size() == 0) {
return new double[4];
}
double min;
double max;
min = max = sample.get(0);
double mean = 0;
double result[] = new double[4];
double secondMoment = 0;
for (int i = 0; i < sample.size(); ++i) {
double d = sample.get(i);
secondMoment += d * d;
mean += d;
if (d > max) max = d;
if (d < min) min = d;
}
mean /= sample.size();
secondMoment /= sample.size();
result[0] = mean;
result[1] = secondMoment - mean * mean;
result[1] = Math.sqrt(result[1]);
result[2] = min;
result[3] = max;
return result;
}
}