package org.wikibrain.matrix;
import org.junit.Before;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import static org.junit.Assert.*;
public class TestSparseMatrix {
private List<SparseMatrixRow> srcRows;
private int NUM_ROWS = 1000;
private int MAX_COLS = NUM_ROWS * 2;
private int MAX_KEY = Math.max(NUM_ROWS, MAX_COLS) * 10;
@Before
public void createTestData() throws IOException {
srcRows = TestUtils.createSparseTestMatrixRows(NUM_ROWS, MAX_COLS, false);
}
@Test
public void testWrite() throws IOException {
File tmp = File.createTempFile("matrix", null);
SparseMatrixWriter.write(tmp, srcRows.iterator());
}
@Test
public void testReadWrite() throws IOException {
File tmp = File.createTempFile("matrix", null);
SparseMatrixWriter.write(tmp, srcRows.iterator());
Matrix m1 = new SparseMatrix(tmp);
Matrix m2 = new SparseMatrix(tmp);
}
@Test
public void testExpandPageForHeader() throws IOException {
List<SparseMatrixRow> shortRows = TestUtils.createSparseTestMatrixRows(1000, 100, false);
File tmp = File.createTempFile("matrix", null);
SparseMatrixWriter.write(tmp, shortRows.iterator());
Matrix m1 = new SparseMatrix(tmp);
assertEquals(1000, m1.getNumRows());
}
@Test
public void testTranspose() throws IOException {
for (int numOpenPages: new int[] { 1, Integer.MAX_VALUE}) {
File tmp1 = File.createTempFile("matrix", null);
File tmp2 = File.createTempFile("matrix", null);
File tmp3 = File.createTempFile("matrix", null);
SparseMatrixWriter.write(tmp1, srcRows.iterator());
SparseMatrix m = new SparseMatrix(tmp1);
verifyIsSourceMatrix(m);
new SparseMatrixTransposer(m, tmp2, 1).transpose();
SparseMatrix m2 = new SparseMatrix(tmp2);
new SparseMatrixTransposer(m2, tmp3, 1).transpose();
Matrix m3 = new SparseMatrix(tmp3);
verifyIsSourceMatrixUnordered(m3, .001);
}
}
@Test
public void testRows() throws IOException {
for (int numOpenPages: new int[] { 1, Integer.MAX_VALUE}) {
File tmp = File.createTempFile("matrix", null);
SparseMatrixWriter.write(tmp, srcRows.iterator());
Matrix m = new SparseMatrix(tmp);
verifyIsSourceMatrix(m);
}
}
private void verifyIsSourceMatrix(Matrix m) throws IOException {
assertEquals(srcRows.size(), m.getNumRows());
int [] ids1 = m.getRowIds();
int [] ids2 = new int[srcRows.size()];
for (int i = 0; i < srcRows.size(); i++) {
ids2[i] = srcRows.get(i).getRowIndex();
}
Arrays.sort(ids1);
Arrays.sort(ids2);
assertArrayEquals(ids2, ids1);
for (SparseMatrixRow srcRow : srcRows) {
MatrixRow destRow = m.getRow(srcRow.getRowIndex());
assertNotNull(destRow);
assertEquals(destRow.getRowIndex(), srcRow.getRowIndex());
assertEquals(destRow.getNumCols(), srcRow.getNumCols());
for (int i = 0; i < destRow.getNumCols(); i++) {
assertEquals(srcRow.getColIndex(i), destRow.getColIndex(i));
assertEquals(srcRow.getColValue(i), destRow.getColValue(i), 0.01);
}
}
}
private void verifyIsSourceMatrixUnordered(Matrix m, double delta) throws IOException {
for (SparseMatrixRow srcRow : srcRows) {
MatrixRow destRow = m.getRow(srcRow.getRowIndex());
LinkedHashMap<Integer, Float> destRowMap = destRow.asMap();
assertEquals(destRow.getRowIndex(), srcRow.getRowIndex());
assertEquals(destRow.getNumCols(), srcRow.getNumCols());
for (int i = 0; i < srcRow.getNumCols(); i++) {
int colId = srcRow.getColIndex(i);
assertTrue(destRowMap.containsKey(colId));
assertEquals(srcRow.getColValue(i), destRowMap.get(colId), delta);
}
}
}
}