package org.wikibrain.matrix;
import gnu.trove.impl.Constants;
import gnu.trove.impl.hash.TIntHash;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TShortArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SparseMatrixTransposer {
final static Logger LOG = LoggerFactory.getLogger(SparseMatrixTransposer.class);
private SparseMatrixWriter writer;
private SparseMatrix matrix;
private int colIds[];
private TIntIntHashMap colCounts = new TIntIntHashMap();
private int bufferMb;
private int numColsTransposed = 0;
public SparseMatrixTransposer(SparseMatrix m, File f) throws IOException {
this(m, f, defaultBufferSizeInMbs());
}
public SparseMatrixTransposer(SparseMatrix m, File f, int bufferMb) throws IOException {
this.matrix = m;
this.writer = new SparseMatrixWriter(f, m.getValueConf());
this.bufferMb = bufferMb;
this.numColsTransposed = 0;
}
public void transpose() throws IOException {
countCellsPerColumn();
while (numColsTransposed < colIds.length) {
Map<Integer, RowAccumulator> batch = accumulateBatch();
writeBatch(batch.values());
}
this.writer.finish();
}
private void countCellsPerColumn() throws IOException {
for (int id : matrix.getRowIds()) {
MatrixRow row = matrix.getRow(id);
for (int i = 0; i < row.getNumCols(); i++) {
colCounts.adjustOrPutValue(row.getColIndex(i), 1, 1);
}
}
colIds = colCounts.keys();
LOG.info("found " + colIds.length + " unique column ids in matrix");
Arrays.sort(colIds);
}
protected Map<Integer, RowAccumulator> accumulateBatch() {
Map<Integer, RowAccumulator> transposedBatch = new LinkedHashMap<Integer, RowAccumulator>();
// figure out which columns we are tracking
double mbs = 0;
TIntHashSet colIdsInBatch = new TIntHashSet();
for (int i = numColsTransposed; i < colIds.length; i++) {
int colId = colIds[i];
int colSize = colCounts.get(colId);
double rowMbs = getSizeInMbOfRowDataStructure(colSize);
if (mbs + rowMbs > bufferMb) {
break;
}
colIdsInBatch.add(colId);
mbs += rowMbs;
}
numColsTransposed += colIdsInBatch.size();
LOG.info("processing " + colIdsInBatch.size() + " columns in batch (total=" + numColsTransposed + " of " + colCounts.size() + ")");
for (SparseMatrixRow row : matrix) {
int rowId = row.getRowIndex();
for (int i = 0; i < row.getNumCols(); i++) {
int colId = row.getColIndex(i);
if (!colIdsInBatch.contains(colId)) {
continue;
}
short colValue = row.getPackedColValue(i);
if (!transposedBatch.containsKey(colId)) {
transposedBatch.put(colId, new RowAccumulator(colId));
}
transposedBatch.get(colId).addCol(rowId, colValue);
}
}
for (int id : transposedBatch.keySet()) {
if (colCounts.get(id) != transposedBatch.get(id).size()) {
throw new IllegalArgumentException("row size unexpected!");
}
}
return transposedBatch;
}
protected void writeBatch(Collection<RowAccumulator> batch) throws IOException {
for (RowAccumulator ra: batch) {
writer.writeRow(ra.toRow(matrix.getValueConf()));
}
}
private static final int BYTES_PER_REF =
Integer.valueOf(System.getProperty("sun.arch.data.model")) / 8;
private static final int BYTES_PER_OBJECT = 40; // an estimate at overhead
private static final double EXPANSION_FACTOR = 1.0 / Constants.DEFAULT_LOAD_FACTOR;
private double getSizeInMbOfRowDataStructure(int numEntries) {
return (
// row accumulator object itself
BYTES_PER_OBJECT + 4 + 2 * BYTES_PER_REF +
// ids and values in accumulator
EXPANSION_FACTOR * numEntries * (4 + 2)
) / (1024.0 * 1024.0);
}
private static class RowAccumulator {
int id;
TIntArrayList colIds = new TIntArrayList();
TShortArrayList colVals = new TShortArrayList();
RowAccumulator(int id) {
this.id = id;
}
SparseMatrixRow toRow(ValueConf vconf) {
return new SparseMatrixRow(vconf, id, colIds.toArray(), colVals.toArray());
}
void addCol(int id, short val) {
this.colIds.add(id);
this.colVals.add(val);
}
int size() { return this.colIds.size(); }
}
/**
* Calculates a reasonable buffer size for transposing the matrix.
* If the heapsize < 1000, returns 1/3 of the heapsize.
* Otherwise return (heapsize/6), but truncated to the range [350MB, 5000MB].
* @return The default heapsize, in MBs.
*/
private static int defaultBufferSizeInMbs() {
int totalMem = (int) (Runtime.getRuntime().maxMemory() / (1024*1024));
if (totalMem < 1000) {
return totalMem / 3;
} else {
int size = totalMem / 6;
if (size < 350) size = 350;
if (size > 5000) size = 5000;
return size;
}
}
public static void main(String args[]) throws IOException {
int bufferMbs = 0;
if (args.length == 2) {
bufferMbs = defaultBufferSizeInMbs();
} else if (args.length == 3) {
bufferMbs = Integer.valueOf(args[2]);
} else {
System.err.println("usage: java " + SparseMatrixTransposer.class + " input_path output_path {buffer_in_MBs}");
System.exit(1);
}
SparseMatrix matrix = new SparseMatrix(new File(args[0]));
SparseMatrixTransposer transposer = new SparseMatrixTransposer(matrix, new File(args[1]), bufferMbs);
transposer.transpose();
}
}