package org.wikibrain.sr.normalize;
import com.typesafe.config.Config;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.interpolation.LoessInterpolator;
import org.apache.commons.math3.stat.ranking.NaNStrategy;
import org.apache.commons.math3.stat.ranking.NaturalRanking;
import org.apache.commons.math3.stat.ranking.TiesStrategy;
import org.wikibrain.conf.Configuration;
import org.wikibrain.conf.ConfigurationException;
import org.wikibrain.conf.Configurator;
import org.wikibrain.utils.WbMathUtils;
import java.text.DecimalFormat;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Normalizes in two steps:
* 1. Create a smoothed weighted average defined over a sample of the observed points.
* 2. Creates a local linear spline fitted to smoothed points.
*/
public class LoessNormalizer extends BaseNormalizer {
private static Logger LOG = LoggerFactory.getLogger(LoessNormalizer.class);
public static final long serialVersionUID = -34232429;
private TDoubleList X = new TDoubleArrayList();
private TDoubleList Y = new TDoubleArrayList();
private boolean logTransform = false;
private boolean monotonic = false;
transient private double interpolatorMin;
transient private double interpolatorMax;
transient private UnivariateFunction interpolator = null;
@Override
public void reset() {
super.reset();
X.clear();
Y.clear();
interpolatorMin = 0;
interpolatorMax = 0;
interpolator = null;
}
@Override
public void observe(double x, double y){
if (!Double.isNaN(x) && !Double.isInfinite(x)) {
synchronized (X) {
X.add(x);
Y.add(y);
}
}
super.observe(x, y);
}
@Override
public void observationsFinished(){
// lazily initialized to overcome problems
// with PolynomialSplineFunction serialization.
super.observationsFinished();
}
private static final double EPSILON = 1E-10;
@Override
public double normalize(double x) {
if (Double.isNaN(x) || Double.isInfinite(x)) {
return missingMean;
}
init();
x = logIfNeeded(x);
double sMin = interpolatorMin;
double sMax = interpolatorMax;
double x2;
if (sMin <= x && x <= sMax) {
x2 = getInterpolationFunction().value(x);
} else {
double yMin = getInterpolationFunction().value(sMin);
double yMax = getInterpolationFunction().value(sMax);
double halfLife = (sMax - sMin) / 4.0;
double yDelta = 0.1 * (yMax - yMin);
if (x < sMin) {
x2 = WbMathUtils.toAsymptote(sMin - x, halfLife, yMin, yMin - yDelta);
} else if (x > sMax) {
x2 = WbMathUtils.toAsymptote(x - sMax, halfLife, yMax, yMax + yDelta);
} else {
throw new IllegalStateException("" + x + " not in [" + sMin + "," + sMax + "]");
}
}
return x2;
}
private synchronized UnivariateFunction getInterpolationFunction() {
init();
return interpolator;
}
private synchronized void init() {
if (interpolator != null) {
return;
}
// remove infinite or nan values
TDoubleList pruned[] = WbMathUtils.removeNotNumberPoints(X, Y);
X = pruned[0];
Y = pruned[1];
// sort points by X coordinate
double ranks[] = new NaturalRanking(NaNStrategy.REMOVED, TiesStrategy.SEQUENTIAL).rank(X.toArray());
if (ranks.length != X.size()) {
throw new IllegalStateException("invalid sizes: " + ranks.length + " and " + X.size());
}
// spots in these arrays will be replaced.
TDoubleList sortedX = new TDoubleArrayList(X);
TDoubleList sortedY = new TDoubleArrayList(Y);
for (int i = 0; i < X.size(); i++) {
int r = (int)Math.round(ranks[i]) - 1;
sortedX.set(r, X.get(i));
sortedY.set(r, Y.get(i));
}
X = sortedX;
Y = sortedY;
// create the smoothed points.
int windowSize = Math.min(20, X.size() / 10);
double smoothed[][] = WbMathUtils.smooth(
logIfNeeded(X.toArray()),
Y.toArray(),
windowSize,
10);
double smoothedX[] = smoothed[0];
double smoothedY[] = smoothed[1];
/*System.err.print("smoothed points: ");
for (int i = 0; i < smoothedX.length; i++) {
System.err.print(" (" + smoothedX[i] + ", " + smoothedY[i] + ")");
}
System.err.println();*/
interpolatorMin = smoothedX[0];
interpolatorMax = smoothedX[smoothedX.length - 1];
WbMathUtils.makeMonotonicIncreasing(smoothedX, EPSILON);
if (monotonic) {
WbMathUtils.makeMonotonicIncreasing(smoothedY, EPSILON);
}
// create the interpolator
interpolator = new LoessInterpolator().interpolate(smoothedX, smoothedY);
}
private double logIfNeeded(double x) {
if (logTransform) {
return (x < X.get(0)) ? 0 : Math.log(1 + X.get(0) + x);
} else {
return x;
}
}
private double[] logIfNeeded(double X[]) {
if (logTransform) {
double X2[] = new double[X.length];
for (int i = 0; i < X.length; i++) {
X2[i] = logIfNeeded(X[i]);
}
return X2;
} else {
return X;
}
}
@Override
public String dump() {
init();
StringBuffer buff = new StringBuffer("loess normalizer");
if (logTransform) buff.append(" (log'ed)");
DecimalFormat df = new DecimalFormat("#.##");
for (int i = 0; i <= 20; i++) {
int j = Math.min(X.size() - 1, i * X.size() / 20);
double x = X.get(j);
buff.append(" <" +
df.format(x) + "," +
df.format(normalize(x)) + ">");
}
return buff.toString();
}
public void setLogTransform(boolean b) {
this.logTransform = b;
}
public boolean getLogTransform() {
return logTransform;
}
public void setMonotonic(boolean b) {
this.monotonic = b;
}
public static class Provider extends org.wikibrain.conf.Provider<LoessNormalizer> {
public Provider(Configurator configurator, Configuration config) throws ConfigurationException {
super(configurator, config);
}
@Override
public Class getType() {
return Normalizer.class;
}
@Override
public String getPath() {
return "sr.normalizer";
}
@Override
public Scope getScope() {
return Scope.INSTANCE;
}
@Override
public LoessNormalizer get(String name, Config config, Map<String, String> runtimeParams) throws ConfigurationException {
if (!config.getString("type").equals("loess")) {
return null;
}
LoessNormalizer ln = new LoessNormalizer();
if (config.hasPath("log")) {
ln.setLogTransform(config.getBoolean("log"));
}
return ln;
}
}
}