/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.mapreduce.common;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.split.TezGroupedSplit;
import org.apache.tez.dag.api.DataSourceDescriptor;
import org.apache.tez.dag.api.UserPayload;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.mapreduce.TezTestUtils;
import org.apache.tez.mapreduce.input.MRInput;
import org.apache.tez.mapreduce.lib.MRInputUtils;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos.MRSplitProto;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputConfigureVertexTasksEvent;
import org.apache.tez.runtime.api.events.InputDataInformationEvent;
import org.junit.Test;
import com.google.protobuf.ByteString;
public class TestMRInputAMSplitGenerator {
private static String SPLITS_LENGTHS = "splits.length";
@Test(timeout = 5000)
public void testGroupSplitsDisabledSortSplitsEnabled()
throws Exception {
testGroupSplitsAndSortSplits(false, true);
}
@Test(timeout = 5000)
public void testGroupSplitsDisabledSortSplitsDisabled()
throws Exception {
testGroupSplitsAndSortSplits(false, false);
}
@Test(timeout = 5000)
public void testGroupSplitsEnabledSortSplitsEnabled()
throws Exception {
testGroupSplitsAndSortSplits(true, true);
}
@Test(timeout = 5000)
public void testGroupSplitsEnabledSortSplitsDisabled()
throws Exception {
testGroupSplitsAndSortSplits(true, false);
}
private void testGroupSplitsAndSortSplits(boolean groupSplitsEnabled,
boolean sortSplitsEnabled) throws Exception {
Configuration conf = new Configuration();
String[] splitLengths = new String[50];
for (int i = 0; i < splitLengths.length; i++) {
splitLengths[i] = Integer.toString(1000 * (i + 1));
}
conf.setStrings(SPLITS_LENGTHS, splitLengths);
DataSourceDescriptor dataSource = MRInput.createConfigBuilder(
conf, InputFormatForTest.class).
groupSplits(groupSplitsEnabled).sortSplits(sortSplitsEnabled).build();
UserPayload userPayload = dataSource.getInputDescriptor().getUserPayload();
InputInitializerContext context =
new TezTestUtils.TezRootInputInitializerContextForTest(userPayload);
MRInputAMSplitGenerator splitGenerator =
new MRInputAMSplitGenerator(context);
List<Event> events = splitGenerator.initialize();
assertTrue(events.get(0) instanceof InputConfigureVertexTasksEvent);
boolean shuffled = false;
InputSplit previousIs = null;
int numRawInputSplits = 0;
for (int i = 1; i < events.size(); i++) {
assertTrue(events.get(i) instanceof InputDataInformationEvent);
InputDataInformationEvent diEvent = (InputDataInformationEvent) (events.get(i));
assertNull(diEvent.getDeserializedUserPayload());
assertNotNull(diEvent.getUserPayload());
MRSplitProto eventProto = MRSplitProto.parseFrom(ByteString.copyFrom(
diEvent.getUserPayload()));
InputSplit is = MRInputUtils.getNewSplitDetailsFromEvent(
eventProto, new Configuration());
if (groupSplitsEnabled) {
numRawInputSplits += ((TezGroupedSplit)is).getGroupedSplits().size();
for (InputSplit inputSplit : ((TezGroupedSplit)is).getGroupedSplits()) {
assertTrue(inputSplit instanceof InputSplitForTest);
}
assertTrue(((TezGroupedSplit)is).getGroupedSplits().get(0)
instanceof InputSplitForTest);
} else {
numRawInputSplits++;
assertTrue(is instanceof InputSplitForTest);
}
// The splits in the list returned from InputFormat has ascending
// size in order.
// If sortSplitsEnabled is true, MRInputAMSplitGenerator will sort the
// splits in descending order.
// If sortSplitsEnabled is false, MRInputAMSplitGenerator will shuffle
// the splits.
if (previousIs != null) {
if (sortSplitsEnabled) {
assertTrue(is.getLength() <= previousIs.getLength());
} else {
shuffled |= (is.getLength() > previousIs.getLength());
}
}
previousIs = is;
}
assertEquals(splitLengths.length, numRawInputSplits);
if (!sortSplitsEnabled) {
assertTrue(shuffled);
}
}
private static class InputFormatForTest
extends InputFormat<IntWritable, IntWritable> {
@Override
public RecordReader<IntWritable, IntWritable> createRecordReader(
org.apache.hadoop.mapreduce.InputSplit split,
TaskAttemptContext context) throws IOException,
InterruptedException {
return new RecordReader<IntWritable, IntWritable>() {
private boolean done = false;
@Override
public void close() throws IOException {
}
@Override
public IntWritable getCurrentKey() throws IOException,
InterruptedException {
return new IntWritable(0);
}
@Override
public IntWritable getCurrentValue() throws IOException,
InterruptedException {
return new IntWritable(0);
}
@Override
public float getProgress() throws IOException, InterruptedException {
return done ? 0 : 1;
}
@Override
public void initialize(org.apache.hadoop.mapreduce.InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
}
@Override
public boolean nextKeyValue() throws IOException, InterruptedException {
if (!done) {
done = true;
return true;
}
return false;
}
};
}
@Override
public List<org.apache.hadoop.mapreduce.InputSplit> getSplits(
JobContext context) throws IOException, InterruptedException {
List<org.apache.hadoop.mapreduce.InputSplit> list = new ArrayList<org.apache.hadoop.mapreduce.InputSplit>();
int[] lengths = context.getConfiguration().getInts(SPLITS_LENGTHS);
for (int i = 0; i < lengths.length; i++) {
list.add(new InputSplitForTest(i + 1, lengths[i]));
}
return list;
}
}
@Private
public static class InputSplitForTest extends InputSplit
implements Writable {
private int identifier;
private int length;
@SuppressWarnings("unused")
public InputSplitForTest() {
// For writable
}
public int getIdentifier() {
return this.identifier;
}
public InputSplitForTest(int identifier, int length) {
this.identifier = identifier;
this.length = length;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(identifier);
out.writeInt(length);
}
@Override
public void readFields(DataInput in) throws IOException {
identifier = in.readInt();
length = in.readInt();
}
@Override
public long getLength() throws IOException {
return length;
}
@Override
public String[] getLocations() throws IOException {
return new String[] {"localhost"};
}
}
}