package com.matrobot.gha.insights.ml; public class NaiveBayes { private int featureCount; private double positiveCount = 1; private double negativeCount = 1; private double[] positiveEvidenceCount; private double[] negativeEvidenceCount; public NaiveBayes(int featureCount){ this.featureCount = featureCount; positiveEvidenceCount = new double[featureCount]; negativeEvidenceCount = new double[featureCount]; for(int i = 0; i < featureCount; i++){ positiveEvidenceCount[i] = 1; negativeEvidenceCount[i] = 1; } } public void train(Sample sample){ if(sample.output == 1){ positiveCount++; for(int i = 0; i < featureCount; i++){ if(sample.features[i] > 0){ positiveEvidenceCount[i] += 1; } } } else{ negativeCount++; for(int i = 0; i < featureCount; i++){ if(sample.features[i] > 0){ negativeEvidenceCount[i] += 1; } } } } protected double getPositiveProb(){ return (positiveCount)/(positiveCount+negativeCount); } protected double getNegativeProb(){ return (negativeCount)/(positiveCount+negativeCount); } public double classify(double[] features) { double positiveScore = calculatePositiveScore(features); double negativeScore = calculateNegativeScore(features); if(positiveScore > negativeScore){ return 1; } else{ return 0; } } private double calculateNegativeScore(double[] features) { double hipothesisProb = getNegativeProb(); double evidenceProb = 1; for(int i = 0; i < features.length; i++){ if(features[i] > 0){ evidenceProb *= negativeEvidenceCount[i]/negativeCount; } else{ evidenceProb *= (negativeCount-negativeEvidenceCount[i])/negativeCount; } } double score = hipothesisProb * evidenceProb; return score; } private double calculatePositiveScore(double[] features) { double hipothesisProb = getPositiveProb(); double evidenceProb = 1; for(int i = 0; i < features.length; i++){ if(features[i] > 0){ evidenceProb *= positiveEvidenceCount[i]/positiveCount; } else{ evidenceProb *= (positiveCount-positiveEvidenceCount[i])/positiveCount; } } double score = hipothesisProb * evidenceProb; return score; } public void printModel() { System.out.println("Positive: " + positiveCount); for(int i = 0; i < positiveEvidenceCount.length; i++){ System.out.println("PF" + (i+1) + ": " + positiveEvidenceCount[i]); } System.out.println("Negative: " + negativeCount); for(int i = 0; i < negativeEvidenceCount.length; i++){ System.out.println("NF" + (i+1) + ": " + negativeEvidenceCount[i]); } } }