package com.matrobot.gha.insights.app.repo;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Properties;
import com.matrobot.gha.insights.filter.RegressionRepositoryFilter;
import com.matrobot.gha.insights.ml.Dataset;
import com.matrobot.gha.insights.ml.Sample;
import com.matrobot.gha.insights.regression.IRegression;
import com.matrobot.gha.insights.regression.MultivariableRegression;
import com.matrobot.gha.insights.regression.StaticRegression;
public class RegressionEvaluatorApp {
private static final int PREDICTION_RANGE = 140;
Properties prop = new Properties();
private Dataset dataset;
private int counter;
protected RegressionEvaluatorApp(String firstPath, String secondPath, String thirdPath) throws IOException{
prop.load(new FileInputStream("config.properties"));
RegressionRepositoryFilter filter = new RegressionRepositoryFilter(
prop.getProperty("data_path") + firstPath,
prop.getProperty("data_path") + secondPath,
prop.getProperty("data_path") + thirdPath);
dataset = filter.getDataset();
dataset.normalize();
}
private void evaluate(IRegression model) {
counter = 0;
int correctCount = 0;
double sumOfErrors = 0;
double sum = 0;
for(Sample sample : dataset.getData()){
double forecast = model.predict(sample.features);
double error = Math.pow(sample.output-forecast, 2);
if(isInRange(sample.output, forecast, PREDICTION_RANGE)){
correctCount ++;
}
sumOfErrors += error;
sum += forecast;
counter += 1;
}
System.out.println(" Forecast mean: " + (sum/counter));
System.out.println(" SE: " + Math.sqrt(sumOfErrors/counter));
System.out.println(" Correct: " + (correctCount*100.0)/counter);
}
private boolean isInRange(double expected, double forecast, int rangePercentage) {
double size = expected/100*rangePercentage;
return (forecast<expected+size && forecast>expected-size);
}
public static void main(String[] args) throws IOException {
IRegression regression;
System.out.println("Load datasets");
RegressionEvaluatorApp app = new RegressionEvaluatorApp("2012-9/", "2012-10/", "2012-11/");
// Static classifier
System.out.println("Static model: ");
regression = new StaticRegression(0.8);
app.evaluate(regression);
// app.evalulateRandomRepos(regression, 20);
System.out.println();
MultivariableRegression model = MultivariableRegression.trainByNormalEquation(app.dataset);
model.save("models/mr1.model");
regression = MultivariableRegression.createFromFile("models/mr1.model");
System.out.println("Multivariable: ");
// app.evalulateRandomRepos(regression, 20);
app.evaluate(regression);
System.out.println();
}
}