/** * Copyright (c) 2013 Oculus Info Inc. * http://www.oculusinfo.com/ * * Released under the MIT License. * * Permission is hereby granted, free of charge, to any person obtaining a copy of * this software and associated documentation files (the "Software"), to deal in * the Software without restriction, including without limitation the rights to * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies * of the Software, and to permit persons to whom the Software is furnished to do * so, subject to the following conditions: * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ package spimedb.cluster.stats; import org.eclipse.collections.impl.map.mutable.primitive.ObjectDoubleHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import spimedb.cluster.Instance; import spimedb.cluster.feature.Feature; import spimedb.cluster.feature.spatial.TrackFeature; import spimedb.cluster.unsupervised.cluster.Cluster; import spimedb.util.geom.geodesic.Track; import spimedb.util.math.statistics.StatTracker; import java.awt.*; import java.util.*; import java.util.List; import java.util.Map.Entry; /** * This class takes a cluster of tracks, and compiles the various statistics * etc. needed on it for actual use. * * It is meant to be used only after clustering is complete. * * @author nathan */ public class TrackClusterWrapper { private static final Logger LOGGER = LoggerFactory.getLogger(TrackClusterWrapper.class); private static final double EPSILON = 1E-12; // The tracks in this cluster, from the closes private List<Track> _tracks; // The mean of everything, ignoring nothing private Track _trueMean; // The distance of each track from the true mean private Map<Track, Double> _distancesFromTrueMean; // The proportion of outliers to ignore private double _outlierIgnoreRatio; // The mean, ignoring outliers private Track _practicalMean; // The distance of each track from the practical mean private Map<Track, Double> _distanceFromPracticalMean; // Standard deviation of tracks from the practical (not the true) mean. private double _standardDeviation; // A (user-set) color to be associated with this cluster private Color _clusterColor; // A (user-set) name to be associated with this cluster private String _clusterName; // Statistics kept on member tracks private final Map<String, StatTracker> _statistics; public TrackClusterWrapper (Cluster cluster) { this(Collections.singleton(cluster)); } public TrackClusterWrapper (Collection<Cluster> clusters) { _statistics = new HashMap<>(); _clusterName = null; _clusterColor = null; initializeTracks(clusters); setOutlierIgnoreRatio(0.2); compileStatistics(); } private void initializeTracks (Collection<Cluster> clusters) { _tracks = new ArrayList<>(); _trueMean = null; // Pull all tracks out of the cluster, and calculate their true mean for (Cluster cluster: clusters) { for (Instance instance : cluster.getMembers()) { for (Feature f : instance.getAllFeatures()) { if (f instanceof TrackFeature) { TrackFeature tf = (TrackFeature) f; Track track = tf.getValue(); if (null == _trueMean) { _trueMean = track; } else { _trueMean = _trueMean.weightedAverage(track, _tracks.size(), 1); } _tracks.add(track); } } } } // Sort them in order of distance from the true mean _distancesFromTrueMean = new HashMap<>(); for (Track track: _tracks) { _distancesFromTrueMean.put(track, track.getDistance(_trueMean)); } _tracks.sort(Comparator.comparing(_distancesFromTrueMean::get)); } /** * Causes the given proportion of tracks farthest from the true mean to be * ignored when calculating the practical mean. * * @param ratio * A number in the range [0.0, 1.0). Floor(proportion times * number of tracks) will be ignored, so that at least one track * is always used, so the practical mean is never null. */ public void setOutlierIgnoreRatio (double ratio) { if (ratio < 0.0) { LOGGER.warn("Illegal outlier ignore ratio {}: Must be >= 0.0", ratio); return; } if (ratio >= 1.0) { LOGGER.warn("Illegal outlier ignore ratio {}: Must be < 1.0", ratio); return; } if (Math.abs(_outlierIgnoreRatio-ratio) < EPSILON) // No change return; _outlierIgnoreRatio = ratio; // Recalculate our practical mean int toKeep = (int) Math.ceil(_tracks.size()*(1.0-_outlierIgnoreRatio)); _practicalMean = null; for (int i=0; i<toKeep; ++i) { Track track = _tracks.get(i); if (0 == i) { _practicalMean = track; } else { _practicalMean = _practicalMean.weightedAverage(track, i, 1); } } // Precalculate distances from practical mean and total standard deviation _distanceFromPracticalMean = new HashMap<>(); double variance = 0.0; for (Track track: _tracks) { double distance = track.getDistance(_practicalMean); _distanceFromPracticalMean.put(track, distance); variance += distance*distance; } variance /= _tracks.size(); _standardDeviation = Math.sqrt(variance); } private void compileStatistics () { for (Track track: _tracks) { ObjectDoubleHashMap<String> trackStats = track.getStatistics(); trackStats.forEachKeyValue((statName,statVal) -> { if (!_statistics.containsKey(statName)) _statistics.put(statName, new StatTracker()); _statistics.get(statName).addStat(statVal); }); } } public double getDistance (Track track) { if (_distanceFromPracticalMean.containsKey(track)) return _distanceFromPracticalMean.get(track); else return track.getDistance(_practicalMean); } /** * retrieves a list of all tracks in this cluster */ public List<Track> getTracks () { return Collections.unmodifiableList(_tracks); } /** * Retrieves the mean of the cluster, as calculated from those tracks most * towards the mean (@see {@link #setOutlierIgnoreRatio(double)}) */ public Track getMean () { return _practicalMean; } /** * Retrieves the standard deviation of this cluster from its mean, where the * mean is calculated ignoring the furthest tracks from the mean. */ public double getStandardDeviation () { return _standardDeviation; } public Map<String, StatTracker> getStatistics () { return _statistics; } public void setClusterName (String name) { _clusterName = name; } public String getClusterName () { return _clusterName; } public void setClusterColor (Color c) { _clusterColor = c; } public Color getClusterColor () { return _clusterColor; } /** * Gets a short description of this cluster, including name, number of * members, length of mean, and standard deviation */ public String getClusterDescription () { String description = String.format("Cluster %s: %d items, %.1f long, std. dev.=%.4f", _clusterName, _tracks.size(), _practicalMean.getLength(), _standardDeviation); for (Entry<String, StatTracker> stringStatTrackerEntry : _statistics.entrySet()) { StatTracker stat = stringStatTrackerEntry.getValue(); description += String.format("\n\t%s:%.4f\n\t [%.4f to %4f],\n\t sd=%.4f", stringStatTrackerEntry.getKey(), stat.mean(), stat.min(), stat.max(), stat.standardDeviation()); } return description; } }