/*
* COMSAT
* Copyright (c) 2013-2016, Parallel Universe Software Co. All rights reserved.
*
* This program and the accompanying materials are dual-licensed under
* either the terms of the Eclipse Public License v1.0 as published by
* the Eclipse Foundation
*
* or (per the licensee's choosing)
*
* under the terms of the GNU Lesser General Public License version 3.0
* as published by the Free Software Foundation.
*/
package co.paralleluniverse.fibers.servlet;
import co.paralleluniverse.common.util.SystemProperties;
import co.paralleluniverse.fibers.Fiber;
import co.paralleluniverse.fibers.SuspendExecution;
import co.paralleluniverse.fibers.Suspendable;
import co.paralleluniverse.strands.SuspendableRunnable;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.text.MessageFormat;
import java.util.Enumeration;
import java.util.ResourceBundle;
import java.util.concurrent.ForkJoinPool;
import javax.servlet.*;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
/**
* Fiber-blocking HttpServlet base class.
*
* @author eithan
* @author circlespainter
*/
public class FiberHttpServlet extends HttpServlet {
private static final String METHOD_DELETE = "DELETE";
private static final String METHOD_HEAD = "HEAD";
private static final String METHOD_GET = "GET";
private static final String METHOD_OPTIONS = "OPTIONS";
private static final String METHOD_POST = "POST";
private static final String METHOD_PUT = "PUT";
private static final String METHOD_TRACE = "TRACE";
private static final String HEADER_IFMODSINCE = "If-Modified-Since";
private static final String HEADER_LASTMOD = "Last-Modified";
private static final String LSTRING_FILE =
"javax.servlet.http.LocalStrings";
private static ResourceBundle lStrings =
ResourceBundle.getBundle(LSTRING_FILE);
private static final String PROP_ASYNC_TIMEOUT = FiberHttpServlet.class.getName() + ".asyncTimeout";
static final Long asyncTimeout;
public static final String PROP_DEBUG_BYPASS_TO_REGULAR_FJP = FiberHttpServlet.class.getName() + ".debug.bypassToRegularFJP";
static final boolean debugBypassToRegularJFPGlobal = SystemProperties.isEmptyOrTrue(PROP_DEBUG_BYPASS_TO_REGULAR_FJP);
public static final String PROP_DISABLE_SYNC_EXCEPTIONS = FiberHttpServlet.class.getName() + ".disableSyncExceptions";
static final boolean disableSyncExceptionsGlobal = SystemProperties.isEmptyOrTrue(PROP_DISABLE_SYNC_EXCEPTIONS);
public static final String PROP_DISABLE_SYNC_FORWARD = FiberHttpServlet.class.getName() + ".disableSyncForward";
static final boolean disableSyncForwardGlobal = SystemProperties.isEmptyOrTrue(PROP_DISABLE_SYNC_FORWARD);
public static final String PROP_DISABLE_JETTY_ASYNC_FIXES = FiberHttpServlet.class.getName() + ".disableJettyAsyncFixes";
static final Boolean disableJettyAsyncFixesGlobal;
public static final String PROP_DISABLE_TOMCAT_ASYNC_FIXES = FiberHttpServlet.class.getName() + ".disableTomcatAsyncFixes";
static final Boolean disableTomcatAsyncFixesGlobal;
static {
asyncTimeout = getLong(PROP_ASYNC_TIMEOUT);
disableJettyAsyncFixesGlobal = getBoolean(PROP_DISABLE_JETTY_ASYNC_FIXES);
disableTomcatAsyncFixesGlobal = getBoolean(PROP_DISABLE_TOMCAT_ASYNC_FIXES);
}
private static final long serialVersionUID = 1L;
private static final String FIBER_ASYNC_REQUEST_EXCEPTION = "co.paralleluniverse.fibers.servlet.exception";
private final ThreadLocal<AsyncContext> currentAsyncContext = new ThreadLocal<>();
private int stackSize = -1;
private transient FiberServletContext contextAD;
private ForkJoinPool fjp = new ForkJoinPool();
boolean debugBypassToRegularJFP, disableSyncExceptions, disableSyncForward, disableJettyAsyncFixes, disableTomcatAsyncFixes;
/**
* @return Wrapped version of the ServletContext initiated by {@link #init(javax.servlet.ServletConfig) }
* @inheritDoc
*/
@Override
public ServletContext getServletContext() {
return disableSyncForward ? super.getServletContext() : contextAD;
}
@Override
public void init(ServletConfig config) throws ServletException {
super.init(config);
this.contextAD = new FiberServletContext(config.getServletContext(), currentAsyncContext);
final String ss = config.getInitParameter("stack-size");
if (ss != null)
stackSize = Integer.parseInt(ss);
final String debugByPassToJFP = config.getInitParameter(PROP_DEBUG_BYPASS_TO_REGULAR_FJP);
if (debugByPassToJFP != null)
debugBypassToRegularJFP = debugBypassToRegularJFPGlobal;
final String disableSE = config.getInitParameter(PROP_DISABLE_SYNC_EXCEPTIONS);
if (disableSE != null)
disableSyncExceptions = disableSyncExceptionsGlobal;
final String disableSF = config.getInitParameter(PROP_DISABLE_SYNC_FORWARD);
if (disableSF != null)
disableSyncForward = disableSyncForwardGlobal;
final String disableJF = config.getInitParameter(PROP_DISABLE_JETTY_ASYNC_FIXES);
if (disableJF != null)
disableJettyAsyncFixes = disableJettyAsyncFixesGlobal != null ? disableJettyAsyncFixesGlobal : !isJetty(config);
final String disableTF = config.getInitParameter(PROP_DISABLE_TOMCAT_ASYNC_FIXES);
if (disableTF != null)
disableTomcatAsyncFixes = disableTomcatAsyncFixesGlobal != null ? disableTomcatAsyncFixesGlobal : !isTomcat(config);
}
protected final void setStackSize(int stackSize) {
this.stackSize = stackSize;
}
protected final int getStackSize() {
return stackSize;
}
@Override
@Suspendable
final public void service(final ServletRequest req, ServletResponse res) throws ServletException, IOException {
if (!disableSyncExceptions && DispatcherType.ASYNC.equals(req.getDispatcherType())) {
final Throwable ex = (Throwable) req.getAttribute(FIBER_ASYNC_REQUEST_EXCEPTION);
if (ex != null)
throw new ServletException(ex);
}
final HttpServletRequest request;
final HttpServletResponse response;
try {
request = (HttpServletRequest) req;
response = (HttpServletResponse) res;
} catch (final ClassCastException cce) {
throw new ServletException("Unsupported non-HTTP request or response detected");
}
if (!disableTomcatAsyncFixes)
req.setAttribute("org.apache.catalina.ASYNC_SUPPORTED", true);
final AsyncContext ac = req.startAsync();
if (asyncTimeout != null)
ac.setTimeout(asyncTimeout);
final HttpServletRequest r =
!disableJettyAsyncFixes ?
new FiberHttpServletRequest(this, request) :
request;
if (debugBypassToRegularJFP)
fjp.execute(new ServletRunnable(this, ac, r, response));
else
new Fiber(null, stackSize, new ServletSuspendableRunnable(this, ac, r, response)).start();
}
private final static class ServletSuspendableRunnable implements SuspendableRunnable {
private final AsyncContext ac;
private final HttpServletRequest request;
private final HttpServletResponse response;
private final FiberHttpServlet servlet;
ServletSuspendableRunnable(FiberHttpServlet servlet, AsyncContext ac, HttpServletRequest request, HttpServletResponse response) {
this.servlet = servlet;
this.ac = ac;
this.request = request;
this.response = response;
}
@Override
public final void run() throws SuspendExecution, InterruptedException {
servlet.exec(servlet, ac, request, response);
}
}
private static class ServletRunnable implements Runnable {
private final FiberHttpServlet servlet;
private final AsyncContext ac;
private final HttpServletRequest request;
private final HttpServletResponse response;
ServletRunnable(FiberHttpServlet servlet, AsyncContext ac, HttpServletRequest request, HttpServletResponse response) {
this.servlet = servlet;
this.ac = ac;
this.request = request;
this.response = response;
}
@Override
public final void run() {
servlet.exec(servlet, ac, request, response);
}
}
@Suspendable
final void exec(FiberHttpServlet servlet, AsyncContext ac, HttpServletRequest request, HttpServletResponse response) {
if (!disableSyncExceptions) {
try {
exec0(servlet, ac, request, response);
} catch (final ServletException | IOException ex) {
// Multi-catch above seems to break ASM during instrumentation in some circumstances
// seemingly tied to structured class-loading, as in standalone servlet containers
servlet.log("Exception in servlet's fiber, dispatching to container", ex);
request.setAttribute(FIBER_ASYNC_REQUEST_EXCEPTION, ex);
if (!disableSyncForward)
servlet.currentAsyncContext.set(null);
ac.dispatch();
}
} else {
try {
exec0(servlet, ac, request, response);
} catch (final Throwable t) {
servlet.log("Error during pool-based execution", t);
((HttpServletResponse) ac.getResponse()).setStatus(500);
try {
ac.complete();
} catch (final IllegalStateException ignored) {}
}
}
}
@Suspendable
private void exec0(FiberHttpServlet servlet, AsyncContext ac, HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
// TODO: check if ac has expired
if (!disableSyncForward)
servlet.currentAsyncContext.set(ac);
servlet.service(request, response);
try {
ac.complete();
} catch (final IllegalStateException ignored) {}
}
private static boolean isJetty(ServletConfig config) {
return config.getClass().getName().startsWith("org.eclipse.jetty.");
}
private static boolean isTomcat(ServletConfig config) {
return config.getClass().getName().startsWith("org.apache.tomcat.");
}
private static Long getLong(String propName) {
final String asyncTimeoutS = System.getProperty(propName);
if (asyncTimeoutS != null) {
Long l = null;
try {
l = Long.parseLong(asyncTimeoutS);
} catch (final NumberFormatException ignored) {
}
return l;
} else {
return null;
}
}
private static Boolean getBoolean(String propName) {
final String disableJettyAsyncFixesGlobalS = System.getProperty(propName);
if (disableJettyAsyncFixesGlobalS != null) {
Boolean b = null;
if (Boolean.TRUE.toString().equals(disableJettyAsyncFixesGlobalS))
b = true;
else if (Boolean.FALSE.toString().equals(disableJettyAsyncFixesGlobalS))
b = false;
return b;
} else {
return null;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// The rest is copied from HttpServlet as we don't instrument by default "java" and "javax" packages
////////////////////////////////////////////////////////////////////////////////////////////////////
@Override
@Suspendable
protected void service(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final String method = req.getMethod();
//noinspection IfCanBeSwitch
if (method.equals(METHOD_GET)) {
long lastModified = getLastModified(req);
if (lastModified == -1) {
// servlet doesn't support if-modified-since, no reason
// to go through further expensive logic
doGet(req, resp);
} else {
long ifModifiedSince = req.getDateHeader(HEADER_IFMODSINCE);
if (ifModifiedSince < lastModified) {
// If the servlet mod time is later, call doGet()
// Round down to the nearest second for a proper compare
// A ifModifiedSince of -1 will always be less
maybeSetLastModified(resp, lastModified);
doGet(req, resp);
} else {
resp.setStatus(HttpServletResponse.SC_NOT_MODIFIED);
}
}
} else if (method.equals(METHOD_HEAD)) {
final long lastModified = getLastModified(req);
maybeSetLastModified(resp, lastModified);
doHead(req, resp);
} else if (method.equals(METHOD_POST)) {
doPost(req, resp);
} else if (method.equals(METHOD_PUT)) {
doPut(req, resp);
} else if (method.equals(METHOD_DELETE)) {
doDelete(req, resp);
} else if (method.equals(METHOD_OPTIONS)) {
doOptions(req,resp);
} else if (method.equals(METHOD_TRACE)) {
doTrace(req,resp);
} else {
//
// Note that this means NO servlet supports whatever
// method was requested, anywhere on this server.
//
String errMsg = lStrings.getString("http.method_not_implemented");
final Object[] errArgs = new Object[1];
errArgs[0] = method;
errMsg = MessageFormat.format(errMsg, errArgs);
resp.sendError(HttpServletResponse.SC_NOT_IMPLEMENTED, errMsg);
}
}
@Override
@Suspendable
protected void doGet(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final String protocol = req.getProtocol();
final String msg = lStrings.getString("http.method_get_not_supported");
if (protocol.endsWith("1.1")) {
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, msg);
} else {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, msg);
}
}
@Override
@Suspendable
protected void doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final String protocol = req.getProtocol();
final String msg = lStrings.getString("http.method_post_not_supported");
if (protocol.endsWith("1.1")) {
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, msg);
} else {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, msg);
}
}
@Override
@Suspendable
protected void doPut(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final String protocol = req.getProtocol();
final String msg = lStrings.getString("http.method_put_not_supported");
if (protocol.endsWith("1.1")) {
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, msg);
} else {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, msg);
}
}
@Override
@Suspendable
protected void doDelete(HttpServletRequest req,
HttpServletResponse resp)
throws ServletException, IOException
{
final String protocol = req.getProtocol();
final String msg = lStrings.getString("http.method_delete_not_supported");
if (protocol.endsWith("1.1")) {
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, msg);
} else {
resp.sendError(HttpServletResponse.SC_BAD_REQUEST, msg);
}
}
@Override
@Suspendable
protected void doOptions(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final Method[] methods = getAllDeclaredMethods(this.getClass());
boolean ALLOW_GET = false;
boolean ALLOW_HEAD = false;
boolean ALLOW_POST = false;
boolean ALLOW_PUT = false;
boolean ALLOW_DELETE = false;
final boolean ALLOW_TRACE = true;
final boolean ALLOW_OPTIONS = true;
for (final Method m : methods) {
if (m.getName().equals("doGet")) {
ALLOW_GET = true;
ALLOW_HEAD = true;
}
if (m.getName().equals("doPost"))
ALLOW_POST = true;
if (m.getName().equals("doPut"))
ALLOW_PUT = true;
if (m.getName().equals("doDelete"))
ALLOW_DELETE = true;
}
String allow = null;
if (ALLOW_GET)
allow=METHOD_GET;
if (ALLOW_HEAD)
if (allow==null) allow=METHOD_HEAD;
else allow += ", " + METHOD_HEAD;
if (ALLOW_POST)
if (allow==null) allow=METHOD_POST;
else allow += ", " + METHOD_POST;
if (ALLOW_PUT)
if (allow==null) allow=METHOD_PUT;
else allow += ", " + METHOD_PUT;
if (ALLOW_DELETE)
if (allow==null) allow=METHOD_DELETE;
else allow += ", " + METHOD_DELETE;
if (ALLOW_TRACE)
if (allow==null) allow=METHOD_TRACE;
else allow += ", " + METHOD_TRACE;
if (ALLOW_OPTIONS)
if (allow==null) allow=METHOD_OPTIONS;
else allow += ", " + METHOD_OPTIONS;
resp.setHeader("Allow", allow);
}
private Method[] getAllDeclaredMethods(Class<?> c) {
if (c.equals(javax.servlet.http.HttpServlet.class)) {
return null;
}
final Method[] parentMethods = getAllDeclaredMethods(c.getSuperclass());
Method[] thisMethods = c.getDeclaredMethods();
if ((parentMethods != null) && (parentMethods.length > 0)) {
final Method[] allMethods =
new Method[parentMethods.length + thisMethods.length];
System.arraycopy(parentMethods, 0, allMethods, 0,
parentMethods.length);
System.arraycopy(thisMethods, 0, allMethods, parentMethods.length,
thisMethods.length);
thisMethods = allMethods;
}
return thisMethods;
}
@Override
@Suspendable
protected void doTrace(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
int responseLength;
final String CRLF = "\r\n";
final StringBuilder buffer = new StringBuilder("TRACE ").append(req.getRequestURI())
.append(" ").append(req.getProtocol());
final Enumeration<String> reqHeaderEnum = req.getHeaderNames();
while (reqHeaderEnum.hasMoreElements()) {
final String headerName = reqHeaderEnum.nextElement();
buffer.append(CRLF).append(headerName).append(": ")
.append(req.getHeader(headerName));
}
buffer.append(CRLF);
responseLength = buffer.length();
resp.setContentType("message/http");
resp.setContentLength(responseLength);
final ServletOutputStream out = resp.getOutputStream();
out.print(buffer.toString());
}
@Override
@Suspendable
protected void doHead(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException
{
final NoBodyResponse response = new NoBodyResponse(resp);
doGet(req, response);
response.setContentLength();
}
private void maybeSetLastModified(HttpServletResponse resp,
long lastModified) {
if (resp.containsHeader(HEADER_LASTMOD))
return;
if (lastModified >= 0)
resp.setDateHeader(HEADER_LASTMOD, lastModified);
}
}
final class NoBodyResponse extends HttpServletResponseWrapper {
private static final ResourceBundle lStrings
= ResourceBundle.getBundle("javax.servlet.http.LocalStrings");
private NoBodyOutputStream noBody;
private PrintWriter writer;
private boolean didSetContentLength;
private boolean usingOutputStream;
NoBodyResponse(HttpServletResponse r) {
super(r);
noBody = new NoBodyOutputStream();
}
final void setContentLength() {
if (!didSetContentLength) {
if (writer != null) {
writer.flush();
}
setContentLength(noBody.getContentLength());
}
}
@Override
public final void setContentLength(int len) {
super.setContentLength(len);
didSetContentLength = true;
}
@Override
public final ServletOutputStream getOutputStream() throws IOException {
if (writer != null) {
throw new IllegalStateException(
lStrings.getString("err.ise.getOutputStream"));
}
usingOutputStream = true;
return noBody;
}
@Override
public final PrintWriter getWriter() throws UnsupportedEncodingException {
if (usingOutputStream) {
throw new IllegalStateException(
lStrings.getString("err.ise.getWriter"));
}
if (writer == null) {
final OutputStreamWriter w = new OutputStreamWriter(
noBody, getCharacterEncoding());
writer = new PrintWriter(w);
}
return writer;
}
}
final class NoBodyOutputStream extends ServletOutputStream {
private static final String LSTRING_FILE =
"javax.servlet.http.LocalStrings";
private static ResourceBundle lStrings =
ResourceBundle.getBundle(LSTRING_FILE);
private int contentLength = 0;
NoBodyOutputStream() {}
final int getContentLength() {
return contentLength;
}
@Override
public final void write(int b) {
contentLength++;
}
@Override
public final void write(byte buf[], int offset, int len)
throws IOException
{
if (len >= 0) {
contentLength += len;
} else {
// This should have thrown an IllegalArgumentException, but
// changing this would break backwards compatibility
throw new IOException(lStrings.getString("err.io.negativelength"));
}
}
}