package org.wikibrain.matrix;
import gnu.trove.map.hash.TIntLongHashMap;
import java.io.File;
import java.io.IOException;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A wrapper around a file channel that contains a matrix in row major order.
* Returns rows at a particular offset in the form of ByteBuffers backed by a memory mapped file.
*/
public class MemoryMappedMatrix {
public static final Logger LOG = LoggerFactory.getLogger(MemoryMappedMatrix.class);
public static final int PAGE_SIZE = 1024 * 1024 * 1024; // 1GB
private FileChannel channel;
protected List<MappedBufferWrapper> buffers = new ArrayList<MappedBufferWrapper>();
private File path;
private final IntBuffer rowIds; // row ids in order of offsets on disk
private final LongBuffer rowOffsets; // row offsets associated with sorted ids
private final int numRows;
private WeakReference<int[]> rowIdsInDiskOrder = null;
public MemoryMappedMatrix(File path, FileChannel channel,TIntLongHashMap rowOffsets) throws IOException {
throw new UnsupportedOperationException();
}
public MemoryMappedMatrix(File path, FileChannel channel, IntBuffer rowIds, LongBuffer rowOffsets) throws IOException {
this.path = path;
this.channel = channel;
if (rowIds.capacity() != rowOffsets.capacity()) {
throw new IllegalArgumentException();
}
this.rowIds = rowIds;
this.rowOffsets = rowOffsets;
this.numRows = rowIds.capacity();
int lastId = Integer.MIN_VALUE;
for (int i = 0; i < numRows; i++) {
if (rowIds.get(i) < lastId) {
throw new IllegalArgumentException("Row ids must be in strictly increasing order");
}
lastId = rowIds.get(i);
}
pageInRows();
}
public void close() throws IOException {
for (MappedBufferWrapper buffer : buffers) {
buffer.close();
}
// try to garbage collect any freed buffers
System.gc();
System.gc();
System.gc();
channel.close();
}
private void pageInRows() throws IOException {
if (numRows == 0) {
return;
}
// tricky: pages must align with row boundaries
long sortedOffsets[] = new long[numRows];
for (int i = 0; i < numRows; i++) {
sortedOffsets[i] = rowOffsets.get(i);
}
Arrays.sort(sortedOffsets);
long startPos = sortedOffsets[0];
long lastPos = startPos;
for (int i = 1; i < numRows; i++) {
long pos = sortedOffsets[i];
if (pos - startPos > PAGE_SIZE) {
assert(lastPos != startPos);
addBuffer(startPos, lastPos);
startPos = lastPos;
}
lastPos = pos;
}
addBuffer(startPos, channel.size());
}
private void addBuffer(long startPos, long endPos) throws IOException {
long length = endPos - startPos;
debug("adding page at " + startPos + " of length " + length);
buffers.add(new MappedBufferWrapper(channel, startPos, endPos));
}
public ByteBuffer getRow(int rowId) throws IOException {
long targetOffset = getRowOffset(rowId);
if (targetOffset < 0) {
return null;
}
MappedBufferWrapper row = null;
// TODO: binary search
for (int i = 0; i < buffers.size(); i++) {
MappedBufferWrapper wrapper = buffers.get(i);
if (wrapper.start <= targetOffset && targetOffset < wrapper.end) {
row = wrapper;
}
}
if (row == null) {
throw new IllegalArgumentException("did not find row " + rowId + " with offset " + targetOffset);
}
return row.get(targetOffset);
}
private long getRowOffset(int rowId) {
int lo = 0;
int hi = numRows - 1;
while (lo <= hi) {
int mid = (lo + hi) / 2;
int midId = rowIds.get(mid);
if (rowId < midId) {
hi = mid - 1;
} else if (rowId > midId) {
lo = mid + 1;
} else {
return rowOffsets.get(mid);
}
}
return -1;
}
static class MappedBufferWrapper {
FileChannel channel;
MappedByteBuffer buffer;
long start;
long end;
public MappedBufferWrapper(FileChannel channel, long start, long end) {
this.channel = channel;
this.start = start;
this.end = end;
}
public synchronized ByteBuffer get(long position) throws IOException {
if (buffer == null) {
buffer = channel.map(FileChannel.MapMode.READ_ONLY, start, end - start);
}
buffer.position((int) (position - start));
return buffer.slice();
}
public synchronized void close() {
buffer = null;
}
}
public synchronized int[] getRowIdsInDiskOrder() {
if (rowIdsInDiskOrder == null || rowIdsInDiskOrder.get() == null) {
int ids[] = new int[numRows];
for (int i = 0; i < numRows; i++) {
ids[i] = rowIds.get(i);
}
sortIdsByOffset(ids);
this.rowIdsInDiskOrder = new WeakReference<int[]>(ids);
return ids;
} else {
return rowIdsInDiskOrder.get();
}
}
// Adapted from http://www.programcreek.com/2012/11/quicksort-array-in-java/
private void sortIdsByOffset(int colIds[]) {
this.sortIdsByOffset(colIds, 0, colIds.length-1);
}
private void sortIdsByOffset(int colIds[], int low, int high) {
if (colIds.length == 0 || low >= high)
return;
// pick the pivot
int middle = (low + high) / 2;
long pivot = getRowOffset(colIds[middle]);
// partition around the pivot
int i = low, j = high;
while (i <= j) {
while (getRowOffset(colIds[i]) < pivot) {
i++;
}
while (getRowOffset(colIds[j]) > pivot) {
j--;
}
if (i <= j) {
int temp = colIds[i];
colIds[i] = colIds[j];
colIds[j] = temp;
i++;
j--;
}
}
//recursively sort two sub parts
sortIdsByOffset(colIds, low, j);
sortIdsByOffset(colIds, i, high);
}
private void info(String message) {
LOG.info("sparse matrix " + path + ": " + message);
}
private void debug(String message) {
LOG.debug("sparse matrix " + path + ": " + message);
}
}