package org.wikibrain.spatial.distance;
import com.vividsolutions.jts.geom.Geometry;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import org.wikibrain.core.dao.DaoException;
import org.wikibrain.spatial.constants.Precision;
import org.wikibrain.spatial.dao.SpatialDataDao;
import org.wikibrain.spatial.util.ClosestPointIndex;
import org.wikibrain.utils.ParallelForEach;
import org.wikibrain.utils.Procedure;
import org.wikibrain.utils.WpThreadUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Estimates the number of kilometers between geometries.
*
* @author Shilad Sen
*/
public class GraphDistanceMetric implements SpatialDistanceMetric {
private static final Logger LOG = LoggerFactory.getLogger(GraphDistanceMetric.class);
private final SpatialDataDao spatialDao;
private final TIntObjectMap<TIntSet> adjacencyList = new TIntObjectHashMap<TIntSet>();
private final ClosestPointIndex index;
private int numNeighbors = 100;
private int maxDistance = 30;
private TIntSet concepts;
private TIntSet validNodes;
private boolean directed = false;
public GraphDistanceMetric(SpatialDataDao dao, ClosestPointIndex index) {
this.spatialDao = dao;
this.index = index;
}
public GraphDistanceMetric(SpatialDataDao dao, SphericalDistanceMetric spherical) {
this.spatialDao = dao;
this.index = spherical.getIndex();
if (spherical.getValidConcepts() != null) {
LOG.warn("Warning: ClosestPoint index has been constrained to particular concepts. You probably don't want this. Instead, let GraphDistanceMetric create its own index");
}
}
public GraphDistanceMetric(SpatialDataDao dao) {
this.spatialDao = dao;
this.index = new ClosestPointIndex();
}
public void setNumNeighbors(int numNeighbors) {
this.numNeighbors = numNeighbors;
}
public void setMaxDistance(int maxDistance) {
this.maxDistance = maxDistance;
}
/**
* Sets the nodes that can be traversed as neighbors.
* @param nodes
*/
public void setValidNodes(TIntSet nodes) {
this.validNodes = nodes;
}
/**
* Sets the nodes that can be RETURNED as neighbors.
* @param concepts
*/
@Override
public void setValidConcepts(TIntSet concepts) {
this.concepts = concepts;
}
@Override
public void enableCache(boolean enable) throws DaoException {
if (index == null) throw new NullPointerException();
final AtomicInteger numEdges = new AtomicInteger();
final Map<Integer, Geometry> points = this.spatialDao.getAllGeometriesInLayer("wikidata", Precision.LatLonPrecision.HIGH);
// Insert points into the index if necessary.
if (index.size() == 0) {
ParallelForEach.loop(points.keySet(), WpThreadUtils.getMaxThreads(),
new Procedure<Integer>() {
@Override
public void call(Integer conceptId) throws Exception {
index.insert(conceptId, points.get(conceptId));
}
}, 50000);
}
ParallelForEach.loop(points.keySet(), WpThreadUtils.getMaxThreads(),
new Procedure<Integer>() {
@Override
public void call(Integer conceptId) throws Exception {
if (validNodes != null && !validNodes.contains(conceptId)) {
return;
}
final TIntSet neighbors = new TIntHashSet();
for (ClosestPointIndex.Result r : index.query(points.get(conceptId), numNeighbors)) {
neighbors.add(r.id);
}
numEdges.addAndGet(neighbors.size());
synchronized (adjacencyList) {
adjacencyList.put(conceptId, neighbors);
}
}
}, 50000);
// Make links symmetric if necessary
if (!directed) {
for (int id1 : adjacencyList.keys()) {
for (int id2 : adjacencyList.get(id1).toArray()) {
if (!adjacencyList.containsKey(id2)) {
adjacencyList.put(id2, new TIntHashSet());
}
adjacencyList.get(id2).add(id1);
}
}
}
LOG.info("Found " + adjacencyList.size() + " edges and " + numEdges.get() + " edges.");
}
@Override
public String getName() {
return "graph distance metric";
}
@Override
public double distance(Geometry g1, Geometry g2) {
if (adjacencyList.isEmpty()) {
throw new UnsupportedOperationException();
}
// Hack: Replace g2 with CLOSEST concept
List<ClosestPointIndex.Result> closest = index.query(g2, 1);
int maxSteps = maxDistance;
if (maxSteps == 0 || closest.isEmpty()) {
return Double.POSITIVE_INFINITY;
}
if (g1 == g2 || g1.equals(g2)) {
return 0;
}
int targetId = closest.get(0).id;
TIntSet seen = new TIntHashSet();
TIntLinkedList queue = new TIntLinkedList();
for (ClosestPointIndex.Result n : index.query(g1, numNeighbors)) {
if (n.id== targetId) {
return 1;
}
queue.add(n.id);
seen.add(n.id);
}
for (int level = 2; level <= maxSteps; level++) {
// Do all nodes at this level
int nodes = queue.size();
for (int i = 0; i < nodes; i++) {
int id = queue.removeAt(0);
if (!adjacencyList.containsKey(id)) {
continue;
}
for (int id2 : adjacencyList.get(id).toArray()) {
if (id2 == targetId) {
return level;
}
if (!seen.contains(id2)) {
queue.add(id2);
seen.add(id2);
}
}
}
// System.err.println("at level " + level + " saw " + seen.size());
}
return Double.POSITIVE_INFINITY;
}
@Override
public float[][] distance(List<Geometry> rowGeometries, List<Geometry> colGeometries) {
return new float[0][];
}
@Override
public float[][] distance(List<Geometry> geometries) {
return new float[0][];
}
@Override
public List<Neighbor> getNeighbors(Geometry g, int maxNeighbors) {
return getNeighbors(g, maxNeighbors, Integer.MAX_VALUE);
}
@Override
public List<Neighbor> getNeighbors(Geometry g, int maxNeighbors, double maxDistance) {
List<Neighbor> result = new ArrayList<Neighbor>();
int maxSteps = (int) Math.round(maxDistance);
if (maxSteps == 0) {
return result;
}
TIntSet seen = new TIntHashSet();
TIntLinkedList queue = new TIntLinkedList();
for (ClosestPointIndex.Result r : index.query(g, numNeighbors)) {
queue.add(r.id);
seen.add(r.id);
if (concepts == null || concepts.contains(r.id)) {
result.add(new Neighbor(r.id, 1));
}
}
for (int level = 2; !queue.isEmpty() && level <= maxSteps; level++) {
// System.err.println("at level " + level + ", size is " + seen.size() + ", " + result.size());
// Do all nodes at this level
int nodes = queue.size();
for (int i = 0; i < nodes; i++) {
int id = queue.removeAt(0);
if (!adjacencyList.containsKey(id)) {
continue;
}
for (int id2 : adjacencyList.get(id).toArray()) {
if (!seen.contains(id2)) {
queue.add(id2);
seen.add(id2);
if (concepts == null || concepts.contains(id2)) {
result.add(new Neighbor(id2, level));
}
}
if (result.size() >= maxNeighbors) {
return result;
}
}
}
}
return result;
}
public void setDirected(boolean directed) {
this.directed = directed;
}
}