/*
* Copyright 2016 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.linkedin.restli.client;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import com.linkedin.common.callback.FutureCallback;
import com.linkedin.common.util.None;
import com.linkedin.parseq.BaseEngineTest;
import com.linkedin.parseq.Engine;
import com.linkedin.parseq.EngineBuilder;
import com.linkedin.parseq.Task;
import com.linkedin.parseq.batching.BatchingSupport;
import com.linkedin.parseq.trace.Trace;
import com.linkedin.r2.transport.common.Client;
import com.linkedin.r2.transport.common.bridge.client.TransportClientAdapter;
import com.linkedin.r2.transport.http.client.HttpClientFactory;
import com.linkedin.r2.transport.http.server.HttpServer;
import com.linkedin.restli.client.config.RequestConfigOverrides;
import com.linkedin.restli.common.BatchResponse;
import com.linkedin.restli.common.EmptyRecord;
import com.linkedin.restli.examples.RestLiIntTestServer;
import com.linkedin.restli.examples.greetings.api.Greeting;
import com.linkedin.restli.examples.greetings.client.GreetingsBuilders;
public abstract class ParSeqRestClientIntegrationTest extends BaseEngineTest {
private static final AtomicInteger PORTER = new AtomicInteger(14497);
private final int _port = PORTER.getAndIncrement();
protected final String URI_PREFIX = "http://localhost:" + _port + "/";
private ScheduledExecutorService _serverScheduler;
private Engine _serverEngine;
private HttpServer _server;
private HttpClientFactory _clientFactory;
private List<Client> _transportClients;
private RestClient _restClient;
private final BatchingSupport _batchingSupport = new BatchingSupport();
private final ThreadLocal<InboundRequestContext> _inboundRequestContext = new ThreadLocal<>();
protected ParSeqRestliClient _parseqClient;
protected abstract ParSeqRestliClientConfig getParSeqRestClientConfig();
protected void setInboundRequestContext(InboundRequestContext irc) {
_inboundRequestContext.set(irc);
}
protected void clearInboundRequestContext() {
_inboundRequestContext.remove();
}
@BeforeClass
public void init() throws Exception {
_serverScheduler = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors() + 1);
EngineBuilder serverEngineBuilder = new EngineBuilder();
serverEngineBuilder.setTaskExecutor(_serverScheduler).setTimerScheduler(_serverScheduler)
.setPlanDeactivationListener(_batchingSupport);
_serverEngine = serverEngineBuilder.build();
_server = RestLiIntTestServer.createServer(_serverEngine, _port,
RestLiIntTestServer.supportedCompression, true, 5000);
_server.start();
_clientFactory = new HttpClientFactory();
_transportClients = new ArrayList<Client>();
_restClient = createRestClient();
}
protected RestClient createRestClient() {
Client client = newTransportClient(Collections.<String, String> emptyMap());
return new RestClient(client, URI_PREFIX);
}
@AfterClass
public void shutdown() throws Exception {
if (_server != null) {
_server.stop();
}
if (_serverEngine != null) {
_serverEngine.shutdown();
}
if (_serverScheduler != null) {
_serverScheduler.shutdownNow();
}
for (Client client : _transportClients) {
FutureCallback<None> callback = new FutureCallback<None>();
client.shutdown(callback);
callback.get();
}
if (_clientFactory != null) {
FutureCallback<None> callback = new FutureCallback<None>();
_clientFactory.shutdown(callback);
callback.get();
}
}
private Client newTransportClient(Map<String, ? extends Object> properties) {
Client client = new TransportClientAdapter(_clientFactory.getClient(properties));
_transportClients.add(client);
return client;
}
protected void customizeParSeqRestliClient(ParSeqRestliClientBuilder parSeqRestliClientBuilder) {
}
@Override
protected void customizeEngine(EngineBuilder engineBuilder) {
engineBuilder.setPlanDeactivationListener(_batchingSupport);
ParSeqRestliClientBuilder parSeqRestliClientBuilder = new ParSeqRestliClientBuilder()
.setRestClient(_restClient)
.setBatchingSupport(_batchingSupport)
.setConfig(getParSeqRestClientConfig())
.setInboundRequestContextFinder(() -> Optional.ofNullable(_inboundRequestContext.get()));
customizeParSeqRestliClient(parSeqRestliClientBuilder);
_parseqClient = parSeqRestliClientBuilder.build();
}
protected Task<String> toMessage(Task<Response<Greeting>> greeting) {
return greeting.map("toMessage", g -> g.getEntity().getMessage());
}
protected Task<Response<Greeting>> greetingGet(Long id) {
return _parseqClient.createTask(new GreetingsBuilders().get().id(id).build());
}
protected Task<Response<Greeting>> greetingGet(Long id, RequestConfigOverrides configOverrides) {
return _parseqClient.createTask(new GreetingsBuilders().get().id(id).build(), configOverrides);
}
protected Task<Response<EmptyRecord>> greetingDel(Long id) {
return _parseqClient.createTask(new GreetingsBuilders().delete().id(id).build());
}
protected Task<Response<EmptyRecord>> greetingDel(Long id, RequestConfigOverrides configOverrides) {
return _parseqClient.createTask(new GreetingsBuilders().delete().id(id).build(), configOverrides);
}
protected Task<Response<BatchResponse<Greeting>>> greetings(Long... ids) {
return _parseqClient.createTask(new GreetingsBuilders().batchGet().ids(ids).build());
}
protected Task<Response<BatchResponse<Greeting>>> greetings(RequestConfigOverrides configOverrides, Long... ids) {
return _parseqClient.createTask(new GreetingsBuilders().batchGet().ids(ids).build(), configOverrides);
}
protected boolean hasTask(final String name, final Trace trace) {
return trace.getTraceMap().values().stream().anyMatch(shallowTrace -> shallowTrace.getName().equals(name));
}
protected String getTestClassName() {
return this.getClass().getName();
}
protected static <T> void addProperty(Map<String, Map<String, Object>> config, String property, String key, T value) {
Map<String, Object> map = config.computeIfAbsent(property, k -> new HashMap<>());
map.put(key, value);
}
}