package com.linkedin.parseq.batching;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertTrue;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.testng.annotations.Test;
import com.linkedin.parseq.BaseEngineTest;
import com.linkedin.parseq.EngineBuilder;
import com.linkedin.parseq.Task;
import com.linkedin.parseq.function.Failure;
import com.linkedin.parseq.function.Success;
import com.linkedin.parseq.function.Try;
public class TestTaskBatchingStrategy extends BaseEngineTest {
private final BatchingSupport _batchingSupport = new BatchingSupport();
@Override
protected void customizeEngine(EngineBuilder engineBuilder) {
engineBuilder.setPlanDeactivationListener(_batchingSupport);
}
@Test
public void testBatchInvoked() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> 0);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1))
.map("concat", (s0, s1) -> s0 + s1);
String result = runAndWait("TestTaskBatchingStrategy.testBatchInvoked", task);
assertEquals(result, "01");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertEquals(strategy.getExecutedBatches().size(), 1);
}
@Test
public void testSingletonsInvoked() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1))
.map("concat", (s0, s1) -> s0 + s1);
String result = runAndWait("TestTaskBatchingStrategy.testSingletonsInvoked", task);
assertEquals(result, "01");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertEquals(strategy.getExecutedBatches().size(), 0);
assertEquals(strategy.getExecutedSingletons().size(), 2);
}
@Test
public void testBatchAndSingleton() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1), strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testBatchAndSingleton", task);
assertEquals(result, "012");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testBatchAndFailedSingleton() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> {
if (key % 2 == 0) {
return Success.of(String.valueOf(key));
} else {
return Failure.of(new Exception());
}
}, key -> key % 2);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1).recover(e -> "failed"), strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testBatchAndFailedSingleton", task);
assertEquals(result, "0failed2");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testFailedBatchAndSingleton() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> {
if (key % 2 == 1) {
return Success.of(String.valueOf(key));
} else {
return Failure.of(new Exception());
}
}, key -> key % 2);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0).recover(e -> "failed"), strategy.batchable(1), strategy.batchable(2).recover(e -> "failed"))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testFailedBatchAndSingleton", task);
assertEquals(result, "failed1failed");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testClassifyFailure() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key / key);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0).recover(e -> "failed"), strategy.batchable(1).recover(e -> "failed"))
.map("concat", (s0, s1) -> s0 + s1);
String result = runAndWait("TestTaskBatchingStrategy.testClassifyFailure", task);
assertEquals(result, "failed1");
assertEquals(strategy.getExecutedBatches().size(), 0);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testExecuteBatchFailure() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2) {
@Override
public Task<Map<Integer, Try<String>>> taskForBatch(Integer group, Set<Integer> keys) {
throw new RuntimeException();
};
};
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0).recover(e -> "failed"), strategy.batchable(1).recover(e -> "failed"), strategy.batchable(2).recover(e -> "failed"))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testExecuteBatchFailure", task);
assertEquals(result, "failedfailedfailed");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 0);
assertEquals(strategy.getExecutedSingletons().size(), 0);
}
@Test
public void testNothingToDoForStrategy() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> 0);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(Task.value("0"), Task.value("1"))
.map("concat", (s0, s1) -> s0 + s1);
String result = runAndWait("TestTaskBatchingStrategy.testNothingToDoForStrategy", task);
assertEquals(result, "01");
assertEquals(strategy.getClassifiedKeys().size(), 0);
assertEquals(strategy.getExecutedBatches().size(), 0);
assertEquals(strategy.getExecutedSingletons().size(), 0);
}
@Test
public void testDeduplication() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1), strategy.batchable(2),
strategy.batchable(0), strategy.batchable(1), strategy.batchable(2))
.map("concat", (s0, s1, s2, s3, s4, s5) -> s0 + s1 + s2 + s3 + s4 + s5);
String result = runAndWait("TestTaskBatchingStrategy.testDeduplication", task);
assertEquals(result, "012012");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testBatchWithTimeoutAndSingleton() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2) {
@Override
public Task<Map<Integer, Try<String>>> taskForBatch(Integer group, Set<Integer> keys) {
return super.taskForBatch(group, keys).flatMap(map -> delayedValue(map, 250, TimeUnit.MILLISECONDS));
}
};
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0).withTimeout(10, TimeUnit.MILLISECONDS).recover("toExceptionName", e -> e.getClass().getName()),
strategy.batchable(1), strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testBatchWithTimeoutAndSingleton", task);
assertEquals(result, "java.util.concurrent.TimeoutException12");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testBatchAndSingletonWithTimeout() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2) {
@Override
public Task<Map<Integer, Try<String>>> taskForBatch(Integer group, Set<Integer> keys) {
return super.taskForBatch(group, keys).flatMap(map -> delayedValue(map, 250, TimeUnit.MILLISECONDS));
}
};
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0),
strategy.batchable(1).withTimeout(10, TimeUnit.MILLISECONDS).recover("toExceptionName", e -> e.getClass().getName()),
strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testBatchAndSingletonWithTimeout", task);
assertEquals(result, "0java.util.concurrent.TimeoutException2");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testEntriesMissingInReturnedMap() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> Success.of(String.valueOf(key)), key -> key % 2) {
@Override
public Task<Map<Integer, Try<String>>> taskForBatch(Integer group, Set<Integer> keys) {
return super.taskForBatch(group, keys).andThen(map -> map.remove(1));
}
};
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1).recover(e -> "missing"), strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testEntriesMissingInReturnedMap", task);
assertEquals(result, "0missing2");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
@Test
public void testFailureReturned() {
RecordingTaskStrategy<Integer, Integer, String> strategy =
new RecordingTaskStrategy<Integer, Integer, String>(key -> {
if (key % 2 == 1) {
return Failure.of(new Exception("failure message"));
} else {
return Success.of(String.valueOf(key));
}
}, key -> key % 2);
_batchingSupport.registerStrategy(strategy);
Task<String> task = Task.par(strategy.batchable(0), strategy.batchable(1).recover(e -> e.getMessage()), strategy.batchable(2))
.map("concat", (s0, s1, s2) -> s0 + s1 + s2);
String result = runAndWait("TestTaskBatchingStrategy.testFailureReturned", task);
assertEquals(result, "0failure message2");
assertTrue(strategy.getClassifiedKeys().contains(0));
assertTrue(strategy.getClassifiedKeys().contains(1));
assertTrue(strategy.getClassifiedKeys().contains(2));
assertEquals(strategy.getExecutedBatches().size(), 1);
assertEquals(strategy.getExecutedSingletons().size(), 1);
}
}