package org.mockserver.proxy;
import com.google.common.base.Strings;
import com.google.common.net.MediaType;
import org.mockserver.client.netty.NettyHttpClient;
import org.mockserver.client.serialization.HttpRequestSerializer;
import org.mockserver.client.serialization.VerificationSequenceSerializer;
import org.mockserver.client.serialization.VerificationSerializer;
import org.mockserver.filters.*;
import org.mockserver.mappers.HttpServletRequestToMockServerRequestDecoder;
import org.mockserver.mappers.MockServerResponseToHttpServletResponseEncoder;
import org.mockserver.model.HttpRequest;
import org.mockserver.model.HttpResponse;
import org.mockserver.model.HttpStatusCode;
import org.mockserver.streams.IOStreamUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE;
import static org.mockserver.model.HttpResponse.notFoundResponse;
import static org.mockserver.model.OutboundHttpRequest.outboundRequest;
/**
* @author jamesdbloom
*/
public class ProxyServlet extends HttpServlet {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
public RequestLogFilter requestLogFilter = new RequestLogFilter();
public RequestResponseLogFilter requestResponseLogFilter = new RequestResponseLogFilter();
// mockserver
private Filters filters = new Filters();
// http client
private NettyHttpClient httpClient = new NettyHttpClient();
// mappers
private HttpServletRequestToMockServerRequestDecoder httpServletRequestToMockServerRequestDecoder = new HttpServletRequestToMockServerRequestDecoder();
private MockServerResponseToHttpServletResponseEncoder mockServerResponseToHttpServletResponseEncoder = new MockServerResponseToHttpServletResponseEncoder();
// serializers
private HttpRequestSerializer httpRequestSerializer = new HttpRequestSerializer();
private VerificationSerializer verificationSerializer = new VerificationSerializer();
private VerificationSequenceSerializer verificationSequenceSerializer = new VerificationSequenceSerializer();
public ProxyServlet() {
filters.withFilter(new HttpRequest(), new HopByHopHeaderFilter());
filters.withFilter(new HttpRequest(), requestLogFilter);
filters.withFilter(new HttpRequest(), requestResponseLogFilter);
}
/**
* Add filter for HTTP requests, each filter get called before each request is proxied, if the filter return null then the request is not proxied
*
* @param httpRequest the request to match against for this filter
* @param filter the filter to execute for this request, if the filter returns null the request will not be proxied
*/
public ProxyServlet withFilter(HttpRequest httpRequest, RequestFilter filter) {
filters.withFilter(httpRequest, filter);
return this;
}
/**
* Add filter for HTTP response, each filter get called after each request has been proxied
*
* @param httpRequest the request to match against for this filter
* @param filter the filter that is executed after this request has been proxied
*/
public ProxyServlet withFilter(HttpRequest httpRequest, ResponseFilter filter) {
filters.withFilter(httpRequest, filter);
return this;
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
@Override
protected void doHead(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
@Override
protected void doPut(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
try {
String requestPath = httpServletRequest.getPathInfo() != null && httpServletRequest.getContextPath() != null ? httpServletRequest.getPathInfo() : httpServletRequest.getRequestURI();
if (requestPath.equals("/status")) {
httpServletResponse.setStatus(HttpStatusCode.OK_200.code());
} else if (requestPath.equals("/clear")) {
requestLogFilter.clear(httpRequestSerializer.deserialize(IOStreamUtils.readInputStreamToString(httpServletRequest)));
httpServletResponse.setStatus(HttpStatusCode.ACCEPTED_202.code());
} else if (requestPath.equals("/reset")) {
requestLogFilter.reset();
httpServletResponse.setStatus(HttpStatusCode.ACCEPTED_202.code());
} else if (requestPath.equals("/dumpToLog")) {
requestResponseLogFilter.dumpToLog(httpRequestSerializer.deserialize(IOStreamUtils.readInputStreamToString(httpServletRequest)), "java".equals(httpServletRequest.getParameter("type")));
httpServletResponse.setStatus(HttpStatusCode.ACCEPTED_202.code());
} else if (requestPath.equals("/retrieve")) {
HttpRequest[] requests = requestLogFilter.retrieve(httpRequestSerializer.deserialize(IOStreamUtils.readInputStreamToString(httpServletRequest)));
httpServletResponse.setStatus(HttpStatusCode.OK_200.code());
httpServletResponse.setHeader(CONTENT_TYPE.toString(), MediaType.JSON_UTF_8.toString());
IOStreamUtils.writeToOutputStream(httpRequestSerializer.serialize(requests).getBytes(), httpServletResponse);
} else if (requestPath.equals("/verify")) {
String result = requestLogFilter.verify(verificationSerializer.deserialize(IOStreamUtils.readInputStreamToString(httpServletRequest)));
if (result.isEmpty()) {
httpServletResponse.setStatus(HttpStatusCode.ACCEPTED_202.code());
} else {
httpServletResponse.setStatus(HttpStatusCode.NOT_ACCEPTABLE_406.code());
httpServletResponse.setHeader(CONTENT_TYPE.toString(), MediaType.JSON_UTF_8.toString());
IOStreamUtils.writeToOutputStream(result.getBytes(), httpServletResponse);
}
} else if (requestPath.equals("/verifySequence")) {
String result = requestLogFilter.verify(verificationSequenceSerializer.deserialize(IOStreamUtils.readInputStreamToString(httpServletRequest)));
if (result.isEmpty()) {
httpServletResponse.setStatus(HttpStatusCode.ACCEPTED_202.code());
} else {
httpServletResponse.setStatus(HttpStatusCode.NOT_ACCEPTABLE_406.code());
httpServletResponse.setHeader(CONTENT_TYPE.toString(), MediaType.JSON_UTF_8.toString());
IOStreamUtils.writeToOutputStream(result.getBytes(), httpServletResponse);
}
} else if (requestPath.equals("/stop")) {
httpServletResponse.setStatus(HttpStatusCode.NOT_IMPLEMENTED_501.code());
} else {
forwardRequest(httpServletRequest, httpServletResponse);
}
} catch (Exception e) {
logger.error("Exception processing " + httpServletRequest, e);
httpServletResponse.setStatus(HttpStatusCode.BAD_REQUEST_400.code());
}
}
@Override
protected void doDelete(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
@Override
protected void doOptions(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
@Override
protected void doTrace(HttpServletRequest request, HttpServletResponse response) {
forwardRequest(request, response);
}
private void forwardRequest(HttpServletRequest request, HttpServletResponse httpServletResponse) {
HttpResponse httpResponse = sendRequest(filters.applyOnRequestFilters(httpServletRequestToMockServerRequestDecoder.mapHttpServletRequestToMockServerRequest(request)));
mockServerResponseToHttpServletResponseEncoder.mapMockServerResponseToHttpServletResponse(httpResponse, httpServletResponse);
}
private HttpResponse sendRequest(HttpRequest httpRequest) {
// if HttpRequest was set to null by a filter don't send request
if (httpRequest != null) {
String hostHeader = httpRequest.getFirstHeader("Host");
if (!Strings.isNullOrEmpty(hostHeader)) {
String[] hostHeaderParts = hostHeader.split(":");
boolean isSsl = httpRequest.isSecure() != null && httpRequest.isSecure();
Integer port = (isSsl ? 443 : 80); // default
if (hostHeaderParts.length > 1) {
port = Integer.parseInt(hostHeaderParts[1]); // non-default
}
HttpResponse httpResponse = filters.applyOnResponseFilters(httpRequest, httpClient.sendRequest(outboundRequest(hostHeaderParts[0], port, "", httpRequest)));
if (httpResponse != null) {
return httpResponse;
}
} else {
logger.error("Host header must be provided for requests being forwarded, the following request does not include the \"Host\" header:" + System.getProperty("line.separator") + httpRequest);
throw new IllegalArgumentException("Host header must be provided for requests being forwarded");
}
}
return notFoundResponse();
}
}