/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.runtime.task; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.ReentrantLock; import org.apache.hadoop.ipc.ProtocolSignature; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.tez.common.ContainerContext; import org.apache.tez.common.ContainerTask; import org.apache.tez.common.TezTaskUmbilicalProtocol; import org.apache.tez.dag.api.TezException; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.runtime.api.AbstractLogicalIOProcessor; import org.apache.tez.runtime.api.Event; import org.apache.tez.runtime.api.TaskFailureType; import org.apache.tez.runtime.api.LogicalInput; import org.apache.tez.runtime.api.LogicalOutput; import org.apache.tez.runtime.api.ProcessorContext; import org.apache.tez.runtime.api.events.TaskAttemptCompletedEvent; import org.apache.tez.runtime.api.events.TaskAttemptFailedEvent; import org.apache.tez.runtime.api.events.TaskAttemptKilledEvent; import org.apache.tez.runtime.api.impl.TezEvent; import org.apache.tez.runtime.api.impl.TezHeartbeatRequest; import org.apache.tez.runtime.api.impl.TezHeartbeatResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class TaskExecutionTestHelpers { public static final String HEARTBEAT_EXCEPTION_STRING = "HeartbeatException"; // Uses static fields for signaling. Ensure only used by one test at a time. public static class TestProcessor extends AbstractLogicalIOProcessor { private static final int EMPTY = 0; private static final int THROW_IO_EXCEPTION = 1; private static final int THROW_TEZ_EXCEPTION = 2; private static final int SIGNAL_DEPRECATEDFATAL_AND_THROW = 3; private static final int SIGNAL_DEPRECATEDFATAL_AND_LOOP = 4; private static final int SIGNAL_DEPRECATEDFATAL_AND_COMPLETE = 5; private static final int SIGNAL_FATAL_AND_THROW = 6; private static final int SIGNAL_NON_FATAL_AND_THROW = 7; private static final int SELF_KILL_AND_COMPLETE = 8; public static final byte[] CONF_EMPTY = new byte[]{EMPTY}; public static final byte[] CONF_THROW_IO_EXCEPTION = new byte[]{THROW_IO_EXCEPTION}; public static final byte[] CONF_THROW_TEZ_EXCEPTION = new byte[]{THROW_TEZ_EXCEPTION}; public static final byte[] CONF_SIGNAL_DEPRECATEDFATAL_AND_THROW = new byte[]{SIGNAL_DEPRECATEDFATAL_AND_THROW}; public static final byte[] CONF_SIGNAL_DEPRECATEDFATAL_AND_LOOP = new byte[]{SIGNAL_DEPRECATEDFATAL_AND_LOOP}; public static final byte[] CONF_SIGNAL_DEPRECATEDFATAL_AND_COMPLETE = new byte[]{SIGNAL_DEPRECATEDFATAL_AND_COMPLETE}; public static final byte[] CONF_SIGNAL_FATAL_AND_THROW = new byte[]{SIGNAL_FATAL_AND_THROW}; public static final byte[] CONF_SIGNAL_NON_FATAL_AND_THROW = new byte[]{SIGNAL_NON_FATAL_AND_THROW}; public static final byte[] CONF_SELF_KILL_AND_COMPLETE = new byte[]{SELF_KILL_AND_COMPLETE}; private static final Logger LOG = LoggerFactory.getLogger(TestProcessor.class); private static final ReentrantLock processorLock = new ReentrantLock(); private static final Condition processorCondition = processorLock.newCondition(); private static final Condition loopCondition = processorLock.newCondition(); private static final Condition completionCondition = processorLock.newCondition(); private static final Condition runningCondition = processorLock.newCondition(); private static volatile boolean completed = false; private static volatile boolean running = false; private static volatile boolean looping = false; private static volatile boolean signalled = false; private static boolean receivedInterrupt = false; private static volatile boolean wasAborted = false; private boolean throwIOException = false; private boolean throwTezException = false; private boolean signalDeprecatedFatalAndThrow = false; private boolean signalDeprecatedFatalAndLoop = false; private boolean signalDeprecatedFatalAndComplete = false; private boolean signalFatalAndThrow = false; private boolean signalNonFatalAndThrow = false; private boolean selfKillAndComplete = false; public TestProcessor(ProcessorContext context) { super(context); } @Override public void initialize() throws Exception { parseConf(getContext().getUserPayload().deepCopyAsArray()); } @Override public void handleEvents(List<Event> processorEvents) { } @Override public void close() throws Exception { } private void parseConf(byte[] bytes) { byte b = bytes[0]; throwIOException = (b == THROW_IO_EXCEPTION); throwTezException = (b == THROW_TEZ_EXCEPTION); signalDeprecatedFatalAndThrow = (b == SIGNAL_DEPRECATEDFATAL_AND_THROW); signalDeprecatedFatalAndLoop = (b == SIGNAL_DEPRECATEDFATAL_AND_LOOP); signalDeprecatedFatalAndComplete = (b == SIGNAL_DEPRECATEDFATAL_AND_COMPLETE); signalFatalAndThrow = (b == SIGNAL_FATAL_AND_THROW); signalNonFatalAndThrow = (b == SIGNAL_NON_FATAL_AND_THROW); selfKillAndComplete = (b == SELF_KILL_AND_COMPLETE); } public static void reset() { signalled = false; receivedInterrupt = false; completed = false; running = false; wasAborted = false; } public static void signal() { LOG.info("Signalled"); processorLock.lock(); try { signalled = true; processorCondition.signal(); } finally { processorLock.unlock(); } } public static void awaitStart() throws InterruptedException { LOG.info("Awaiting Process run"); processorLock.lock(); try { if (running) { return; } runningCondition.await(); } finally { processorLock.unlock(); } } public static void awaitLoop() throws InterruptedException { LOG.info("Awaiting loop after signalling error"); processorLock.lock(); try { if (looping) { return; } loopCondition.await(); } finally { processorLock.unlock(); } } public static void awaitCompletion() throws InterruptedException { LOG.info("Await completion"); processorLock.lock(); try { if (completed) { return; } else { completionCondition.await(); } } finally { processorLock.unlock(); } } public static boolean wasInterrupted() { processorLock.lock(); try { return receivedInterrupt; } finally { processorLock.unlock(); } } public static boolean wasAborted() { processorLock.lock(); try { return wasAborted; } finally { processorLock.unlock(); } } @Override public void abort() { wasAborted = true; } @SuppressWarnings("deprecation") @Override public void run(Map<String, LogicalInput> inputs, Map<String, LogicalOutput> outputs) throws Exception { processorLock.lock(); running = true; runningCondition.signal(); try { try { LOG.info("Signal is: " + signalled); if (!signalled) { LOG.info("Waiting for processor signal"); processorCondition.await(); } if (Thread.currentThread().isInterrupted()) { throw new InterruptedException(); } LOG.info("Received processor signal"); if (throwIOException) { throw createProcessorIOException(); } else if (throwTezException) { throw createProcessorTezException(); } else if (signalDeprecatedFatalAndThrow) { IOException io = new IOException(IOException.class.getSimpleName()); getContext().fatalError(io, IOException.class.getSimpleName()); throw io; } else if (signalDeprecatedFatalAndComplete) { IOException io = new IOException(IOException.class.getSimpleName()); getContext().fatalError(io, IOException.class.getSimpleName()); return; } else if (signalDeprecatedFatalAndLoop) { IOException io = createProcessorIOException(); getContext().fatalError(io, IOException.class.getSimpleName()); LOG.info("looping"); looping = true; loopCondition.signal(); LOG.info("Waiting for Processor signal again"); processorCondition.await(); LOG.info("Received second processor signal"); } else if (signalFatalAndThrow) { IOException io = new IOException(IOException.class.getSimpleName()); getContext().reportFailure(TaskFailureType.FATAL, io, IOException.class.getSimpleName()); LOG.info("throwing"); throw io; } else if (signalNonFatalAndThrow) { IOException io = new IOException(IOException.class.getSimpleName()); getContext().reportFailure(TaskFailureType.NON_FATAL, io, IOException.class.getSimpleName()); LOG.info("throwing"); throw io; } else if (selfKillAndComplete) { LOG.info("Reporting kill self"); getContext().killSelf(new IOException(IOException.class.getSimpleName()), "SELFKILL"); LOG.info("Returning"); } } catch (InterruptedException e) { receivedInterrupt = true; } } finally { completed = true; completionCondition.signal(); processorLock.unlock(); } } } public static TezException createProcessorTezException() { return new TezException("TezException"); } public static IOException createProcessorIOException() { return new IOException("IOException"); } public static class TezTaskUmbilicalForTest implements TezTaskUmbilicalProtocol { private static final Logger LOG = LoggerFactory.getLogger(TezTaskUmbilicalForTest.class); private final List<TezEvent> requestEvents = new LinkedList<TezEvent>(); private final ReentrantLock umbilicalLock = new ReentrantLock(); private final Condition eventCondition = umbilicalLock.newCondition(); private boolean pendingEvent = false; private boolean eventEnacted = false; volatile int getTaskInvocations = 0; private boolean shouldThrowException = false; private boolean shouldSendDieSignal = false; public void signalThrowException() { umbilicalLock.lock(); try { shouldThrowException = true; pendingEvent = true; } finally { umbilicalLock.unlock(); } } public void signalSendShouldDie() { umbilicalLock.lock(); try { shouldSendDieSignal = true; pendingEvent = true; } finally { umbilicalLock.unlock(); } } public void awaitRegisteredEvent() throws InterruptedException { umbilicalLock.lock(); try { if (eventEnacted) { return; } LOG.info("Awaiting event"); eventCondition.await(); } finally { umbilicalLock.unlock(); } } public void resetTrackedEvents() { umbilicalLock.lock(); try { requestEvents.clear(); } finally { umbilicalLock.unlock(); } } public void verifyNoCompletionEvents() { umbilicalLock.lock(); try { for (TezEvent event : requestEvents) { if (event.getEvent() instanceof TaskAttemptFailedEvent) { fail("Found a TaskAttemptFailedEvent when not expected"); } if (event.getEvent() instanceof TaskAttemptCompletedEvent) { fail("Found a TaskAttemptCompletedvent when not expected"); } } } finally { umbilicalLock.unlock(); } } public void verifyTaskFailedEvent(String diagnostics) { umbilicalLock.lock(); try { for (TezEvent event : requestEvents) { if (event.getEvent() instanceof TaskAttemptFailedEvent) { TaskAttemptFailedEvent failedEvent = (TaskAttemptFailedEvent) event.getEvent(); if (failedEvent.getDiagnostics().startsWith(diagnostics)) { return; } else { fail("Diagnostic message does not match expected message. Found [" + failedEvent.getDiagnostics() + "], Expected: [" + diagnostics + "]"); } } } fail("No TaskAttemptFailedEvents sent over umbilical"); } finally { umbilicalLock.unlock(); } } public void verifyTaskFailedEvent(String diagStart, String diagContains) { verifyTaskFailedEvent(diagStart, diagContains, TaskFailureType.NON_FATAL); } public void verifyTaskFailedEvent(String diagStart, String diagContains, TaskFailureType taskFailureType) { umbilicalLock.lock(); try { for (TezEvent event : requestEvents) { if (event.getEvent() instanceof TaskAttemptFailedEvent) { TaskAttemptFailedEvent failedEvent = (TaskAttemptFailedEvent) event.getEvent(); if (failedEvent.getDiagnostics().startsWith(diagStart)) { if (diagContains != null) { if (failedEvent.getDiagnostics().contains(diagContains)) { assertEquals(taskFailureType, failedEvent.getTaskFailureType()); return; } else { fail("Diagnostic message does not contain expected message. Found [" + failedEvent.getDiagnostics() + "], Expected: [" + diagContains + "]"); } } } else { fail("Diagnostic message does not start with expected message. Found [" + failedEvent.getDiagnostics() + "], Expected: [" + diagStart + "]"); } } } fail("No TaskAttemptFailedEvents sent over umbilical"); } finally { umbilicalLock.unlock(); } } public void verifyTaskKilledEvent(String diagStart, String diagContains) { umbilicalLock.lock(); try { for (TezEvent event : requestEvents) { if (event.getEvent() instanceof TaskAttemptKilledEvent) { TaskAttemptKilledEvent killedEvent = (TaskAttemptKilledEvent) event.getEvent(); if (killedEvent.getDiagnostics().startsWith(diagStart)) { if (diagContains != null) { if (killedEvent.getDiagnostics().contains(diagContains)) { return; } else { fail("Diagnostic message does not contain expected message. Found [" + killedEvent.getDiagnostics() + "], Expected: [" + diagContains + "]"); } } } else { fail("Diagnostic message does not start with expected message. Found [" + killedEvent.getDiagnostics() + "], Expected: [" + diagStart + "]"); } } } fail("No TaskAttemptKilledEvents sent over umbilical"); } finally { umbilicalLock.unlock(); } } public void verifyTaskSuccessEvent() { umbilicalLock.lock(); try { for (TezEvent event : requestEvents) { if (event.getEvent() instanceof TaskAttemptCompletedEvent) { return; } } fail("No TaskAttemptFailedEvents sent over umbilical"); } finally { umbilicalLock.unlock(); } } @Override public long getProtocolVersion(String protocol, long clientVersion) throws IOException { return 0; } @Override public ProtocolSignature getProtocolSignature(String protocol, long clientVersion, int clientMethodsHash) throws IOException { return null; } @Override public ContainerTask getTask(ContainerContext containerContext) throws IOException { // Return shouldDie = true getTaskInvocations++; return new ContainerTask(null, true, null, null, false); } @Override public boolean canCommit(TezTaskAttemptID taskid) throws IOException { return true; } @Override public TezHeartbeatResponse heartbeat(TezHeartbeatRequest request) throws IOException, TezException { umbilicalLock.lock(); if (request.getEvents() != null) { requestEvents.addAll(request.getEvents()); } try { if (shouldThrowException) { LOG.info("TestUmbilical throwing Exception"); throw new IOException(HEARTBEAT_EXCEPTION_STRING); } TezHeartbeatResponse response = new TezHeartbeatResponse(); response.setLastRequestId(request.getRequestId()); if (shouldSendDieSignal) { LOG.info("TestUmbilical returning shouldDie=true"); response.setShouldDie(); } return response; } finally { if (pendingEvent) { eventEnacted = true; LOG.info("Signalling Event"); eventCondition.signal(); } umbilicalLock.unlock(); } } } @SuppressWarnings("deprecation") public static ContainerId createContainerId(ApplicationId appId) { ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(appId, 1); ContainerId containerId = ContainerId.newInstance(appAttemptId, 1); return containerId; } public static TaskReporter createTaskReporter(ApplicationId appId, TezTaskUmbilicalForTest umbilical) { TaskReporter taskReporter = new TaskReporter(umbilical, 100, 1000, 100, new AtomicLong(0), createContainerId(appId).toString()); return taskReporter; } }