package com.matrobot.gha.insights.ml;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
import com.matrobot.gha.insights.ml.Dataset;
import com.matrobot.gha.insights.ml.GradientDescentLogistic;
public class GradientDescentLogisticTest {
@Test
public void testRegression1() {
Dataset dataset = new Dataset(2);
double[][] input = {
{0.1, 0.2}, {0.1, 0.3}, {0.1, 0.3},
{2, 2}, {1, 3}, {4, 1},
};
double[] testPositive = {5, 6};
double[] testNegative = {0.5, 0.6};
dataset.addSample(input[0], 0);
dataset.addSample(input[1], 0);
dataset.addSample(input[2], 0);
dataset.addSample(input[3], 1);
dataset.addSample(input[4], 1);
dataset.addSample(input[5], 1);
GradientDescentLogistic gradientDescent = new GradientDescentLogistic();
gradientDescent.setAlpha(1);
gradientDescent.train(dataset);
assertTrue(gradientDescent.predict(testPositive) > 0.5);
assertTrue(gradientDescent.predict(testNegative) < 0.5);
}
}