/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.task;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import com.google.common.collect.Lists;
import org.apache.tez.common.TezTaskUmbilicalProtocol;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.LogicalIOProcessorRuntimeTask;
import org.apache.tez.runtime.api.events.TaskStatusUpdateEvent;
import org.apache.tez.runtime.api.impl.TaskStatistics;
import org.apache.tez.runtime.api.impl.TezEvent;
import org.apache.tez.runtime.api.impl.TezHeartbeatRequest;
import org.apache.tez.runtime.api.impl.TezHeartbeatResponse;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@SuppressWarnings("rawtypes")
public class TestTaskReporter {
@Test(timeout = 10000)
public void testContinuousHeartbeatsOnMaxEvents() throws Exception {
final Object lock = new Object();
final AtomicBoolean hb2Done = new AtomicBoolean(false);
TezTaskUmbilicalProtocol mockUmbilical = mock(TezTaskUmbilicalProtocol.class);
doAnswer(new Answer() {
@Override
public Object answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
TezHeartbeatRequest request = (TezHeartbeatRequest) args[0];
if (request.getRequestId() == 1 || request.getRequestId() == 2) {
TezHeartbeatResponse response = new TezHeartbeatResponse(createEvents(5));
response.setLastRequestId(request.getRequestId());
return response;
} else if (request.getRequestId() == 3) {
TezHeartbeatResponse response = new TezHeartbeatResponse(createEvents(1));
response.setLastRequestId(request.getRequestId());
synchronized (lock) {
hb2Done.set(true);
lock.notify();
}
return response;
} else {
throw new TezUncheckedException("Invalid request id for test: " + request.getRequestId());
}
}
}).when(mockUmbilical).heartbeat(any(TezHeartbeatRequest.class));
TezTaskAttemptID mockTaskAttemptId = mock(TezTaskAttemptID.class);
LogicalIOProcessorRuntimeTask mockTask = mock(LogicalIOProcessorRuntimeTask.class);
doReturn("vertexName").when(mockTask).getVertexName();
doReturn(mockTaskAttemptId).when(mockTask).getTaskAttemptID();
// Setup the sleep time to be way higher than the test timeout
TaskReporter.HeartbeatCallable heartbeatCallable =
new TaskReporter.HeartbeatCallable(mockTask, mockUmbilical, 100000, 100000, 5,
new AtomicLong(0),
"containerIdStr");
ExecutorService executor = Executors.newSingleThreadExecutor();
executor.submit(heartbeatCallable);
try {
synchronized (lock) {
if (!hb2Done.get()) {
lock.wait();
}
}
verify(mockUmbilical, times(3)).heartbeat(any(TezHeartbeatRequest.class));
Thread.sleep(2000l);
// Sleep for 2 seconds, less than the callable sleep time. No more invocations.
verify(mockUmbilical, times(3)).heartbeat(any(TezHeartbeatRequest.class));
} finally {
executor.shutdownNow();
}
}
@Test(timeout = 10000)
public void testEventThrottling() throws Exception {
TezTaskAttemptID mockTaskAttemptId = mock(TezTaskAttemptID.class);
LogicalIOProcessorRuntimeTask mockTask = mock(LogicalIOProcessorRuntimeTask.class);
when(mockTask.getMaxEventsToHandle()).thenReturn(10000, 1);
when(mockTask.getVertexName()).thenReturn("vertexName");
when(mockTask.getTaskAttemptID()).thenReturn(mockTaskAttemptId);
TezTaskUmbilicalProtocol mockUmbilical = mock(TezTaskUmbilicalProtocol.class);
TezHeartbeatResponse resp1 = new TezHeartbeatResponse(createEvents(5));
resp1.setLastRequestId(1);
TezHeartbeatResponse resp2 = new TezHeartbeatResponse(createEvents(1));
resp2.setLastRequestId(2);
resp2.setShouldDie();
when(mockUmbilical.heartbeat(isA(TezHeartbeatRequest.class))).thenReturn(resp1, resp2);
// Setup the sleep time to be way higher than the test timeout
TaskReporter.HeartbeatCallable heartbeatCallable =
new TaskReporter.HeartbeatCallable(mockTask, mockUmbilical, 100000, 100000, 5,
new AtomicLong(0),
"containerIdStr");
ExecutorService executor = Executors.newSingleThreadExecutor();
try {
Future<Boolean> result = executor.submit(heartbeatCallable);
Assert.assertFalse(result.get());
} finally {
executor.shutdownNow();
}
ArgumentCaptor<TezHeartbeatRequest> captor = ArgumentCaptor.forClass(TezHeartbeatRequest.class);
verify(mockUmbilical, times(2)).heartbeat(captor.capture());
TezHeartbeatRequest req = captor.getValue();
Assert.assertEquals(2, req.getRequestId());
Assert.assertEquals(1, req.getMaxEvents());
}
@Test (timeout=5000)
public void testStatusUpdateAfterInitializationAndCounterFlag() {
TezTaskAttemptID mockTaskAttemptId = mock(TezTaskAttemptID.class);
LogicalIOProcessorRuntimeTask mockTask = mock(LogicalIOProcessorRuntimeTask.class);
doReturn("vertexName").when(mockTask).getVertexName();
doReturn(mockTaskAttemptId).when(mockTask).getTaskAttemptID();
boolean progressNotified = false;
doReturn(progressNotified).when(mockTask).getAndClearProgressNotification();
TezTaskUmbilicalProtocol mockUmbilical = mock(TezTaskUmbilicalProtocol.class);
float progress = 0.5f;
TaskStatistics stats = new TaskStatistics();
TezCounters counters = new TezCounters();
doReturn(progress).when(mockTask).getProgress();
doReturn(stats).when(mockTask).getTaskStatistics();
doReturn(counters).when(mockTask).getCounters();
// Setup the sleep time to be way higher than the test timeout
TaskReporter.HeartbeatCallable heartbeatCallable =
new TaskReporter.HeartbeatCallable(mockTask, mockUmbilical, 100000, 100000, 5,
new AtomicLong(0),
"containerIdStr");
// task not initialized - nothing obtained from task
doReturn(false).when(mockTask).hasInitialized();
TaskStatusUpdateEvent event = heartbeatCallable.getStatusUpdateEvent(true);
verify(mockTask, times(1)).hasInitialized();
verify(mockTask, times(0)).getProgress();
verify(mockTask, times(0)).getAndClearProgressNotification();
verify(mockTask, times(0)).getTaskStatistics();
verify(mockTask, times(0)).getCounters();
Assert.assertEquals(0, event.getProgress(), 0);
Assert.assertEquals(false, event.getProgressNotified());
Assert.assertNull(event.getCounters());
Assert.assertNull(event.getStatistics());
// task is initialized - progress obtained but not counters since flag is false
doReturn(true).when(mockTask).hasInitialized();
event = heartbeatCallable.getStatusUpdateEvent(false);
verify(mockTask, times(2)).hasInitialized();
verify(mockTask, times(1)).getProgress();
verify(mockTask, times(1)).getAndClearProgressNotification();
verify(mockTask, times(0)).getTaskStatistics();
verify(mockTask, times(0)).getCounters();
Assert.assertEquals(progress, event.getProgress(), 0);
Assert.assertEquals(progressNotified, event.getProgressNotified());
Assert.assertNull(event.getCounters());
Assert.assertNull(event.getStatistics());
// task is initialized - progress obtained and also counters since flag is true
progressNotified = true;
doReturn(progressNotified).when(mockTask).getAndClearProgressNotification();
doReturn(true).when(mockTask).hasInitialized();
event = heartbeatCallable.getStatusUpdateEvent(true);
verify(mockTask, times(3)).hasInitialized();
verify(mockTask, times(2)).getProgress();
verify(mockTask, times(2)).getAndClearProgressNotification();
verify(mockTask, times(1)).getTaskStatistics();
verify(mockTask, times(1)).getCounters();
Assert.assertEquals(progress, event.getProgress(), 0);
Assert.assertEquals(progressNotified, event.getProgressNotified());
Assert.assertEquals(counters, event.getCounters());
Assert.assertEquals(stats, event.getStatistics());
}
private List<TezEvent> createEvents(int numEvents) {
List<TezEvent> list = Lists.newArrayListWithCapacity(numEvents);
for (int i = 0; i < numEvents; i++) {
list.add(mock(TezEvent.class));
}
return list;
}
}