package com.spbsu.exp; import com.spbsu.bernulli.FittedModel; import com.spbsu.bernulli.Learner; import com.spbsu.bernulli.MCMCBernoulliMixture.BernoulliPrior; import com.spbsu.bernulli.MCMCBernoulliMixture.MCMCBernoulliEstimation; import com.spbsu.bernulli.MCMCBernoulliMixture.UniformPrior; import com.spbsu.bernulli.MixtureObservations; import com.spbsu.bernulli.StochasticSearch; import com.spbsu.bernulli.betaBinomialMixture.BetaBinomialMixture; import com.spbsu.bernulli.betaBinomialMixture.BetaBinomialMixtureEM; import com.spbsu.bernulli.betaBinomialMixture.RegularizedBetaBinomialMixtureEM; import com.spbsu.bernulli.naiveMixture.BernoulliMixtureEM; import com.spbsu.bernulli.naiveMixture.NaiveMixture; import com.spbsu.commons.func.Factory; import com.spbsu.commons.math.vectors.Vec; import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.util.ThreadTools; import junit.framework.TestCase; import org.apache.commons.math3.special.Gamma; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadPoolExecutor; import static com.spbsu.bernulli.betaBinomialMixture.BetaBinomialMixtureEM.Type; import static com.spbsu.commons.math.MathTools.sqr; import static com.spbsu.commons.math.vectors.VecTools.*; /** * bernoulli models experiments * User: noxoomo * Date: 09.03.2015 */ public class BernoulliTest extends TestCase { FastRandom rand = new FastRandom(22); public void testBetaBinomialMixture() { final int k = 3; final int n = 200; final int count = 5000; int tries = 200; for (int i = 0; i < tries; ++i) { BetaBinomialMixture mixture = new BetaBinomialMixture(k, count, rand); MixtureObservations observation = mixture.sample(n); BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k, observation.sums, n, rand); FittedModel<BetaBinomialMixture> fittedModel = em.fit(true); System.out.println("Real model " + mixture); System.out.println("Fitted model " + fittedModel.model); } } public void testBetaBinomialRegularizedMixture() { final int k = 3; final int n = 200; final int count = 5000; int tries = 200; int fakeObservations = 1000; for (int i = 0; i < tries; ++i) { BetaBinomialMixture mixture = new BetaBinomialMixture(k, count, rand); MixtureObservations observation = mixture.sample(n); RegularizedBetaBinomialMixtureEM em = new RegularizedBetaBinomialMixtureEM(k, observation.sums, n, fakeObservations, rand); FittedModel<BetaBinomialMixture> fittedModel = em.fit(true); System.out.println("Real model " + mixture); System.out.println("Fitted model " + fittedModel.model); } } public void testBetaBinomialMixtureEstimation() { final int k = 2; int tries = 100; final int from = 25600; final int to = 100000; final int step = 1000; for (int n = 320; n < 10000; n += 1000) for (int N = from; N < to; N += step) { double sumAvgMixture = 0; double sumAvgNaive = 0; for (int tr = 1; tr <= tries; ++tr) { BetaBinomialMixture mix = new BetaBinomialMixture(2, n, rand); final MixtureObservations observations = mix.sample(N); final int finaln = n; StochasticSearch<BetaBinomialMixtureEM> search = new StochasticSearch<>(new Factory<Learner<BetaBinomialMixtureEM>>() { @Override public Learner<BetaBinomialMixtureEM> create() { return new Learner<BetaBinomialMixtureEM>() { @Override public FittedModel<BetaBinomialMixtureEM> fit() { BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k, observations.sums, finaln, rand); FittedModel<BetaBinomialMixture> model = em.fit(); return new FittedModel<>(model.likelihood, em); } }; } }); BetaBinomialMixtureEM em = search.fit(8); // BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k,observations.sums,n,rand); double[] means = em.estimate(false); sumAvgMixture += observations.quality(means) / observations.thetas.length; sumAvgNaive += observations.naiveQuality() / observations.thetas.length; System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tr + "\t" + sumAvgNaive / tr); } System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tries + "\t" + sumAvgNaive / tries); } } public void testSteinBernoulli() { Random rng = new FastRandom(); for (int n = 50; n < 1001; n *= 2) { double sumAvgNaive = 0; double sumAvgJS = 0; double sumAvgBetaBin = 0; for (int k = 0; k < 1000; k++) { Vec sum = new ArrayVec(1000); Vec m = new ArrayVec(1000); for (int i = 0; i < m.dim(); ++i) m.set(i, rng.nextDouble()); double sigma = 1; for (int i = 0; i < n; i++) { for (int t = 0; t < sum.dim(); t++) { sum.adjust(t, rng.nextDouble() > m.get(t) ? 0 : 1); } } Vec naive = copy(sum); scale(naive, 1. / n); Vec js = copy(sum); scale(js, (1 - (js.dim() - 2) * sigma * sigma / sqr(norm(sum))) / n); double[] betameans; { final int[] isums = new int[sum.dim()]; for (int i = 0; i < isums.length; ++i) { isums[i] = (int) sum.get(i); } final int finalN = n; StochasticSearch<RegularizedBetaBinomialMixtureEM> search = new StochasticSearch<>(new Factory<Learner<RegularizedBetaBinomialMixtureEM>>() { @Override public Learner<RegularizedBetaBinomialMixtureEM> create() { return new Learner<RegularizedBetaBinomialMixtureEM>() { @Override public FittedModel<RegularizedBetaBinomialMixtureEM> fit() { RegularizedBetaBinomialMixtureEM em = new RegularizedBetaBinomialMixtureEM(4, isums, finalN, 100, rand); FittedModel<BetaBinomialMixture> model = em.fit(); return new FittedModel<>(model.likelihood, em); } }; } }); RegularizedBetaBinomialMixtureEM emb = search.fit(8); // BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k,observations.sums,n,rand); betameans = emb.estimate(false); } sumAvgNaive += distance(naive, m) / Math.sqrt(m.dim()); sumAvgJS += distance(js, m) / Math.sqrt(m.dim()); sumAvgBetaBin += distance(new ArrayVec(betameans), m) / Math.sqrt(m.dim()); } System.out.println(n + "\t" + sumAvgNaive / 1000 + "\t" + sumAvgJS / 1000 + "\t" + sumAvgBetaBin / 1000); } } public void testBetaBinomialRegularizedMixtureEstimation() { final int k = 2; int tries = 100; final int from = 5000; final int to = 100001; final int step = 1000; for (int n = 30; n < 10001; n *= 10) for (int N = from; N < to; N *= 10) { double sumAvgMixture = 0; double sumAvgNaive = 0; for (int tr = 1; tr <= tries; ++tr) { BetaBinomialMixture mix = new BetaBinomialMixture(2, n, rand); final MixtureObservations observations = mix.sample(N); final int finaln = n; StochasticSearch<RegularizedBetaBinomialMixtureEM> search = new StochasticSearch<>(new Factory<Learner<RegularizedBetaBinomialMixtureEM>>() { @Override public Learner<RegularizedBetaBinomialMixtureEM> create() { return new Learner<RegularizedBetaBinomialMixtureEM>() { @Override public FittedModel<RegularizedBetaBinomialMixtureEM> fit() { RegularizedBetaBinomialMixtureEM em = new RegularizedBetaBinomialMixtureEM(k, observations.sums, finaln,200, rand); FittedModel<BetaBinomialMixture> model = em.fit(); return new FittedModel<>(model.likelihood, em); } }; } }); RegularizedBetaBinomialMixtureEM em = search.fit(8); // BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k,observations.sums,n,rand); double[] means = em.estimate(false); sumAvgMixture += observations.quality(means); sumAvgNaive += observations.naiveQuality() ; System.out.println("Real model " + mix.toString()); System.out.println("Fitted model " + em.model().toString()); System.out.println(tr + "\t" + n + "\t" + sumAvgMixture / tr + "\t" + sumAvgNaive / tr); } System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tries + "\t" + sumAvgNaive / tries); } } public void testSpecialFunctionCache() { // BetaFunctionsProportion prop = new BetaFunctionsProportion(0.5,0.5,50); BetaBinomialMixtureEM.SpecialFunctionCache prop = new BetaBinomialMixtureEM.SpecialFunctionCache(1, 1, 10); // assertTrue((prop.calculate(10,50)-1.860013246596769 * 1e-13) < 1e-20); // prop.update(1,1); assertTrue(Math.abs(prop.calculate(2, 10) - 0.0020202) < 1e-7); assertTrue(Math.abs(prop.calculate(3, 10) - 0.000757576) < 1e-7); assertTrue(Math.abs(prop.calculate(4, 10) - 0.0004329) < 1e-7); assertTrue(Math.abs(prop.calculate(5, 10) - 0.00036075) < 1e-7); assertTrue(Math.abs(prop.calculate(9, 10) - 0.00909091) < 1e-7); assertTrue(Math.abs(prop.calculate(10, 10) - 0.0909091) < 1e-7); for (int i = 0; i < 10; ++i) { assertTrue(Math.abs(prop.digamma(Type.Alpha, i) - Gamma.digamma(1 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.Beta, i) - Gamma.digamma(1 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.AlphaBeta, i) - Gamma.digamma(2 + i)) < 1e-12); } prop.update(102.5, 10.11); // prop = new BetaFunctionsProportion(102.5,10.11,10); assertTrue(Math.abs(prop.calculate(5, 10) - 6.478713223568241e-6) < 1e-9); assertTrue(Math.abs(prop.calculate(8, 10) - 0.00369375) < 1e-8); for (int i = 0; i < 10; ++i) { assertTrue(Math.abs(prop.digamma(Type.Alpha, i) - Gamma.digamma(102.5 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.Beta, i) - Gamma.digamma(10.11 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.AlphaBeta, i) - Gamma.digamma(102.5 + 10.11 + i)) < 1e-12); } for (int i = 0; i < 10; ++i) { assertTrue(Math.abs(prop.digamma1(Type.Alpha, i) - Gamma.trigamma(102.5 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma1(Type.Beta, i) - Gamma.trigamma(10.11 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma1(Type.AlphaBeta, i) - Gamma.trigamma(102.5 + 10.11 + i)) < 1e-12); } prop = new BetaBinomialMixtureEM.SpecialFunctionCache(12.2, 55.1, 100); assertTrue(Math.abs(prop.calculate(30, 100) - 3.520721627628687e-28) < 1e-32); assertTrue(Math.abs(prop.calculate(60, 100) - 7.007620723590574e-37) < 1e-41); assertTrue(Math.abs(prop.calculate(10, 100) - 1.764175281317258e-15) < 1e-19); assertTrue(Math.abs(prop.calculate(88, 100) - 3.97033387162681e-36) < 1e-40); for (int i = 0; i < 100; ++i) { assertTrue(Math.abs(prop.digamma(Type.Alpha, i) - Gamma.digamma(12.2 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.Beta, i) - Gamma.digamma(55.1 + i)) < 1e-12); assertTrue(Math.abs(prop.digamma(Type.AlphaBeta, i) - Gamma.digamma(12.2 + 55.1 + i)) < 1e-12); } } static ThreadPoolExecutor exec = ThreadTools.createBGExecutor("bernoulli tests", -1); public void testMCMCConvergence() { final int k = 3; final int count = 10; final int n = 100; final int runIters = 100000000; NaiveMixture mixture = new NaiveMixture(k, count, rand); final MixtureObservations<NaiveMixture> observations = mixture.sample(n); final int tries = 8; final double[] scores = new double[tries]; final double naiveScore = observations.naiveQuality(); final BernoulliPrior prior = new UniformPrior(n * count + 1); // final BernoulliPrior prior = new LLPrior(); System.out.println("Naive score is " + naiveScore); // { // MCMCBernoulliEstimation estimation = new MCMCBernoulliEstimation(k, observations.n, observations.sums, prior, rand); // int currentIters = 1; // for (int i=0; i < 1000;++i) { // estimation.run(2 * currentIters); // System.out.println("Score for " + i + " is " + observations.quality(estimation.estimation())); // estimation.clear(); // currentIters *= 2; // } // } final CountDownLatch latch = new CountDownLatch(tries); for (int i = 0; i < tries; ++i) { final int ind = i; exec.submit(new Runnable() { @Override public void run() { int chainsCount = 8; double[] means = new double[observations.sums.length]; for (int ii = 0; ii < chainsCount; ++ii) { MCMCBernoulliEstimation estimation = new MCMCBernoulliEstimation(k, observations.n, observations.sums, prior, rand); estimation.run(runIters); double[] tmp = estimation.estimation(); System.out.println("Score for ind " + ind + " and chain " + ii + " is " + observations.quality(tmp)); for (int j = 0; j < tmp.length; ++j) means[j] += tmp[j]; } for (int iii = 0; iii < means.length; ++iii) { means[iii] /= chainsCount; } scores[ind] = observations.quality(means); System.out.println("Score for " + ind + " is " + scores[ind]); latch.countDown(); } }); } try { latch.await(); } catch (Exception e) { // } } public void testMCMCEstimation() { final int k = 5; final int count = 100; final int n = 100000; final int runIters = 10000000; NaiveMixture mixture = new NaiveMixture(k, count, rand); final MixtureObservations<NaiveMixture> observations = mixture.sample(n); final double naiveScore = observations.naiveQuality(); final BernoulliPrior prior = new UniformPrior(n * count + 1); // final BernoulliPrior prior = new LLPrior(); System.out.println("Naive score is " + naiveScore); MCMCBernoulliEstimation estimation = new MCMCBernoulliEstimation(k, observations.n, observations.sums, prior, rand); for (int i = 0; i < runIters; ++i) { estimation.run(10); double[] tmp = estimation.estimation(); for (double p : tmp) { assertTrue("Probabilty shoud be in [0,1]", p >= 0 && p <= 1); } } } <Mixture> double[] fitMCMCMixture(final int k, final MixtureObservations<Mixture> experiment,final int chainsCount, final int iters) { final double[][] means = new double[chainsCount][experiment.sums.length]; final BernoulliPrior prior = new UniformPrior(experiment.n * experiment.sums.length+1); final CountDownLatch latch = new CountDownLatch(chainsCount); for (int i = 0; i < chainsCount; ++i) { final int ind = i; exec.submit(new Runnable() { @Override public void run() { MCMCBernoulliEstimation estimation = new MCMCBernoulliEstimation(k, experiment.n, experiment.sums, prior, rand); estimation.run(iters); means[ind] = estimation.estimation(); latch.countDown(); } }); } try { latch.await(); } catch (Exception e) { // } for (int i=1; i < means.length-1;++i) { for (int j=0; j < means[0].length;++j) means[0][j] += means[i][j]; } final int last = means.length - 1; for (int j=0; j < means[0].length;++j) { means[0][j] = (means[0][j] + means[last][j]) / means.length; } return means[0]; } <Mixture> double[] fitBetaMixture(final int k, final MixtureObservations<Mixture> experiment) { StochasticSearch<RegularizedBetaBinomialMixtureEM> search = new StochasticSearch<>(new Factory<Learner<RegularizedBetaBinomialMixtureEM>>() { @Override public Learner<RegularizedBetaBinomialMixtureEM> create() { return new Learner<RegularizedBetaBinomialMixtureEM>() { @Override public FittedModel<RegularizedBetaBinomialMixtureEM> fit() { RegularizedBetaBinomialMixtureEM em = new RegularizedBetaBinomialMixtureEM(k, experiment.sums, experiment.n,500, rand); FittedModel<BetaBinomialMixture> model = em.fit(); return new FittedModel<>(model.likelihood, em); } }; } }); RegularizedBetaBinomialMixtureEM emb = search.fit(8); // BetaBinomialMixtureEM em = new BetaBinomialMixtureEM(k,observations.sums,n,rand); return emb.estimate(false); } <Mixture> double[] fitNaiveMixture(final int k, final MixtureObservations<Mixture> experiment) { StochasticSearch<BernoulliMixtureEM> search = new StochasticSearch<>(new Factory<Learner<BernoulliMixtureEM>>() { @Override public Learner<BernoulliMixtureEM> create() { return new Learner<BernoulliMixtureEM>() { @Override public FittedModel<BernoulliMixtureEM> fit() { BernoulliMixtureEM em = new BernoulliMixtureEM(experiment.sums, experiment.n,k, rand); FittedModel<NaiveMixture> model = em.fit(); return new FittedModel<>(model.likelihood, em); } }; } }); BernoulliMixtureEM emb = search.fit(8); return emb.estimate(false); } public void testMixture() { final int k = 50; final int from = 500; final int to = 100001; final int tries = 1000; for (int n = 20; n < 1001; n *= 2) for (int N = from; N < to; N *= 10) { double sumAvgMixture = 0; double sumAvgBetaMixture = 0; double sumAvgNaive = 0; double sumAvgMCMC = 0; final NaiveMixture mixture = new NaiveMixture(k,n,rand); for (int tr = 1; tr <= tries; ++tr) { final MixtureObservations<NaiveMixture> experiment = mixture.sample(N); double[] betameans = fitBetaMixture(k, experiment); double[] mcmcmeans = fitMCMCMixture(k, experiment, 4, 100000000); double[] means = fitNaiveMixture(k, experiment); sumAvgMixture += experiment.quality(means); sumAvgBetaMixture +=experiment.quality(betameans); sumAvgNaive += experiment.naiveQuality(); sumAvgMCMC += experiment.quality(mcmcmeans); System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tr + "\t" + sumAvgBetaMixture / tr+ "\t" +sumAvgMCMC / tr + "\t" + sumAvgNaive / tr); } System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tries + "\t" + sumAvgBetaMixture / tries + "\t" +sumAvgMCMC / tries + "\t" + sumAvgNaive / tries); } } public void testNaiveMixture() { final int k = 2; final int from = 6400; final int to = 100001; final int tries = 1000; for (int n = 320; n < 1001; n *= 2) for (int N = from; N < to; N *= 10) { double sumAvgMixture = 0; double sumAvgNaive = 0; final NaiveMixture mixture = new NaiveMixture(k,n,rand); for (int tr = 1; tr <= tries; ++tr) { final MixtureObservations<NaiveMixture> experiment = mixture.sample(N); double[] means = fitNaiveMixture(k, experiment); sumAvgMixture += experiment.quality(means); sumAvgNaive += experiment.naiveQuality(); System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tr + "\t" + "\t" + sumAvgNaive / tr); } System.out.println(N + "\t" + n + "\t" + sumAvgMixture / tries + "\t" + "\t" + sumAvgNaive / tries); } } }