/*
* Copyright 2016 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.linkedin.restli.client;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.linkedin.common.callback.Callback;
import com.linkedin.parseq.Task;
import com.linkedin.parseq.batching.Batch;
import com.linkedin.parseq.batching.BatchingStrategy;
import com.linkedin.parseq.internal.ArgumentUtil;
import com.linkedin.parseq.promise.Promise;
import com.linkedin.parseq.promise.Promises;
import com.linkedin.parseq.promise.SettablePromise;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.restli.client.config.ConfigValue;
import com.linkedin.restli.client.config.RequestConfig;
import com.linkedin.restli.client.config.RequestConfigBuilder;
import com.linkedin.restli.client.config.RequestConfigOverrides;
import com.linkedin.restli.client.config.RequestConfigProvider;
import com.linkedin.restli.client.metrics.BatchingMetrics;
import com.linkedin.restli.client.metrics.Metrics;
import com.linkedin.restli.common.OperationNameGenerator;
/**
*
* @author Jaroslaw Odzga (jodzga@linkedin.com)
*
*/
public class ParSeqRestClient extends BatchingStrategy<RequestGroup, RestRequestBatchKey, Response<Object>>
implements ParSeqRestliClient {
private static final Logger LOGGER = LoggerFactory.getLogger(ParSeqRestClient.class);
private final RestClient _restClient;
private final BatchingMetrics _batchingMetrics = new BatchingMetrics();
private final RequestConfigProvider _clientConfig;
private final Function<Request<?>, RequestContext> _requestContextProvider;
ParSeqRestClient(final RestClient restClient, final RequestConfigProvider config,
Function<Request<?>, RequestContext> requestContextProvider) {
ArgumentUtil.requireNotNull(restClient, "restClient");
ArgumentUtil.requireNotNull(config, "config");
ArgumentUtil.requireNotNull(config, "requestContextProvider");
_restClient = restClient;
_clientConfig = config;
_requestContextProvider = requestContextProvider;
}
/**
* Creates new ParSeqRestClient with default configuration.
*
* @deprecated Please use {@link ParSeqRestliClientBuilder} to create instances.
*/
@Deprecated
public ParSeqRestClient(final RestClient restClient) {
ArgumentUtil.requireNotNull(restClient, "restClient");
_restClient = restClient;
_clientConfig = RequestConfigProvider.build(new ParSeqRestliClientConfigBuilder().build(), () -> Optional.empty());
_requestContextProvider = request -> new RequestContext();
}
@Override
public <T> Promise<Response<T>> sendRequest(final Request<T> request) {
return sendRequest(request, _requestContextProvider.apply(request));
}
@Override
public <T> Promise<Response<T>> sendRequest(final Request<T> request, final RequestContext requestContext) {
final SettablePromise<Response<T>> promise = Promises.settable();
_restClient.sendRequest(request, requestContext, new PromiseCallbackAdapter<T>(promise));
return promise;
}
static class PromiseCallbackAdapter<T> implements Callback<Response<T>> {
private final SettablePromise<Response<T>> _promise;
public PromiseCallbackAdapter(final SettablePromise<Response<T>> promise) {
this._promise = promise;
}
@Override
public void onSuccess(final Response<T> result) {
try {
_promise.done(result);
} catch (Exception e) {
onError(e);
}
}
@Override
public void onError(final Throwable e) {
_promise.fail(e);
}
}
@Override
public <T> Task<Response<T>> createTask(final Request<T> request) {
return createTask(request, _requestContextProvider.apply(request));
}
@Override
public <T> Task<Response<T>> createTask(final Request<T> request, final RequestContext requestContext) {
return createTask(generateTaskName(request), request, requestContext, _clientConfig.apply(request));
}
/**
* @deprecated ParSeqRestClient generates consistent names for tasks based on request parameters and it is
* recommended to us default names.
*/
@Deprecated
public <T> Task<Response<T>> createTask(final String name, final Request<T> request, final RequestContext requestContext) {
return createTask(name, request, requestContext, _clientConfig.apply(request));
}
@Override
public <T> Task<Response<T>> createTask(Request<T> request, RequestConfigOverrides configOverrides) {
return createTask(request, _requestContextProvider.apply(request), configOverrides);
}
@Override
public <T> Task<Response<T>> createTask(Request<T> request, RequestContext requestContext,
RequestConfigOverrides configOverrides) {
RequestConfig config = _clientConfig.apply(request);
RequestConfigBuilder configBuilder = new RequestConfigBuilder(config);
RequestConfig effectiveConfig = configBuilder.applyOverrides(configOverrides).build();
return createTask(generateTaskName(request), request, requestContext, effectiveConfig);
}
/**
* Generates a task name for the request.
* @param request
* @return a task name
*/
static String generateTaskName(final Request<?> request) {
return request.getBaseUriTemplate() + " "
+ OperationNameGenerator.generate(request.getMethod(), request.getMethodName());
}
private <T> Task<Response<T>> withTimeout(final Task<Response<T>> task, RequestConfig config) {
ConfigValue<Long> timeout = config.getTimeoutMs();
if (timeout.getValue() != null && timeout.getValue() > 0) {
if (timeout.getSource().isPresent()) {
return task.withTimeout("src: " + timeout.getSource().get(), timeout.getValue(), TimeUnit.MILLISECONDS);
} else {
return task.withTimeout(timeout.getValue(), TimeUnit.MILLISECONDS);
}
} else {
return task;
}
}
private <T> Task<Response<T>> createTask(final String name, final Request<T> request,
final RequestContext requestContext, RequestConfig config) {
LOGGER.debug("createTask, name: '{}', config: {}", name, config);
if (RequestGroup.isBatchable(request, config)) {
return withTimeout(createBatchableTask(name, request, requestContext, config), config);
} else {
return withTimeout(Task.async(name, () -> sendRequest(request, requestContext)), config);
}
}
private RestRequestBatchKey createKey(Request<Object> request, RequestContext requestContext,
RequestConfig config) {
return new RestRequestBatchKey(request, requestContext, config);
}
@SuppressWarnings("unchecked")
private <T> Task<Response<T>> createBatchableTask(String name, Request<T> request, RequestContext requestContext,
RequestConfig config) {
return cast(batchable(name, createKey((Request<Object>) request, requestContext, config)));
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private static <X> Task<X> cast(Task t) {
return (Task<X>) t;
}
@Override
public void executeBatch(RequestGroup group, Batch<RestRequestBatchKey, Response<Object>> batch) {
if (group instanceof GetRequestGroup) {
_batchingMetrics.recordBatchSize(group.getBaseUriTemplate(), batch.batchSize());
}
group.executeBatch(_restClient, batch, _requestContextProvider);
}
@Override
public RequestGroup classify(RestRequestBatchKey key) {
Request<?> request = key.getRequest();
return RequestGroup.fromRequest(request, key.getRequestConfig().getMaxBatchSize().getValue());
}
@Override
public String getBatchName(RequestGroup group, Batch<RestRequestBatchKey, Response<Object>> batch) {
return group.getBatchName(batch);
}
@Override
public int keySize(RequestGroup group, RestRequestBatchKey key) {
return group.keySize(key);
}
@Override
public int maxBatchSizeForGroup(RequestGroup group) {
return group.getMaxBatchSize();
}
public BatchingMetrics getBatchingMetrics() {
return _batchingMetrics;
}
@Override
public Metrics getMetrics() {
return () -> _batchingMetrics;
}
}