package com.linkedin.parseq.exec; import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ConcurrentSkipListSet; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.linkedin.parseq.Exceptions; import com.linkedin.parseq.Task; import com.linkedin.parseq.promise.Promises; import com.linkedin.parseq.promise.SettablePromise; public class Exec { private static final Logger LOGGER = LoggerFactory.getLogger(Exec.class); private final ConcurrentMap<Process, ProcessEntry> _runningProcesses = new ConcurrentHashMap<>(); private final ConcurrentMap<Long, Process> _runningProcessesByTaskId = new ConcurrentHashMap<>(); private final ScheduledExecutorService _reaperExecutor = Executors.newSingleThreadScheduledExecutor(); private final AtomicLong _seqGenerator = new AtomicLong(0); private final ConcurrentSkipListSet<ProcessRequest> _processRequestQueue = new ConcurrentSkipListSet<>(Comparator.comparingLong(request -> request.getSeq())); private final int _parallelizationLevel; private final long _reaperDelayMs; private final int _maxProcessQueueSize; private final AtomicInteger _processQueueSize = new AtomicInteger(0); private volatile boolean _shutdownInitiated = false; public Exec(int parallelizationLevel, long reaperDelayMs, int maxProcessQueueSize) { _parallelizationLevel = parallelizationLevel; _reaperDelayMs = reaperDelayMs; _maxProcessQueueSize = maxProcessQueueSize; } public static class Result { private final Path _stdout; private final Path _stderr; private final int status; public Result(int status, Path stdout, Path stderr) { this.status = status; _stdout = stdout; _stderr = stderr; } public Path getStdout() { return _stdout; } public Path getStderr() { return _stderr; } public int getStatus() { return status; } } private static class ProcessEntry { private final SettablePromise<Result> _resultPromise; private final Path _stdout; private final Path _stderr; private final Long _taskId; public ProcessEntry(SettablePromise<Result> resultPromise, Path stdout, Path stderr, Long taskId) { _resultPromise = resultPromise; _stderr = stderr; _stdout = stdout; _taskId = taskId; } public SettablePromise<Result> getResultPromise() { return _resultPromise; } public Path getStdout() { return _stdout; } public Path getStderr() { return _stderr; } public Long getTaskId() { return _taskId; } } private static class ProcessRequest { private final long _seq; private final ProcessBuilder _builder; private final ProcessEntry _entry; private final long _timeout; private final TimeUnit _timeUnit; private final Long _taskId; public ProcessRequest(long seq, ProcessBuilder builder, ProcessEntry entry, final long timeout, final TimeUnit timeUnit, Long taskId) { _seq = seq; _builder = builder; _entry = entry; _timeout = timeout; _timeUnit = timeUnit; _taskId = taskId; } public long getSeq() { return _seq; } public ProcessBuilder getBuilder() { return _builder; } public ProcessEntry getEntry() { return _entry; } public long getTimeout() { return _timeout; } public TimeUnit getTimeUnit() { return _timeUnit; } public Long getTaskId() { return _taskId; } } public Task<Result> command(final String desc, final long timeout, final TimeUnit timeUnit, final String... command) { final Task<Result> task = Task.async(desc, ctx -> { int queueSize = _processQueueSize.get(); if (_shutdownInitiated) { throw new IllegalStateException("can't start new process because Exec has been shut down"); } else if (queueSize >= _maxProcessQueueSize) { throw new RuntimeException("queue for processes to run is full, size: " + queueSize); } else { final SettablePromise<Result> result = Promises.settable(); final ProcessBuilder builder = new ProcessBuilder(command); final Path stderr = Files.createTempFile("parseq-Exec", ".stderr"); final Path stdout = Files.createTempFile("parseq-Exec", ".stdout"); builder.redirectError(stderr.toFile()); builder.redirectOutput(stdout.toFile()); final ProcessRequest request = new ProcessRequest(_seqGenerator.getAndIncrement(), builder, new ProcessEntry(result, stdout, stderr, ctx.getTaskId()), timeout, timeUnit, ctx.getTaskId()); _processRequestQueue.add(request); _processQueueSize.incrementAndGet(); return result; } }); task.addListener(p -> { if (p.isFailed() && Exceptions.isCancellation(p.getError())) { //best effort to try to kill process in case task was cancelled Process process = _runningProcessesByTaskId.get(task.getId()); if (process != null) { process.destroyForcibly(); } } }); return task; } public void start() { _reaperExecutor.scheduleWithFixedDelay(() -> { try { for (Entry<Process, ProcessEntry> en: _runningProcesses.entrySet()) { final Process process = en.getKey(); final ProcessEntry entry = en.getValue(); if (!process.isAlive()) { _runningProcesses.remove(process); _runningProcessesByTaskId.remove(entry.getTaskId()); final Result result = new Result(process.exitValue(), entry.getStdout(), entry.getStderr()); entry.getResultPromise().done(result); } } while (_runningProcesses.size() < _parallelizationLevel && !_processRequestQueue.isEmpty()) { ProcessRequest request = _processRequestQueue.pollFirst(); if (request != null) { _processQueueSize.decrementAndGet(); final Process process = request.getBuilder().start(); _runningProcesses.put(process, request.getEntry()); _runningProcessesByTaskId.put(request.getTaskId(), process); _reaperExecutor.schedule(() -> { if (process.isAlive()) { process.destroyForcibly(); } }, request.getTimeout(), request.getTimeUnit()); } } } catch (Exception e) { LOGGER.error("error while checking process status", e); } }, _reaperDelayMs, _reaperDelayMs, TimeUnit.MILLISECONDS); } public void stop() { _shutdownInitiated = true; _reaperExecutor.shutdown(); } }