// Copyright (C) 2010 Steffen Rendle, Zeno Gantner, Chris Newell // Copyright (C) 2011 Zeno Gantner // //This file is part of MyMediaLite. // //MyMediaLite is free software: you can redistribute it and/or modify //it under the terms of the GNU General Public License as published by //the Free Software Foundation, either version 3 of the License, or //(at your option) any later version. // //MyMediaLite is distributed in the hope that it will be useful, //but WITHOUT ANY WARRANTY; without even the implied warranty of //MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //GNU General Public License for more details. // //You should have received a copy of the GNU General Public License //along with MyMediaLite. If not, see <http://www.gnu.org/licenses/>. package org.mymedialite.datatype; import java.util.*; /** * Class for storing dense matrices. * The data is stored in row-major mode. * Indexes are zero-based. * @param <T> the type of the matrix entries * @version 2.03 */ public class Matrix<T> implements IMatrix<T> { /** Data array: data is stored in columns. */ public Object[] data; /** Dimension 1, the number of rows */ public int dim1; /** Dimension 2, the number of columns */ public int dim2; /** * Initializes a new instance of the Matrix class * @param dim1 the number of rows * @param dim2 the number of columns */ public Matrix(int dim1, int dim2) { this(dim1, dim2, null); } /** * Initializes a new instance of the Matrix class * @param dim1 the number of rows * @param dim2 the number of columns * @param d the default value for the elements */ public Matrix(int dim1, int dim2, T d) { if (dim1 < 0) throw new IllegalArgumentException("dim1 must be at least 0"); if (dim2 < 0) throw new IllegalArgumentException("dim2 must be at least 0"); this.dim1 = dim1; this.dim2 = dim2; this.data = new Object[dim1 * dim2]; if(d != null) for (int i = 0; i < dim1 * dim2; i++) data[i] = d; } /** * Copy constructor. Creates a deep copy of the given matrix. * @param matrix the matrix to be copied */ public Matrix(Matrix<T> matrix) { this.dim1 = matrix.dim1; this.dim2 = matrix.dim2; this.data = Arrays.copyOf(matrix.data, matrix.data.length); } /** * Constructor that takes a list of lists to initialize the matrix. * @param data a list of lists of T */ public Matrix(List<List<T>> data) { this.dim1 = data.size(); this.dim2 = data.get(0).size(); this.data = new Object[dim1 * dim2]; for (int i = 0; i < dim1; i++) for (int j = 0; j < dim2; j++) this.data[i * dim2 + j] = data.get(i).get(j); } @Override public IMatrix<T> createMatrix(int num_rows, int num_columns) { return new Matrix<T>(num_rows, num_columns, null); } @Override public IMatrix<T> transpose() { Matrix<T> transpose = new Matrix<T>(dim2, dim1, null); for (int i = 0; i < dim1; i++) for (int j = 0; j < dim2; j++) transpose.data[j * dim1 + i] = data[i * dim2 + j]; return transpose; } @Override public int numberOfRows() { return dim1; } @Override public int numberOfColumns() { return dim2; } @Override @SuppressWarnings({"unchecked"}) public T get(int i, int j) { return (T) data[i * dim2 + j]; } @Override public void set(int i, int j, T value) { data[i * dim2 + j] = value; } @Override public boolean isSymmetric() { if (dim1 != dim2) return false; for (int i = 0; i < dim1; i++) for (int j = i + 1; j < dim2; j++) if (!get(i, j).equals(get(j, i))) return false; return true; } /** * Returns a copy of the i-th row of the matrix * @param i the row ID * @return a List<T> containing the row data */ public List<T> getRow(int i) { List<T> row = new ArrayList<T>(this.dim2); for (int x = 0; x < this.dim2; x++) row.set(x, get(i, x)); return row; } /** * Returns a copy of the j-th column of the matrix * @param j the column ID * @return T[] containing the column data */ public List<T> getColumn(int j) { List<T> column = new ArrayList<T>(this.dim1); for (int x = 0; x < this.dim1; x++) column.set(x, get(x, j)); return column; } /** * Sets the values of the i-th row to the values in a given array * @param i the row ID * @param row A of length dim1 */ public void setRow(int i, List<T> row) { if (row.size() != this.dim2) throw new IllegalArgumentException("Array length " + row.size() + " must equal number of columns " + this.dim2); for (int j = 0; j < this.dim2; j++) set(i, j, row.get(j)); } /** * Sets the values of the j-th column to the values in a given array * @param j the column ID * @param column A T[] of length dim2 */ public void setColumn(int j, List<T> column) { if (column.size() != this.dim1) throw new IllegalArgumentException("Array length " + column.size() + " must equal number of rows " + this.dim1); for (int i = 0; i < this.dim1; i++) set(i, j, column.get(i)); } /** * Enlarges the matrix to num_rows rows * Do nothing if num_rows is less than dim1. * The new entries are filled with zeros. * @param num_rows the minimum number of rows */ public void addRows(int num_rows) { if (num_rows > dim1) { // create new data structure Object[] data_new = new Object[num_rows * dim2]; System.arraycopy(data, 0, data_new, 0, data.length); // replace old data structure this.dim1 = num_rows; this.data = data_new; } } /** * Grows the matrix to the requested size, if necessary * The new entries are filled with zeros. * @param num_rows the minimum number of rows * @param num_cols the minimum number of columns */ @Override public void grow(int num_rows, int num_cols) { if (num_rows > dim1 || num_cols > dim2) { // create new data structure Object[] new_data = new Object[num_rows * num_cols]; for (int i = 0; i < dim1; i++) { for (int j = 0; j < dim2; j++) { new_data[i * num_cols + j] = get(i, j); } } // Replace old data structure this.dim1 = num_rows; this.dim2 = num_cols; this.data = new_data; } } /** * Sets an entire row to a specified value * @param v the value to be used * @param i the row ID */ public void setRowToOneValue(int i, T v) { for (int j = 0; j < dim2; j++) { set(i, j, v); } } /** * Sets an entire column to a specified value * @param v the value to be used * @param j the column ID */ public void setColumnToOneValue(int j, T v) { for (int i = 0; i < dim1; i++) { set(i, j, v); } } }