/**
* Copyright (c) 2013 Oculus Info Inc.
* http://www.oculusinfo.com/
*
* Released under the MIT License.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is furnished to do
* so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package spimedb.util.math.linearalgebra;
public class TriDiagonalMatrix {
private final int _n;
// The n-1 entries below the diagonal; the array is of size n, though, and
// the first entry is ignored, for consistency with standard nomenclature
private final double[] _a;
// The n diagonal entries
private final double[] _b;
// The n-1 entries above the diagonal
private final double[] _c;
private double EPSILON;
public TriDiagonalMatrix (double... entries) {
EPSILON = 1E-12;
// There should be 3n-2 entries
int en = entries.length;
if (1 != (en % 3))
throw new IndexOutOfBoundsException("Tridiagonal matrices must have (3n-2) entries");
_n = (en + 2) / 3;
_a = new double[_n];
_b = new double[_n];
_c = new double[_n - 1];
for (int i = 0; i < _n; ++i) {
_b[i] = entries[3 * i];
if (i < _n - 1) {
_c[i] = entries[3 * i + 1];
_a[i + 1] = entries[3 * i + 2];
}
}
}
/**
* Set the precision for equality and zero tests for this matrix. If any
* calculation yields two numbers closer than the given precision, they are
* deemed equal. If any calculation yields a number less than the given
* precision, it is deemed zero.
*
* @param precision
* The precision for calculations with this matrix.
*/
public void setPrecision (double precision) {
EPSILON = precision;
}
/**
* Find the X for which this*X=d
*
* Taken from {@linkplain http
* ://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm}, but modified to
* use recursion instead of iteration, and thereby to handle degenerate
* cases cleanly.
*
* @param d
* The result (<code>d</code>) in the above equation
* @return The <code>X</code> in the above equation
*/
public Vector solve (Vector d) {
if (d.size() != _n)
throw new IllegalArgumentException("Attempt to find tri-diagonal solution with improper-sized vector");
double[] x = new double[_n];
solve(d, x, 0,
(_n > 0 ? _b[0] : 0),
(_n > 1 ? _c[0] : 0),
(_n > 0 ? d.coord(0) : 0));
//solve3(d, x, 1, _b[0], _c[0], 0, d.coord(0));
return new Vector(x);
}
private void solve (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) {
// Solve the case M x = d for x, where we (the tri-diagonal matrix) is m
//
// We do this recursively, solving one place at a time, pretending in
// each case we are at the top left of the matrix.
//
// So, in each case, we have one of three cases, depending on whether
// there are one, two, or three-or-more rows left to solve.
int rowsLeft = _n-currentColumn;
if (0 == rowsLeft) {
} else if (1 == rowsLeft) {
solveSingleRow(d, x, currentColumn, b0, c0, d0);
} else if (2 == rowsLeft) {
solveDoubleRow(d, x, currentColumn, b0, c0, d0);
} else {
solveTripleRow(d, x, currentColumn, b0, c0, d0);
}
}
private static double anythingIfNotNaN(double testValue) {
return ifNotNaN(testValue, 1.0);
}
private static double ifNotNaN(double testValue, double value) {
if (Double.isNaN(testValue)) return Double.NaN;
else return value;
}
private void solveSingleRow (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) {
if (currentColumn != (_n-1))
throw new IllegalArgumentException("Attempt to solve a single row when not on the last row.");
// If we have only one row, we have:
// | b0 | | x0 | = | d0 |
// and the solution is trivial: x0 = d0/b0
if (Math.abs(b0) < EPSILON) {
if (Math.abs(d0) < EPSILON) {
// Anything will work
x[currentColumn] = 1;
} else {
// Nothing will work
x[currentColumn] = Double.NaN;
}
} else {
x[currentColumn] = d0/b0;
}
}
private void solveDoubleRow (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) {
if (currentColumn != (_n-2))
throw new IllegalArgumentException("Attempt to solve a double row when not on the second to last row.");
// If we have two rows, we have:
// | b0 c0 | | x0 | _ | d0 |
// | a1 b1 | | x1 | - | d1 |
// or
// b0 x0 + c0 x1 = d0
// a1 x0 + b1 x1 = d1
//
// which can, of course, be solved one of two ways:
// a1 b0 x0 + a1 c0 x1 = a1 d0
// a1 b0 x0 + b0 b1 x1 = b0 d1
// (a1 c0 - b0 b1) x1 = (a1 d0 - b0 d1)
// or
// b0 b1 x0 + b1 c0 x1 = b1 d0
// a1 c0 x0 + b1 c0 x1 = c0 d1
// (b0 b1 - a1 c0) x0 = (b1 d0 - c0 d1)
//
// in either case, we require
// b0 b1 - a1 c0 != 0
// (i.e., non-zero determinate)
//
// If the determinant is 0, then the two rows are dependent, and we
// depend on the solution also being in the same proportion to be able
// to solve them.
double a1 = _a[currentColumn+1];
double b1 = _b[currentColumn+1];
double d1 = d.coord(currentColumn+1);
double determinate = b0 * b1 - a1 * c0;
if (Math.abs(determinate) < EPSILON) {
if (Math.abs(b0) < EPSILON &&
Math.abs(c0) < EPSILON &&
Math.abs(a1) < EPSILON &&
Math.abs(b1) < EPSILON) {
// We are the zero matrix. This is fine if D is the zero vector, in which case anything will work; otherwise, there is no solution
if (Math.abs(d0) < EPSILON && Math.abs(d1) < EPSILON) {
x[currentColumn] = 1;
x[currentColumn+1] = 1;
} else {
x[currentColumn] = Double.NaN;
x[currentColumn+1] = Double.NaN;
}
} else {
// figure out the proportion
// All these are solved the same way, we just need to pick an
// order based on a known non-zero element.
//
// If they are solvable, they are solved assuming the non-solved coordinate is 1.
if (Math.abs(a1) >= EPSILON) {
// a1 > 0
double sln = solveDegenerate2D(a1, b1, d1, b0, c0, d0);
x[currentColumn] = sln;
x[currentColumn+1] = anythingIfNotNaN(sln);
} else if (Math.abs(b0) >= EPSILON) {
// b0 > 0
double sln = solveDegenerate2D(b0, c0, d0, b1, a1, d1);
x[currentColumn] = sln;
x[currentColumn+1] = anythingIfNotNaN(sln);
} else if (Math.abs(c0) >= EPSILON) {
// c0 > 0
double sln = solveDegenerate2D(c0, b0, d0, b1, a1, d1);
x[currentColumn] = anythingIfNotNaN(sln);
x[currentColumn+1] = sln;
} else {
// b1 > 0
// This case is never actually reached - it only can be if
// the determinate is non-zero or all are 0
double sln = solveDegenerate2D(b1, a1, d1, c0, b0, d0);
x[currentColumn] = anythingIfNotNaN(sln);
x[currentColumn+1] = sln;
}
}
} else {
// From above:
// (a1 c0 - b0 b1) x1 = (a1 d0 - b0 d1)
// (b0 b1 - a1 c0) x0 = (b1 d0 - c0 d1)
x[currentColumn+0] = (b1 * d0 - c0 * d1) / determinate;
x[currentColumn+1] = (b0 * d1 - a1 * d0) / determinate;
}
}
private double solveDegenerate2D (double a0, double b0, double d0, double a1, double b1, double d1) {
// a0 is assumed to be non-zero
double rowRatio = a1/a0;
double dRatio = d1/d0;
if (Math.abs(rowRatio-dRatio) > EPSILON)
return Double.NaN;
else
// a0 x0 + b0 x1 = d0;
// x0 = (d0 - b0 x1) / a0
// assume x1 is 1
return (d0 - b0) / a0;
}
private void solveTripleRow (Vector d, double[] x, int currentColumn, double b0, double c0, double d0) {
// If we have three or more rows, we have:
// | b0 c0 0 ... | | x0 | | d0 |
// | a1 b1 c1 ... | | x1 | | d1 |
// | . | | x2 | = | d2 |
// | . | | . | | . |
// | . | | . | | . |
//
// or, put another way:
// b0 x0 + c0 x1 = d0
// a1 x0 + b1 x1 + c1 x2 = d1
double a1 = _a[currentColumn+1];
double b1 = _b[currentColumn+1];
double c1 = _c[currentColumn+1];
double d1 = d.coord(currentColumn+1);
if (Math.abs(b0) < EPSILON && Math.abs(a1) < EPSILON) {
// b0 and a1 are both zero
// First, skip this row and just go on
solve(d, x, currentColumn+1, b1, c1, d1);
// We just need to make sure that our input row does work
if (Math.abs(c0 * x[currentColumn+1] - d0) >= EPSILON) {
// Nope; doesn't work.
for (int i=currentColumn; i<_n; ++i)
x[i] = Double.NaN;
} else {
x[currentColumn] = anythingIfNotNaN(x[currentColumn+1]);
}
} else if (Math.abs(b0) < EPSILON) {
// b0 is 0
// First, skip the next row and just go on
solve(d, x, currentColumn+1, c0, 0, d0);
// Now, use the next row to solve our current value
double x1 = x[currentColumn+1];
double x2 = x[currentColumn+2];
// a1 x0 + b1 x1 + c1 x2 = d1
// a1 is known not to be zero, so we shouldn't have any problems.
x[currentColumn] = (d1 - b1 * x1 - c1 * x2) / a1;
} else if (Math.abs(a1) < EPSILON) {
// a1 is 0
// first, skip this row and just go on
solve(d, x, currentColumn+1, b1, c1, d1);
// Now, use the current row to solve our current value
// b0 is known not to be zero, so we shouldn't have any problems.
double x1 = x[currentColumn+1];
// b0 x0 + c0 x1 = d0
x[currentColumn] = (d0 - c0 * x1) / b0;
} else {
// neither is 0
// a1 x0 + b1 x1 + c1 x2 = d1
// b0 x0 + c0 x1 = d0
//
// a1 b0 x0 + b0 b1 x1 + b0 c1 x2 = b0 d1
// a1 b0 x0 + a1 c0 x1 = a1 d0
//
// (b0 b1 - a1 c0) x1 + (b0 c1) x2 = (b0 d1 - a1 d0);
//
// We've removed x0!
solve(d, x, currentColumn+1,
(b0 * b1 - a1 * c0),
(b0 * c1),
(b0 * d1 - a1 * d0));
// Now we back-solve to get x0
// b0 is known not to be zero, so we shouldn't have any problems.
double x1 = x[currentColumn+1];
// b0 x0 + c0 x1 = d0
x[currentColumn] = (d0 - c0 * x1) / b0;
}
}
/**
* Find this*X
*
* @param X
* The vector by which we are being multiplied
* @return The result of multiplying us by <code>X</code>
*/
public Vector times (Vector X) {
if (X.size() != _n)
throw new IllegalArgumentException("Illegal result - matrix multiplication can't result in a vector of size "
+ X.size());
double[] r = new double[_n];
for (int i = 0; i < _n; ++i) {
double entry = 0;
if (i > 0)
entry += _a[i] * X.coord(i - 1);
entry += _b[i] * X.coord(i);
if (i < _n - 1)
entry += _c[i] * X.coord(i + 1);
r[i] = entry;
}
return new Vector(r);
}
}