/**
* Copyright 2016-2017 Sixt GmbH & Co. Autovermietung KG
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may obtain a
* copy of the License at http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.sixt.service.framework.jetty;
import com.codahale.metrics.MetricRegistry;
import com.google.protobuf.Message;
import com.sixt.service.framework.*;
import com.sixt.service.framework.metrics.GoTimer;
import com.sixt.service.framework.rpc.RpcCallException;
import io.opentracing.Span;
import io.opentracing.SpanContext;
import io.opentracing.Tracer;
import io.opentracing.propagation.Format;
import io.opentracing.propagation.TextMapExtractAdapter;
import io.opentracing.tag.Tags;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;
import static com.google.common.collect.ImmutableSortedSet.of;
public abstract class RpcHandler {
private static final Logger logger = LoggerFactory.getLogger(RpcHandler.class);
protected final MethodHandlerDictionary handlers;
protected final MetricRegistry metricRegistry;
protected final RpcHandlerMetrics handlerMetrics;
protected final ServiceProperties serviceProps;
protected final Tracer tracer;
//For now, we block services from getting certain input headers.
//The reason is that these headers are also then used for outgoing requests.
//If you need the incoming headers, we can create an additional bucket inside of OrangeContext to hold them.
private static final Set<String> blackListedHeaders = of("user-agent", "content-length", "content-type",
"date", "expect", "host");
public RpcHandler(MethodHandlerDictionary handlers, MetricRegistry registry,
RpcHandlerMetrics handlerMetrics, ServiceProperties serviceProperties,
Tracer tracer) {
this.handlers = handlers;
this.metricRegistry = registry;
this.handlerMetrics = handlerMetrics;
this.serviceProps = serviceProperties;
this.tracer = tracer;
}
protected void incrementFailureCounter(String methodName, String originService,
String originMethod) {
handlerMetrics.incrementFailureCounter(methodName, originService, originMethod);
}
protected void incrementSuccessCounter(String methodName, String originService,
String originMethod) {
handlerMetrics.incrementSuccessCounter(methodName, originService, originMethod);
}
protected GoTimer getMethodTimer(String methodName, String originService,
String originMethod) {
return handlerMetrics.getMethodTimer(methodName, originService, originMethod);
}
protected Span getSpan(String methodName, Map<String, String> headers, OrangeContext context) {
Span span = null;
if (tracer != null) {
SpanContext spanContext = tracer.extract(Format.Builtin.HTTP_HEADERS, new TextMapExtractAdapter(headers));
if (spanContext != null) {
span = tracer.buildSpan(methodName).asChildOf(spanContext).start();
} else {
span = tracer.buildSpan(methodName).start();
}
span.setTag("correlation_id", context.getCorrelationId());
Tags.SPAN_KIND.set(span, Tags.SPAN_KIND_SERVER);
Tags.PEER_SERVICE.set(span, context.getRpcOriginService());
context.setTracingContext(span.context());
}
return span;
}
protected Map<String, String> gatherHttpHeaders(HttpServletRequest req) {
Map<String, String> headers = new HashMap<>();
Enumeration<String> headerNames = req.getHeaderNames();
if (headerNames == null) {
return headers;
}
while (headerNames.hasMoreElements()) {
String headerName = headerNames.nextElement();
Enumeration<String> headerValues = req.getHeaders(headerName);
if (headerValues == null) {
continue;
}
String headerNameLower = headerName.toLowerCase();
if (! blackListedHeaders.contains(headerNameLower)) {
if (headerValues.hasMoreElements()) {
headers.put(headerNameLower, headerValues.nextElement());
}
while (headerValues.hasMoreElements()) {
logger.debug("Duplicate http-header, discarding: {} = {}", headerName, headerValues.nextElement());
}
} else {
logger.trace("Blocking header {}", headerNameLower);
}
}
return headers;
}
/**
* Invoke in the following order:
* <ol><li>Global pre-handler hooks</li>
* <li>Pre-handler hooks for this methodName</li>
* <li>The handler</li>
* <li>Post-handler hooks for this methodName</li>
* <li>Global post-handler hooks</li></ol>
*/
@SuppressWarnings("unchecked")
protected Message invokeHandlerChain(String methodName, ServiceMethodHandler handler,
Message request, OrangeContext context) throws RpcCallException {
List<ServiceMethodPreHook<? extends Message>> preHooks = handlers.getPreHooksFor(methodName);
for (ServiceMethodPreHook hook : preHooks) {
request = hook.handleRequest(request, context);
}
Message response = handler.handleRequest(request, context);
List<ServiceMethodPostHook<? extends Message>> postHooks = handlers.getPostHooksFor(methodName);
for (ServiceMethodPostHook hook : postHooks) {
response = hook.handleRequest(response, context);
}
return response;
}
protected void writeResponse(HttpServletResponse resp, int statusCode, String s) throws IOException {
if (statusCode != 200 && FeatureFlags.shouldExposeErrorsToHttp(serviceProps)) {
resp.setStatus(statusCode);
} else {
resp.setStatus(HttpServletResponse.SC_OK);
}
resp.getWriter().write(s);
resp.getWriter().flush();
}
}