package gr.iti.mklab.visual.quantization;
import gr.iti.mklab.visual.aggregation.AbstractFeatureAggregator;
import gr.iti.mklab.visual.datastructures.Linear;
import gr.iti.mklab.visual.utilities.RandomPermutation;
import gr.iti.mklab.visual.utilities.RandomRotation;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import weka.clusterers.AbstractClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.SelectedTag;
/**
* This class is used to learn a Product Quantizer from a set of vectors that are stored in a {@link Linear}
* index (BDB store).
*
* @author Eleftherios Spyromitros-Xioufis
*
*/
public class ProductQuantizationLearning {
/** the number of k-means executions for each sub-quantizer */
public static int numKmeansRepeats = 1;
/**
* The various options can be given from the command line.
*
* @param args
* @throws Exception
*/
public static void main(String args[]) throws Exception {
// create Options object
Options options = new Options();
// add options
options.addOption("path", true,
"path to the Linear BDB index that contains the training vectors: e.g. C:/lef/BDB_1024/");
options.addOption("d", true, "the dimensionality of the training vectors: e.g. 128 or 1024");
options.addOption("m", true,
"the number of subvectors to be created (d should be divided exactly by m)");
options.addOption("c", true, "the number of centroids of each subquantizer: e.g. 256 or 1024");
options.addOption("transform", true,
"the type of transformation to apply before learning the product quantizer");
options.addOption("samples", true, "how many learning vectors to use: e.g. 20000");
options.addOption("i", true, "the maximum number of clustering iterations (default 100).");
options.addOption("s", true, "the number of parallel execution slots to use in k-means clustering");
options.addOption("split", false,
"whether to also split the training vectors before learning the sub-quantizers");
options.addOption("ivf", false,
"whether product quantization will be combined with an inverted file (ivf)");
options.addOption("cqfile", true, "path to the coarse quantizer file");
options.addOption("cqcentroids", true, "number of coarse quantizer centroids: e.g. 8192");
CommandLineParser parser = new PosixParser();
CommandLine cmd = null;
try {
cmd = parser.parse(options, args);
} catch (ParseException e) {
System.out.println(e.getMessage());
System.exit(0);
}
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp("Product quantization learning", options, true);
// the following two parameters get the default value of false is they are
// not passed in the console
boolean split = cmd.hasOption("split");
if (split) {
System.out.println("Vectors will be splitted");
} else {
System.out.println("Splitting will not be performed");
}
String transformationType = null;
if (cmd.getOptionValue("transform").equals("no")) {
System.out.println("No transformation will be applied");
transformationType = "no";
} else if (cmd.getOptionValue("transform").equals("rr")) {
System.out.println("Random Rotation will be applied");
transformationType = "rr";
} else if (cmd.getOptionValue("transform").equals("rp")) {
System.out.println("Random Permutation will be applied");
transformationType = "rp";
} else {
throw new Exception("Unsupported transformation type!");
}
boolean ivf = cmd.hasOption("ivf");
String coarseQuantizerFile = "";
int numCoarseCentroids = 0;
if (ivf) {
System.out.println("IVF will be used");
// parse the other ivf related options
if (cmd.getOptionValue("cqfile") == null) {
throw new Exception("IVF selected but coarse quantizer file not given.");
}
coarseQuantizerFile = cmd.getOptionValue("cqfile");
if (cmd.getOptionValue("cqcentroids") == null) {
throw new Exception("IVF selected but number of coarse quantizer centroids not given.");
}
numCoarseCentroids = Integer.parseInt(cmd.getOptionValue("cqcentroids"));
}
int numVectors;
if (cmd.getOptionValue("samples") == null) {
System.out.println("Using the default 50000 samples for learning.");
numVectors = 50000;
} else {
numVectors = Integer.parseInt(cmd.getOptionValue("samples"));
System.out.println("Using " + numVectors + " samples for learning.");
}
int maxIterations;
if (cmd.getOptionValue("i") == null) {
System.out.println("Using the default 100 max iterations.");
maxIterations = 100;
} else {
maxIterations = Integer.parseInt(cmd.getOptionValue("i"));
System.out.println("Using " + maxIterations + " maximum iterations.");
}
String BDBpath = cmd.getOptionValue("path");
if (cmd.getOptionValue("path") == null) {
throw new Exception("The path to the training vectors is undefined!");
}
int numSlots = 1;
if (cmd.getOptionValue("s") == null) {
System.out.println("Using 1 execution slot!");
}
numSlots = Integer.parseInt(cmd.getOptionValue("s"));
int vectorLength;
if (cmd.getOptionValue("d") == null) {
throw new Exception("The dimensionality of the training vectors is undefined!");
}
vectorLength = Integer.parseInt(cmd.getOptionValue("d"));
int numProductCentroids;
if (cmd.getOptionValue("d") == null) {
throw new Exception("The number of centroids for each sub-quantizer is undefined!");
}
numProductCentroids = Integer.parseInt(cmd.getOptionValue("c"));
int m;
if (cmd.getOptionValue("m") == null) {
throw new Exception("The number of sub-vectors is undefined!");
}
m = Integer.parseInt(cmd.getOptionValue("m"));
// checking that m is valid
int subVectorLength;// the dimensionality of the subvectors
if (vectorLength % m != 0) { // d is not a multiple of m
throw new Exception("d is not a multiple of m");
}
subVectorLength = vectorLength / m;
RandomRotation rr = null;
RandomPermutation rp = null;
if (transformationType.equals("rr")) {
rr = new RandomRotation(1, vectorLength);
} else if (transformationType.equals("rp")) {
rp = new RandomPermutation(1, vectorLength);
}
System.out.println("== Creating subquantizers using " + numVectors + " vectors ==");
System.out.println("Vector dimensionality: " + vectorLength);
System.out.println("Sub vector dimensionality: " + subVectorLength);
System.out.println("Num centroids: " + numProductCentroids);
System.out.println("Max iterations: " + maxIterations);
// create a single file to store all the sub-quantizers of the product quantizer
// construct the filename
String subquantizersFilename = BDBpath + "/pq_" + vectorLength + "_" + m + "x"
+ (int) (Math.log(numProductCentroids) / Math.log(2)) + "_" + numVectors;
if (transformationType.equals("rr")) {
subquantizersFilename += "_rr";
} else if (transformationType.equals("rp")) {
subquantizersFilename += "_rp";
}
if (ivf) {
subquantizersFilename += "_ivf_c" + numCoarseCentroids;
}
subquantizersFilename += ".csv";
BufferedWriter out = new BufferedWriter(new FileWriter(new String(subquantizersFilename)));
// load the vectors from the Linear index (BDB store)
Linear vectors = new Linear(vectorLength, numVectors, true, BDBpath, false, true, 0);
// in case of ivf, load the coarse quantizer so that residual vectors can be computed.
ResidualVectorComputation res = null;
if (ivf) {
double[][] coarseQuantizer = AbstractFeatureAggregator.readQuantizer(coarseQuantizerFile,
numCoarseCentroids, vectorLength);
res = new ResidualVectorComputation(coarseQuantizer, vectorLength, numCoarseCentroids);
}
// create one Instances object for learning each sub-quantizer
Instances[] datasets = new Instances[m];
// creating weka attributes
ArrayList<Attribute> attributes = new ArrayList<Attribute>();
for (int i = 0; i < subVectorLength; i++) {
Attribute attr = new Attribute("feature" + (i + 1));
attributes.add(attr);
}
for (int i = 0; i < m; i++) {
datasets[i] = new Instances("subvectors", attributes, numVectors);
}
// load the sub-vectors into the corresponding Instances objects
for (int i = 0; i < numVectors; i++) {
double[] vector = vectors.getVector(i);
// compute the residual vector if needed
if (ivf) {
vector = res.ComputeResidualVector(vector);
}
// apply the appropriate transformation
if (transformationType.equals("rr")) {
vector = rr.rotate(vector);
} else if (transformationType.equals("rp")) {
vector = rp.permute(vector);
}
// load each sub-vector into the corresponding Instances object
for (int j = 0; j < m; j++) {
double[] subVector = Arrays.copyOfRange(vector, j * subVectorLength, j * subVectorLength
+ subVectorLength);
DenseInstance instance = new DenseInstance(1.0, subVector);
datasets[j].add(instance);
}
}
// learning the sub-quantizers with k-means clustering
for (int i = 0; i < m; i++) {
System.out.println("Learning sub-quantizer " + (i + 1));
double minSSE = Double.MAX_VALUE;
SimpleKMeans bestClusterer = null;
// try k-mean using numKmeansRepeats different random seeds and keep the one with lowest SSE
for (int j = 0; j < numKmeansRepeats; j++) {
// Create a new k-means instance
SimpleKMeans clusterer = new SimpleKMeans();
clusterer.setInitializationMethod(new SelectedTag(SimpleKMeans.KMEANS_PLUS_PLUS,
SimpleKMeans.TAGS_SELECTION));
clusterer.setNumExecutionSlots(numSlots);
clusterer.setNumClusters(numProductCentroids);
clusterer.setMaxIterations(maxIterations);
// clusterer.setDebug(false);
clusterer.setSeed(j + 1);
// build the clusterer
clusterer.buildClusterer(datasets[i]);
double SSE = clusterer.getSquaredError();
if (SSE < minSSE) {
minSSE = SSE;
bestClusterer = (SimpleKMeans) AbstractClusterer.makeCopy(clusterer);
}
}
System.out.println("Mininum SSE: " + minSSE + " Seed: " + bestClusterer.getSeed());
System.out.println("Saving best sub-quantizer in file..");
// write the results of the clustering to the new file (csv formated)
Instances clusterCentroids = bestClusterer.getClusterCentroids();
for (int j = 0; j < clusterCentroids.numInstances(); j++) {
Instance centroid = clusterCentroids.instance(j);
for (int k = 0; k < centroid.numAttributes() - 1; k++) {
out.write(centroid.value(k) + ",");
}
out.write(centroid.value(centroid.numAttributes() - 1) + "\n");
}
// check whether fewer than the desired clusters where generated and add fake centroids in that
// case
int numCentroidsMissing = numProductCentroids - clusterCentroids.numInstances();
// add fake, distant centroids so that no vector is quantized in them
if (numCentroidsMissing > 0) {
System.out
.println("Problem! Number of generated clusters is smaller that the desired one. Use more samples!");
System.out.println("Non empty clusters: " + clusterCentroids.numInstances() + " instead of: "
+ numProductCentroids);
// System.exit(1);
// Random rand = new Random(1);
for (int k = 0; k < numCentroidsMissing; k++) {
// int index = rand.nextInt(dataset.numInstances());
// String instanceString = dataset.instance(index).toStringNoWeight();
// outputStream.println(instanceString);
for (int f = 0; f < subVectorLength - 1; f++) {
out.write("1000,");
}
out.write("1000\n");
}
}
out.flush();
}
out.close();
}
}