/* * Copyright 2012 LinkedIn, Inc * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.linkedin.parseq; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; import com.linkedin.parseq.internal.InternalUtil; import com.linkedin.parseq.promise.Promise; import com.linkedin.parseq.promise.PromiseListener; import com.linkedin.parseq.promise.Promises; import com.linkedin.parseq.promise.SettablePromise; import com.linkedin.parseq.trace.ResultType; /** * A {@link Task} that will run all of the constructor-supplied tasks in parallel. * <p> * Use {@link Tasks#par(Task[])} or {@link Tasks#par(Iterable)} to create an * instance of this class. * * @author Chris Pettitt (cpettitt@linkedin.com) * @author Chi Chan (ckchan@linkedin.com) * @see Task#par(Task, Task) Task.par */ /* package private */ class ParTaskImpl<T> extends BaseTask<List<T>>implements ParTask<T> { private final Task<? extends Task<? extends T>>[] _tasks; @SuppressWarnings("unchecked") private static final Task<? extends Task<?>>[] EMPTY = new Task[0]; @SuppressWarnings("unchecked") public ParTaskImpl(final String name) { super(name); _tasks = (Task<? extends Task<? extends T>>[]) EMPTY; } public ParTaskImpl(final String name, final Iterable<? extends Task<? extends T>> tasks) { super(name); if (tasks instanceof Collection) { _tasks = tasksFromCollection((Collection<? extends Task<? extends T>>) tasks); } else { _tasks = tasksFromIterable(tasks); } if (_tasks.length == 0) { throw new IllegalArgumentException("No tasks to parallelize!"); } } @SuppressWarnings("unchecked") private Task<? extends Task<? extends T>>[] tasksFromIterable(Iterable<? extends Task<? extends T>> tasks) { List<Task<T>> taskList = new ArrayList<Task<T>>(); for (Task<? extends T> task : tasks) { // Safe to coerce Task<? extends T> to Task<T> final Task<T> coercedTask = (Task<T>) task; taskList.add(coercedTask); } return (Task<? extends Task<? extends T>>[]) taskList.toArray(); } @SuppressWarnings("unchecked") private Task<? extends Task<? extends T>>[] tasksFromCollection(Collection<? extends Task<? extends T>> tasks) { Task<? extends Task<? extends T>>[] tasksArr = new Task[tasks.size()]; int i = 0; for (@SuppressWarnings("rawtypes") Task task: tasks) { tasksArr[i++] = task; } return tasksArr; } @Override protected Promise<List<T>> run(final Context context) throws Exception { if (_tasks.length == 0) { return Promises.value(Collections.emptyList()); } final SettablePromise<List<T>> result = Promises.settable(); @SuppressWarnings("unchecked") final PromiseListener<?> listener = resolvedPromise -> { boolean allEarlyFinish = true; final List<T> taskResult = new ArrayList<T>(_tasks.length); List<Throwable> errors = null; for (Task<?> task : _tasks) { if (task.isFailed()) { if (allEarlyFinish && ResultType.fromTask(task) != ResultType.EARLY_FINISH) { allEarlyFinish = false; } if (errors == null) { errors = new ArrayList<Throwable>(); } errors.add(task.getError()); } else { taskResult.add((T) task.get()); } } if (errors != null) { result.fail(allEarlyFinish ? errors.get(0) : new MultiException("Multiple errors in 'ParTask' task.", errors)); } else { result.done(taskResult); } }; InternalUtil.after(listener, _tasks); for (Task<?> task : _tasks) { context.run(task); } return result; } @SuppressWarnings("unchecked") @Override public List<Task<T>> getTasks() { List<Task<T>> tasks = new ArrayList<>(_tasks.length); for (Task<?> task : _tasks) { tasks.add((Task<T>) task); } return tasks; } @SuppressWarnings("unchecked") @Override public List<T> getSuccessful() { if (!this.isFailed()) { return this.get(); } final List<T> taskResult = new ArrayList<>(); for (Task<?> task : _tasks) { if (!task.isFailed()) { taskResult.add((T) task.get()); } } return taskResult; } }