package org.wikibrain.utils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* @author Shilad Sen
* Utilities to run for each loops in parallel.
*/
public class ParallelForEach {
public static final Logger LOG = LoggerFactory.getLogger(ParallelForEach.class);
/**
* Construct a parallel loop on [from, to).
*
* @param from bottom of range (inclusive)
* @param to top of range (exclusive)
* @param numThreads
* @param fn callback
*/
public static void range(int from, int to, int numThreads, final Procedure<Integer> fn) {
iterate(new IntRangeIterator(from, to), numThreads, 10000, fn, Integer.MAX_VALUE);
}
public static void range(int from, int to, final Procedure<Integer> fn) {
iterate(new IntRangeIterator(from, to), WpThreadUtils.getMaxThreads(), 10000, fn, Integer.MAX_VALUE);
}
public static <T,R> List<R> range(int from, int to, int numThreads, final Function<Integer, R> fn) {
List<Integer> range = new ArrayList<Integer>();
for (int i = from; i < to; i++) { range.add(i); }
return loop(range, numThreads, fn);
}
public static <T,R> List<R> range(int from, int to, final Function<Integer, R> fn) {
return range(from, to, WpThreadUtils.getMaxThreads(), fn);
}
public static <T,R> List<R> loop(
Collection<T> collection,
int numThreads,
final Function<T,R> fn) {
return loop(collection, numThreads, fn, 50);
}
public static <T,R> List<R> loop(
Collection<T> collection,
final Function<T,R> fn) {
return loop(collection, WpThreadUtils.getMaxThreads(), fn, 50);
}
public static <T> void loop(
Collection<T> collection,
final Procedure<T> fn) {
loop(collection, WpThreadUtils.getMaxThreads(), fn, 50);
}
public static <T> void loop(
Collection<T> collection,
int numThreads,
final Procedure<T> fn) {
loop(collection, numThreads, fn, 50);
}
public static <T> void loop(
Collection<T> collection,
int numThreads,
final Procedure<T> fn,
final int logModulo) {
loop(collection, numThreads, new Function<T, Object> () {
public Object call(T arg) throws Exception {
fn.call(arg);
return null;
}
}, logModulo);
}
public static <T> void loop(
Collection<T> collection,
final Procedure<T> fn,
final int logModulo) {
loop(collection, WpThreadUtils.getMaxThreads(), new Function<T, Object> () {
public Object call(T arg) throws Exception {
fn.call(arg);
return null;
}
}, logModulo);
}
public static <T,R> List<R> loop(
Collection<T> collection,
final Function<T,R> fn,
final int logModulo) {
return loop(collection, WpThreadUtils.getMaxThreads(), fn, logModulo);
}
public static <T,R> List<R> loop(
Collection<T> collection,
int numThreads,
final Function<T,R> fn,
final int logModulo) {
final List<R> result = new ArrayList<R>();
for (int i = 0; i < collection.size(); i++) result.add(null);
final ExecutorService exec = new ThreadPoolErrors(numThreads);
final CountDownLatch latch = new CountDownLatch(collection.size());
try {
// create a copy so that modifications to original list are safe
final List<T> asList = new ArrayList<T>(collection);
for (int i = 0; i < asList.size(); i++) {
final int finalI = i;
exec.submit(new Runnable() {
public void run() {
T obj = asList.get(finalI);
try {
if (finalI % logModulo == 0) {
LOG.info("processing list element " + (finalI+1) + " of " + asList.size());
}
R r = fn.call(obj);
result.set(finalI, r);
} catch (Exception e) {
e.printStackTrace();
LOG.error("error processing list element " + obj, e);
LOG.error("stacktrace: " + ExceptionUtils.getStackTrace(e).replaceAll("\n", " ").replaceAll("\\s+", " "));
} finally {
latch.countDown();
}
}});
}
latch.await();
return result;
} catch (InterruptedException e) {
LOG.error("Interrupted parallel for each", e);
throw new RuntimeException(e);
} finally {
exec.shutdown();
}
}
public static <T> void iterate(Iterator<T> iterator, final Procedure<T> fn, int logModulo) {
iterate(iterator, WpThreadUtils.getMaxThreads(), 100, fn, logModulo);
}
public static <T> void iterate(Iterator<T> iterator, final Procedure<T> fn) {
iterate(iterator, WpThreadUtils.getMaxThreads(), 100, fn, -1);
}
public static <T> void iterate(
Iterator<T> iterator,
int numThreads,
int queueSize,
final Procedure<T> fn,
final int logModulo) {
final ExecutorService exec = new ThreadPoolErrors(numThreads);
BoundedExecutor boundedExec = new BoundedExecutor(exec, queueSize);
final AtomicInteger counter = new AtomicInteger(0);
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger elemsToGo = new AtomicInteger(0);
try {
// create a copy so that modifications to original list are safe
elemsToGo.incrementAndGet();
while (iterator.hasNext()) {
final T obj = iterator.next();
elemsToGo.incrementAndGet();
boundedExec.submitTask(new Runnable() {
public void run() {
try {
int i = counter.incrementAndGet();
if (logModulo >= 0 && i % logModulo == 0) {
LOG.info("processing iterable " + i);
}
fn.call(obj);
} catch (Exception e) {
e.printStackTrace();
LOG.error("error processing list element " + obj, e);
LOG.error("stacktrace: " + ExceptionUtils.getStackTrace(e).replaceAll("\n", " ").replaceAll("\\s+", " "));
} finally {
if (elemsToGo.decrementAndGet() == 0) {
latch.countDown();
}
}
}
});
}
if (elemsToGo.decrementAndGet() > 0) {
latch.await();
}
} catch (InterruptedException e) {
LOG.error("Interrupted parallel for each", e);
throw new RuntimeException(e);
} finally {
exec.shutdown();
}
}
/**
* This code adapted from:
* http://stackoverflow.com/questions/2248131/handling-exceptions-from-java-executorservice-tasks
*/
private static class ThreadPoolErrors extends ThreadPoolExecutor {
public ThreadPoolErrors(int threads) {
super( threads, // core threads
threads, // max threads
0, // timeout
TimeUnit.MILLISECONDS, // timeout units
new LinkedBlockingQueue<Runnable>() // work queue
);
}
protected void afterExecute(Runnable r, Throwable t) {
super.afterExecute(r, t);
if (t == null && r instanceof Future<?>) {
try {
Future<?> future = (Future<?>) r;
if (future.isDone()) {
future.get();
}
} catch (CancellationException ce) {
t = ce;
} catch (ExecutionException ee) {
t = ee.getCause();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt(); // ignore/reset
}
}
if (t != null) {
LOG.error("Uncaught Exception: ", t);
LOG.error("stacktrace: " + ExceptionUtils.getStackTrace(t).replaceAll("\n", " ").replaceAll("\\s+", " "));
}
}
}
/**
* From
*/
public static class BoundedExecutor {
private final Executor exec;
private final Semaphore semaphore;
public BoundedExecutor(Executor exec, int bound) {
this.exec = exec;
this.semaphore = new Semaphore(bound);
}
public void submitTask(final Runnable command)
throws InterruptedException, RejectedExecutionException {
semaphore.acquire();
try {
exec.execute(new Runnable() {
public void run() {
try {
command.run();
} finally {
semaphore.release();
}
}
});
} catch (RejectedExecutionException e) {
semaphore.release();
throw e;
}
}
}
}