/* * Copyright 2017 LinkedIn, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.linkedin.parseq; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; import static org.testng.Assert.fail; import java.io.IOException; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.linkedin.parseq.internal.PlanCompletionListener; import com.linkedin.parseq.internal.PlanContext; import com.linkedin.parseq.internal.TimeUnitHelper; import com.linkedin.parseq.promise.PromiseException; import com.linkedin.parseq.promise.Promises; import com.linkedin.parseq.promise.SettablePromise; import com.linkedin.parseq.trace.Trace; import com.linkedin.parseq.trace.TraceUtil; /** * A helper class for ParSeq unit tests. * * @author Jaroslaw Odzga (jodzga@linkedin.com) */ public class ParSeqUnitTestHelper { private static final Logger LOG = LoggerFactory.getLogger(ParSeqUnitTestHelper.class.getName()); private final Consumer<EngineBuilder> _engineCustomizer; private volatile ScheduledExecutorService _scheduler; private volatile ExecutorService _asyncExecutor; private volatile Engine _engine; private volatile ListLoggerFactory _loggerFactory; private volatile TaskDoneListener _taskDoneListener; public ParSeqUnitTestHelper() { this(engineBuilder -> {}); } public ParSeqUnitTestHelper(Consumer<EngineBuilder> engineCustomizer) { _engineCustomizer = engineCustomizer; } /** * Creates Engine instance to be used for testing. */ @SuppressWarnings("deprecation") public void setUp() throws Exception { final int numCores = Runtime.getRuntime().availableProcessors(); _scheduler = Executors.newScheduledThreadPool(numCores + 1); _asyncExecutor = Executors.newFixedThreadPool(2); _loggerFactory = new ListLoggerFactory(); EngineBuilder engineBuilder = new EngineBuilder().setTaskExecutor(_scheduler).setTimerScheduler(_scheduler).setLoggerFactory(_loggerFactory); AsyncCallableTask.register(engineBuilder, _asyncExecutor); _engineCustomizer.accept(engineBuilder); // Add taskDoneListener to engine builder. _taskDoneListener = new TaskDoneListener(); PlanCompletionListener planCompletionListener = engineBuilder.getPlanCompletionListener(); if (planCompletionListener == null) { engineBuilder.setPlanCompletionListener(_taskDoneListener); } else { engineBuilder.setPlanCompletionListener(planContext -> { try { planCompletionListener.onPlanCompleted(planContext); } catch (Throwable t) { LOG.error("Uncaught exception from custom planCompletionListener.", t); } finally { _taskDoneListener.onPlanCompleted(planContext); } }); } _engine = engineBuilder.build(); } /** * Equivalent to {@code tearDown(200, TimeUnit.MILLISECONDS);}. * @see #tearDown(int, TimeUnit) */ public void tearDown() throws Exception { tearDown(200, TimeUnit.MILLISECONDS); } public void tearDown(final int time, final TimeUnit unit) throws Exception { _engine.shutdown(); _engine.awaitTermination(time, unit); _engine = null; _scheduler.shutdownNow(); _scheduler = null; _asyncExecutor.shutdownNow(); _asyncExecutor = null; _loggerFactory.reset(); _loggerFactory = null; } public Engine getEngine() { return _engine; } public ScheduledExecutorService getScheduler() { return _scheduler; } /** * Equivalent to {@code runAndWait(this.getClass().getName(), task)}. * @see #runAndWait(String, Task, long, TimeUnit) */ public <T> T runAndWait(Task<T> task) { return runAndWait("runAndWait", task); } /** * Equivalent to {@code runAndWait(this.getClass().getName(), task, 5, TimeUnit.SECONDS)}. * @see #runAndWait(String, Task, long, TimeUnit) */ public <T> T runAndWait(Task<T> task, long time, TimeUnit timeUnit) { return runAndWait("runAndWait", task, time, timeUnit); } /** * Equivalent to {@code runAndWait(desc, task, 5, TimeUnit.SECONDS)}. * @see #runAndWait(String, Task, long, TimeUnit) */ public <T> T runAndWait(final String desc, Task<T> task) { return runAndWait(desc, task, 5, TimeUnit.SECONDS); } /** * Runs task, verifies that task finishes within specified amount of time, * logs trace from the task execution and return value which task completed with. * If task completes with an exception, it is re-thrown by this method. * * @param desc description of a test * @param task task to run * @param time amount of time to wait for task completion * @param timeUnit unit of time * @return value task was completed with or exception is being thrown if task failed */ public <T> T runAndWait(final String desc, Task<T> task, long time, TimeUnit timeUnit) { try { _engine.run(task); assertTrue(task.await(time, timeUnit)); return task.get(); } catch (InterruptedException e) { throw new RuntimeException(e); } finally { logTracingResults(desc, task); } } /** * Runs task, verifies that the entire plan(including side-effect tasks) * finishes within specified amount of time, logs trace from the task execution * and return value which task completed with. * If task completes with an exception, it is re-thrown by this method. * * @param desc description of a test * @param task task to run * @param time amount of time to wait for task completion * @param timeUnit unit of time * @param <T> task result type * @return value task was completed with or exception is being thrown if task failed */ public <T> T runAndWaitForPlanToComplete(final String desc, Task<T> task, long time, TimeUnit timeUnit) { try { _engine.run(task); _taskDoneListener.await(task, time, timeUnit); return task.get(); } catch (InterruptedException e) { throw new RuntimeException(e); } finally { logTracingResults(desc, task); } } public <T> T runAndWaitForPlanToComplete(Task<T> task, long time, TimeUnit timeUnit) { return runAndWaitForPlanToComplete("runAndWaitForPlanToComplete", task, time, timeUnit); } /** * Runs a task and verifies that it finishes with an error. * @param desc description of a test * @param task task to run * @param exceptionClass expected exception class * @param time amount of time to wait for task completion * @param timeUnit unit of time * @param <T> expected exception type * @return error returned by the task */ public <T extends Throwable> T runAndWaitException(final String desc, Task<?> task, Class<T> exceptionClass, long time, TimeUnit timeUnit) { try { runAndWait(desc, task, time, timeUnit); fail("An exception is expected, but the task succeeded"); // just to make the compiler happy, we will never get here return null; } catch (PromiseException pe) { Throwable cause = pe.getCause(); assertEquals(cause.getClass(), exceptionClass); return exceptionClass.cast(cause); } finally { logTracingResults(desc, task); } } /** * Equivalent to {@code runAndWaitException(desc, task, exceptionClass, 5, TimeUnit.SECONDS)}. * @see #runAndWaitException(String, Task, Class, long, TimeUnit) */ public <T extends Throwable> T runAndWaitException(final String desc, Task<?> task, Class<T> exceptionClass) { return runAndWaitException(desc, task, exceptionClass, 5, TimeUnit.SECONDS); } /** * Equivalent to {@code runAndWaitException(this.getClass().getName(), task, exceptionClass)}. * @see #runAndWaitException(String, Task, Class, long, TimeUnit) */ public <T extends Throwable> T runAndWaitException(Task<?> task, Class<T> exceptionClass) { return runAndWaitException("runAndWaitException", task, exceptionClass); } /** * Equivalent to {@code runAndWaitException(this.getClass().getName(), task, exceptionClass, time, timeUnit)}. * @see #runAndWaitException(String, Task, Class, long, TimeUnit) */ public <T extends Throwable> T runAndWaitException(Task<?> task, Class<T> exceptionClass, long time, TimeUnit timeUnit) { return runAndWaitException("runAndWaitException", task, exceptionClass, time, timeUnit); } /** * Runs task. * @param task task to run */ public void run(Task<?> task) { _engine.run(task); } public void logTracingResults(final String test, final Task<?> task) { try { LOG.info("Trace [" + test + "]:\n" + TraceUtil.getJsonTrace(task)); } catch (IOException e) { LOG.error("Failed to encode JSON"); } } public void setLogLevel(final String loggerName, final int level) { _loggerFactory.getLogger(loggerName).setLogLevel(level); } public List<ListLogger.Entry> getLogEntries(final String loggerName) { return _loggerFactory.getLogger(loggerName).getEntries(); } public void resetLoggers() { _loggerFactory.reset(); } /** * Returns task which completes with given value after specified period * of time. Timer starts counting the moment this method is invoked. */ public <T> Task<T> delayedValue(T value, long time, TimeUnit timeUnit) { return Task.async(value.toString() + " delayed " + time + " " + TimeUnitHelper.toString(timeUnit), () -> { final SettablePromise<T> promise = Promises.settable(); _scheduler.schedule(() -> promise.done(value), time, timeUnit); return promise; }); } /** * Returns task which fails with given error after specified period * of time. Timer starts counting the moment this method is invoked. */ public <T> Task<T> delayedFailure(Throwable error, long time, TimeUnit timeUnit) { return Task.async(error.toString() + " delayed " + time + " " + TimeUnitHelper.toString(timeUnit), () -> { final SettablePromise<T> promise = Promises.settable(); _scheduler.schedule(() -> promise.fail(error), time, timeUnit); return promise; }); } public int countTasks(Trace trace) { return trace.getTraceMap().size(); } private static final class TaskDoneListener implements PlanCompletionListener { private final ConcurrentMap<Task<?>, CountDownLatch> _taskDoneLatch = new ConcurrentHashMap<>(); @Override public void onPlanCompleted(PlanContext planContext) { CountDownLatch latch = _taskDoneLatch.computeIfAbsent(planContext.getRootTask(), key -> new CountDownLatch(1)); latch.countDown(); } public void await(Task<?> root, long timeout, TimeUnit unit) throws InterruptedException { CountDownLatch latch = _taskDoneLatch.computeIfAbsent(root, key -> new CountDownLatch(1)); latch.await(timeout, unit); } } }