package com.devicehive.shim.kafka.test;
/*
* #%L
* DeviceHive Shim Kafka Implementation
* %%
* Copyright (C) 2016 DataArt
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
import com.devicehive.json.adapters.RuntimeTypeAdapterFactory;
import com.devicehive.test.rule.KafkaEmbeddedRule;
import com.devicehive.shim.api.Body;
import com.devicehive.shim.api.Request;
import com.devicehive.shim.api.Response;
import com.devicehive.shim.api.client.RpcClient;
import com.devicehive.shim.api.server.RequestHandler;
import com.devicehive.shim.api.server.RpcServer;
import com.devicehive.shim.kafka.builder.ClientBuilder;
import com.devicehive.shim.kafka.builder.ServerBuilder;
import com.devicehive.shim.kafka.fixture.RequestHandlerWrapper;
import com.devicehive.shim.kafka.fixture.TestRequestBody;
import com.devicehive.shim.kafka.fixture.TestResponseBody;
import com.devicehive.shim.kafka.serializer.RequestSerializer;
import com.devicehive.shim.kafka.serializer.ResponseSerializer;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.junit.*;
import org.junit.rules.Timeout;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.Assert.*;
public class KafkaRpcClientServerCommunicationTest {
private static final String REQUEST_TOPIC = "request_topic";
private static final String RESPONSE_TOPIC = "response_topic";
@ClassRule
public static KafkaEmbeddedRule kafkaRule = new KafkaEmbeddedRule(true, 1, REQUEST_TOPIC, RESPONSE_TOPIC);
@Rule
public Timeout testTimeout = new Timeout(180000, TimeUnit.MILLISECONDS); // 180k ms = 3 minutes
private static RpcServer server;
private static RpcClient client;
private static RequestHandlerWrapper handlerWrapper = new RequestHandlerWrapper();
private static Gson gson;
@BeforeClass
public static void setUp() throws Exception {
RuntimeTypeAdapterFactory<Body> requestFactory = RuntimeTypeAdapterFactory.of(Body.class, "action")
.registerSubtype(TestRequestBody.class, "test_request")
.registerSubtype(TestResponseBody.class, "test_response");
gson = new GsonBuilder()
.registerTypeAdapterFactory(requestFactory)
.create();
server = new ServerBuilder()
.withConsumerProps(kafkaRule.getConsumerProperties())
.withProducerProps(kafkaRule.getProducerProperties())
.withConsumerValueDeserializer(new RequestSerializer(gson))
.withProducerValueSerializer(new ResponseSerializer(gson))
.withConsumerThreads(1)
.withWorkerThreads(1)
.withRequestHandler(handlerWrapper)
.withTopic(REQUEST_TOPIC)
.build();
server.start();
client = new ClientBuilder()
.withProducerProps(kafkaRule.getProducerProperties())
.withConsumerProps(kafkaRule.getConsumerProperties())
.withProducerValueSerializer(new RequestSerializer(gson))
.withConsumerValueDeserializer(new ResponseSerializer(gson))
.withReplyTopic(RESPONSE_TOPIC)
.withRequestTopic(REQUEST_TOPIC)
.withConsumerThreads(1)
.build();
client.start();
}
@AfterClass
public static void tearDown() throws Exception {
if (client != null) {
client.shutdown();
}
if (server != null) {
server.shutdown();
}
}
@Test
public void shouldSendRequestToServer() throws Exception {
CompletableFuture<Request> future = new CompletableFuture<>();
RequestHandler handler = request -> {
future.complete(request);
return Response.newBuilder()
.withBody(new TestResponseBody("Response"))
.withCorrelationId(request.getCorrelationId())
.withLast(true)
.buildSuccess();
};
handlerWrapper.setDelegate(handler);
Request request = Request.newBuilder()
.withBody(new TestRequestBody("RequestResponseTest"))
.withSingleReply(true)
.build();
client.push(request);
Request receivedRequest = future.get(10, TimeUnit.SECONDS);
assertEquals(request, receivedRequest);
}
@Test
public void shouldSuccessfullyReplyToRequest() throws Exception {
RequestHandler handler = request -> Response.newBuilder()
.withBody(new TestResponseBody("ResponseFromServer"))
.withCorrelationId(request.getCorrelationId())
.withLast(true)
.buildSuccess();
handlerWrapper.setDelegate(handler);
Request request = Request.newBuilder()
.withBody(new TestRequestBody("RequestResponseTest"))
.withSingleReply(true)
.build();
CompletableFuture<Response> future = new CompletableFuture<>();
client.call(request, future::complete);
Response response = future.get(10, TimeUnit.SECONDS);
assertNotNull(response);
assertEquals(request.getCorrelationId(), response.getCorrelationId());
assertTrue(response.getBody() instanceof TestResponseBody);
assertEquals("ResponseFromServer", ((TestResponseBody) response.getBody()).getResponseBody());
assertTrue(response.isLast());
assertFalse(response.isFailed());
}
@Test
public void shouldSendErrorToClient() throws Exception {
RequestHandler handler = request -> {
throw new RuntimeException("Something went wrong");
};
handlerWrapper.setDelegate(handler);
Request request = Request.newBuilder()
.withBody(new TestRequestBody("RequestResponseTest"))
.withSingleReply(true)
.build();
CompletableFuture<Response> future = new CompletableFuture<>();
client.call(request, future::complete);
Response response = future.get(10, TimeUnit.SECONDS);
assertNotNull(response);
assertEquals(request.getCorrelationId(), response.getCorrelationId());
assertTrue(response.isLast());
assertTrue(response.isFailed());
assertNull(response.getBody());
}
@Test
public void shouldSendMultipleResponsesToClient() throws Exception {
RequestHandler handler = request -> Response.newBuilder()
.withBody(new TestResponseBody("ResponseFromServer"))
.withCorrelationId(request.getCorrelationId())
.withLast(request.isSingleReplyExpected())
.buildSuccess();
handlerWrapper.setDelegate(handler);
Request request = Request.newBuilder()
.withBody(new TestRequestBody("RequestResponseTest"))
.withSingleReply(false)
.build();
CountDownLatch latch = new CountDownLatch(10);
List<Response> responses = Collections.synchronizedList(new LinkedList<>());
Consumer<Response> func = response -> {
responses.add(response);
latch.countDown();
};
client.call(request, func);
Executor executor = Executors.newFixedThreadPool(2);
for (int i = 0; i < 9; i++) {
final int number = i;
executor.execute(() -> {
Response response = Response.newBuilder()
.withBody(new TestResponseBody(number + "-response"))
.withCorrelationId(request.getCorrelationId())
.withLast(false)
.buildSuccess();
server.getDispatcher().send(RESPONSE_TOPIC, response);
});
}
latch.await();
assertEquals(10, responses.size());
Set<String> correlationIds = responses.stream()
.map(Response::getCorrelationId).collect(Collectors.toSet());
assertEquals(1, correlationIds.size());
assertTrue(correlationIds.contains(request.getCorrelationId()));
Set<String> bodies = responses.stream()
.map(Response::getBody)
.map(responseBody -> (TestResponseBody) responseBody)
.map(TestResponseBody::getResponseBody)
.collect(Collectors.toSet());
assertEquals(10, bodies.size());
assertTrue(bodies.contains("ResponseFromServer"));
for (int i = 0; i < 9; i++) {
assertTrue(bodies.contains(i + "-response"));
}
}
}