package com.spbsu.exp.multiclass;
import com.spbsu.commons.math.Trans;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.MxTools;
import com.spbsu.commons.math.vectors.Vec;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx;
import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec;
import com.spbsu.commons.random.FastRandom;
import com.spbsu.commons.util.Pair;
import com.spbsu.ml.GridTools;
import com.spbsu.ml.cli.output.printers.MulticlassProgressPrinter;
import com.spbsu.ml.data.set.VecDataSet;
import com.spbsu.ml.data.tools.DataTools;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.data.tools.SubPool;
import com.spbsu.ml.factorization.Factorization;
import com.spbsu.ml.factorization.impl.ALS;
import com.spbsu.ml.factorization.impl.ElasticNetFactorization;
import com.spbsu.ml.factorization.impl.SVDAdapterEjml;
import com.spbsu.ml.factorization.impl.StochasticALS;
import com.spbsu.ml.func.Ensemble;
import com.spbsu.ml.func.FuncJoin;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.LogL2;
import com.spbsu.ml.loss.SatL2;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.meta.FeatureMeta;
import com.spbsu.ml.meta.impl.fake.FakeTargetMeta;
import com.spbsu.ml.methods.GradientBoosting;
import com.spbsu.ml.methods.MultiClass;
import com.spbsu.ml.methods.multiclass.gradfac.FMCBoosting;
import com.spbsu.ml.methods.multiclass.gradfac.GradFacMulticlass;
import com.spbsu.ml.methods.multiclass.gradfac.GradFacSvdNMulticlass;
import com.spbsu.ml.methods.multiclass.gradfac.MultiClassColumnBootstrapOptimization;
import com.spbsu.ml.methods.trees.GreedyObliviousTree;
import com.spbsu.ml.models.MultiClassModel;
import com.spbsu.ml.testUtils.TestResourceLoader;
import junit.framework.TestCase;
import java.io.IOException;
/**
* User: qdeee
* Date: 27.02.15
*/
public class GradFacTest extends TestCase {
private static Pool<?> learn;
private static Pool<?> test;
private synchronized static void init() throws IOException {
if (learn == null || test == null) {
final Pool<?> pool = TestResourceLoader.loadPool("multiclass/ds_letter/letter.tsv.gz");
pool.addTarget(new FakeTargetMeta(pool.vecData(), FeatureMeta.ValueType.INTS),
VecTools.toIntSeq(pool.target(L2.class).target)
);
final int[][] idxs = DataTools.splitAtRandom(pool.size(), new FastRandom(100500), 0.8, 0.5);
learn = new SubPool<>(pool, idxs[0]);
test = new SubPool<>(pool, idxs[1]);
}
}
@Override
protected void setUp() throws Exception {
init();
}
public void testGradMxApproxSVDN() throws Exception {
final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class);
final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim()));
double time = System.currentTimeMillis();
int factorDim = 1;
// for (int factorDim = gradient.columns(); factorDim >= 1; factorDim--)
{
final Pair<Vec, Vec> pair = new SVDAdapterEjml(factorDim).factorize(gradient);
final Mx h = (Mx) pair.getFirst();
final Mx b = (Mx) pair.getSecond();
System.out.println("factor dim: " + factorDim);
System.out.println("time: " + ((System.currentTimeMillis() - time) / 1000));
final Mx afterFactor = MxTools.multiply(h, MxTools.transpose(b));
System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(gradient, afterFactor) + ", l1 = " + VecTools.distanceL1(gradient, afterFactor));
System.out.println();
}
}
public void testElasticNetGradFac() throws Exception {
final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class);
final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim()));
final ElasticNetFactorization elasticNetFactorization = new ElasticNetFactorization(20, 1e-2, 0.95, 0.15 * 1e-6);
final Pair<Vec, Vec> pair = elasticNetFactorization.factorize(gradient);
final Vec h = pair.getFirst();
final Vec b = pair.getSecond();
final Mx afterFactor = VecTools.outer(h, b);
System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(gradient, afterFactor) + ", l1 = " + VecTools.distanceL1(gradient, afterFactor));
System.out.println();
}
private static class ParameterCollector {
double lambda;
double alpha;
double l2;
double l1;
public ParameterCollector(final double lambda, final double alpha, final double l2, final double l1) {
this.lambda = lambda;
this.alpha = alpha;
this.l2 = l2;
this.l1 = l1;
}
@Override
public String toString() {
return "ParameterCollector{" +
"lambda=" + lambda +
", alpha=" + alpha +
", l2=" + l2 +
", l1=" + l1 +
'}';
}
}
public void testElasticNetGradFacGridSearch() throws Exception {
final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class);
final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim()));
ParameterCollector minL1ParameterCollector = new ParameterCollector(0, 0, Double.MAX_VALUE, Double.MAX_VALUE);
ParameterCollector minL2ParameterCollector = new ParameterCollector(0, 0, Double.MAX_VALUE, Double.MAX_VALUE);
for (double lambda = 0.15 * 1e-7; lambda < 1e-4; lambda += 1e-6) {
for (double alpha = 0.1; alpha < 1.0; alpha += 0.01) {
final ElasticNetFactorization elasticNetFactorization = new ElasticNetFactorization(20, 1e-2, 0.95, 0.15 * 1e-6);
final Pair<Vec, Vec> pair = elasticNetFactorization.factorize(gradient);
final Vec h = pair.getFirst();
final Vec b = pair.getSecond();
final Mx afterFactor = VecTools.outer(h, b);
final double l2 = VecTools.distance(gradient, afterFactor);
final double l1 = VecTools.distanceL1(gradient, afterFactor);
if (l2 < minL2ParameterCollector.l2) {
minL2ParameterCollector = new ParameterCollector(lambda, alpha, l2, l1);
}
if (l1 < minL1ParameterCollector.l1) {
minL1ParameterCollector = new ParameterCollector(lambda, alpha, l2, l1);
}
}
}
System.out.println(minL2ParameterCollector.toString());
System.out.println(minL1ParameterCollector.toString());
}
public void testSimpleMx() throws Exception {
final Mx mx = genUniformRandMx(5, 3, 100500);
final ElasticNetFactorization elasticNetFactorization = new ElasticNetFactorization(1, 1e-1, 0.0, 0.0);
final Pair<Vec, Vec> pair = elasticNetFactorization.factorize(mx);
final Vec h = pair.getFirst();
final Vec b = pair.getSecond();
final Mx afterFactor = VecTools.outer(h, b);
System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(mx, afterFactor) + ", l1 = " + VecTools.distanceL1(mx, afterFactor));
}
private static Mx genUniformRandMx(final int m, final int n, final int seed) {
final Mx mx = new VecBasedMx(m, n);
final FastRandom fastRandom = new FastRandom(seed);
for (int i = 0; i < mx.dim(); i++) {
mx.set(i, fastRandom.nextDouble());
}
return mx;
}
public void testDifferentMethods() throws Exception {
final Mx mx = genUniformRandMx(500, 300, 100500);
applyFactorMethod(mx, new ALS(15));
applyFactorMethod(mx, new SVDAdapterEjml());
final double lambda = 0.0015;
applyFactorMethod(mx, new ElasticNetFactorization(20, 1e-4, 0.5, lambda));
}
public void testGradFacBaseline() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new GradFacMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
new SVDAdapterEjml(1),
LogL2.class
),
L2.class,
400,
7
);
fitModel(boosting);
}
public void testGradFacSALS() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new GradFacMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
new StochasticALS(new FastRandom(0), 100),
LogL2.class
),
L2.class,
400,
7
);
fitModel(boosting);
}
public void testFMCBoostSALS() throws Exception {
final FMCBoosting boosting = new FMCBoosting(
new StochasticALS(new FastRandom(0), 100),
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
L2.class,
400,
7
);
fitModel(boosting);
}
public void testGradFacElasticNet() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new GradFacMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
new ElasticNetFactorization(1, 1., 1., 1.),
SatL2.class
),
L2.class,
400,
7
);
fitModel(boosting);
}
public void testGradFacSVDNColumnsBootstrap() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new MultiClassColumnBootstrapOptimization(
new GradFacSvdNMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
LogL2.class,
2
),
new FastRandom(100500),
1.
),
L2.class,
5000,
5
);
fitModel(boosting);
}
public void testGradFacColumnsBootstrap() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new MultiClassColumnBootstrapOptimization(
new GradFacMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
new SVDAdapterEjml(1),
SatL2.class
), new FastRandom(),
1.
),
L2.class,
7500,
7
);
fitModel(boosting);
}
public void testGradFacElasticNetColumnsBootstrap() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new MultiClassColumnBootstrapOptimization(
new GradFacMulticlass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
new ElasticNetFactorization(20, 1e-2, 0.95, 0.15 * 1e-6),
LogL2.class,
true
),
new FastRandom(100500),
1.
),
L2.class,
5000,
7
);
fitModel(boosting);
}
public void testBaseline() throws Exception {
final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(
new MultiClass(
new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5),
LogL2.class
),
L2.class,
400,
0.3
);
fitModel(boosting);
}
private void fitModel(final GradientBoosting<BlockwiseMLLLogit> boosting) {
final VecDataSet vecDataSet = learn.vecData();
final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class);
final MulticlassProgressPrinter multiclassProgressPrinter = new MulticlassProgressPrinter(learn, test);
boosting.addListener(multiclassProgressPrinter);
final Ensemble ensemble = boosting.fit(vecDataSet, globalLoss);
final FuncJoin joined = MCTools.joinBoostingResult(ensemble);
final MultiClassModel multiclassModel = new MultiClassModel(joined);
final String learnResult = MCTools.evalModel(multiclassModel, learn, "[LEARN] ", false);
final String testResult = MCTools.evalModel(multiclassModel, test, "[TEST] ", false);
System.out.println(learnResult);
System.out.println(testResult);
}
private void fitModel(final FMCBoosting boosting) {
final VecDataSet vecDataSet = learn.vecData();
final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class);
final MulticlassProgressPrinter multiclassProgressPrinter = new MulticlassProgressPrinter(learn, test);
boosting.addListener(multiclassProgressPrinter);
final Ensemble ensemble = boosting.fit(vecDataSet, globalLoss);
final Trans joined = ensemble.last() instanceof FuncJoin ? MCTools.joinBoostingResult(ensemble) : ensemble;
final MultiClassModel multiclassModel = new MultiClassModel(joined);
final String learnResult = MCTools.evalModel(multiclassModel, learn, "[LEARN] ", false);
final String testResult = MCTools.evalModel(multiclassModel, test, "[TEST] ", false);
System.out.println(learnResult);
System.out.println(testResult);
}
private static void applyFactorMethod(final Mx x, final Factorization method) {
final Pair<Vec, Vec> pair = method.factorize(x);
final Vec h = pair.getFirst();
final Vec b = pair.getSecond();
final double normB = VecTools.norm(b);
VecTools.scale(b, 1 / normB);
VecTools.scale(h, normB);
final Mx afterFactor = VecTools.outer(h, b);
System.out.println(method.getClass().getSimpleName() + ": ||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(x, afterFactor) + ", l1 = " + VecTools.distanceL1(x, afterFactor));
}
}