// Copyright (C) 2010, 2011 Zeno Gantner // Copyright (C) 2011 Chris Newell // // This file is part of MyMediaLite. // // MyMediaLite is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // MyMediaLite is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.datatype; import java.util.*; /** * Class for storing sparse matrices. * The data is stored in row-major mode. * Indexes are zero-based. * T the matrix element type, must have a default constructor/value * @version 2.03 */ public class SparseMatrix<T> implements IMatrix<T> { private int numberOfColumns; /** * List that stores the rows of the matrix. */ // TODO Consider fastutil protected List<HashMap<Integer, T>> row_list = new ArrayList<HashMap<Integer, T>>(); /** * The default values for elements. */ private T d = null; /** * Create a sparse matrix with a given number of rows. * @param num_rows the number of rows * @param num_cols the number of columns */ public SparseMatrix(int num_rows, int num_cols) { this(num_rows, num_cols, null); } /** * Create a sparse matrix with a given number of rows. * @param num_rows the number of rows * @param num_cols the number of columns * @param d the default value for elements */ public SparseMatrix(int num_rows, int num_cols, T d) { for (int i = 0; i < num_rows; i++) { row_list.add(new HashMap<Integer, T>()); } this.numberOfColumns = num_cols; this.d = d; } @Override public IMatrix<T> createMatrix(int num_rows, int num_columns) { return new SparseMatrix<T>(num_rows, num_columns, null); } @Override public boolean isSymmetric() { if (numberOfRows() != numberOfColumns()) return false; for (int i = 0; i < row_list.size(); i++) for (int j : row_list.get(i).keySet()) { if (i > j) continue; // check every pair only once if (! get(i, j).equals(get(j, i))) return false; } return true; } @Override public int numberOfRows() { return row_list.size(); } @Override public int numberOfColumns() { return numberOfColumns; } @Override public IMatrix<T> transpose() { SparseMatrix<T> transpose = new SparseMatrix<T>(numberOfColumns(), numberOfRows()); for (Pair<Integer, Integer> p : nonEmptyEntryIDs()) { transpose.set(p.second, p.first, get(p.first, p.second)); } return transpose; } /** * Get a row of the matrix. * @param x the row ID */ public HashMap<Integer, T> get(int x) { if (x >= row_list.size()) return new HashMap<Integer, T>(); else return row_list.get(x); } /** * Access the elements of the sparse matrix. * @param x the row ID * @param y the column ID */ @Override public T get(int x, int y) { T result; if (x < row_list.size()) { result = row_list.get(x).get(y); if(result != null) { return result; } } return d; } @Override public void set(int x, int y, T value) { if (x >= row_list.size()) for (int i = row_list.size(); i <= x; i++) row_list.add(new HashMap<Integer, T>()); row_list.get(x).put(y, value); } /** * The non-empty rows of the matrix (the ones that contain at least one non-zero entry), * with their IDs * . */ public HashMap<Integer, HashMap<Integer, T>> nonEmptyRows() { HashMap<Integer, HashMap<Integer, T>> return_list = new HashMap<Integer, HashMap<Integer, T>>(); for(int i=0; i < row_list.size(); i++) { HashMap<Integer, T> row = get(i); if(row.size() > 0) return_list.put(i, row); } return return_list; } /** * The row and column IDs of non-empty entries in the matrix. * @return The row and column IDs of non-empty entries in the matrix */ public List<Pair<Integer, Integer>> nonEmptyEntryIDs() { List <Pair<Integer, Integer>> return_list = new ArrayList<Pair<Integer, Integer>>(); for (Map.Entry<Integer, HashMap<Integer, T>> id_row : nonEmptyRows().entrySet()) for (Integer col_id : id_row.getValue().keySet()) return_list.add(new Pair<Integer, Integer>(id_row.getKey(), col_id)); return return_list; } /** * The number of non-empty entries in the matrix. * @return The number of non-empty entries in the matrix */ public int numberOfNonEmptyEntries() { int counter = 0; for (HashMap<Integer, T> row : row_list) counter += row.size(); return counter; } @Override public void grow(int num_rows, int num_cols) { // If necessary, grow rows if (num_rows > numberOfRows()) for (int i = row_list.size(); i < num_rows; i++) row_list.add(new HashMap<Integer, T>()); // If necessary, grow columns if (num_cols > numberOfColumns) numberOfColumns = num_cols; } }