/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.mapreduce.input;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapred.InputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RecordReader;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.DataSourceDescriptor;
import org.apache.tez.mapreduce.hadoop.MRInputHelpers;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.junit.Test;
public class TestMRInput {
@Test(timeout = 5000)
public void test0PhysicalInputs() throws IOException {
InputContext inputContext = mock(InputContext.class);
DataSourceDescriptor dsd = MRInput.createConfigBuilder(new Configuration(false),
FileInputFormat.class, "testPath").build();
ApplicationId applicationId = ApplicationId.newInstance(1000, 1);
doReturn(dsd.getInputDescriptor().getUserPayload()).when(inputContext).getUserPayload();
doReturn(applicationId).when(inputContext).getApplicationId();
doReturn("dagName").when(inputContext).getDAGName();
doReturn("vertexName").when(inputContext).getTaskVertexName();
doReturn("inputName").when(inputContext).getSourceVertexName();
doReturn("uniqueIdentifier").when(inputContext).getUniqueIdentifier();
doReturn(1).when(inputContext).getTaskIndex();
doReturn(1).when(inputContext).getTaskAttemptNumber();
doReturn(new TezCounters()).when(inputContext).getCounters();
MRInput mrInput = new MRInput(inputContext, 0);
mrInput.initialize();
mrInput.start();
assertFalse(mrInput.getReader().next());
verify(inputContext, times(1)).notifyProgress();
List<Event> events = new LinkedList<>();
try {
mrInput.handleEvents(events);
fail("HandleEvents should cause an input with 0 physical inputs to fail");
} catch (Exception e) {
assertTrue(e instanceof IllegalStateException);
}
}
private static final String TEST_ATTRIBUTES_DAG_NAME = "dagName";
private static final String TEST_ATTRIBUTES_VERTEX_NAME = "vertexName";
private static final String TEST_ATTRIBUTES_INPUT_NAME = "inputName";
private static final ApplicationId TEST_ATTRIBUTES_APPLICATION_ID = ApplicationId.newInstance(0, 0);
private static final String TEST_ATTRIBUTES_UNIQUE_IDENTIFIER = "uniqueId";
private static final int TEST_ATTRIBUTES_DAG_INDEX = 1000;
private static final int TEST_ATTRIBUTES_VERTEX_INDEX = 2000;
private static final int TEST_ATTRIBUTES_TASK_INDEX = 3000;
private static final int TEST_ATTRIBUTES_TASK_ATTEMPT_INDEX = 4000;
private static final int TEST_ATTRIBUTES_INPUT_INDEX = 5000;
private static final int TEST_ATTRIBUTES_DAG_ATTEMPT_NUMBER = 6000;
private static final String TEST_ATTRIBUTES_APPLICATION_ID_STRING = "application_0_0000";
private static final String TEST_ATTRIBUTES_DAG_ID_STRING = "dag_0_0000_1000";
private static final String TEST_ATTRIBUTES_VERTEX_ID_STRING = "vertex_0_0000_1000_2000";
private static final String TEST_ATTRIBUTES_TASK_ID_STRING = "task_0_0000_1000_2000_003000";
private static final String TEST_ATTRIBUTES_TASK_ATTEMPT_ID_STRING = "attempt_0_0000_1000_2000_003000_4000";
@Test(timeout = 5000)
public void testAttributesInJobConf() throws Exception {
InputContext inputContext = mock(InputContext.class);
doReturn(TEST_ATTRIBUTES_DAG_INDEX).when(inputContext).getDagIdentifier();
doReturn(TEST_ATTRIBUTES_VERTEX_INDEX).when(inputContext).getTaskVertexIndex();
doReturn(TEST_ATTRIBUTES_TASK_INDEX).when(inputContext).getTaskIndex();
doReturn(TEST_ATTRIBUTES_TASK_ATTEMPT_INDEX).when(inputContext).getTaskAttemptNumber();
doReturn(TEST_ATTRIBUTES_INPUT_INDEX).when(inputContext).getInputIndex();
doReturn(TEST_ATTRIBUTES_DAG_ATTEMPT_NUMBER).when(inputContext).getDAGAttemptNumber();
doReturn(TEST_ATTRIBUTES_DAG_NAME).when(inputContext).getDAGName();
doReturn(TEST_ATTRIBUTES_VERTEX_NAME).when(inputContext).getTaskVertexName();
doReturn(TEST_ATTRIBUTES_INPUT_NAME).when(inputContext).getSourceVertexName();
doReturn(TEST_ATTRIBUTES_APPLICATION_ID).when(inputContext).getApplicationId();
doReturn(TEST_ATTRIBUTES_UNIQUE_IDENTIFIER).when(inputContext).getUniqueIdentifier();
DataSourceDescriptor dsd = MRInput.createConfigBuilder(new Configuration(false),
TestInputFormat.class).groupSplits(false).build();
doReturn(dsd.getInputDescriptor().getUserPayload()).when(inputContext).getUserPayload();
doReturn(new TezCounters()).when(inputContext).getCounters();
MRInput mrInput = new MRInput(inputContext, 1);
mrInput.initialize();
MRRuntimeProtos.MRSplitProto splitProto =
MRRuntimeProtos.MRSplitProto.newBuilder().setSplitClassName(TestInputSplit.class.getName())
.build();
InputDataInformationEvent diEvent = InputDataInformationEvent
.createWithSerializedPayload(0, splitProto.toByteString().asReadOnlyByteBuffer());
List<Event> events = new LinkedList<>();
events.add(diEvent);
mrInput.handleEvents(events);
TezCounter counter = mrInput.getContext().getCounters()
.findCounter(TaskCounter.INPUT_SPLIT_LENGTH_BYTES);
assertEquals(counter.getValue(), TestInputSplit.length);
assertTrue(TestInputFormat.invoked.get());
}
/**
* Test class to verify
*/
static class TestInputFormat implements InputFormat {
private static final AtomicBoolean invoked = new AtomicBoolean(false);
@Override
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
return null;
}
@Override
public RecordReader getRecordReader(InputSplit split, JobConf job, Reporter reporter) throws
IOException {
assertEquals(TEST_ATTRIBUTES_DAG_NAME, MRInputHelpers.getDagName(job));
assertEquals(TEST_ATTRIBUTES_VERTEX_NAME, MRInputHelpers.getVertexName(job));
assertEquals(TEST_ATTRIBUTES_INPUT_NAME, MRInputHelpers.getInputName(job));
assertEquals(TEST_ATTRIBUTES_DAG_INDEX, MRInputHelpers.getDagIndex(job));
assertEquals(TEST_ATTRIBUTES_VERTEX_INDEX, MRInputHelpers.getVertexIndex(job));
assertEquals(TEST_ATTRIBUTES_APPLICATION_ID.toString(), MRInputHelpers.getApplicationIdString(job));
assertEquals(TEST_ATTRIBUTES_UNIQUE_IDENTIFIER, MRInputHelpers.getUniqueIdentifier(job));
assertEquals(TEST_ATTRIBUTES_TASK_INDEX, MRInputHelpers.getTaskIndex(job));
assertEquals(TEST_ATTRIBUTES_TASK_ATTEMPT_INDEX, MRInputHelpers.getTaskAttemptIndex(job));
assertEquals(TEST_ATTRIBUTES_INPUT_INDEX, MRInputHelpers.getInputIndex(job));
assertEquals(TEST_ATTRIBUTES_DAG_ATTEMPT_NUMBER, MRInputHelpers.getDagAttemptNumber(job));
assertEquals(TEST_ATTRIBUTES_APPLICATION_ID_STRING, MRInputHelpers.getApplicationIdString(job));
assertEquals(TEST_ATTRIBUTES_DAG_ID_STRING, MRInputHelpers.getDagIdString(job));
assertEquals(TEST_ATTRIBUTES_VERTEX_ID_STRING, MRInputHelpers.getVertexIdString(job));
assertEquals(TEST_ATTRIBUTES_TASK_ID_STRING, MRInputHelpers.getTaskIdString(job));
assertEquals(TEST_ATTRIBUTES_TASK_ATTEMPT_ID_STRING, MRInputHelpers.getTaskAttemptIdString(job));
invoked.set(true);
return new RecordReader() {
@Override
public boolean next(Object key, Object value) throws IOException {
return false;
}
@Override
public Object createKey() {
return null;
}
@Override
public Object createValue() {
return null;
}
@Override
public long getPos() throws IOException {
return 0;
}
@Override
public void close() throws IOException {
}
@Override
public float getProgress() throws IOException {
return 0;
}
};
}
}
public static class TestInputSplit implements InputSplit {
public static long length = Math.abs(new Random().nextLong());
@Override
public long getLength() throws IOException {
return length;
}
@Override
public String[] getLocations() throws IOException {
return new String[0];
}
@Override
public void write(DataOutput out) throws IOException {
}
@Override
public void readFields(DataInput in) throws IOException {
}
}
}