/* * NumericalSpaceTimeProbs2D.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.geo; import dr.math.distributions.MultivariateNormalDistribution; import java.awt.geom.Point2D; import java.awt.geom.Rectangle2D; import java.io.FileNotFoundException; import java.io.PrintWriter; /** * @author Alexei Drummond */ public class NumericalSpaceTimeProbs2D { final int latticeWidth; final int latticeHeight; final int tsteps; final int subtsteps; final double minx, miny, dx, dy, dt; final MultivariateNormalDistribution D; final SpaceTimeRejector rejector; int[][][][][] counts; int[][][] normalization; int[][][] maxCount; public NumericalSpaceTimeProbs2D( int latticeWidth, int latticeHeight, int tsteps, int subtsteps, double dt, Rectangle2D bounds, MultivariateNormalDistribution distribution, SpaceTimeRejector rejector) { this.latticeWidth = latticeWidth; this.latticeHeight = latticeHeight; this.tsteps = tsteps; this.subtsteps = subtsteps; this.D = distribution; this.rejector = rejector; this.dt = dt; minx = bounds.getMinX(); miny = bounds.getMinY(); dx = (bounds.getMaxX() - minx) / latticeWidth; dy = (bounds.getMaxY() - miny) / latticeHeight; counts = new int[latticeWidth][latticeHeight][latticeWidth][latticeHeight][tsteps]; normalization = new int[latticeWidth][latticeHeight][tsteps]; maxCount = new int[latticeWidth][latticeHeight][tsteps]; } public void populate(Point2D start, int paths, boolean includeSubpaths) { populate(x(start.getX()), y(start.getY()), paths, includeSubpaths); } public int populateAbsorbing(Point2D start, int paths) { return populateAbsorbing(x(start.getX()), y(start.getY()), paths); } /** * @param i * @param j * @param paths * @return the number of successfully simulated paths */ public int populateAbsorbing(int i, int j, int paths) { double subdt = dt / (double) subtsteps; double[] next = new double[2]; double[] start = new double[2]; int[] pathx = new int[tsteps]; int[] pathy = new int[tsteps]; int successes = 0; for (int reps = 0; reps < paths; reps += 1) { double time = 0.0; start[0] = (i + Math.random()) * dx + minx; start[1] = (j + Math.random()) * dy + miny; while (rejector.reject(0, start)) { start[0] = (i + Math.random()) * dx + minx; start[1] = (j + Math.random()) * dy + miny; } boolean reject = false; for (int t = 0; t < tsteps && !reject; t++) { for (int s = 0; s < subtsteps && !reject; s++) { D.nextScaledMultivariateNormal(start, subdt, next); time += subdt; reject = rejector.reject(time, next); if (!reject) { start[0] = next[0]; start[1] = next[1]; } } if (!reject) { pathx[t] = x(next[0]); pathy[t] = y(next[1]); increment(i, j, pathx[t], pathy[t], t); } } if (!reject) successes += 1; if (reps % 10000 == 0) { System.out.print("."); System.out.flush(); } } System.out.println(); return successes; } public void populate(int i, int j, int paths, boolean includeSubpaths) { double subdt = dt / (double) subtsteps; double[] next = new double[2]; double[] start = new double[2]; int[] pathx = new int[tsteps]; int[] pathy = new int[tsteps]; for (int reps = 0; reps < paths; reps += 1) { double time = 0.0; start[0] = (i + Math.random()) * dx + minx; start[1] = (j + Math.random()) * dy + miny; while (rejector.reject(0, start)) { start[0] = (i + Math.random()) * dx + minx; start[1] = (j + Math.random()) * dy + miny; } for (int t = 0; t < tsteps; t++) { for (int s = 0; s < subtsteps; s++) { do { D.nextScaledMultivariateNormal(start, subdt, next); time += subdt; } while (rejector.reject(time, next)); start[0] = next[0]; start[1] = next[1]; } pathx[t] = x(next[0]); pathy[t] = y(next[1]); increment(i, j, pathx[t], pathy[t], t); } if (includeSubpaths) { for (int t = 0; t < tsteps; t++) { for (int s = t + 1; s < tsteps; s++) { increment(pathx[t], pathy[t], pathx[s], pathy[s], s - t - 1); } } } if (reps % 1000 == 0) { System.out.print("."); System.out.flush(); } } } private void increment(int i, int j, int k, int l, int t) { counts[i][j][k][l][t] += 1; normalization[i][j][t] += 1; if (counts[i][j][k][l][t] > maxCount[i][j][t]) { maxCount[i][j][t] = counts[i][j][k][l][t]; } } public void populate(int paths) { System.out.println("Populating numerical transition probabilities"); for (int i = 0; i < latticeWidth; i++) { for (int j = 0; j < latticeHeight; j++) { populate(i, j, paths, true); } System.out.print("."); System.out.flush(); } System.out.println(latticeWidth * latticeHeight * paths + " new paths computed."); } public final int x(double x) { return (int) ((x - minx) / dx); } public final int y(double y) { return (int) ((y - miny) / dy); } public final int t(double time) { return (int) (time / dt); } public double getProb(Point2D start, Point2D end, double time) { int i = x(start.getX()); int j = x(start.getY()); int k = x(end.getX()); int l = x(end.getY()); if (time > tsteps * dt) { System.err.println("Time = " + time + ", max time estimated is " + tsteps * dt); return (double) counts[i][j][k][l][tsteps - 1] / normalization[i][j][tsteps - 1]; } else { //time interpolation int t = t(time); double tlow = t * dt; double thigh = tlow + dt; double weightlow = (thigh - time) / dt; return weightlow * p(i, j, k, l, t) + (1.0 - weightlow) * p(i, j, k, l, t + 1); } } public double p(int i, int j, int k, int l, int t) { return (double) counts[i][j][k][l][t] / (double) normalization[i][j][t]; } public double r(int i, int j, int k, int l, int t) { return (double) counts[i][j][k][l][t] / (double) maxCount[i][j][t]; } public void writeToFile(String s) throws FileNotFoundException { PrintWriter writer = new PrintWriter(s); writer.write("xsteps=" + latticeWidth + "\n"); writer.write("ysteps=" + latticeHeight + "\n"); writer.write("tsteps=" + tsteps + "\n"); writer.write("dx=" + dx + "\n"); writer.write("dy=" + dy + "\n"); writer.write("dt=" + dt + "\n"); writer.write("minx=" + minx + "\n"); writer.write("miny=" + miny + "\n"); writer.write("D=" + matrixString()); for (int i = 0; i < latticeWidth; i++) { for (int j = 0; j < latticeHeight; j++) { for (int k = 0; k < latticeWidth; k++) { for (int l = 0; l < latticeHeight; l++) { for (int t = 0; t < tsteps; t++) { writer.write(i + "\t" + j + "\t" + k + "\t" + l + "\t" + t + "\t" + counts[i][j][k][l][t] + "\n"); } } } } } writer.close(); } private String matrixString() { double[][] m = D.getScaleMatrix(); StringBuilder builder = new StringBuilder(); builder.append("["); for (int i = 0; i < m.length; i++) { builder.append("["); builder.append(m[i][0]); for (int j = 1; j < m[i].length; j++) { builder.append("," + m[i][0]); } builder.append("]"); } builder.append("]"); return builder.toString(); } public static void main(String[] args) throws FileNotFoundException { Rectangle2D bounds = new Rectangle2D.Double(0, 0, 1, 1); MultivariateNormalDistribution D = new MultivariateNormalDistribution(new double[]{0.0}, new double[][]{{1, 0}, {0, 1}}); NumericalSpaceTimeProbs2D nstp2D = new NumericalSpaceTimeProbs2D(50, 50, 50, 1, 0.02, bounds, D, SpaceTimeRejector.Utils.createSimpleBounds2D(bounds)); long startTime = System.currentTimeMillis(); nstp2D.populate(0, 0, 1000, true); long stopTime = System.currentTimeMillis(); System.out.println("Time taken = " + (stopTime - startTime) / 1000 + " seconds"); // System.out.println("Writing to file..."); // nstp2D.writeToFile("unitSquareDiffusion.txt"); for (int i = 0; i < 10; i++) { Point2D start = new Point2D.Double(Math.random(), Math.random()); Point2D end = new Point2D.Double(Math.random(), Math.random()); double time = Math.random(); double p = nstp2D.getProb(start, end, time); System.out.println("Pr(" + end.getX() + ", " + end.getY() + " | " + start.getX() + ", " + start.getY() + ", t=" + time + ") = " + p); } } }