package org.wikibrain.matrix;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntLongHashMap;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import java.io.*;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SparseMatrixWriter {
public static final byte ROW_PADDING = Byte.MIN_VALUE;
private static final Logger LOG = LoggerFactory.getLogger(SparseMatrixWriter.class);
private File path;
private TIntLongHashMap rowOffsets = new TIntLongHashMap();
private TIntArrayList rowIndexes = new TIntArrayList();
private File bodyPath;
private BufferedOutputStream body;
private long bodyOffset = 0;
private ValueConf vconf;
public SparseMatrixWriter(File path, ValueConf conf) throws IOException {
this.path = path;
this.vconf = conf;
info("writing matrix to " + path);
// write tmp matrix file
this.bodyPath = File.createTempFile("matrix", null);
this.bodyPath.deleteOnExit();
this.body = new BufferedOutputStream(new FileOutputStream(bodyPath));
info("writing body to tmp file at " + bodyPath);
}
public synchronized void writeRow(SparseMatrixRow row) throws IOException {
if (!row.getValueConf().almostEquals(vconf)) {
throw new IllegalArgumentException("Value conf for row does not match the writer's value conf");
}
row.getBuffer().rewind();
byte[] bytes = new byte[row.getBuffer().remaining()];
row.getBuffer().get(bytes, 0, bytes.length);
rowOffsets.put(row.getRowIndex(), bodyOffset);
rowIndexes.add(row.getRowIndex());
body.write(bytes);
bodyOffset += bytes.length;
// pad rows to 8 byte offsets to speed things up.
while (bodyOffset % 8 != 0) {
bodyOffset++;
body.write(ROW_PADDING);
}
}
public void finish() throws IOException {
body.close();
info("wrote " + bodyOffset + " bytes in body of matrix");
// write offset file
info("generating header");
int sizeHeader = 16 + rowOffsets.size() * (4 + 8);
body = new BufferedOutputStream(new FileOutputStream(path));
body.write(intToBytes(SparseMatrix.FILE_HEADER));
body.write(floatToBytes(vconf.minScore));
body.write(floatToBytes(vconf.maxScore));
body.write(intToBytes(rowOffsets.size()));
// Next write row indexes in sorted order (4 bytes per row)
int sortedIndexes[] = rowIndexes.toArray();
Arrays.sort(sortedIndexes);
for (int rowIndex : sortedIndexes) {
body.write(intToBytes(rowIndex));
}
// Next write offsets for sorted indexes. (8 bytes per row)
for (int rowIndex : sortedIndexes) {
long rowOffset = rowOffsets.get(rowIndex);
body.write(longToBytes(rowOffset + sizeHeader));
}
InputStream r = new FileInputStream(bodyPath);
// append other file
IOUtils.copyLarge(r, body);
r.close();
body.flush();
body.close();
info("wrote " + FileUtils.sizeOf(path) + " bytes to " + path);
}
private void info(String message) {
LOG.info("sparse matrix writer " + path + ": " + message);
}
public static void write(File file, Iterator<SparseMatrixRow> rows) throws IOException {
write(file, rows, new ValueConf());
}
public static void write(File file, Iterator<SparseMatrixRow> rows, ValueConf vconf) throws IOException {
SparseMatrixWriter w = new SparseMatrixWriter(file, vconf);
while (rows.hasNext()) {
w.writeRow(rows.next());
}
w.finish();
}
public ValueConf getValueConf() {
return vconf;
}
private static byte[] intToBytes(int i) {
return ByteBuffer.allocate(4).putInt(i).array();
}
private static byte[] longToBytes(long i) {
return ByteBuffer.allocate(8).putLong(i).array();
}
private static byte[] floatToBytes(float f) {
return ByteBuffer.allocate(4).putFloat(f).array();
}
/**
* Writes a matrix in sparse matrix format.
* If the matrix itself is a sparse matrix formatted matrix, this will be optimized.
*
* @param matrix
* @param output
* @throws IOException
*/
public static void write(Matrix<? extends MatrixRow> matrix, File output) throws IOException {
ValueConf vconf = null;
if (matrix instanceof SparseMatrix) {
vconf = ((SparseMatrix)matrix).getValueConf();
} else {
float min = Float.MAX_VALUE;
float max = -Float.MAX_VALUE;
for (MatrixRow r : matrix) {
for (int i = 0; i < r.getNumCols(); i++) {
min = Math.min(min, r.getColValue(i));
max = Math.max(max, r.getColValue(i));
}
}
vconf = new ValueConf(min, max);
}
SparseMatrixWriter writer = new SparseMatrixWriter(output, vconf);
for (MatrixRow r : matrix) {
if (r instanceof SparseMatrixRow) {
writer.writeRow((SparseMatrixRow) r);
} else {
writer.writeRow(new SparseMatrixRow(vconf, r.getRowIndex(), r.asTroveMap()));
}
}
writer.finish();
}
}