/**
*
*/
package org.streaminer.stream.classifier;
import org.streaminer.stream.learner.LearnerUtils;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.model.Distribution;
import org.streaminer.stream.model.NominalDistributionModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author chris
*
*/
public class MajorityClass extends AbstractClassifier<Data, String> {
/** The unique class ID */
private static final long serialVersionUID = 2843128554725059166L;
static Logger log = LoggerFactory.getLogger( MajorityClass.class );
Distribution<String> dist = new NominalDistributionModel<String>();
String labelAttribute = null;
/**
* @see stream.learner.AbstractClassifier#learn(java.lang.Object)
*/
@Override
public void learn(Data item) {
if( labelAttribute == null ){
labelAttribute = LearnerUtils.detectLabelAttribute( item );
}
if( labelAttribute == null ){
log.warn( "Ignoring unlabeled example (no label defined for MajorityClass learner)!" );
return;
}
dist.update( item.get( labelAttribute ).toString() );
}
/**
* @see stream.learner.AbstractClassifier#predict(java.lang.Object)
*/
@Override
public String predict(Data item) {
if( labelAttribute == null )
log.error( "No label-attribute defined!" );
if( dist.getCount() == 0 )
return "?";
String major = null;
Integer max = null;
for( String key : dist.getElements() ){
if( major == null || dist.getCount( key ) > max ){
major = key;
max = dist.getCount( key );
}
}
return major;
}
/**
* @return the labelAttribute
*/
public String getLabelAttribute() {
return labelAttribute;
}
/**
* @param labelAttribute the labelAttribute to set
*/
public void setLabelAttribute(String labelAttribute) {
this.labelAttribute = labelAttribute;
}
}