/*
* Copyright 2016 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.linkedin.restli.client;
import static org.testng.Assert.assertTrue;
import java.util.HashSet;
import java.util.Set;
import org.testng.annotations.Test;
import com.linkedin.data.DataMap;
import com.linkedin.parseq.Task;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.restli.examples.greetings.api.Greeting;
import com.linkedin.restli.internal.client.response.BatchEntityResponse;
public class TestRequest404WithBatching extends ParSeqRestClientIntegrationTest {
private CapturingRestClient _capturingRestClient;
@Override
public ParSeqRestliClientConfig getParSeqRestClientConfig() {
return new ParSeqRestliClientConfigBuilder()
.addTimeoutMs("*.*/greetings.GET", 9999L)
.addTimeoutMs("*.*/greetings.*", 10001L)
.addTimeoutMs("*.*/*.GET", 10002L)
.addTimeoutMs("foo.*/greetings.GET", 10003L)
.addTimeoutMs("foo.GET/greetings.GET", 10004L)
.addTimeoutMs("foo.ACTION-*/greetings.GET", 10005L)
.addTimeoutMs("foo.ACTION-bar/greetings.GET", 10006L)
.addBatchingEnabled("withBatching.*/*.*", true)
.addMaxBatchSize("withBatching.*/*.*", 3)
.build();
}
private Object remove404(Object o) {
if (o instanceof Response) {
Response r = (Response) o;
Object entity = r.getEntity();
if (entity instanceof BatchEntityResponse) {
BatchEntityResponse ber = (BatchEntityResponse) entity;
DataMap data = ber.data();
DataMap errors = (DataMap) data.getDataMap("errors");
Set<String> keys = new HashSet<>(errors.keySet());
keys.forEach(key -> {
DataMap error = errors.getDataMap(key);
if (error.getInteger("status").equals(404)) {
errors.remove(key);
}
});
}
}
return o;
}
@Override
protected RestClient createRestClient() {
_capturingRestClient = new CapturingRestClient(null, null, super.createRestClient(), this::remove404);
return _capturingRestClient;
}
private <T> RequestContext createRequestContext(Request<T> request) {
RequestContext requestContext = new RequestContext();
requestContext.putLocalAttr("method", request.getMethod());
return requestContext;
}
@Override
protected void customizeParSeqRestliClient(ParSeqRestliClientBuilder parSeqRestliClientBuilder) {
parSeqRestliClientBuilder.setRequestContextProvider(this::createRequestContext);
}
@Test
public void testBatchGet404() {
try {
setInboundRequestContext(new InboundRequestContextBuilder()
.setName("withBatching")
.build());
Task<?> task = Task.par(greetingGet(1L).map(Response::getEntity).map(Greeting::getMessage),
greetingGet(2L).map(Response::getEntity).map(Greeting::getMessage),
greetingGet(400L).map(Response::getEntity).map(Greeting::getMessage).recover(t -> t.toString()))
.map((a, b, c) -> a + b + c);
runAndWait(getTestClassName() + ".testBatchGet404", task);
assertTrue(task.get().toString().contains("Good morning!"));
assertTrue(task.get().toString().contains("Guten Morgen!"));
assertTrue(task.get().toString().contains("com.linkedin.restli.client.RestLiResponseException: Response status 404"));
} finally {
clearInboundRequestContext();
}
}
}