package com.spbsu.bernulli.betaBinomialMixture;
import com.spbsu.bernulli.EM;
import com.spbsu.bernulli.caches.BetaCache;
import com.spbsu.bernulli.caches.Digamma1Cache;
import com.spbsu.bernulli.caches.DigammaCache;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx;
import com.spbsu.commons.random.FastRandom;
import java.util.Arrays;
import static java.lang.Double.isNaN;
//reparametrezation of beta-distribution for solve singularities
// alpha + beta <= MaxN — maximim number of "prior" observations
//reparametrization of beta — \mu \in [0,1] and N.
//optimizaion — EM for mixture, after E — newton with alternating mu and N, shrinking N on maxN
public class RegularizedBetaBinomialMixtureEM extends EM<BetaBinomialMixture> {
final int k;
final int[] sums;
final int n;
final Mx dummy;
final BetaBinomialMixture model;
final FastRandom random;
final SpecialFunctionCache funcs[];
final MeanOptimization mean;
final PrecisionOptimization precision;
final int N; //it's regularized — maximum number of pseudo-observations
final double mu[];
final double precisions[];
public RegularizedBetaBinomialMixtureEM(int k, final int[] sums, final int n, final int N, FastRandom random) {
this.k = k; //components count
this.sums = sums;
this.n = n;
this.dummy = new VecBasedMx(sums.length, k);
this.model = new BetaBinomialMixture(k, n, random);
this.random = random;
this.funcs = new SpecialFunctionCache[k];
this.mu = new double[k];
this.precisions = new double[k];
for (int i = 0; i < k; ++i) {
precisions[i] = (this.model.alphas[i] + this.model.betas[i]);
mu[i] = this.model.alphas[i] / precisions[i];
precisions[i] = Math.min(precisions[i], N);
this.funcs[i] = new SpecialFunctionCache(mu[i], precisions[i], n);
}
updateModel();
this.N = N;
this.mean = new MeanOptimization();
precision = new PrecisionOptimization(N);
}
private void updateCache() {
for (int i = 0; i < k; ++i) {
funcs[i].update(model.alphas[i], model.betas[i]);
}
}
@Override
protected void expectation() {
double[] probs = new double[k];
updateCache();
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
double denum = 0;
for (int j = 0; j < k; ++j) {
probs[j] = model.q[j] * funcs[j].calculate(m, n);
denum += probs[j];
}
for (int j = 0; j < k; ++j) {
dummy.set(i, j, probs[j] /= denum);
}
}
}
private final int newtonIters = 3;
private final double gradientStep = 0.05;
private final double newtonStep = 0.01;
private final int gradientIters = 20;
private final int iterations = 3;
boolean first = true;
@Override
protected void maximization() {
final double probs[] = new double[k];
for (int i = 0; i < sums.length; ++i) {
for (int j = 0; j < k; ++j) {
probs[j] += dummy.get(i, j);
}
}
double total = 0;
for (int i = 0; i < k; ++i) {
total += probs[i];
}
for (int i = 0; i < k; ++i)
model.q[i] = probs[i] / total;
for (int i = 0; i < iterations; ++i) {
mean.maximize();
precision.maximize();
}
}
int count = 300;
double oldLikelihood = Double.NEGATIVE_INFINITY;
@Override
protected boolean stop() {
final double currentLL = likelihood();
if (Math.abs(oldLikelihood - currentLL) < 1e-1) {
return true;
}
oldLikelihood = currentLL;
return --count < 0;
}
@Override
public BetaBinomialMixture model() {
return model;
}
@Override
protected double likelihood() {
updateCache();
double ll = 0;
for (int i = 0; i < sums.length; ++i) {
double p = 0;
final int m = sums[i];
for (int j = 0; j < model.alphas.length; ++j) {
p += model.q[j] * funcs[j].calculate(m, n);
}
ll += Math.log(p);
}
return ll;
}
//emperical bayes estimation with fitted beta-mixture
public double[] estimate(boolean fit) {
if (fit) {
fit();
}
expectation();
double result[] = new double[sums.length];
for (int i = 0; i < sums.length; ++i) {
for (int j = 0; j < k; ++j)
result[i] += dummy.get(i, j) * (sums[i] + model.alphas[j]) * 1.0 / (n + model.betas[j] + model.alphas[j]);
}
return result;
}
class MeanOptimization {
final double gradientCache[];
final double newtonCache[];
MeanOptimization() {
this.gradientCache = new double[2 * model.alphas.length];
this.newtonCache = new double[3 * model.alphas.length];
}
private boolean newtonStep(double step) {
updateCache();
Arrays.fill(newtonCache, 0.0);
fillGradient();
final double cache[] = new double[k];
final double cache2[] = new double[k];
for (int i = 0; i < k; ++i) {
final double beta = model.betas[i];
final double alpha = model.alphas[i];
cache[i] = -2 * (funcs[i].digamma(Type.Alpha, 0) - funcs[i].digamma(Type.Beta, 0));
cache[i] += beta * (funcs[i].digamma1(Type.Alpha, 0) + funcs[i].digamma1(Type.Beta, 0));
cache2[i] = alpha * beta * (funcs[i].digamma1(Type.Alpha, 0) + funcs[i].digamma1(Type.Beta, 0));
cache2[i] += (beta - alpha) * (funcs[i].digamma(Type.Alpha, 0) - funcs[i].digamma(Type.Beta, 0));
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
final SpecialFunctionCache func = funcs[j];
final double prob = dummy.get(i, j);
final double alpha = model.alphas[j];
final double beta = model.betas[j];
final int idx0 = 3 * j; //a
final int idx1 = 3 * j + 1;//b
final int idx2 = 3 * j + 2;//d
final double dpp;
final double dqq;
{
final double tmp1 = -2 * func.digamma(Type.Beta, n - m) - beta * func.digamma1(Type.Beta, n - m);
final double tmp2 = 2 * func.digamma(Type.Alpha, m) - beta * func.digamma1(Type.Alpha, m);
final double tmp = prob * (tmp1 + tmp2 + cache[j]);
dpp = -beta * tmp;
dqq = alpha * tmp;
}
final double dpq;
{
final double tmp1 = (beta - alpha) * (func.digamma(Type.Beta, n - m) - func.digamma(Type.Alpha, m));
final double tmp2 = -(alpha * beta) * (func.digamma1(Type.Beta, n - m) + func.digamma1(Type.Alpha, m));
final double tmp = prob * (tmp1 + tmp2 + cache2[j]);
dpq = tmp;
}
newtonCache[idx0] += dpp;
newtonCache[idx1] += dpq;
newtonCache[idx2] += dqq;
}
}
for (int i = 0; i < k; ++i) {
//gradient
final double dp = gradientCache[2 * i];
final double dq = gradientCache[2 * i + 1];
//hessian
final double a = newtonCache[3 * i];
final double b = newtonCache[3 * i + 1];
final double d = newtonCache[3 * i + 2];
final double det = a * d - b * b;
final double dirp = (d * dp - b * dq) / det;
final double dirq = (a * dq - b * dp) / det;
while (mu[i] - step * dirp < 1e-3 || 1 - mu[i] - step * dirq < 1e-3) {
step *= step;
if (step < 1e-15)
return true;
}
double p = mu[i] - step * dirp;
double q = 1 - mu[i] - step * dirq;
p /= (p + q);
mu[i] = p;
}
updateModel();
return false;
}
//don't use it before update
private void fillGradient() {
Arrays.fill(gradientCache, 0.0);
final double cache[] = new double[k];
for (int i = 0; i < k; ++i) {
cache[i] = funcs[i].digamma(Type.Alpha, 0) - funcs[i].digamma(Type.Beta, 0);
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
final double p = dummy.get(i, j);
final double tmp = p * (cache[j] - funcs[j].digamma(Type.Alpha, m) + funcs[j].digamma(Type.Beta, n - m));
final double grad1 = -model.betas[j] * tmp / precisions[j];
final double grad2 = model.alphas[j] * tmp / precisions[j];
gradientCache[2 * j] += grad1;
gradientCache[2 * j + 1] += grad2;
}
}
}
private boolean gradientStep(double step) {
updateCache();
fillGradient();
for (int i = 0; i < gradientCache.length; ++i) {
if (isNaN(gradientCache[i])) {
return true;
}
}
for (int i = 0; i < k; ++i) {
final double alpha = model.alphas[i];
final double dalpha = gradientCache[2 * i];
final double beta = model.betas[i];
final double dbeta = gradientCache[2 * i + 1];
while (alpha + step * dalpha < 1e-3 || beta + step * dbeta < 1e-3) {
step *= step;
if (step < 1e-15)
return true;
}
final double newAlpha = alpha + step * dalpha;
final double newBeta = beta + step * dbeta;
mu[i] = newAlpha / (newAlpha + newBeta);
}
updateModel();
return false;
}
boolean first = true;
boolean maximize() {
// if (first) {
for (int i = 0; i < gradientIters; ++i)
gradientStep(gradientStep);
// first = false;
// }
// for (int i = 0; i < newtonIters; ++i)
// newtonStep(newtonStep);
return true;
}
}
class PrecisionOptimization {
final double gradientCache[];
final double newtonCache[];
final int maxPrecision;
boolean[] stopped;
int stoppedCount;
PrecisionOptimization(int N) {
this.gradientCache = new double[model.alphas.length];
this.maxPrecision = N;
this.newtonCache = new double[model.alphas.length];
this.stopped = new boolean[model.alphas.length];
stoppedCount = 0;
}
private boolean newtonStep(double step) {
updateCache();
Arrays.fill(newtonCache, 0.0);
fillGradient();
final double cache[] = new double[k];
for (int i = 0; i < k; ++i) {
if (stopped[i])
continue;
final double p = mu[i];
cache[i] = -p * p * funcs[i].digamma1(Type.Alpha, 0) - (1 - p) * (1 - p) * funcs[i].digamma1(Type.Beta, 0);
cache[i] += -funcs[i].digamma1(Type.AlphaBeta, n) + funcs[i].digamma1(Type.AlphaBeta, 0);
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
if (stopped[j])
continue;
final double prob = dummy.get(i, j);
final double p = mu[j];
final double dgrad = prob * (cache[j] + (1 - p) * (1 - p) * funcs[j].digamma1(Type.Beta, n - m) + p * p * funcs[j].digamma1(Type.Alpha, m));
newtonCache[j] += dgrad;
}
}
for (int i = 0; i < k; ++i) {
if (stopped[i])
continue;
//matrix
double N = precisions[i] - step * (gradientCache[i] / newtonCache[i]);
if (N > 0) {
if (precisions[i] > maxPrecision) {
precisions[i] = maxPrecision;
stopped[i] = true;
stoppedCount++;
} else {
precisions[i] = N;
}
}
}
updateModel();
return false;
}
//don't use it before update()
private void fillGradient() {
Arrays.fill(gradientCache, 0.0);
final double cache[] = new double[k];
for (int i = 0; i < k; ++i) {
if (stopped[i])
continue;
cache[i] = -mu[i] * funcs[i].digamma(Type.Alpha, 0) - (1 - mu[i]) * funcs[i].digamma(Type.Beta, 0)
+ funcs[i].digamma(Type.AlphaBeta, 0) - funcs[i].digamma(Type.AlphaBeta, n);
}
for (int i = 0; i < sums.length; ++i) {
final int m = sums[i];
for (int j = 0; j < k; ++j) {
if (stopped[j])
continue;
final double prob = dummy.get(i, j);
final double p = mu[j];
final SpecialFunctionCache func = funcs[j];
final double grad = (1 - p) * func.digamma(Type.Beta, n - m) + p * func.digamma(Type.Alpha, m) + cache[j];
gradientCache[j] += prob * grad;
}
}
}
private boolean gradientStep(double step) {
updateCache();
fillGradient();
for (int i = 0; i < gradientCache.length; ++i) {
if (isNaN(gradientCache[i])) {
return true;
}
}
for (int i = 0; i < k; ++i) {
if (stopped[i])
continue;
while (precisions[i] + step * gradientCache[i] < 1e-3) {
step *= 0.5;
if (step < 1e-10)
return false;
}
precisions[i] += step * gradientCache[i];
if (precisions[i] > maxPrecision) {
precisions[i] = maxPrecision;
stopped[i] = true;
stoppedCount++;
}
}
updateModel();
return false;
}
boolean first = true;
boolean maximize() {
if (stoppedCount == k)
return false;
if (first) {
for (int i = 0; i < gradientIters; ++i)
gradientStep(gradientStep);
first = false;
}
for (int i = 0; i < 3; ++i)
gradientStep(0.01);
for (int i = 0; i < newtonIters; ++i)
newtonStep(newtonStep);
return true;
// }
}
}
private void updateModel() {
for (int i = 0; i < model.alphas.length; ++i) {
model.alphas[i] = mu[i] * precisions[i];
model.betas[i] = (1 - mu[i]) * precisions[i];
}
}
private enum Type {
Alpha,
Beta,
AlphaBeta
}
private double alpha(double mu, double N) {
return mu * N;
}
private double beta(double mu, double N) {
return (1 - mu) * N;
}
private class SpecialFunctionCache {
DigammaCache head;
DigammaCache tail;
DigammaCache alphabeta;
Digamma1Cache dalphabeta;
Digamma1Cache dhead;
Digamma1Cache dtail;
BetaCache betaCache;
public SpecialFunctionCache(double mu, double N, int n) {
betaCache = new BetaCache(alpha(mu, N), beta(mu, N), n);
head = new DigammaCache(mu * N, n);
tail = new DigammaCache((1 - mu) * N, n);
dhead = new Digamma1Cache(mu * N, n);
dtail = new Digamma1Cache((1 - mu) * N, n);
dalphabeta = new Digamma1Cache(N, n);
alphabeta = new DigammaCache(N, n);
}
public double calculate(int m, int n) {
return betaCache.calculate(m, n);
}
final public double digamma(Type type, int offset) {
if (type == Type.Alpha) {
return head.calculate(offset);
} else if (type == Type.Beta) {
return tail.calculate(offset);
}
return alphabeta.calculate(offset);
}
public double digamma1(Type type, int offset) {
if (type == Type.Alpha) {
return dhead.calculate(offset);
} else if (type == Type.Beta) {
return dtail.calculate(offset);
}
return dalphabeta.calculate(offset);
}
final public void update(final double alpha, final double beta) {
betaCache.update(alpha, beta);
head.update(alpha);
tail.update(beta);
dhead.update(alpha);
dtail.update(beta);
dalphabeta.update(alpha + beta);
alphabeta.update(alpha + beta);
}
}
}