package org.wikibrain.matrix;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.IOException;
import java.util.*;
public class TestUtils {
/**
/**
* Creates a new random matrix with nRows rows.
* Each row has a random length chosen uniformly from 0 to maxRowLen.
* If sameIds is true, the column ids are chosen from the row ids.
* @param nRows
* @param maxRowLen
* @param sameIds
* @return
*/
public static List<SparseMatrixRow> createSparseTestMatrixRows(int nRows, int maxRowLen, boolean sameIds) throws IOException {
return createSparseTestMatrixRowsInternal(nRows, maxRowLen, sameIds, null);
}
public static List<DenseMatrixRow> createDenseTestMatrixRows(int nRows, int numCols) throws IOException {
return createDenseTestMatrixRowsInternal(nRows, numCols, null);
}
/**
* Creates a new random matrix with nRows rows.
* Each row has a random length chosen uniformly from 0 to maxRowLen.
* If sameIds is true, the column ids are chosen from the row ids.
* @param nRows
* @param maxRowLen
* @param sameIds
* @return
*/
public static SparseMatrix createSparseTestMatrix(int nRows, int maxRowLen, boolean sameIds) throws IOException {
File tmpFile = File.createTempFile("matrix", null);
tmpFile.deleteOnExit();
SparseMatrixWriter writer = new SparseMatrixWriter(tmpFile, new ValueConf());
createSparseTestMatrixRowsInternal(nRows, maxRowLen, sameIds, writer);
writer.finish();
return new SparseMatrix(tmpFile);
}
/*
public static DenseMatrix createDenseTestMatrix(int nRows, int numCols) throws IOException {
return createDenseTestMatrix(nRows, numCols, SparseMatrix.DEFAULT_LOAD_ALL_PAGES, SparseMatrix.DEFAULT_MAX_PAGE_SIZE);
}
public static DenseMatrix createDenseTestMatrix(int nRows, int numCols, boolean readAllRows, int pageSize) throws IOException {
File tmpFile = File.createTempFile("matrix", null);
tmpFile.deleteOnExit();
DenseMatrixWriter writer = new DenseMatrixWriter(tmpFile, new ValueConf());
createDenseTestMatrixRowsInternal(nRows, numCols, writer);
writer.finish();
return new DenseMatrix(tmpFile, readAllRows, pageSize);
}
*/
/**
* Either writes or returns the sparse matrix rows depending on whether the writer is passed.
* @param nRows
* @param maxRowLen
* @param sameIds
* @param writer
* @return if writer == null the list of rows, else null
*/
private static List<SparseMatrixRow> createSparseTestMatrixRowsInternal(
int nRows, int maxRowLen, boolean sameIds, SparseMatrixWriter writer)
throws IOException {
Random random = new Random();
List<SparseMatrixRow> rows = new ArrayList<SparseMatrixRow>();
int rowIds[] = pickIds(nRows, nRows * 2);
for (int id1 : rowIds) {
LinkedHashMap<Integer, Float> data = new LinkedHashMap<Integer, Float>();
int numCols = Math.max(1, random.nextInt(maxRowLen));
int colIds[] = sameIds ? pickIdsFrom(rowIds, numCols) : pickIds(numCols, maxRowLen * 2);
for (int id2 : colIds) {
data.put(id2, random.nextFloat());
}
SparseMatrixRow row = new SparseMatrixRow(new ValueConf(), id1, data);
if (writer == null) {
rows.add(row);
} else {
writer.writeRow(row);
}
}
return (writer == null) ? rows : null;
}
/**
* Either writes or returns the sparse matrix rows depending on whether the writer is passed.
* @param nRows
* @param numCols
* @param writer
* @return if writer == null the list of rows, else null
*/
private static List<DenseMatrixRow> createDenseTestMatrixRowsInternal(
int nRows, int numCols, DenseMatrixWriter writer)
throws IOException {
Random random = new Random();
List<DenseMatrixRow> rows = new ArrayList<DenseMatrixRow>();
int rowIds[] = pickIds(nRows, nRows * 2);
int colIds[] = pickIds(numCols, numCols * 2);
Arrays.sort(colIds);
for (int id1 : rowIds) {
LinkedHashMap<Integer, Float> data = new LinkedHashMap<Integer, Float>();
for (int id2 : colIds) {
data.put(id2, random.nextFloat());
}
DenseMatrixRow row = new DenseMatrixRow(new ValueConf(), id1, data);
if (writer == null) {
rows.add(row);
} else {
writer.writeRow(row);
}
}
return (writer == null) ? rows : null;
}
/**
* Returns a set of n unique ids from 1 through maxId in random order.
* @param n
* @param maxId
* @return
*/
public static int[] pickIds(int n, int maxId) {
assert(n < maxId);
Random random = new Random();
TIntHashSet picked = new TIntHashSet();
for (int i = 0; i < n; i++) {
while (true) {
int id = random.nextInt(maxId - 1) + 1;
if (!picked.contains(id)) {
picked.add(id);
break;
}
}
}
return picked.toArray();
}
/**
* Selects n random unique ids from the array of ids.
* @param ids
* @param n
* @return
*/
public static int[] pickIdsFrom(int ids[], int n) {
assert(ids.length >= n);
Random random = new Random();
TIntHashSet picked = new TIntHashSet();
for (int i = 0; i < n; i++) {
while (true) {
int id = ids[random.nextInt(ids.length)];
if (!picked.contains(id)) {
picked.add(id);
break;
}
}
}
return picked.toArray();
}
}