package org.seqcode.ml.clustering.hierarchical; import java.util.Collection; import java.util.Vector; import org.seqcode.ml.clustering.Cluster; import org.seqcode.ml.clustering.ClusterRepresentative; import org.seqcode.ml.clustering.ClusteringMethod; import org.seqcode.ml.clustering.PairwiseElementMetric; import org.seqcode.ml.clustering.SingletonCluster; /** * @author Timothy Danford * */ public class HierarchicalClustering<X> implements ClusteringMethod<X> { private ClusterRepresentative<X> repr; private PairwiseElementMetric<X> metric; private double maxDistanceToAccept; public HierarchicalClustering(ClusterRepresentative<X> rep, PairwiseElementMetric<X> m) { repr = rep; metric = m; maxDistanceToAccept = Double.MAX_VALUE; } public void setMaxDistanceToAccept(double d) { maxDistanceToAccept = d; } /* (non-Javadoc) * @see org.seqcode.gse.clustering.ClusteringMethod#clusterElements(java.util.Collection) */ public Collection<Cluster<X>> clusterElements(Collection<X> elmts) { Vector<Cluster<X>> clusters = new Vector<Cluster<X>>(); Vector<X> reps = new Vector<X>(); Double distances[][] = new Double[elmts.size()][elmts.size()]; for(X ce : elmts) { Cluster<X> c= new SingletonCluster<X>(ce); clusters.add(c); X repMember = repr.getRepresentative(c); reps.add(repMember); } for (int i = 0; i < elmts.size(); i++) { for (int j = 0; j < elmts.size(); j++) { X r1 = reps.get(i); X r2 = reps.get(j); distances[i][j] = metric.evaluate(r1,r2); } } int nclusters = clusters.size(); while(nclusters > 1) { int mini = -1, minj = -1; double mindist = Double.MAX_VALUE; for(int i = 0; i < clusters.size() - 1; i++) { if (clusters.get(i) == null) {continue;} for(int j = i + 1; j < clusters.size(); j++) { if (clusters.get(j) == null) {continue;} double d; if (Double.isNaN(distances[i][j])) { X r1 = reps.get(i); X r2 = reps.get(j); d = metric.evaluate(r1, r2); distances[i][j] = d; } else { d = distances[i][j]; } if(!Double.isNaN(d) && d < mindist) { mindist = d; mini = i; minj = j; } } } if (mini == -1) { break; } if (mindist > maxDistanceToAccept) { break; } Cluster<X> left = clusters.get(mini), right = clusters.get(minj); if (left == null || right == null) { throw new NullPointerException("left is " + left + " + from " + mini +". right is " + right + " from " + minj); } ClusterNode<X> node = new ClusterNode<X>(left, right); clusters.set(minj,null); reps.set(minj,null); for (int i = 0; i < elmts.size(); i++) { distances[i][minj] = Double.NaN; distances[minj][i] = Double.NaN; distances[i][mini] = Double.NaN; distances[mini][i] = Double.NaN; } clusters.set(mini,node); reps.set(mini,repr.getRepresentative(node)); nclusters--; //System.out.println("# Clusters: " + nclusters + "(" + mindist + ")"); } Vector<Cluster<X>> output = new Vector<Cluster<X>>(); for (int i = 0; i < clusters.size(); i++) { if (clusters.get(i) != null) { output.add(clusters.get(i)); } } return output; } }