package org.wikibrain.sr.dataset;
import org.apache.commons.lang3.StringUtils;
import org.wikibrain.core.lang.Language;
import org.wikibrain.sr.utils.KnownSim;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* A gold standard dataset in some language.
*
* @author Ben Hillmann
* @author Matt Lesicko
*/
public class Dataset {
private String name;
private Language language;
private List<KnownSim> data;
public Dataset(String name, Language language) {
this.name = name;
this.language = language;
this.data = new ArrayList<KnownSim>();
}
public Dataset(String name, Language language, List<KnownSim> data) {
this.name = name;
this.language = language;
this.data = data;
}
/**
* Concatenates a list of datasets into a new merged dataset.
* @param datasets
*/
public Dataset(List<Dataset> datasets) {
this(createJointName(datasets), datasets);
}
/**
* Concatenates a list of datasets into a new merged dataset.
* @param name
* @param datasets
*/
public Dataset(String name, List<Dataset> datasets) {
if (datasets==null||datasets.isEmpty()) {
throw new IllegalArgumentException("Attempted to create dataset from an empty list");
}
this.language = datasets.get(0).getLanguage();
this.data = new ArrayList<KnownSim>();
this.name = name;
for (Dataset dataset : datasets) {
if (dataset.getLanguage()!=language) {
throw new IllegalArgumentException("Dataset language was " + language + " but attempted to add " + dataset.getLanguage());
}
this.data.addAll(dataset.getData());
}
}
public Language getLanguage() {
return language;
}
public List<KnownSim> getData() {
return data;
}
public Dataset prune(double minSim, double maxSim) {
List<KnownSim> pruned = new ArrayList<KnownSim>();
for (KnownSim ks : data) {
if (minSim <= ks.similarity && ks.similarity <= maxSim) {
pruned.add(ks);
}
}
return new Dataset(name + "+pruned", language, pruned);
}
/**
* Shuffles a dataset and splits it into k equally sized subsets, and returns them all
* @param k the number of desired subsets
* @return a list of k equally sized subsets of the original dataset
*/
public List<Dataset> split(int k) {
if (k>data.size()){
k=data.size();
}
List<KnownSim> clone = new ArrayList<KnownSim>();
for (KnownSim ks : data){
clone.add(ks);
}
Collections.shuffle(clone);
List<Dataset> splitSets = new ArrayList<Dataset>();
for (int i=0; i<k; i++) {
splitSets.add(new Dataset(name + "+split-" + i, language));
}
for (int i=0; i< clone.size(); i++) {
splitSets.get(i%k).getData().add(clone.get(i));
}
return splitSets;
}
public String getName() {
return name;
}
private static String createJointName(List<Dataset> datasets) {
List<String> names = new ArrayList<String>();
for (Dataset dataset : datasets) {
names.add(dataset.getName());
}
Collections.sort(names); // makes name order insensitive
return StringUtils.join(names, '+');
}
/**
* Normalizes the range of scores to [0,1]
*/
public void normalize() {
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
for (KnownSim ks : data) {
if (!Double.isNaN(ks.similarity)) {
min = Math.min(ks.similarity, min);
max = Math.max(ks.similarity, max);
}
}
if (max == min) {
throw new IllegalStateException();
}
for (KnownSim ks : data) {
ks.similarity = (ks.similarity - min) / (max-min);
}
}
}