package com.matrobot.gha.insights.regression;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import com.matrobot.gha.insights.ml.Dataset;
public class MultivariableRegression implements IRegression{
private double[] coefficients;
public MultivariableRegression(double[] coefficients){
this.coefficients = coefficients;
}
@Override
public double predict(double[] input) {
double sum = 0;
for(int i = 0; i < coefficients.length && i < input.length+1; i++){
double x = (i==0)? 1: input[i-1];
sum += coefficients[i]*x;
}
return Math.max(sum, 0);
}
public static MultivariableRegression train(double[][] inputs, double[] outputs){
long time = System.currentTimeMillis();
double[] coefficients = new double[inputs[0].length+1];
double[] tempCoeffs = new double[inputs[0].length+1];
double maxGradient = 10;
double alpha = 1;
for(int i = 0; i < coefficients.length; i++){
coefficients[i] = 0;
}
double oldCost = calculateCost(inputs, coefficients, outputs);
while(maxGradient > .0001){
maxGradient = 0;
for(int i = 0; i < coefficients.length; i++){
double sum = 0;
for(int j = 0; j < inputs.length; j++){
double h = functionValue(inputs[j], coefficients);
double x = (i==0)? 1 : inputs[j][i-1];
sum += (h-outputs[j])*x;
}
double gradient = sum/inputs.length;
tempCoeffs[i] = coefficients[i] - alpha*gradient;
maxGradient = Math.max(maxGradient, Math.abs(gradient));
}
for(int i = 0; i < coefficients.length; i++){
coefficients[i] = tempCoeffs[i];
}
double newCost = calculateCost(inputs, coefficients, outputs);
if(newCost > oldCost){
System.out.println("Cost function incresing. Probably aplha too big");
break;
}
else{
oldCost = newCost;
}
}
System.out.println("Learning time: " + (System.currentTimeMillis()-time)/1000);
return new MultivariableRegression(coefficients);
}
private static double calculateCost(double[][] inputs, double[] coeffs, double[] outputs) {
double sum = 0;
for(int j = 0; j < inputs.length; j++){
double h = functionValue(inputs[j], coeffs);
sum += Math.pow((h-outputs[j]), 2);
}
return sum;
}
private static double functionValue(double[] params, double[] coefficients) {
double sum = 0;
for(int i = 0; i < coefficients.length && i < params.length+1; i++){
double x = (i==0)? 1: params[i-1];
sum += coefficients[i]*x;
}
return sum;
}
@Override
public void printModel() {
for(int i = 0; i < coefficients.length; i++){
System.out.println(i + ": " + coefficients[i]);
}
}
public static MultivariableRegression trainByNormalEquation(Dataset dataset){
OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
regression.newSampleData(dataset.getOutputs(), dataset.getFeatures());
return new MultivariableRegression(regression.estimateRegressionParameters());
}
public void save(String filename){
try{
FileOutputStream fos = new FileOutputStream(filename, false);
Writer writer = new OutputStreamWriter(fos, "UTF-8");
for(int i = 0; i < coefficients.length; i++){
writer.write(Double.toString(coefficients[i]));
if(i+1 < coefficients.length){
writer.write(",");
}
}
writer.close();
}catch (Exception e){
System.err.println("Error: " + e.getMessage());
}
}
public static MultivariableRegression createFromFile(String filename){
double[] coeff = null;
try{
BufferedReader reader = new BufferedReader(new FileReader(filename));
String content = reader.readLine();
String[] tokens = content.split(",");
coeff = new double[tokens.length];
for(int i = 0; i < tokens.length; i++){
coeff[i] = Double.parseDouble(tokens[i]);
}
reader.close();
}catch (Exception e){
System.err.println("Error: " + e.getMessage());
}
return new MultivariableRegression(coeff);
}
}