/* * Copyright 2012 LinkedIn, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.linkedin.parseq; import java.util.Collection; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.linkedin.parseq.internal.IdGenerator; import com.linkedin.parseq.internal.TaskLogger; import com.linkedin.parseq.promise.DelegatingPromise; import com.linkedin.parseq.promise.Promise; import com.linkedin.parseq.promise.Promises; import com.linkedin.parseq.promise.SettablePromise; import com.linkedin.parseq.trace.Relationship; import com.linkedin.parseq.trace.ResultType; import com.linkedin.parseq.trace.ShallowTrace; import com.linkedin.parseq.trace.ShallowTraceBuilder; import com.linkedin.parseq.trace.Trace; import com.linkedin.parseq.trace.TraceBuilder; /** * An abstract base class that can be used to build implementations of * {@link Task}. * * @author Chris Pettitt (cpettitt@linkedin.com) * @author Chi Chan (ckchan@linkedin.com) */ public abstract class BaseTask<T> extends DelegatingPromise<T>implements Task<T> { private static final int TASK_NAME_MAX_LENGTH = 1024; static final Logger LOGGER = LoggerFactory.getLogger(BaseTask.class); private static final String CANONICAL_NAME = BaseTask.class.getCanonicalName(); private static enum StateType { // The initial state of the task. // // A task in this state can be cancelled and have its priority changed. INIT, // The task is currently executing. That is, a thread is in the run() // method for the TaskDef. // // A task in this state in not cancellable and cannot have its priority // changed. RUN, // The task has finished running, but the result has not yet been set. // This occurs for Tasks with AsyncTaskDefs. // // A task in this state in cancellable, and cannot have its priority // changed. PENDING, // The task is resolved. // // A task in this state in not cancellable and cannot have its priority // changed. DONE } private final Long _id = IdGenerator.getNextId(); private final AtomicReference<State> _stateRef; private final String _name; protected final ShallowTraceBuilder _shallowTraceBuilder; protected volatile Function<T, String> _traceValueProvider; private volatile TraceBuilder _traceBuilder; private final Throwable _taskStackTraceHolder; /** * Constructs a base task without a specified name. The name for this task * will be the {@link #toString} representation for this instance. It is * usually best to use {@link BaseTask#BaseTask(String)}. */ public BaseTask() { this(null); } /** * Constructs a base task with a name. * * @param name the name for this task. */ public BaseTask(final String name) { this(name, null); } /** * Constructs a base task with a name and type of task * * @param name the name for this task. * @param taskType the type of the task */ public BaseTask(final String name, final String taskType) { super(Promises.settable()); _name = truncate(name); final State state = State.INIT; _shallowTraceBuilder = new ShallowTraceBuilder(_id); _shallowTraceBuilder.setName(getName()); _shallowTraceBuilder.setResultType(ResultType.UNFINISHED); if (taskType != null) { _shallowTraceBuilder.setTaskType(taskType); } _stateRef = new AtomicReference<>(state); if (ParSeqGlobalConfiguration.isCrossThreadStackTracesEnabled()) { _taskStackTraceHolder = new Throwable(); } else { _taskStackTraceHolder = null; } } private String truncate(String name) { if (name == null || name.length() <= TASK_NAME_MAX_LENGTH) { return name; } else { return name.substring(0, TASK_NAME_MAX_LENGTH); } } @Override public Long getId() { return _id; } @Override public int getPriority() { return _stateRef.get().getPriority(); } @Override public boolean setPriority(final int priority) { if (priority < Priority.MIN_PRIORITY || priority > Priority.MAX_PRIORITY) { throw new IllegalArgumentException("Priority out of bounds: " + priority); } State state; State newState; do { state = _stateRef.get(); if (state.getType() != StateType.INIT) { return false; } newState = new State(state.getType(), priority); } while (!_stateRef.compareAndSet(state, newState)); return true; } @Override public TraceBuilder getTraceBuilder() { return _traceBuilder; } @Override public final void contextRun(final Context context, final Task<?> parent, final Collection<Task<?>> predecessors) { final TaskLogger taskLogger = context.getTaskLogger(); final TraceBuilder traceBuilder = context.getTraceBuilder(); if (transitionRun(traceBuilder)) { markTaskStarted(); final Promise<T> promise; try { if (parent != null) { traceBuilder.addRelationship(Relationship.CHILD_OF, getShallowTraceBuilder(), parent.getShallowTraceBuilder()); } for (Task<?> predecessor : predecessors) { traceBuilder.addRelationship(Relationship.SUCCESSOR_OF, getShallowTraceBuilder(), predecessor.getShallowTraceBuilder()); } taskLogger.logTaskStart(this); try { final Context wrapperContext = new WrappedContext(context); promise = doContextRun(wrapperContext); } finally { transitionPending(); } promise.addListener(resolvedPromise -> { if (resolvedPromise.isFailed()) { fail(resolvedPromise.getError(), taskLogger); } else { done(resolvedPromise.get(), taskLogger); } } ); } catch (Throwable t) { fail(t, taskLogger); } } else { //this is possible when task was cancelled or has been executed multiple times //e.g. task has multiple paths it can be executed with or has been completed by //a FusionTask if (parent != null) { traceBuilder.addRelationship(Relationship.POTENTIAL_CHILD_OF, getShallowTraceBuilder(), parent.getShallowTraceBuilder()); } for (Task<?> predecessor : predecessors) { traceBuilder.addRelationship(Relationship.POSSIBLE_SUCCESSOR_OF, getShallowTraceBuilder(), predecessor.getShallowTraceBuilder()); } } } @SuppressWarnings("unchecked") private Promise<T> doContextRun(final Context context) throws Throwable { return (Promise<T>) run(context); } /** * Returns the name of this task. If no name was set during construction * this method will return the value of {@link #toString()}. In most * cases it is preferable to explicitly set a name. * * @return the name of this task */ @Override public String getName() { return _name == null ? toString() : _name; } @Override public boolean cancel(final Exception rootReason) { if (transitionCancel(rootReason)) { final Exception reason = new CancellationException(rootReason); traceFailure(reason); getSettableDelegate().fail(reason); return true; } return false; } protected void traceFailure(final Throwable reason) { if (Exceptions.isEarlyFinish(reason)) { _shallowTraceBuilder.setResultType(ResultType.EARLY_FINISH); } else { _shallowTraceBuilder.setResultType(ResultType.ERROR); _shallowTraceBuilder.setValue(Exceptions.failureToString(reason)); } } @Override public ShallowTraceBuilder getShallowTraceBuilder() { return _shallowTraceBuilder; } @Override public ShallowTrace getShallowTrace() { return _shallowTraceBuilder.build(); } @Override public void setTraceValueSerializer(final Function<T, String> traceValueProvider) { _traceValueProvider = traceValueProvider; } @Override public Trace getTrace() { TraceBuilder traceBuilder = getTraceBuilder(); if (traceBuilder != null) { return traceBuilder.build(); } else { return Trace.single(getShallowTrace(), "none", 0L); } } /** * This template method is invoked when the task is run. * * @param context the context to use while running this task * @return a promise that will have its value set when this task is finished * @throws Exception if an error occurs while running the task */ protected abstract Promise<? extends T> run(final Context context) throws Throwable; private void traceDone(final T value) { _shallowTraceBuilder.setResultType(ResultType.SUCCESS); final Function<T, String> traceValueProvider = _traceValueProvider; if (traceValueProvider != null) { try { _shallowTraceBuilder.setValue(traceValueProvider.apply(value)); } catch (Exception e) { _shallowTraceBuilder.setValue(Exceptions.failureToString(e)); } } } private void done(final T value, final TaskLogger taskLogger) { if (transitionDone()) { traceDone(value); getSettableDelegate().done(value); taskLogger.logTaskEnd(BaseTask.this, _traceValueProvider); } } private void fail(final Throwable error, final TaskLogger taskLogger) { if (transitionDone()) { appendTaskStackTrace(error); traceFailure(error); getSettableDelegate().fail(error); taskLogger.logTaskEnd(BaseTask.this, _traceValueProvider); } } // Concatenate stack traces if kept the original stack trace from the task creation private void appendTaskStackTrace(final Throwable error) { StackTraceElement[] taskStackTrace = _taskStackTraceHolder != null ? _taskStackTraceHolder.getStackTrace() : null; // At a minimum, any stack trace should have at least 3 stack frames (caller + BaseTask + getStackTrace). // So if there are less than 3 stack frames available then there's something fishy and it's better to ignore them. if (!ParSeqGlobalConfiguration.isCrossThreadStackTracesEnabled() || error == null || taskStackTrace == null || taskStackTrace.length <= 2) { return; } StackTraceElement[] errorStackTrace = error.getStackTrace(); if (errorStackTrace.length <= 2) { return; } // Skip stack frames up to the BaseTask (useless java.util.concurrent stuff) int skipErrorFrames = 1; while (skipErrorFrames < errorStackTrace.length) { int index = errorStackTrace.length - 1 - skipErrorFrames; if (!errorStackTrace[index].getClassName().equals(CANONICAL_NAME) && errorStackTrace[index + 1].getClassName().equals(CANONICAL_NAME)) { break; } skipErrorFrames++; } // Safeguard against accidentally removing entire stack trace if (skipErrorFrames == errorStackTrace.length) { skipErrorFrames = 0; } // Skip stack frames up to the BaseTask (useless Thread.getStackTrace stuff) int skipTaskFrames = 1; while (skipTaskFrames < taskStackTrace.length) { if (!taskStackTrace[skipTaskFrames].getClassName().equals(CANONICAL_NAME) && taskStackTrace[skipTaskFrames - 1].getClassName().equals(CANONICAL_NAME)) { break; } skipTaskFrames++; } // Safeguard against accidentally removing entire stack trace if (skipTaskFrames == taskStackTrace.length) { skipTaskFrames = 0; } int combinedLength = errorStackTrace.length - skipErrorFrames + taskStackTrace.length - skipTaskFrames; if (combinedLength <= 0) { return; } StackTraceElement[] concatenatedStackTrace = new StackTraceElement[combinedLength + 1]; System.arraycopy(errorStackTrace, 0, concatenatedStackTrace, 0, errorStackTrace.length - skipErrorFrames); concatenatedStackTrace[errorStackTrace.length - skipErrorFrames] = new StackTraceElement("********** Task \"" + getName() + "\" (above) was instantiated as following (below): **********", "",null, 0); System.arraycopy(taskStackTrace, skipTaskFrames, concatenatedStackTrace, errorStackTrace.length - skipErrorFrames + 1, taskStackTrace.length - skipTaskFrames); error.setStackTrace(concatenatedStackTrace); } protected boolean transitionRun(final TraceBuilder traceBuilder) { State state; State newState; do { state = _stateRef.get(); if (state.getType() != StateType.INIT) { return false; } newState = state.transitionRun(); } while (!_stateRef.compareAndSet(state, newState)); _traceBuilder = traceBuilder; traceBuilder.addShallowTrace(_shallowTraceBuilder); return true; } protected void markTaskStarted() { _shallowTraceBuilder.setStartNanos(System.nanoTime()); } protected void transitionPending() { State state; State newState; do { state = _stateRef.get(); if (state.getType() != StateType.RUN) { return; } newState = state.transitionPending(); } while (!_stateRef.compareAndSet(state, newState)); markTaskPending(); } protected void markTaskPending() { _shallowTraceBuilder.setPendingNanos(System.nanoTime()); } protected boolean transitionCancel(final Exception reason) { State state; State newState; do { state = _stateRef.get(); final StateType type = state.getType(); if (type == StateType.RUN || type == StateType.DONE) { return false; } newState = state.transitionDone(); } while (!_stateRef.compareAndSet(state, newState)); return true; } protected boolean transitionDone() { State state; State newState; do { state = _stateRef.get(); if (state.getType() == StateType.DONE) { return false; } newState = state.transitionDone(); } while (!_stateRef.compareAndSet(state, newState)); return true; } protected SettablePromise<T> getSettableDelegate() { return (SettablePromise<T>) super.getDelegate(); } protected static class State { private final StateType _type; private final int _priority; private State(final StateType type, final int priority) { _type = type; _priority = priority; } public StateType getType() { return _type; } public int getPriority() { return _priority; } public State transitionRun() { return new State(StateType.RUN, _priority); } public State transitionPending() { return new State(StateType.PENDING, _priority); } public State transitionDone() { return new State(StateType.DONE, _priority); } public static final State INIT = new State(StateType.INIT, Priority.DEFAULT_PRIORITY) { @Override final public State transitionDone() { return DONE; }; @Override final public State transitionRun() { return RUN; } @Override final public State transitionPending() { return PENDING; } }; public static final State RUN = new State(StateType.RUN, Priority.DEFAULT_PRIORITY) { @Override final public State transitionDone() { return DONE; }; @Override final public State transitionRun() { return RUN; } @Override final public State transitionPending() { return PENDING; } }; public static final State PENDING = new State(StateType.PENDING, Priority.DEFAULT_PRIORITY) { @Override final public State transitionDone() { return DONE; }; @Override final public State transitionRun() { return RUN; } @Override final public State transitionPending() { return PENDING; } }; public static final State DONE = new State(StateType.DONE, Priority.DEFAULT_PRIORITY) { @Override final public State transitionDone() { return DONE; }; @Override final public State transitionRun() { return RUN; } @Override final public State transitionPending() { return PENDING; } }; } private class WrappedContext implements Context { private final Context _context; public WrappedContext(final Context context) { _context = context; } @Override public Cancellable createTimer(final long time, final TimeUnit unit, final Task<?> task) { final Cancellable cancellable = _context.createTimer(time, unit, task); getTraceBuilder().addRelationship(Relationship.POTENTIAL_PARENT_OF, getShallowTraceBuilder(), task.getShallowTraceBuilder()); return cancellable; } @Override public void run(final Task<?>... tasks) { _context.run(tasks); for (Task<?> task : tasks) { getTraceBuilder().addRelationship(Relationship.POTENTIAL_PARENT_OF, getShallowTraceBuilder(), task.getShallowTraceBuilder()); } } @Override public After after(final Promise<?>... promises) { return new WrappedAfter(_context.after(promises)); } @Override public Object getEngineProperty(String key) { return _context.getEngineProperty(key); } @Override public TraceBuilder getTraceBuilder() { return _context.getTraceBuilder(); } private class WrappedAfter implements After { private final After _after; public WrappedAfter(final After after) { _after = after; } @Override public void run(final Task<?> task) { _after.run(task); getTraceBuilder().addRelationship(Relationship.POTENTIAL_PARENT_OF, getShallowTraceBuilder(), task.getShallowTraceBuilder()); } @Override public void run(Supplier<Task<?>> taskSupplier) { _after.run(() -> { Task<?> task = taskSupplier.get(); if (task != null) { getTraceBuilder().addRelationship(Relationship.POTENTIAL_PARENT_OF, getShallowTraceBuilder(), task.getShallowTraceBuilder()); } return task; } ); } } @Override public ShallowTraceBuilder getShallowTraceBuilder() { return _context.getShallowTraceBuilder(); } @Override public Long getPlanId() { return _context.getPlanId(); } @Override public Long getTaskId() { return _context.getTaskId(); } @Override public TaskLogger getTaskLogger() { return _context.getTaskLogger(); } @Override public void runSideEffect(Task<?>... tasks) { _context.runSideEffect(tasks); for (Task<?> task : tasks) { getTraceBuilder().addRelationship(Relationship.POTENTIAL_PARENT_OF, getShallowTraceBuilder(), task.getShallowTraceBuilder()); } } @Override public String getPlanClass() { return _context.getPlanClass(); } } @Override public String toString() { return "Task [id=" + _id + ", name=" + _name + "]"; } }