/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.dag.app; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import java.io.IOException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.security.Credentials; import org.apache.hadoop.yarn.api.records.ApplicationAccessType; import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.Container; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.api.records.NodeId; import org.apache.hadoop.yarn.event.Event; import org.apache.hadoop.yarn.event.EventHandler; import org.apache.hadoop.yarn.util.SystemClock; import org.apache.tez.common.TezUtils; import org.apache.tez.dag.api.NamedEntityDescriptor; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.dag.api.TezConstants; import org.apache.tez.dag.api.TezException; import org.apache.tez.dag.api.UserPayload; import org.apache.tez.dag.app.dag.Vertex; import org.apache.tez.dag.records.TezDAGID; import org.apache.tez.dag.records.TezTaskID; import org.apache.tez.dag.records.TezVertexID; import org.apache.tez.runtime.api.TaskFailureType; import org.apache.tez.runtime.api.events.TaskAttemptFailedEvent; import org.apache.tez.runtime.api.impl.EventMetaData; import org.apache.tez.runtime.api.impl.TezEvent; import org.apache.tez.serviceplugins.api.TaskAttemptEndReason; import org.apache.tez.dag.app.dag.DAG; import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptFailed; import org.apache.tez.dag.app.dag.event.TaskAttemptEventAttemptKilled; import org.apache.tez.dag.app.rm.container.AMContainer; import org.apache.tez.dag.app.rm.container.AMContainerMap; import org.apache.tez.dag.app.rm.container.AMContainerTask; import org.apache.tez.dag.records.TaskAttemptTerminationCause; import org.apache.tez.dag.records.TezTaskAttemptID; import org.apache.tez.runtime.api.impl.TaskSpec; import org.apache.tez.serviceplugins.api.TaskHeartbeatRequest; import org.junit.Test; import org.mockito.ArgumentCaptor; public class TestTaskCommunicatorManager2 { @SuppressWarnings("unchecked") @Test(timeout = 5000) public void testTaskAttemptFailedKilled() throws IOException, TezException { TaskCommunicatorManagerWrapperForTest wrapper = new TaskCommunicatorManagerWrapperForTest(); TaskSpec taskSpec1 = wrapper.createTaskSpec(); AMContainerTask amContainerTask1 = new AMContainerTask(taskSpec1, null, null, false, 10); TaskSpec taskSpec2 = wrapper.createTaskSpec(); AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec2, null, null, false, 10); ContainerId containerId1 = wrapper.createContainerId(1); wrapper.registerRunningContainer(containerId1); wrapper.registerTaskAttempt(containerId1, amContainerTask1); ContainerId containerId2 = wrapper.createContainerId(2); wrapper.registerRunningContainer(containerId2); wrapper.registerTaskAttempt(containerId2, amContainerTask2); wrapper.getTaskCommunicatorManager().taskFailed(amContainerTask1.getTask().getTaskAttemptID(), TaskFailureType.NON_FATAL, TaskAttemptEndReason.COMMUNICATION_ERROR, "Diagnostics1"); wrapper.getTaskCommunicatorManager().taskKilled(amContainerTask2.getTask().getTaskAttemptID(), TaskAttemptEndReason.EXECUTOR_BUSY, "Diagnostics2"); ArgumentCaptor<Event> argumentCaptor = ArgumentCaptor.forClass(Event.class); verify(wrapper.getEventHandler(), times(2)).handle(argumentCaptor.capture()); assertTrue(argumentCaptor.getAllValues().get(0) instanceof TaskAttemptEventAttemptFailed); assertTrue(argumentCaptor.getAllValues().get(1) instanceof TaskAttemptEventAttemptKilled); TaskAttemptEventAttemptFailed failedEvent = (TaskAttemptEventAttemptFailed) argumentCaptor.getAllValues().get(0); TaskAttemptEventAttemptKilled killedEvent = (TaskAttemptEventAttemptKilled) argumentCaptor.getAllValues().get(1); assertEquals("Diagnostics1", failedEvent.getDiagnosticInfo()); assertEquals(TaskAttemptTerminationCause.COMMUNICATION_ERROR, failedEvent.getTerminationCause()); assertEquals("Diagnostics2", killedEvent.getDiagnosticInfo()); assertEquals(TaskAttemptTerminationCause.SERVICE_BUSY, killedEvent.getTerminationCause()); // TODO TEZ-2003. Verify unregistration from the registered list } // Tests fatal and non fatal @SuppressWarnings("unchecked") @Test(timeout = 5000) public void testTaskAttemptFailureViaHeartbeat() throws IOException, TezException { TaskCommunicatorManagerWrapperForTest wrapper = new TaskCommunicatorManagerWrapperForTest(); TaskSpec taskSpec1 = wrapper.createTaskSpec(); AMContainerTask amContainerTask1 = new AMContainerTask(taskSpec1, null, null, false, 10); TaskSpec taskSpec2 = wrapper.createTaskSpec(); AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec2, null, null, false, 10); ContainerId containerId1 = wrapper.createContainerId(1); wrapper.registerRunningContainer(containerId1); wrapper.registerTaskAttempt(containerId1, amContainerTask1); ContainerId containerId2 = wrapper.createContainerId(2); wrapper.registerRunningContainer(containerId2); wrapper.registerTaskAttempt(containerId2, amContainerTask2); List<TezEvent> events = new LinkedList<>(); EventMetaData sourceInfo1 = new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, "testVertex", null, taskSpec1.getTaskAttemptID()); TaskAttemptFailedEvent failedEvent1 = new TaskAttemptFailedEvent("non-fatal test error", TaskFailureType.NON_FATAL); TezEvent failedEventT1 = new TezEvent(failedEvent1, sourceInfo1); events.add(failedEventT1); TaskHeartbeatRequest taskHeartbeatRequest1 = new TaskHeartbeatRequest(containerId1.toString(), taskSpec1.getTaskAttemptID(), events, 0, 0, 0); wrapper.getTaskCommunicatorManager().heartbeat(taskHeartbeatRequest1); ArgumentCaptor<Event> argumentCaptor = ArgumentCaptor.forClass(Event.class); verify(wrapper.getEventHandler(), times(1)).handle(argumentCaptor.capture()); assertTrue(argumentCaptor.getAllValues().get(0) instanceof TaskAttemptEventAttemptFailed); TaskAttemptEventAttemptFailed failedEvent = (TaskAttemptEventAttemptFailed) argumentCaptor.getAllValues().get(0); assertEquals(TaskFailureType.NON_FATAL, failedEvent.getTaskFailureType()); assertTrue(failedEvent.getDiagnosticInfo().contains("non-fatal")); events.clear(); reset(wrapper.getEventHandler()); EventMetaData sourceInfo2 = new EventMetaData(EventMetaData.EventProducerConsumerType.PROCESSOR, "testVertex", null, taskSpec2.getTaskAttemptID()); TaskAttemptFailedEvent failedEvent2 = new TaskAttemptFailedEvent("-fatal- test error", TaskFailureType.FATAL); TezEvent failedEventT2 = new TezEvent(failedEvent2, sourceInfo2); events.add(failedEventT2); TaskHeartbeatRequest taskHeartbeatRequest2 = new TaskHeartbeatRequest(containerId2.toString(), taskSpec2.getTaskAttemptID(), events, 0, 0, 0); wrapper.getTaskCommunicatorManager().heartbeat(taskHeartbeatRequest2); argumentCaptor = ArgumentCaptor.forClass(Event.class); verify(wrapper.getEventHandler(), times(1)).handle(argumentCaptor.capture()); assertTrue(argumentCaptor.getAllValues().get(0) instanceof TaskAttemptEventAttemptFailed); failedEvent = (TaskAttemptEventAttemptFailed) argumentCaptor.getAllValues().get(0); assertEquals(TaskFailureType.FATAL, failedEvent.getTaskFailureType()); assertTrue(failedEvent.getDiagnosticInfo().contains("-fatal-")); } // Tests fatal and non fatal @SuppressWarnings("unchecked") @Test(timeout = 5000) public void testTaskAttemptFailureViaContext() throws IOException, TezException { TaskCommunicatorManagerWrapperForTest wrapper = new TaskCommunicatorManagerWrapperForTest(); TaskSpec taskSpec1 = wrapper.createTaskSpec(); AMContainerTask amContainerTask1 = new AMContainerTask(taskSpec1, null, null, false, 10); TaskSpec taskSpec2 = wrapper.createTaskSpec(); AMContainerTask amContainerTask2 = new AMContainerTask(taskSpec2, null, null, false, 10); ContainerId containerId1 = wrapper.createContainerId(1); wrapper.registerRunningContainer(containerId1); wrapper.registerTaskAttempt(containerId1, amContainerTask1); ContainerId containerId2 = wrapper.createContainerId(2); wrapper.registerRunningContainer(containerId2); wrapper.registerTaskAttempt(containerId2, amContainerTask2); // non-fatal wrapper.getTaskCommunicatorManager() .taskFailed(taskSpec1.getTaskAttemptID(), TaskFailureType.NON_FATAL, TaskAttemptEndReason.CONTAINER_EXITED, "--non-fatal--"); ArgumentCaptor<Event> argumentCaptor = ArgumentCaptor.forClass(Event.class); verify(wrapper.getEventHandler(), times(1)).handle(argumentCaptor.capture()); assertTrue(argumentCaptor.getAllValues().get(0) instanceof TaskAttemptEventAttemptFailed); TaskAttemptEventAttemptFailed failedEvent = (TaskAttemptEventAttemptFailed) argumentCaptor.getAllValues().get(0); assertEquals(TaskFailureType.NON_FATAL, failedEvent.getTaskFailureType()); assertTrue(failedEvent.getDiagnosticInfo().contains("--non-fatal--")); reset(wrapper.getEventHandler()); // fatal wrapper.getTaskCommunicatorManager() .taskFailed(taskSpec2.getTaskAttemptID(), TaskFailureType.FATAL, TaskAttemptEndReason.OTHER, "--fatal--"); argumentCaptor = ArgumentCaptor.forClass(Event.class); verify(wrapper.getEventHandler(), times(1)).handle(argumentCaptor.capture()); assertTrue(argumentCaptor.getAllValues().get(0) instanceof TaskAttemptEventAttemptFailed); failedEvent = (TaskAttemptEventAttemptFailed) argumentCaptor.getAllValues().get(0); assertEquals(TaskFailureType.FATAL, failedEvent.getTaskFailureType()); assertTrue(failedEvent.getDiagnosticInfo().contains("--fatal--")); } @SuppressWarnings("unchecked") private static class TaskCommunicatorManagerWrapperForTest { ApplicationId appId = ApplicationId.newInstance(1000, 1); ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(appId, 1); Credentials credentials = new Credentials(); AppContext appContext = mock(AppContext.class); EventHandler eventHandler = mock(EventHandler.class); DAG dag = mock(DAG.class); Vertex vertex = mock(Vertex.class); TezDAGID dagId; TezVertexID vertexId; AMContainerMap amContainerMap = mock(AMContainerMap.class); Map<ApplicationAccessType, String> appAcls = new HashMap<ApplicationAccessType, String>(); Configuration conf = new TezConfiguration(); UserPayload userPayload; TaskCommunicatorManager taskCommunicatorManager; private AtomicInteger taskIdCounter = new AtomicInteger(0); TaskCommunicatorManagerWrapperForTest() throws IOException, TezException { dagId = TezDAGID.getInstance(appId, 1); vertexId = TezVertexID.getInstance(dagId, 100); doReturn(eventHandler).when(appContext).getEventHandler(); doReturn(dag).when(appContext).getCurrentDAG(); doReturn(vertex).when(dag).getVertex(eq(vertexId)); doReturn(new TaskAttemptEventInfo(0, new LinkedList<TezEvent>(), 0)).when(vertex) .getTaskAttemptTezEvents(any(TezTaskAttemptID.class), anyInt(), anyInt(), anyInt()); doReturn(appAttemptId).when(appContext).getApplicationAttemptId(); doReturn(credentials).when(appContext).getAppCredentials(); doReturn(appAcls).when(appContext).getApplicationACLs(); doReturn(amContainerMap).when(appContext).getAllContainers(); doReturn(new SystemClock()).when(appContext).getClock(); NodeId nodeId = NodeId.newInstance("localhost", 0); AMContainer amContainer = mock(AMContainer.class); Container container = mock(Container.class); doReturn(nodeId).when(container).getNodeId(); doReturn(amContainer).when(amContainerMap).get(any(ContainerId.class)); doReturn(container).when(amContainer).getContainer(); userPayload = TezUtils.createUserPayloadFromConf(conf); taskCommunicatorManager = new TaskCommunicatorManager(appContext, mock(TaskHeartbeatHandler.class), mock(ContainerHeartbeatHandler.class), Lists.newArrayList(new NamedEntityDescriptor( TezConstants.getTezYarnServicePluginName(), null).setUserPayload(userPayload))); } TaskCommunicatorManager getTaskCommunicatorManager() { return taskCommunicatorManager; } EventHandler getEventHandler() { return eventHandler; } private void registerRunningContainer(ContainerId containerId) { taskCommunicatorManager.registerRunningContainer(containerId, 0); } private void registerTaskAttempt(ContainerId containerId, AMContainerTask amContainerTask) { taskCommunicatorManager.registerTaskAttempt(amContainerTask, containerId, 0); } private TaskSpec createTaskSpec() { TaskSpec taskSpec = mock(TaskSpec.class); TezTaskID taskId = TezTaskID.getInstance(vertexId, taskIdCounter.incrementAndGet()); TezTaskAttemptID taskAttemptId = TezTaskAttemptID.getInstance(taskId, 0); doReturn(taskAttemptId).when(taskSpec).getTaskAttemptID(); return taskSpec; } @SuppressWarnings("deprecation") private ContainerId createContainerId(int id) { ApplicationAttemptId appAttemptId = ApplicationAttemptId.newInstance(appId, 1); ContainerId containerId = ContainerId.newInstance(appAttemptId, id); return containerId; } } }