/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * <p/> * http://www.apache.org/licenses/LICENSE-2.0 * <p/> * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.mapreduce.lib; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.RecordReader; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.tez.common.counters.TaskCounter; import org.apache.tez.common.counters.TezCounter; import org.apache.tez.common.counters.TezCounters; import org.apache.tez.runtime.api.InputContext; import org.junit.Before; import org.junit.Test; import java.io.IOException; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; public class TestKVReadersWithMR { private JobConf conf; private TezCounters counters; private TezCounter inputRecordCounter; @Before public void setup() { conf = new JobConf(); counters = new TezCounters(); inputRecordCounter = counters.findCounter(TaskCounter.INPUT_RECORDS_PROCESSED); } @Test(timeout = 10000) public void testMRReaderMapred() throws IOException { //empty testWithSpecificNumberOfKV(0); testWithSpecificNumberOfKV(10); //empty testWithSpecificNumberOfKV_MapReduce(0); testWithSpecificNumberOfKV_MapReduce(10); } public void testWithSpecificNumberOfKV(int kvPairs) throws IOException { InputContext mockContext = mock(InputContext.class); MRReaderMapred reader = new MRReaderMapred(conf, counters, inputRecordCounter, mockContext); reader.recordReader = new DummyRecordReader(kvPairs); int records = 0; while (reader.next()) { records++; verify(mockContext, times(records)).notifyProgress(); } assertTrue(kvPairs == records); //reading again should fail try { boolean hasNext = reader.next(); fail(); } catch (IOException e) { assertTrue(e.getMessage().contains("For usage, please refer to")); } } public void testWithSpecificNumberOfKV_MapReduce(int kvPairs) throws IOException { InputContext mockContext = mock(InputContext.class); MRReaderMapReduce reader = new MRReaderMapReduce(conf, counters, inputRecordCounter, -1, 1, 10, 20, 30, mockContext); reader.recordReader = new DummyRecordReaderMapReduce(kvPairs); int records = 0; while (reader.next()) { records++; verify(mockContext, times(records)).notifyProgress(); } assertTrue(kvPairs == records); //reading again should fail try { boolean hasNext = reader.next(); fail(); } catch (IOException e) { assertTrue(e.getMessage().contains("For usage, please refer to")); } } static class DummyRecordReader implements RecordReader { int records; public DummyRecordReader(int records) { this.records = records; } @Override public boolean next(Object o, Object o2) throws IOException { return (records-- > 0); } @Override public Object createKey() { return null; } @Override public Object createValue() { return null; } @Override public long getPos() throws IOException { return 0; } @Override public void close() throws IOException { } @Override public float getProgress() throws IOException { return 0; } } static class DummyRecordReaderMapReduce extends org.apache.hadoop.mapreduce.RecordReader { int records; public DummyRecordReaderMapReduce(int records) { this.records = records; } @Override public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException { } @Override public boolean nextKeyValue() throws IOException, InterruptedException { return (records-- > 0); } @Override public Object getCurrentKey() throws IOException, InterruptedException { return null; } @Override public Object getCurrentValue() throws IOException, InterruptedException { return null; } @Override public float getProgress() throws IOException, InterruptedException { return 0; } @Override public void close() throws IOException { } } }