/*
* Copyright 2016 LinkedIn, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.linkedin.restli.client;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.linkedin.common.callback.Callback;
import com.linkedin.data.DataMap;
import com.linkedin.data.schema.PathSpec;
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.parseq.batching.Batch;
import com.linkedin.parseq.batching.BatchImpl.BatchEntry;
import com.linkedin.parseq.function.Tuple3;
import com.linkedin.parseq.function.Tuples;
import com.linkedin.r2.RemoteInvocationException;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.restli.client.response.BatchKVResponse;
import com.linkedin.restli.common.BatchResponse;
import com.linkedin.restli.common.EntityResponse;
import com.linkedin.restli.common.ErrorResponse;
import com.linkedin.restli.common.HttpStatus;
import com.linkedin.restli.common.ProtocolVersion;
import com.linkedin.restli.common.ResourceMethod;
import com.linkedin.restli.common.ResourceSpec;
import com.linkedin.restli.common.RestConstants;
import com.linkedin.restli.internal.client.ResponseImpl;
import com.linkedin.restli.internal.client.response.BatchEntityResponse;
import com.linkedin.restli.internal.common.ProtocolVersionUtil;
import com.linkedin.restli.internal.common.ResponseUtils;
class GetRequestGroup implements RequestGroup {
private static final Logger LOGGER = LoggerFactory.getLogger(GetRequestGroup.class);
private static final RestLiResponseException NOT_FOUND_EXCEPTION =
new RestLiResponseException(new RestResponseBuilder().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()).build(),
null, new ErrorResponse().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()));
private final String _baseUriTemplate; //taken from first request, used to differentiate between groups
private final ResourceSpec _resourceSpec; //taken from first request
private final Map<String, String> _headers; //taken from first request, used to differentiate between groups
private final RestliRequestOptions _requestOptions; //taken from first request, used to differentiate between groups
private final Map<String, Object> _queryParams; //taken from first request, used to differentiate between groups
private final int _maxBatchSize;
@SuppressWarnings("deprecation")
public GetRequestGroup(Request<?> request, int maxBatchSize) {
_baseUriTemplate = request.getBaseUriTemplate();
_headers = request.getHeaders();
_queryParams = getQueryParamsForBatchingKey(request);
_resourceSpec = request.getResourceSpec();
_requestOptions = request.getRequestOptions();
_maxBatchSize = maxBatchSize;
}
private static Map<String, Object> getQueryParamsForBatchingKey(Request<?> request)
{
final Map<String, Object> params = new HashMap<>(request.getQueryParamsObjects());
params.remove(RestConstants.QUERY_BATCH_IDS_PARAM);
params.remove(RestConstants.FIELDS_PARAM);
return params;
}
private static <K, RT extends RecordTemplate> Response<RT> unbatchResponse(BatchGetEntityRequest<K, RT> request,
Response<BatchKVResponse<K, EntityResponse<RT>>> batchResponse, Object id) throws RemoteInvocationException {
final BatchKVResponse<K, EntityResponse<RT>> batchEntity = batchResponse.getEntity();
final ErrorResponse errorResponse = batchEntity.getErrors().get(id);
if (errorResponse != null) {
throw new RestLiResponseException(errorResponse);
}
final EntityResponse<RT> entityResponse = batchEntity.getResults().get(id);
if (entityResponse != null) {
final RT entityResult = entityResponse.getEntity();
if (entityResult != null) {
return new ResponseImpl<>(batchResponse, entityResult);
}
}
LOGGER.debug("No result or error for base URI : {}, id: {}. Verify that the batchGet endpoint returns response keys that match batchGet request IDs.",
request.getBaseUriTemplate(), id);
throw NOT_FOUND_EXCEPTION;
}
private DataMap filterIdsInBatchResult(DataMap data, Set<String> ids) {
DataMap dm = new DataMap(data.size());
data.forEach((key, value) -> {
switch(key) {
case BatchResponse.ERRORS:
dm.put(key, filterIds((DataMap)value, ids));
break;
case BatchResponse.RESULTS:
dm.put(key, filterIds((DataMap)value, ids));
break;
case BatchResponse.STATUSES:
dm.put(key, filterIds((DataMap)value, ids));
break;
default:
dm.put(key, value);
break;
}
});
return dm;
}
private Object filterIds(DataMap data, Set<String> ids) {
DataMap dm = new DataMap(data.size());
data.forEach((key, value) -> {
if (ids.contains(key)) {
dm.put(key, value);
}
});
return dm;
}
//Tuple3: (keys, fields, contains-batch-get)
private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceRequests(final Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
final Request<?> rq) {
return reduceContainsBatch(reduceIds(reduceFields(state, rq), rq), rq);
}
//Tuple3: (keys, fields, contains-batch-get)
private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceContainsBatch(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
Request<?> request) {
if (request instanceof GetRequest) {
return state;
} else if (request instanceof BatchRequest) {
return Tuples.tuple(state._1(), state._2(), true);
} else {
throw unsupportedGetRequestType(request);
}
}
//Tuple3: (keys, fields, contains-batch-get)
private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceIds(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
Request<?> request) {
if (request instanceof GetRequest) {
GetRequest<?> getRequest = (GetRequest<?>)request;
state._1().add(getRequest.getObjectId());
return state;
} else if (request instanceof BatchRequest) {
BatchRequest<?> batchRequest = (BatchRequest<?>)request;
state._1().addAll(batchRequest.getObjectIds());
return state;
} else {
throw unsupportedGetRequestType(request);
}
}
//Tuple3: (keys, fields, contains-batch-get)
private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceFields(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
Request<?> request) {
if (request instanceof GetRequest || request instanceof BatchRequest) {
final Set<PathSpec> requestFields = request.getFields();
if (requestFields != null && !requestFields.isEmpty()) {
if (state._2() != null) {
state._2().addAll(requestFields);
}
return state;
} else {
return Tuples.tuple(state._1(), null, state._3());
}
} else {
throw unsupportedGetRequestType(request);
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private <K, RT extends RecordTemplate> void doExecuteBatchGet(final RestClient restClient,
final Batch<RestRequestBatchKey, Response<Object>> batch, final Set<Object> ids, final Set<PathSpec> fields,
Function<Request<?>, RequestContext> requestContextProvider) {
final BatchGetEntityRequestBuilder<K, RT> builder = new BatchGetEntityRequestBuilder<>(_baseUriTemplate, _resourceSpec, _requestOptions);
builder.setHeaders(_headers);
_queryParams.forEach((key, value) -> builder.setParam(key, value));
builder.setParam(RestConstants.QUERY_BATCH_IDS_PARAM, ids);
if (fields != null && !fields.isEmpty()) {
builder.setParam(RestConstants.FIELDS_PARAM, fields.toArray());
}
final BatchGetEntityRequest<K, RT> batchGet = builder.build();
restClient.sendRequest(batchGet, requestContextProvider.apply(batchGet), new Callback<Response<BatchKVResponse<K, EntityResponse<RT>>>>() {
@Override
public void onSuccess(Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch) {
final ProtocolVersion version = ProtocolVersionUtil.extractProtocolVersion(responseToBatch.getHeaders());
batch.entries().stream()
.forEach(entry -> {
try {
RestRequestBatchKey rrbk = entry.getKey();
Request request = rrbk.getRequest();
if (request instanceof GetRequest) {
successGet((GetRequest) request, responseToBatch, batchGet, entry, version);
} else if (request instanceof BatchGetKVRequest) {
successBatchGetKV((BatchGetKVRequest) request, responseToBatch, entry, version);
} else if (request instanceof BatchGetRequest) {
successBatchGet((BatchGetRequest) request, responseToBatch, entry, version);
} else if (request instanceof BatchGetEntityRequest) {
successBatchGetEntity((BatchGetEntityRequest) request, responseToBatch, entry, version);
} else {
entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
}
} catch (RemoteInvocationException e) {
entry.getValue().getPromise().fail(e);
}
});
}
@SuppressWarnings({ "deprecation" })
private void successBatchGetEntity(BatchGetEntityRequest request,
Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version) {
Set<String> ids = (Set<String>) request.getObjectIds().stream()
.map(o -> BatchResponse.keyToString(o, version))
.collect(Collectors.toSet());
DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
BatchKVResponse br = new BatchEntityResponse<>(dm, request.getResourceSpec().getKeyType(),
request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
request.getResourceSpec().getComplexKeyType(), version);
Response rsp = new ResponseImpl(responseToBatch, br);
entry.getValue().getPromise().done(rsp);
}
private void successBatchGet(BatchGetRequest request, Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version) {
Set<String> ids = (Set<String>) request.getObjectIds().stream()
.map(o -> BatchResponse.keyToString(o, version))
.collect(Collectors.toSet());
DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
BatchResponse br = new BatchResponse<>(dm, request.getResponseDecoder().getEntityClass());
Response rsp = new ResponseImpl(responseToBatch, br);
entry.getValue().getPromise().done(rsp);
}
@SuppressWarnings({ "deprecation" })
private void successBatchGetKV(BatchGetKVRequest request, Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry,
final ProtocolVersion version) {
Set<String> ids = (Set<String>) request.getObjectIds().stream()
.map(o -> BatchResponse.keyToString(o, version))
.collect(Collectors.toSet());
DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
BatchKVResponse br = new BatchKVResponse(dm, request.getResourceSpec().getKeyType(),
request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
request.getResourceSpec().getComplexKeyType(), version);
Response rsp = new ResponseImpl(responseToBatch, br);
entry.getValue().getPromise().done(rsp);
}
@SuppressWarnings({ "deprecation" })
private void successGet(GetRequest request,
Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch, final BatchGetEntityRequest<K, RT> batchGet,
Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version)
throws RemoteInvocationException {
String idString = BatchResponse.keyToString(request.getObjectId(), version);
Object id = ResponseUtils.convertKey(idString, request.getResourceSpec().getKeyType(),
request.getResourceSpec().getKeyParts(), request.getResourceSpec().getComplexKeyType(), version);
Response rsp = unbatchResponse(batchGet, responseToBatch, id);
entry.getValue().getPromise().done(rsp);
}
@Override
public void onError(Throwable e) {
batch.failAll(e);
}
});
}
private static RuntimeException unsupportedGetRequestType(Request<?> request) {
return new RuntimeException("ParSeqRestliClient could not handle this type of GET request: " + request.getClass().getName());
}
@SuppressWarnings({ "rawtypes", "unchecked" })
private <K, RT extends RecordTemplate> void doExecuteGet(final RestClient restClient,
final Batch<RestRequestBatchKey, Response<Object>> batch, final Set<Object> ids, final Set<PathSpec> fields,
Function<Request<?>, RequestContext> requestContextProvider) {
final GetRequestBuilder<K, RT> builder = (GetRequestBuilder<K, RT>) new GetRequestBuilder<>(_baseUriTemplate,
_resourceSpec.getValueClass(), _resourceSpec, _requestOptions);
builder.setHeaders(_headers);
_queryParams.forEach((key, value) -> builder.setParam(key, value));
builder.id((K) ids.iterator().next());
if (fields != null && !fields.isEmpty()) {
builder.setParam(RestConstants.FIELDS_PARAM, fields.toArray());
}
final GetRequest<RT> get = builder.build();
restClient.sendRequest(get, requestContextProvider.apply(get), new Callback<Response<RT>>() {
@Override
public void onError(Throwable e) {
batch.failAll(e);
}
@Override
public void onSuccess(Response<RT> responseToGet) {
batch.entries().stream().forEach(entry -> {
Request request = entry.getKey().getRequest();
if (request instanceof GetRequest) {
entry.getValue().getPromise().done(new ResponseImpl<>(responseToGet, responseToGet.getEntity()));
} else {
entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
}
});
}
});
}
//Tuple3: (keys, fields, contains-batch-get)
private Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceRequests(
final Batch<RestRequestBatchKey, Response<Object>> batch) {
return batch.entries().stream()
.map(Entry::getKey)
.map(RestRequestBatchKey::getRequest)
.reduce(Tuples.tuple(new HashSet<>(), new HashSet<>(), false),
GetRequestGroup::reduceRequests,
GetRequestGroup::combine);
}
private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> combine(Tuple3<Set<Object>, Set<PathSpec>, Boolean> a,
Tuple3<Set<Object>, Set<PathSpec>, Boolean> b) {
Set<Object> ids = a._1();
ids.addAll(b._1());
Set<PathSpec> paths = a._2();
paths.addAll(b._2());
return Tuples.tuple(ids, paths, a._3() || b._3());
}
@Override
public <RT extends RecordTemplate> void executeBatch(final RestClient restClient, final Batch<RestRequestBatchKey, Response<Object>> batch,
Function<Request<?>, RequestContext> requestContextProvider) {
final Tuple3<Set<Object>, Set<PathSpec>, Boolean> reductionResults = reduceRequests(batch);
final Set<Object> ids = reductionResults._1();
final Set<PathSpec> fields = reductionResults._2();
final boolean containsBatchGet = reductionResults._3();
LOGGER.debug("executeBatch, ids: '{}', fields: {}", ids, fields);
if (ids.size() == 1 && !containsBatchGet) {
doExecuteGet(restClient, batch, ids, fields, requestContextProvider);
} else {
doExecuteBatchGet(restClient, batch, ids, fields, requestContextProvider);
}
}
@Override
public String getBaseUriTemplate() {
return _baseUriTemplate;
}
public Map<String, String> getHeaders() {
return _headers;
}
public Map<String, Object> getQueryParams() {
return _queryParams;
}
public ResourceSpec getResourceSpec() {
return _resourceSpec;
}
public RestliRequestOptions getRequestOptions() {
return _requestOptions;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((_baseUriTemplate == null) ? 0 : _baseUriTemplate.hashCode());
result = prime * result + ((_headers == null) ? 0 : _headers.hashCode());
result = prime * result + ((_queryParams == null) ? 0 : _queryParams.hashCode());
result = prime * result + ((_requestOptions == null) ? 0 : _requestOptions.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
GetRequestGroup other = (GetRequestGroup) obj;
if (_baseUriTemplate == null) {
if (other._baseUriTemplate != null)
return false;
} else if (!_baseUriTemplate.equals(other._baseUriTemplate))
return false;
if (_headers == null) {
if (other._headers != null)
return false;
} else if (!_headers.equals(other._headers))
return false;
if (_queryParams == null) {
if (other._queryParams != null)
return false;
} else if (!_queryParams.equals(other._queryParams))
return false;
if (_requestOptions == null) {
if (other._requestOptions != null)
return false;
} else if (!_requestOptions.equals(other._requestOptions))
return false;
return true;
}
@Override
public String toString() {
return "GetRequestGroup [_baseUriTemplate=" + _baseUriTemplate + ", _queryParams=" + _queryParams
+ ", _requestOptions=" + _requestOptions + ", _headers=" + _headers + ", _maxBatchSize=" + _maxBatchSize + "]";
}
@Override
public <K, V> String getBatchName(final Batch<K, V> batch) {
return _baseUriTemplate + " " + (batch.batchSize() == 1 ? ResourceMethod.GET : (ResourceMethod.BATCH_GET +
"(reqs: " + batch.keySize() + ", ids: " + batch.batchSize() + ")"));
}
@Override
public int getMaxBatchSize() {
return _maxBatchSize;
}
@Override
public int keySize(RestRequestBatchKey key) {
return key.ids().size();
}
}