/* * LinearEquations.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.math.matrixAlgebra; /** * Class representing a system of linear equations. * * @author Didier H. Besset */ public class LinearEquations { /** * components is a matrix build from the system's matrix and * the constant vector */ private double[][] rows; /** * Array containing the solution vectors. */ private Vector[] solutions; /** * Construct a system of linear equation Ax = y1, y2,.... * * @param m double[][] * @param c double[][] * @throws IllegalDimension if the system's matrix is not square * if constant dimension does not match * that of the matrix */ public LinearEquations(double[][] m, double[][] c) throws IllegalDimension { int n = m.length; if (m[0].length != n) throw new IllegalDimension("Illegal system: a" + n + " by " + m[0].length + " matrix is not a square matrix"); if (c[0].length != n) throw new IllegalDimension("Illegal system: a " + n + " by " + n + " matrix cannot build a system with a " + c[0].length + "-dimensional vector"); rows = new double[n][n + c.length]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) rows[i][j] = m[i][j]; for (int j = 0; j < c.length; j++) rows[i][n + j] = c[j][i]; } } /** * Construct a system of linear equation Ax = y. * * @param m double[][] components of the system's matrix * @param c double[] components of the constant vector * @throws IllegalDimension if the system's matrix is not square * if constant dimension does not match * that of the matrix */ public LinearEquations(double[][] m, double[] c) throws IllegalDimension { int n = m.length; if (m[0].length != n) throw new IllegalDimension("Illegal system: a" + n + " by " + m[0].length + " matrix is not a square matrix"); if (c.length != n) throw new IllegalDimension("Illegal system: a " + n + " by " + n + " matrix cannot build a system with a " + c.length + "-dimensional vector"); rows = new double[n][n + 1]; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) rows[i][j] = m[i][j]; rows[i][n] = c[i]; } } /** * Construct a system of linear equation Ax = y. * * @param a MatrixAlgebra.Matrix matrix A * @param y MatrixAlgebra.Vector vector y * @throws IllegalDimension if the system's matrix is not square * if vector dimension does not match * that of the matrix */ public LinearEquations(Matrix a, Vector y) throws IllegalDimension { this(a.components, y.components); } /** * Computes the solution for constant vector p applying * backsubstitution. * * @param p int * @throws java.lang.ArithmeticException if one diagonal element * of the triangle matrix is zero. */ private void backSubstitution(int p) throws ArithmeticException { int n = rows.length; double[] answer = new double[n]; double x; for (int i = n - 1; i >= 0; i--) { x = rows[i][n + p]; for (int j = i + 1; j < n; j++) x -= answer[j] * rows[i][j]; answer[i] = x / rows[i][i]; } solutions[p] = new Vector(answer); return; } /** * Finds the position of the largest pivot at step p. * * @param p int step of pivoting. * @return int */ private int largestPivot(int p) { double pivot = Math.abs(rows[p][p]); int answer = p; double x; for (int i = p + 1; i < rows.length; i++) { x = Math.abs(rows[i][p]); if (x > pivot) { answer = i; pivot = x; } } return answer; } /** * Perform pivot operation at location p. * * @param p int * @throws java.lang.ArithmeticException if the pivot element * is zero. */ private void pivot(int p) throws ArithmeticException { double inversePivot = 1 / rows[p][p]; double r; int n = rows.length; int m = rows[0].length; for (int i = p + 1; i < n; i++) { r = inversePivot * rows[i][p]; for (int j = p; j < m; j++) rows[i][j] -= rows[p][j] * r; } return; } /** * Perform optimum pivot operation at location p. * * @param p int */ private void pivotingStep(int p) { swapRows(p, largestPivot(p)); pivot(p); return; } /** * @return Vector solution for the 1st constant vector */ public Vector solution() throws ArithmeticException { return solution(0); } /** * Return the vector solution of constants indexed by p. * * @param p int index of the constant vector fed into the system. * @return matrixAlgebra.Vector * @throws java.lang.ArithmeticException if the system cannot be solved. */ public Vector solution(int p) throws ArithmeticException { if (solutions == null) solve(); if (solutions[p] == null) backSubstitution(p); return solutions[p]; } /** * @throws java.lang.ArithmeticException if the system cannot be solved. */ private void solve() throws ArithmeticException { int n = rows.length; for (int i = 0; i < n; i++) pivotingStep(i); solutions = new Vector[rows[0].length - n]; } /** * Swaps rows p and q. * * @param p int * @param q int */ private void swapRows(int p, int q) { if (p != q) { double temp; int m = rows[p].length; for (int j = 0; j < m; j++) { temp = rows[p][j]; rows[p][j] = rows[q][j]; rows[q][j] = temp; } } return; } /** * Returns a string representation of the system. * * @return java.lang.String */ public String toString() { StringBuffer sb = new StringBuffer(); char[] separator = {'[', ' '}; int n = rows.length; int m = rows[0].length; for (int i = 0; i < n; i++) { separator[0] = '('; for (int j = 0; j < n; j++) { sb.append(separator); sb.append(rows[i][j]); separator[0] = ','; } separator[0] = ':'; for (int j = n; j < m; j++) { sb.append(separator); sb.append(rows[i][j]); separator[0] = ','; } sb.append(')'); sb.append('\n'); } return sb.toString(); } }