package com.matrobot.gha.insights.ml;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.List;
public class Dataset {
private List<Sample> samples = new ArrayList<Sample>();
private int featureCount = 0;
private double[] maxValues;
private List<String> featureNames = new ArrayList<String>();
public Dataset(int featureCount){
this.featureCount = featureCount;
for(int i = 0; i < featureCount; i++){
featureNames.add("Feature " + (i+1));
}
}
/**
* Add feature names.
* @param featureNames
*/
public void addFeatureNames(List<String> names){
featureNames.clear();
for(String name : names){
featureNames.add(name);
}
}
public void addSample(double[] features, double output){
assert(features.length == featureCount);
Sample sample = new Sample(features, output);
addSample(sample);
}
public int size(){
return samples.size();
}
public int getFeatureCount() {
return featureCount;
}
public List<Sample> getData(){
return samples;
}
public void normalize() {
maxValues = new double[featureCount];
for(int i = 0; i < featureCount; i++){
maxValues[i] = 0;
}
// Calculate max value
for(Sample sample : samples){
for(int i = 0; i < featureCount; i++){
double value = Math.abs(sample.features[i]);
if(value > maxValues[i]){
maxValues[i] = value;
}
}
}
// scale
for(Sample sample : samples){
sample.features = normalize(sample.features);
}
}
public double[] normalize(double[] features) {
for(int i = 0; i < featureCount; i++){
features[i] = features[i]/maxValues[i];
}
return features;
}
public void addSample(Sample sample) {
samples.add(sample);
}
public double[] getOutputs(){
double[] outputs = new double[samples.size()];
for(int i = 0; i < samples.size(); i++){
outputs[i] = samples.get(i).output;
}
return outputs;
}
public double[] getFeatureFromSamples(int index){
double[] outputs = new double[samples.size()];
for(int i = 0; i < samples.size(); i++){
outputs[i] = samples.get(i).features[index];
}
return outputs;
}
public double[][] getFeatures() {
double[][] outputs = new double[samples.size()][featureCount];
for(int i = 0; i < samples.size(); i++){
for(int j = 0; j < featureCount; j++){
outputs[i][j] = samples.get(i).features[j];
}
}
return outputs;
}
public void saveAsCSV(String filepath) throws IOException{
FileOutputStream fos = new FileOutputStream(filepath, false);
Writer writer = new OutputStreamWriter(fos, "UTF-8");
for(String name : featureNames){
writer.write(name + ",");
}
writer.write("output\n");
for(Sample sample : samples){
String line = "";
for(int i = 0; i < featureCount; i++){
line += sample.features[i] + ",";
}
line += sample.output + "\n";
writer.write(line);
}
writer.close();
}
}