package org.wikibrain.spatial.distance; import com.vividsolutions.jts.geom.*; import org.junit.Test; import org.wikibrain.core.dao.DaoException; import org.wikibrain.spatial.constants.Precision; import org.wikibrain.spatial.dao.SpatialDataDao; import java.util.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; /** * @author Shilad Sen */ public class TestOrdinalDistanceMetric { private Random random = new Random(); private GeometryFactory factory = new GeometryFactory(new PrecisionModel(),8307); @Test public void testKnn() throws DaoException { Map<Integer, Geometry> points = new HashMap<Integer, Geometry>(); for (int i = 0; i < 10000; i++) { points.put(i * 3, makePoint()); } SpatialDataDao dao = mock(SpatialDataDao.class); when(dao.getAllGeometriesInLayer("wikidata", Precision.LatLonPrecision.HIGH)) .thenReturn(points); SphericalDistanceMetric spherical = new SphericalDistanceMetric(dao); OrdinalDistanceMetric ordinal = new OrdinalDistanceMetric(dao, spherical); spherical.enableCache(true); ordinal.enableCache(true); assertEquals(spherical.getNumConcepts(), points.size()); int numNeighbors = (int) (spherical.getNumConcepts() * ordinal.getFractionRankedExactly()); for (int i = 0; i < 100; i++) { Point p = makePoint(); List<SpatialDistanceMetric.Neighbor> sphericalNeighbors = spherical.getNeighbors(p, numNeighbors); List<SphericalDistanceMetric.Neighbor> ordinalNeighbors = ordinal.getNeighbors(p, numNeighbors); assertEquals(ordinalNeighbors.size(), numNeighbors); for (int j = 0; j < numNeighbors; j++) { assertEquals(sphericalNeighbors.get(j).conceptId, ordinalNeighbors.get(j).conceptId); assertEquals(j, ordinalNeighbors.get(j).distance, 0.01); } } } @Test public void testPairwise() throws DaoException { Map<Integer, Geometry> points = new HashMap<Integer, Geometry>(); for (int i = 0; i < 10000; i++) { points.put(i * 3, makePoint()); } SpatialDataDao dao = mock(SpatialDataDao.class); when(dao.getAllGeometriesInLayer("wikidata", Precision.LatLonPrecision.HIGH)) .thenReturn(points); final SphericalDistanceMetric spherical = new SphericalDistanceMetric(dao); final OrdinalDistanceMetric ordinal = new OrdinalDistanceMetric(dao, spherical); spherical.enableCache(true); ordinal.enableCache(true); assertEquals(spherical.getNumConcepts(), points.size()); int numNeighbors = (int) (spherical.getNumConcepts() * ordinal.getFractionRankedExactly()); for (int i = 0; i < 50; i++) { final Point p1 = makePoint(); assertTrue(ordinal.distance(p1, p1) <= 1.0); List<SpatialDistanceMetric.Neighbor> neighbors = spherical.getNeighbors(p1, numNeighbors); double furthest = neighbors.get(neighbors.size() - 1).distance; List<Point> others = new ArrayList<Point>(); for (int j = 0; j < 50; j++) { others.add(makePoint()); } // Sort others by spherical distance Collections.sort(others, new Comparator<Point>() { @Override public int compare(Point p2, Point p3) { return Double.compare(spherical.distance(p1, p2), spherical.distance(p1, p3)); } }); double lastDistance = -1; for (Point p2 : others) { double d = ordinal.distance(p1, p2); double sphericald = spherical.distance(p1, p2); assertTrue(d >= lastDistance); if (sphericald < furthest) { int j = (int)Math.round(d); if (j == 0) { assertTrue(sphericald < neighbors.get(0).distance); } else { assertTrue(sphericald >= neighbors.get(j-1).distance); assertTrue(sphericald <= neighbors.get(j).distance); } } else { assertTrue(d >= numNeighbors); } lastDistance = d; } } } @Test public void testMatrix() throws DaoException { Map<Integer, Geometry> points = new HashMap<Integer, Geometry>(); for (int i = 0; i < 10000; i++) { points.put(i * 3, makePoint()); } SpatialDataDao dao = mock(SpatialDataDao.class); when(dao.getAllGeometriesInLayer("wikidata", Precision.LatLonPrecision.HIGH)) .thenReturn(points); SphericalDistanceMetric spherical = new SphericalDistanceMetric(dao); OrdinalDistanceMetric ordinal = new OrdinalDistanceMetric(dao, spherical); spherical.enableCache(true); ordinal.enableCache(true); List<Geometry> cols = new ArrayList<Geometry>(); List<Geometry> rows = new ArrayList<Geometry>(); for (int i = 0; i < 1000; i++) { cols.add(makePoint()); } for (int i = 0; i < 1000; i++) { rows.add(makePoint()); } float [][] ordinalDistance = ordinal.distance(rows, cols); float [][] sphericalDistance = spherical.distance(rows, cols); // Check that neighbor comparisons are correctly ordered with respect to // spherical distance. This isn't exhaustive, but it should eventually catch // errors. for (int i = 0; i < rows.size(); i++) { for (int j = 1; j < cols.size(); j++) { if (sphericalDistance[i][j] < sphericalDistance[i][j-1]) { assertTrue(ordinalDistance[i][j] <= ordinalDistance[i][j-1]); } else if (sphericalDistance[i][j] > sphericalDistance[i][j-1]) { assertTrue(ordinalDistance[i][j] >= ordinalDistance[i][j-1]); } else { // Otherwise the ordinal distances may differ by one (should basically never happen!) assertTrue(Math.abs(ordinalDistance[i][j] - ordinalDistance[i][j - 1]) <= 1.01); } } } } private Point makePoint() { double lat = 90 - random.nextDouble() * 180; double lon = 180 - random.nextDouble() * 360; return factory.createPoint(new Coordinate(lon, lat)); } }