/**
* Copyright (c) 2013 Oculus Info Inc.
* http://www.oculusinfo.com/
*
* Released under the MIT License.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is furnished to do
* so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package spimedb.cluster;
import spimedb.cluster.feature.Feature;
import spimedb.cluster.feature.numeric.NumericVectorFeature;
import java.io.Serializable;
import java.util.*;
/***
* DataSet represents a structured collection of data instances that is the input
* to oculus machine learning tools.
*
* @author slangevin
*
*/
public class DataSet implements Serializable, Iterable<Instance> {
private static final long serialVersionUID = -544003860939601075L;
private final Map<String, Instance> map = new LinkedHashMap<>();
public Set<String> getKeys() {
return map.keySet();
}
/***
* Add an Instance to the Dataset
* @param inst the Instance to add
* @return whether an instance was replace with the same id
*/
public boolean add(Instance inst) {
return (map.put(inst.getId(), inst) != null);
}
/***
* Remove an instance from the Dataset
* @param inst the Instance to remove
* @return the removed Instance
*/
public Instance remove(Instance inst) {
return map.remove(inst.getId());
}
/***
* Get an instance by Id
* @param id the id of the Instance to return
* @return the Instance with id or null if no Instance exists
*/
public Instance get(String id) {
return map.get(id);
}
/***
* Return the number of Instances in the DataSet
* @return
*/
public int size() {
return map.size();
}
/***
* Return a DataSet with random fraction number of Instances as this DataSet
* @param fraction fraction of Instances to return in the range of 0 and 1
* @return resulting sample DataSet
*/
public DataSet sample(double fraction) {
DataSet sample = new DataSet();
// max fraction is 1
if (fraction > 1) fraction = 1;
// no samples are requested - return an empty data set
if (fraction <= 0) return sample;
long numToSample = Math.round( map.size() * fraction );
ArrayList<String> keys = new ArrayList<>(map.keySet());
// randomly pick k instances as the initial k means
ArrayList<String> indexes = new ArrayList<>(keys.size());
for (int i = 0; i < indexes.size(); i++) {
indexes.add( keys.get(i) );
}
Collections.shuffle(indexes); // permutate the indexes
for (int i = 0; i < numToSample; i++) {
sample.add( get(indexes.get(i)) );
}
return sample;
}
private static void swap(Instance a, Instance b) {
Instance tmp = a;
a = b;
b = tmp;
}
private static void shuffle(Instance[] array) {
Random rnd = new Random();
int n = array.length;
while (n > 1) {
int k = rnd.nextInt(n);
n--;
swap(array[n], array[k]);
}
}
/***
* Randomly split this DataSet into n similarly sized DataSets.
* @param n the number of folds to split the DataSet - value must be greater than 1 and not greater than the number of instances in DataSet
* @return a list of n DataSets
*/
public List<DataSet> nFolds(int n) {
// Make sure n is valid: each fold must have at least one instance!
if (n > size() || n < 1) return null;
List<DataSet> folds = new LinkedList<>();
Instance[] instances = new Instance[size()];
instances = map.values().toArray(instances);
shuffle(instances);
int sliceSize = size() / n;
int extra = size() % n;
// create n folds
for (int i=0; i < n; i++) {
DataSet fold = new DataSet();
folds.add(fold);
int start = i*sliceSize;
int end = start + sliceSize;
for (int j=start; j < end; j++) {
fold.add(instances[j]);
}
}
// evenly distribute any extra instances
for (int i=0; i < extra; i++) {
int offset = instances.length - extra + i;
folds.get(i).add(instances[offset]);
}
return folds;
}
/***
* Normalize the specified Feature for all Instances in this DataSet.
*
* Currently only NumericVectorFeature types are supported.
*
* @param featureName the name of the feature to normalize
*/
public void normalizeInstanceFeature(String featureName) {
List<Feature> allFeatures = new ArrayList<>();
// gather up all matching features
for (Instance inst : this) {
if (inst.containsFeature(featureName)) {
allFeatures.add( inst.getFeature(featureName) );
}
}
if (allFeatures.isEmpty()) return;
double N = allFeatures.size();
// currently only support normalizing numeric vector features
if (!(allFeatures.get(0) instanceof NumericVectorFeature)) return;
double[] meanVector = ((NumericVectorFeature)allFeatures.get(0)).getValue().clone();
// compute mean of feature
for (int i=1; i < allFeatures.size(); i++) {
NumericVectorFeature v = (NumericVectorFeature)allFeatures.get(i);
double[] vals = v.getValue();
for (int j=0; j < meanVector.length; j++) {
meanVector[j] += vals[j];
}
}
for (int i=0; i < meanVector.length; i++) {
meanVector[i] /= N;
}
double[] stdevVector = new double[meanVector.length];
// compute stdev of feature
for (int i=0; i < allFeatures.size(); i++) {
NumericVectorFeature v = (NumericVectorFeature)allFeatures.get(i);
double[] vals = v.getValue();
for (int j=0; j < meanVector.length; j++) {
stdevVector[j] += (vals[j] - meanVector[j])*(vals[j] - meanVector[j]);
}
}
for (int i=0; i < stdevVector.length; i++) {
stdevVector[i] = Math.sqrt( stdevVector[i] / (N-1) );
}
// normalize each feature vector
for (int i=0; i < allFeatures.size(); i++) {
NumericVectorFeature v = (NumericVectorFeature)allFeatures.get(i);
double[] vals = v.getValue();
for (int j=0; j < vals.length; j++) {
vals[j] = (vals[j] - meanVector[j]) / stdevVector[j];
}
}
}
@Override
public Iterator<Instance> iterator() {
return new DataSetIterator(map);
}
/***
* Add a collection of Instances to the DataSet
* @param c the Instances to add
* @return true if an existing Instance with a matching id in the DataSet was replaced
*/
public boolean addAll(Collection<Instance> c) {
boolean altered = false;
for (Instance i : c) {
if ( add(i) ) altered = true;
}
return altered;
}
/***
* Remove all Instances from this DataSet
*/
public void clear() {
map.clear();
}
/***
* Return true if the DataSet contains the specified Instance
* @param inst the Instance to test
* @return true if the Instance is a member of the DataSet
*/
public boolean contains(Instance inst) {
//return map.containsKey(inst);
return map.containsValue(inst);
}
/***
* Return true if the DataSet contains all the specified Instances
* @param c the Instances to test
* @return true if the Instances are all members of the DataSet
*/
public boolean containsAll(Collection<Instance> c) {
for (Instance i : c) {
if (!map.containsKey(i.getId())) return false;
}
return true;
}
/***
* Returns true if this DataSet contains no Instances
* @return true if this DataSet contains no Instances
*/
public boolean isEmpty() {
return map.isEmpty();
}
}