package org.wikibrain.matrix;
import org.apache.commons.io.FileUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implementation of a sparse matrix.
* The rows are memory mapped, so they can be immediately read from disk.
*/
public class SparseMatrix implements Matrix<SparseMatrixRow> {
public static final Logger LOG = LoggerFactory.getLogger(SparseMatrix.class);
// default header page size is 100MB, will be expanded if necessary
public static final int DEFAULT_HEADER_SIZE = 100 * 1024 * 1024;
public static final int FILE_HEADER = 0xabcdef;
MemoryMappedMatrix rowBuffers;
private int numRows = 0;
private IntBuffer rowIds; // row ids in sorted order
private LongBuffer rowOffsets; // file offsets associated with sorted row ids
private FileChannel channel;
private File path;
private ValueConf vconf;
public SparseMatrix(File path) throws IOException {
this.path = path;
if (!path.isFile()) {
throw new IOException("File does not exist: " + path);
}
info("initializing sparse matrix with file length " + FileUtils.sizeOf(path));
this.channel = (new FileInputStream(path)).getChannel();
readHeaders();
rowBuffers = new MemoryMappedMatrix(path, channel, rowIds, rowOffsets);
}
public long lastModified() {
return path.lastModified();
}
private void readHeaders() throws IOException {
long size = Math.min(channel.size(), DEFAULT_HEADER_SIZE);
MappedByteBuffer buffer = channel.map(FileChannel.MapMode.READ_ONLY, 0, size);
if (buffer.getInt(0) != FILE_HEADER) {
throw new IOException("invalid file header: " + buffer.getInt(0));
}
this.vconf = new ValueConf(buffer.getFloat(4), buffer.getFloat(8));
this.numRows = buffer.getInt(12);
int headerSize = 16 + 12*numRows;
if (headerSize > DEFAULT_HEADER_SIZE) {
info("maxPageSize not large enough for entire header. Resizing to " + headerSize);
buffer = channel.map(FileChannel.MapMode.READ_ONLY, 0, headerSize);
}
debug("preparing buffer for " + numRows + " rows");
buffer.position(16);
buffer.limit(buffer.position() + 4 * numRows);
rowIds = buffer.slice().asIntBuffer();
if (rowIds.capacity() != numRows) {
throw new IllegalStateException();
}
buffer.position(16 + 4 * numRows);
buffer.limit(buffer.position() + 8 * numRows);
rowOffsets = buffer.slice().asLongBuffer();
if (rowOffsets.capacity() != numRows) {
throw new IllegalStateException();
}
}
@Override
public SparseMatrixRow getRow(int rowId) throws IOException {
ByteBuffer bb = rowBuffers.getRow(rowId);
if (bb == null) {
return null;
} else {
return new SparseMatrixRow(vconf, bb);
}
}
@Override
public int[] getRowIds() {
return rowBuffers.getRowIdsInDiskOrder();
}
@Override
public int getNumRows() {
return numRows;
}
public ValueConf getValueConf() {
return vconf;
}
@Override
public Iterator<SparseMatrixRow> iterator() {
return new SparseMatrixIterator();
}
public class SparseMatrixIterator implements Iterator<SparseMatrixRow> {
private AtomicInteger i = new AtomicInteger();
private int[] rowIds = rowBuffers.getRowIdsInDiskOrder();
@Override
public boolean hasNext() {
return i.get() < numRows;
}
@Override
public SparseMatrixRow next() {
try {
return (SparseMatrixRow)getRow(rowIds[i.getAndIncrement()]);
} catch (IOException e) {
LOG.error("getRow failed", e);
return null;
}
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
public void close() throws IOException {
rowBuffers.close();
}
@Override
public File getPath() {
return path;
}
private void info(String message) {
LOG.info("sparse matrix " + path + ": " + message);
}
private void debug(String message) {
LOG.debug("sparse matrix " + path + ": " + message);
}
}