package org.wikibrain.matrix;
import gnu.trove.map.TIntFloatMap;
import org.apache.commons.lang3.ArrayUtils;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.ShortBuffer;
import java.util.LinkedHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A single sparse matrix row backed by a byte buffer. The row contains:
* - a row id (int),
* - a set of n columns, each with an id (int) and value (float packed into two bytes)
*
* The row can either be created from the component data, or from a byte buffer.
* This means that the object can wrap data from an mmap'd file in the correct format.
*
* Newly created rows are reordered so that the columns appear in sorted order.
*/
public final class SparseMatrixRow extends BaseMatrixRow implements MatrixRow {
private static final Logger LOG = LoggerFactory.getLogger(SparseMatrixRow.class);
public static final Float MIN_SCORE = -1.1f;
public static final Float MAX_SCORE = 1.1f;
public static final Float SCORE_RANGE = (MAX_SCORE - MIN_SCORE);
public static final int PACKED_RANGE = (Short.MAX_VALUE - Short.MIN_VALUE);
public static final int HEADER = 0xfefefefe;
/**
* The main "source" buffer.
*/
private ByteBuffer buffer;
/**
* A view buffer that points to the header.
*/
private IntBuffer headerBuffer;
/**
* A view buffer that points to the ids.
*/
private IntBuffer idBuffer;
/**
* A view buffer that points to the values.
*/
private ShortBuffer valBuffer;
private ValueConf vconf;
public SparseMatrixRow(ValueConf vconf, int rowIndex, TIntFloatMap row) {
this(vconf, rowIndex, row.keys(), row.values());
}
public SparseMatrixRow(ValueConf vconf, int rowIndex, LinkedHashMap<Integer, Float> row) {
this(vconf, rowIndex,
ArrayUtils.toPrimitive(row.keySet().toArray(new Integer[] {})),
ArrayUtils.toPrimitive(row.values().toArray(new Float[]{}))
);
}
public SparseMatrixRow(ValueConf vconf, int rowIndex, int colIds[], float colVals[]) {
this.vconf = vconf;
short packed[] = new short[colVals.length];
for (int i = 0; i < colVals.length; i++) {
packed[i] = vconf.pack(colVals[i]);
}
createBuffer(rowIndex, colIds, packed);
}
public SparseMatrixRow(ValueConf vconf, int rowIndex, int colIds[], short colVals[]) {
this.vconf = vconf;
createBuffer(rowIndex, colIds, colVals);
}
public void createBuffer(int rowIndex, int colIds[], short colVals[]) {
assert(colIds.length == colVals.length);
if (!isNonDecreasing(colIds)) {
quickSort(colIds, colVals, 0, colIds.length - 1);
if (!isNonDecreasing(colIds)) {
throw new IllegalStateException();
}
}
buffer = ByteBuffer.allocate(
4 + // header
4 + // row index
4 + // num cols
4 * colVals.length + // col indexes
2 * colVals.length // col values
);
createViewBuffers(colVals.length);
headerBuffer.put(0, HEADER);
headerBuffer.put(1, rowIndex);
headerBuffer.put(2, colVals.length);
idBuffer.put(colIds, 0, colIds.length);
valBuffer.put(colVals, 0, colVals.length);
}
// Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/
private void quickSort(int colIds[], short colVals[], int low, int high) {
if (colIds.length == 0 || low >= high)
return;
// pick the pivot
int middle = (low + high) / 2;
int pivot = colIds[middle];
// partition around the pivot
int i = low, j = high;
while (i <= j) {
while (colIds[i] < pivot) {
i++;
}
while (colIds[j] > pivot) {
j--;
}
if (i <= j) {
int temp = colIds[i];
short tempV = colVals[i];
colIds[i] = colIds[j];
colVals[i] = colVals[j];
colIds[j] = temp;
colVals[j] = tempV;
i++;
j--;
}
}
//recursively sort two sub parts
quickSort(colIds, colVals, low, j);
quickSort(colIds, colVals, i, high);
}
static boolean isNonDecreasing(int A[]) {
int lastId = Integer.MIN_VALUE;
for (int i = 0; i < A.length; i++) {
if (A[i] < lastId) {
return false;
}
lastId = A[i];
}
return true;
}
private void createViewBuffers(int numColumns) {
buffer.position(0);
headerBuffer = buffer.asIntBuffer();
buffer.position(3 * 4);
idBuffer = buffer.asIntBuffer();
buffer.position(3 * 4 + numColumns * 4);
valBuffer = buffer.asShortBuffer();
}
/**
* Wrap an existing byte buffer that contains a row.
* @param buffer
*/
public SparseMatrixRow(ValueConf vconf, ByteBuffer buffer) {
this.vconf = vconf;
this.buffer = buffer;
if (this.buffer.getInt(0) != HEADER) {
throw new IllegalArgumentException("Invalid header in byte buffer");
}
createViewBuffers(buffer.getInt(8));
}
@Override
public final int getColIndex(int i) {
return idBuffer.get(i);
}
@Override
public final float getColValue(int i) {
return vconf.unpack(valBuffer.get(i));
}
public final short getPackedColValue(int i) {
return valBuffer.get(i);
}
@Override
public final int getRowIndex() {
return headerBuffer.get(1);
}
@Override
public final int getNumCols() {
return headerBuffer.get(2);
}
public ByteBuffer getBuffer() {
return buffer;
}
public ValueConf getValueConf() {
return vconf;
}
}