/**
* Copyright 2016-2017 Sixt GmbH & Co. Autovermietung KG
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may obtain a
* copy of the License at http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.sixt.service.framework.jetty;
import com.codahale.metrics.MetricRegistry;
import com.google.common.primitives.Ints;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import com.google.protobuf.Message;
import com.sixt.service.framework.*;
import com.sixt.service.framework.metrics.GoTimer;
import com.sixt.service.framework.protobuf.ProtobufUtil;
import com.sixt.service.framework.protobuf.RpcEnvelope;
import com.sixt.service.framework.rpc.RpcCallException;
import com.sixt.service.framework.util.ReflectionUtil;
import io.opentracing.Span;
import io.opentracing.Tracer;
import io.opentracing.tag.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Map;
@Singleton
public class ProtobufHandler extends RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(ProtobufHandler.class);
@Inject
public ProtobufHandler(MethodHandlerDictionary handlers, MetricRegistry registry,
RpcHandlerMetrics handlerMetrics, ServiceProperties serviceProperties, Tracer tracer) {
super(handlers, registry, handlerMetrics, serviceProperties, tracer);
}
@SuppressWarnings("unchecked")
public void doPost(HttpServletRequest req, HttpServletResponse resp) {
logger.debug("Handling protobuf request");
RpcEnvelope.Request rpcRequest = null;
String methodName = null;
Span span = null;
Map<String, String> headers = gatherHttpHeaders(req);
OrangeContext context = new OrangeContext(headers);
HttpServletRequest blubb = new HttpServletRequestWrapper(req);
try {
MDC.put(OrangeContext.CORRELATION_ID, context.getCorrelationId());
ServletInputStream in = req.getInputStream();
rpcRequest = readRpcEnvelope(in);
methodName = rpcRequest.getServiceMethod();
span = getSpan(methodName, headers, context);
ServiceMethodHandler handler = handlers.getMethodHandler(methodName);
if (handler == null) {
incrementFailureCounter(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
throw new IllegalArgumentException("Invalid method: " +
rpcRequest.getServiceMethod());
}
Class<? extends Message> requestClass = (Class<? extends Message>)
ReflectionUtil.findSubClassParameterType(handler, 0);
Message pbRequest = readRpcBody(in, requestClass);
GoTimer methodTimer = getMethodTimer(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
long startTime = methodTimer.start();
Message pbResponse = invokeHandlerChain(methodName, handler, pbRequest, context);
resp.setContentType(RpcServlet.TYPE_OCTET);
sendSuccessfulResponse(resp, rpcRequest, pbResponse);
//TODO: should we check the response for errors?
methodTimer.recordSuccess(startTime);
incrementSuccessCounter(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
} catch (RpcCallException rpcEx) {
sendErrorResponse(resp, rpcRequest, rpcEx.toString(), rpcEx.getCategory().getHttpStatus());
if (span != null) {
Tags.ERROR.set(span, true);
}
incrementFailureCounter(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
} catch (RpcReadException ex) {
logger.warn("Bad request, cannot decode rpc message: {}", ex.toJson(req));
sendErrorResponse(resp, rpcRequest, ex.getMessage(), HttpServletResponse.SC_BAD_REQUEST);
if (span != null) {
Tags.ERROR.set(span, true);
}
incrementFailureCounter(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
} catch (Exception ex) {
logger.warn("Uncaught exception", ex);
sendErrorResponse(resp, rpcRequest, ex.getMessage(), HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
if (span != null) {
Tags.ERROR.set(span, true);
}
incrementFailureCounter(methodName, context.getRpcOriginService(),
context.getRpcOriginMethod());
} finally {
if (span != null) {
span.finish();
}
MDC.remove(OrangeContext.CORRELATION_ID);
}
}
private void sendSuccessfulResponse(HttpServletResponse response,
RpcEnvelope.Request rpcRequest,
Message pbResponse) throws IOException {
response.setStatus(HttpServletResponse.SC_OK);
RpcEnvelope.Response rpcResponse = RpcEnvelope.Response.newBuilder().
setServiceMethod(rpcRequest.getServiceMethod()).
setSequenceNumber(rpcRequest.getSequenceNumber()).build();
byte responseHeader[] = rpcResponse.toByteArray();
byte responseBody[];
if (pbResponse == null) {
responseBody = new byte[0];
} else {
responseBody = pbResponse.toByteArray();
}
try {
ServletOutputStream out = response.getOutputStream();
out.write(Ints.toByteArray(responseHeader.length));
out.write(responseHeader);
out.write(Ints.toByteArray(responseBody.length));
out.write(responseBody);
} catch (IOException ioex) {
//there is nothing we can do, client probably went away
logger.debug("Caught IOException, assuming client disconnected");
}
}
private void sendErrorResponse(HttpServletResponse resp,
RpcEnvelope.Request rpcRequest,
String message,
int httpStatusCode) {
if (rpcRequest != null) {
try {
if (FeatureFlags.shouldExposeErrorsToHttp(serviceProps)) {
resp.setStatus(httpStatusCode);
} else {
resp.setStatus(HttpServletResponse.SC_OK);
}
if (message == null) {
message = "null";
}
RpcEnvelope.Response rpcResponse = RpcEnvelope.Response.newBuilder().
setServiceMethod(rpcRequest.getServiceMethod()).
setSequenceNumber(rpcRequest.getSequenceNumber()).
setError(message).build();
byte responseHeader[] = rpcResponse.toByteArray();
ServletOutputStream out = resp.getOutputStream();
out.write(Ints.toByteArray(responseHeader.length));
out.write(responseHeader);
out.write(Ints.toByteArray(0)); //zero-length (no) body
} catch (Exception ex) {
logger.warn("Error writing error response", ex);
}
}
}
private RpcEnvelope.Request readRpcEnvelope(ServletInputStream in) throws Exception {
byte chunkSize[] = new byte[4];
in.read(chunkSize);
int size = Ints.fromByteArray(chunkSize);
if (size <= 0 || size > ProtobufUtil.MAX_HEADER_CHUNK_SIZE) {
String message = "Invalid header chunk size: " + size;
throw new RpcReadException(chunkSize, in, message);
}
byte headerData[] = readyFully(in, size);
RpcEnvelope.Request rpcRequest = RpcEnvelope.Request.parseFrom(headerData);
return rpcRequest;
}
private Message readRpcBody(ServletInputStream in,
Class<? extends Message> requestClass) throws Exception {
byte chunkSize[] = new byte[4];
in.read(chunkSize);
int size = Ints.fromByteArray(chunkSize);
if (size == 0) {
return ProtobufUtil.newEmptyMessage(requestClass);
}
if (size > ProtobufUtil.MAX_BODY_CHUNK_SIZE) {
String message = "Invalid body chunk size: " + size;
throw new RpcReadException(chunkSize, in, message);
}
byte bodyData[] = readyFully(in, size);
Message pbRequest = ProtobufUtil.byteArrayToProtobuf(bodyData, requestClass);
return pbRequest;
}
private byte[] readyFully(ServletInputStream in, int totalSize) throws Exception {
byte[] retval = new byte[totalSize];
int bytesRead = 0;
while (bytesRead < totalSize) {
try {
int read = in.read(retval, bytesRead, totalSize - bytesRead);
if (read == -1) {
throw new RpcCallException(RpcCallException.Category.InternalServerError,
"Unable to read complete request or response");
}
bytesRead += read;
} catch (IOException e) {
throw new RpcCallException(RpcCallException.Category.InternalServerError,
"IOException reading data: " + e);
}
}
return retval;
}
}