package cc.mallet.cluster.iterator; import cc.mallet.cluster.Clustering; import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor; import cc.mallet.cluster.util.ClusterUtils; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.util.Randoms; /** * Sample pairs of Instances. * * @author "Aron Culotta" <culotta@degas.cs.umass.edu> * @version 1.0 * @since 1.0 * @see NeighborIterator */ public class PairSampleIterator extends NeighborIterator { protected InstanceList instances; protected Randoms random; protected double positiveProportion; protected int numberSamples; protected int positiveTarget; protected int positiveCount; protected int totalCount; protected int[] nonsingletonClusters; /** * * @param clustering True clustering. * @param random Source of randomness. * @param positiveProportion Proportion of Instances that should be positive examples. * @param numberSamples Total number of samples to generate. * @return */ public PairSampleIterator (Clustering clustering, Randoms random, double positiveProportion, int numberSamples) { super(clustering); this.random = random; this.positiveProportion = positiveProportion; this.numberSamples = numberSamples; this.positiveTarget = (int)(numberSamples * positiveProportion); this.totalCount = this.positiveCount = 0; this.instances = clustering.getInstances(); setNonSingletons(); } private void setNonSingletons () { int c = 0; for (int i = 0; i < clustering.getNumClusters(); i++) if (clustering.size(i) > 1) c++; nonsingletonClusters = new int[c]; c = 0; for (int i = 0; i < clustering.getNumClusters(); i++) if (clustering.size(i) > 1) nonsingletonClusters[c++] = i; } public boolean hasNext () { return totalCount < numberSamples; } public Instance next () { AgglomerativeNeighbor neighbor = null; if (nonsingletonClusters.length>0 && ( positiveCount < positiveTarget || clustering.getNumClusters() == 1)) { //mmwick modified positiveCount++; int label = nonsingletonClusters[random.nextInt(nonsingletonClusters.length)]; int[] instances = clustering.getIndicesWithLabel(label); int ii = instances[random.nextInt(instances.length)]; int ij = instances[random.nextInt(instances.length)]; while (ii == ij) ij = instances[random.nextInt(instances.length)]; neighbor = new AgglomerativeNeighbor(clustering, clustering, ii, ij); } else { int ii = random.nextInt(instances.size()); int ij = random.nextInt(instances.size()); while (clustering.getLabel(ii) == clustering.getLabel(ij)) ij = random.nextInt(instances.size()); neighbor = new AgglomerativeNeighbor(clustering, ClusterUtils.copyAndMergeInstances(clustering, ii, ij), ii, ij); } totalCount++; return new Instance(neighbor, null, null, null); } }