/*
* Copyright 2016 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.linkedin.restli.client;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
import java.util.Collection;
import org.testng.annotations.Test;
import com.linkedin.parseq.Task;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.restli.client.config.RequestConfigOverridesBuilder;
import com.linkedin.restli.common.ResourceMethod;
import com.linkedin.restli.examples.greetings.api.Greeting;
import com.linkedin.restli.examples.greetings.client.GreetingsBuilders;
public class TestRequestContextProvider extends ParSeqRestClientIntegrationTest {
private CapturingRestClient _capturingRestClient;
@Override
public ParSeqRestliClientConfig getParSeqRestClientConfig() {
return new ParSeqRestliClientConfigBuilder()
.addTimeoutMs("*.*/greetings.GET", 9999L)
.addTimeoutMs("*.*/greetings.*", 10001L)
.addTimeoutMs("*.*/*.GET", 10002L)
.addTimeoutMs("foo.*/greetings.GET", 10003L)
.addTimeoutMs("foo.GET/greetings.GET", 10004L)
.addTimeoutMs("foo.ACTION-*/greetings.GET", 10005L)
.addTimeoutMs("foo.ACTION-bar/greetings.GET", 10006L)
.addBatchingEnabled("withBatching.*/*.*", true)
.addMaxBatchSize("withBatching.*/*.*", 3)
.build();
}
@Override
protected RestClient createRestClient() {
_capturingRestClient = new CapturingRestClient(null, null, super.createRestClient());
return _capturingRestClient;
}
private <T> RequestContext createRequestContext(Request<T> request) {
RequestContext requestContext = new RequestContext();
requestContext.putLocalAttr("method", request.getMethod());
return requestContext;
}
@Override
protected void customizeParSeqRestliClient(ParSeqRestliClientBuilder parSeqRestliClientBuilder) {
parSeqRestliClientBuilder.setRequestContextProvider(this::createRequestContext);
}
@Test
public void testNonBatchableRequest() {
try {
GetRequest<Greeting> request = new GreetingsBuilders().get().id(1L).build();
Task<?> task = _parseqClient.createTask(request);
runAndWait(getTestClassName() + ".testNonBatchableRequest", task);
verifyRequestContext(request);
} finally {
_capturingRestClient.clearCapturedRequestContexts();
}
}
@Test
public void testBatchableRequestNotBatched() {
try {
setInboundRequestContext(new InboundRequestContextBuilder()
.setName("withBatching")
.build());
GetRequest<Greeting> request = new GreetingsBuilders().get().id(1L).build();
Task<?> task = _parseqClient.createTask(request);
runAndWait(getTestClassName() + ".testBatchableRequestNotBatched", task);
verifyRequestContext(request);
} finally {
_capturingRestClient.clearCapturedRequestContexts();
}
}
@Test
public void testBatchableRequestBatched() {
try {
setInboundRequestContext(new InboundRequestContextBuilder()
.setName("withBatching")
.build());
GetRequest<Greeting> request1 = new GreetingsBuilders().get().id(1L).build();
GetRequest<Greeting> request2 = new GreetingsBuilders().get().id(2L).build();
Task<?> task = Task.par(_parseqClient.createTask(request1), _parseqClient.createTask(request2));
runAndWait(getTestClassName() + ".testBatchableRequestBatched", task);
Collection<RequestContext> contexts = _capturingRestClient.getCapturedRequestContexts().values();
assertEquals(contexts.size(), 1);
RequestContext context = contexts.iterator().next();
assertNotNull(context.getLocalAttr("method"));
assertEquals(context.getLocalAttr("method"), ResourceMethod.BATCH_GET);
} finally {
_capturingRestClient.clearCapturedRequestContexts();
}
}
private void verifyRequestContext(Request<?> request) {
assertTrue(_capturingRestClient.getCapturedRequestContexts().containsKey(request));
assertNotNull(_capturingRestClient.getCapturedRequestContexts().get(request).getLocalAttr("method"));
assertEquals(_capturingRestClient.getCapturedRequestContexts().get(request).getLocalAttr("method"), request.getMethod());
}
}