package com.spbsu.exp; import com.google.gson.JsonArray; import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.google.gson.JsonParser; import com.spbsu.commons.func.Action; import com.spbsu.commons.func.Computable; import com.spbsu.commons.math.vectors.Mx; import com.spbsu.commons.math.vectors.MxTools; import com.spbsu.commons.math.vectors.Vec; import com.spbsu.commons.math.vectors.VecTools; import com.spbsu.commons.math.vectors.impl.mx.RowsVecArrayMx; import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec; import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.util.Pair; import com.spbsu.commons.util.logging.Interval; import com.spbsu.ml.data.set.DataSet; import com.spbsu.ml.data.set.VecDataSet; import com.spbsu.ml.data.set.impl.VecDataSetImpl; import com.spbsu.ml.loss.LLLogit; import com.spbsu.ml.methods.PGMEM; import com.spbsu.ml.models.pgm.ProbabilisticGraphicalModel; import com.spbsu.ml.models.pgm.Route; import com.spbsu.ml.models.pgm.SimplePGM; import com.spbsu.ml.testUtils.TestResourceLoader; import junit.framework.TestCase; import java.io.*; import java.util.ArrayList; import java.util.zip.GZIPInputStream; /** * Created by inikifor on 12.04.14. */ public class PGMEMLogDataTest extends TestCase { private static final int SIZE = 11; private FastRandom rng; private Action<Pair<Integer, Double>> modelValidationListener = new Action<Pair<Integer, Double>>() { @Override public void invoke(Pair<Integer, Double> data) { System.out.println("For top = " + data.first + " average probability: " + data.second); } }; private LogsData learn; private LogsData validate; protected void setUp() throws Exception { rng = new FastRandom(0); learn = new LogsData(new GZIPInputStream(TestResourceLoader.loadResourceAsStream("pgmem/ses_100k_simple_rand1.dat.gz"))); validate = new LogsData(new GZIPInputStream(TestResourceLoader.loadResourceAsStream("pgmem/ses_100k_simple_rand2.dat.gz"))); } public void testMostProbable() throws IOException { VecDataSet dataSet = new VecDataSetImpl(new RowsVecArrayMx(learn.getRoutes()), null); ProbabilisticGraphicalModel model = getModel(dataSet, 100, PGMEM.MOST_PROBABLE_PATH, true); checkModel(model, 3, modelValidationListener); } public void testLaplacePrior() throws IOException { VecDataSet dataSet = new VecDataSetImpl(new RowsVecArrayMx(learn.getRoutes()), null); ProbabilisticGraphicalModel model = getModel(dataSet, 100, PGMEM.LAPLACE_PRIOR_PATH, true); checkModel(model, 3, modelValidationListener); } public void testFreqDensityPrior() throws IOException { VecDataSet dataSet = new VecDataSetImpl(new RowsVecArrayMx(learn.getRoutes()), null); ProbabilisticGraphicalModel model = getModel(dataSet, 100, PGMEM.FREQ_DENSITY_PRIOR_PATH, true); checkModel(model, 3, modelValidationListener); } public void testMostProbablePartial() throws IOException { testPartial(learn, 10, 100, PGMEM.MOST_PROBABLE_PATH); } public void testLaplacePriorPartial() throws IOException { testPartial(learn, 10, 100, PGMEM.LAPLACE_PRIOR_PATH); } public void testFreqDensityPriorPartial() throws IOException { testPartial(learn, 10, 100, PGMEM.FREQ_DENSITY_PRIOR_PATH); } private void testPartial(LogsData ld, int partitionCount, int iterations, Computable<ProbabilisticGraphicalModel, PGMEM.Policy> policy) throws IOException { int stepSize = ld.getRoutes().length / partitionCount; for (int i = 1; i <= partitionCount; i++) { Vec[] part = new Vec[i == partitionCount ? ld.getRoutes().length : i * stepSize]; System.arraycopy(ld.getRoutes(), 0, part, 0, part.length); VecDataSet dataSet = new VecDataSetImpl(new RowsVecArrayMx(part), null); System.out.println("\nData set size: " + dataSet.length() + ":"); ProbabilisticGraphicalModel model = getModel(dataSet, iterations, policy, false); checkModel(model, 3, modelValidationListener); } } private ProbabilisticGraphicalModel getModel(VecDataSet dataSet, int iterations, Computable<ProbabilisticGraphicalModel, PGMEM.Policy> policy, boolean listen) throws IOException { final Mx original = new VecBasedMx(SIZE, VecTools.fill(new ArrayVec(SIZE * SIZE), 1.)); PGMEM pgmem = new PGMEM(original, 0.2, iterations, rng, policy); if (listen) { final Action<SimplePGM> listener = new Action<SimplePGM>() { int iteration = 0; @Override public void invoke(SimplePGM pgm) { Interval.stopAndPrint("Iteration " + ++iteration); System.out.println(); System.out.print(VecTools.distance(pgm.topology, original)); for (int i = 0; i < pgm.topology.columns(); i++) { System.out.print(" " + VecTools.distance(pgm.topology.row(i), original.row(i))); } System.out.println(); Interval.start(); } }; pgmem.addListener(listener); Interval.start(); } SimplePGM fit = pgmem.fit(dataSet, new LLLogit(VecTools.fill(new ArrayVec(dataSet.length()), 1.), dataSet)); VecTools.fill(fit.topology.row(fit.topology.rows() - 1), 0); System.out.println(MxTools.prettyPrint(fit.topology)); return fit; } private void checkModel(ProbabilisticGraphicalModel model, int accuracyLimit, Action<Pair<Integer, Double>> listener) throws IOException { VecDataSet check = new VecDataSetImpl(new RowsVecArrayMx(validate.getRoutes()), null); for (int i = 0; i < accuracyLimit; i++) { listener.invoke(Pair.create(i, checkModel(check, (SimplePGM) model, i))); } } private Route[] knownRoutes(SimplePGM model) { ArrayList<Route> routes = new ArrayList<Route>(); for(int i=0; i<model.knownRoutesCount(); i++) { routes.add(model.knownRoute(i)); } return routes.toArray(new Route[0]); } private double checkModel(DataSet check, SimplePGM fit, int accuracy) { final int[][] cpds = new int[check.length()][]; final Mx data = ((VecDataSet) check).data(); for (int j = 0; j < data.rows(); j++) { cpds[j] = fit.extractControlPoints(data.row(j)); } Route[] knownRoutes = knownRoutes(fit); int count = 0; for (int i = 0; i < cpds.length; i++) { for (int j = 0; j <= Math.min(knownRoutes.length - 1, accuracy); j++) { if (checkRoute(knownRoutes[j], cpds[i])) { count++; break; } } } return ((double) count) / cpds.length; } private boolean checkRoute(Route route, int... controlPoints) { int index = 0; for (int t = 0; t < route.length() && index < controlPoints.length; t++) { if (route.dst(t) == controlPoints[index]) index++; } return index == controlPoints.length; } public static final class LogsData { private static final int TYPE_OTHER = 0; private static final int TYPE_DIRECT = 1; private static final int TYPE_IMAGES = 2; private static final int TYPE_NEWS = 3; private static final int TYPE_VIDEO = 4; private static final int TYPE_WEB = 5; private static final int REL_NOT_ASED = 0; private static final int REL_IRRELEVANT = 1; private static final int REL_RELEVANT_MINUS = 2; private static final int REL_RELEVANT_PLUS = 3; private static final int REL_USEFUL = 4; private static final int REL_VITAL = 5; private Vec[] routes = new Vec[0]; private Vec[] relevances = new Vec[0]; private Vec[] types = new Vec[0]; public LogsData(InputStream in) throws IOException { BufferedReader reader = null; ArrayList<Vec> lrouts = new ArrayList<Vec>(); ArrayList<Vec> lrels = new ArrayList<Vec>(); ArrayList<Vec> ltypes = new ArrayList<Vec>(); try { reader = new BufferedReader(new InputStreamReader(in)); String line; while ((line = reader.readLine()) != null) { JsonObject obj = new JsonParser().parse(line.split("\t")[1]).getAsJsonObject(); JsonArray clicks = obj.getAsJsonArray("clicks"); Vec vclicks = new ArrayVec(0); if (clicks != null && !clicks.isJsonNull()) { vclicks = new ArrayVec(clicks.size() + 1); int i = 0; for(JsonElement el: clicks) { vclicks.set(i++, el.getAsInt() + 1); } } lrouts.add(vclicks); Vec vrel = new ArrayVec(0); JsonArray rels = obj.getAsJsonArray("rel"); if (rels != null && !rels.isJsonNull()) { vrel = new ArrayVec(rels.size()); int i = 0; for(JsonElement el: rels) { vrel.set(i++, parseRelevance(el.getAsString())); } } lrels.add(vrel); Vec vtype = new ArrayVec(0); JsonArray types = obj.getAsJsonArray("sntypes"); if (types != null && !types.isJsonNull()) { vtype = new ArrayVec(types.size()); int i = 0; for(JsonElement el: types) { vtype.set(i++, parseType(el.getAsString())); } } ltypes.add(vtype); } routes = lrouts.toArray(routes); relevances = lrels.toArray(relevances); types = ltypes.toArray(types); } finally { if (reader != null) { reader.close(); } } } private int parseType(String type) { int result = TYPE_OTHER; if (type.equals("DIRECT")) { result = TYPE_DIRECT; } if (type.equals("IMAGES")) { result = TYPE_IMAGES; } if (type.equals("NEWS")) { result = TYPE_NEWS; } if (type.equals("VIDEO")) { result = TYPE_VIDEO; } if (type.equals("WEB")) { result = TYPE_WEB; } return result; } private int parseRelevance(String relevance) { int result = REL_NOT_ASED; if (relevance.equals("IRRELEVANT")) { result = REL_IRRELEVANT; } if (relevance.equals("RELEVANT_MINUS")) { result = REL_RELEVANT_MINUS; } if (relevance.equals("RELEVANT_PLUS")) { result = REL_RELEVANT_PLUS; } if (relevance.equals("USEFUL")) { result = REL_USEFUL; } if (relevance.equals("VITAL")) { result = REL_VITAL; } return result; } public Vec[] getRoutes() { return routes; } public Vec[] getRelevances() { return relevances; } public Vec[] getTypes() { return types; } } }