/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.catalina.nonblocking;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.Socket;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import javax.net.SocketFactory;
import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.apache.catalina.Context;
import org.apache.catalina.core.StandardContext;
import org.apache.catalina.startup.BytesStreamer;
import org.apache.catalina.startup.TesterServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.catalina.valves.TesterAccessLogValve;
import org.apache.tomcat.util.buf.ByteChunk;
public class TestNonBlockingAPI extends TomcatBaseTest {
private static final int CHUNK_SIZE = 1024 * 1024;
private static final int WRITE_SIZE = CHUNK_SIZE * 10;
private static final byte[] DATA = new byte[WRITE_SIZE];
private static final int WRITE_PAUSE_MS = 500;
static {
// Use this sequence for padding to make it easier to spot errors
byte[] padding = new byte[] {'z', 'y', 'x', 'w', 'v', 'u', 't', 's',
'r', 'q', 'p', 'o', 'n', 'm', 'l', 'k'};
int blockSize = padding.length;
for (int i = 0; i < WRITE_SIZE / blockSize; i++) {
String hex = String.format("%01X", Integer.valueOf(i));
int hexSize = hex.length();
int padSize = blockSize - hexSize;
System.arraycopy(padding, 0, DATA, i * blockSize, padSize);
System.arraycopy(
hex.getBytes(), 0, DATA, i * blockSize + padSize, hexSize);
}
}
@Test
public void testNonBlockingRead() throws Exception {
doTestNonBlockingRead(false);
}
@Test(expected=IOException.class)
public void testNonBlockingReadIgnoreIsReady() throws Exception {
doTestNonBlockingRead(true);
}
private void doTestNonBlockingRead(boolean ignoreIsReady) throws Exception {
Tomcat tomcat = getTomcatInstance();
// Must have a real docBase - just use temp
StandardContext ctx = (StandardContext) tomcat.addContext("",
System.getProperty("java.io.tmpdir"));
NBReadServlet servlet = new NBReadServlet(ignoreIsReady);
String servletName = NBReadServlet.class.getName();
Tomcat.addServlet(ctx, servletName, servlet);
ctx.addServletMappingDecoded("/", servletName);
tomcat.start();
Map<String, List<String>> resHeaders = new HashMap<>();
int rc = postUrl(true, new DataWriter(500), "http://localhost:" +
getPort() + "/", new ByteChunk(), resHeaders, null);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
}
@Test
public void testNonBlockingWrite() throws Exception {
testNonBlockingWriteInternal(false);
}
@Test
public void testNonBlockingWriteWithKeepAlive() throws Exception {
testNonBlockingWriteInternal(true);
}
private void testNonBlockingWriteInternal(boolean keepAlive) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
NBWriteServlet servlet = new NBWriteServlet();
String servletName = NBWriteServlet.class.getName();
Tomcat.addServlet(ctx, servletName, servlet);
ctx.addServletMappingDecoded("/", servletName);
tomcat.getConnector().setProperty("socket.txBufSize", "1024");
tomcat.start();
SocketFactory factory = SocketFactory.getDefault();
Socket s = factory.createSocket("localhost", getPort());
InputStream is = s.getInputStream();
byte[] buffer = new byte[8192];
ByteChunk result = new ByteChunk();
OutputStream os = s.getOutputStream();
if (keepAlive) {
os.write(("OPTIONS * HTTP/1.1\r\n" +
"Host: localhost:" + getPort() + "\r\n" +
"\r\n").getBytes(StandardCharsets.ISO_8859_1));
os.flush();
is.read(buffer);
}
os.write(("GET / HTTP/1.1\r\n" +
"Host: localhost:" + getPort() + "\r\n" +
"Connection: close\r\n" +
"\r\n").getBytes(StandardCharsets.ISO_8859_1));
os.flush();
int read = 0;
int readSinceLastPause = 0;
while (read != -1) {
read = is.read(buffer);
if (read > 0) {
result.append(buffer, 0, read);
}
readSinceLastPause += read;
if (readSinceLastPause > WRITE_SIZE / 16) {
readSinceLastPause = 0;
Thread.sleep(500);
}
}
os.close();
is.close();
s.close();
// Validate the result.
// Response line
String resultString = result.toString();
log.info("Client read " + resultString.length() + " bytes");
int lineStart = 0;
int lineEnd = resultString.indexOf('\n', 0);
String line = resultString.substring(lineStart, lineEnd + 1);
Assert.assertEquals("HTTP/1.1 200 \r\n", line);
// Check headers - looking to see if response is chunked (it should be)
boolean chunked = false;
while (line.length() > 2) {
lineStart = lineEnd + 1;
lineEnd = resultString.indexOf('\n', lineStart);
line = resultString.substring(lineStart, lineEnd + 1);
if (line.startsWith("Transfer-Encoding:")) {
Assert.assertEquals("Transfer-Encoding: chunked\r\n", line);
chunked = true;
}
}
Assert.assertTrue(chunked);
// Now check body size
int totalBodyRead = 0;
int chunkSize = -1;
while (chunkSize != 0) {
// Chunk size in hex
lineStart = lineEnd + 1;
lineEnd = resultString.indexOf('\n', lineStart);
line = resultString.substring(lineStart, lineEnd + 1);
Assert.assertTrue(line.endsWith("\r\n"));
line = line.substring(0, line.length() - 2);
log.info("[" + line + "]");
chunkSize = Integer.parseInt(line, 16);
// Read the chunk
lineStart = lineEnd + 1;
lineEnd = resultString.indexOf('\n', lineStart);
log.info("Start : " + lineStart + ", End: " + lineEnd);
if (lineEnd > lineStart) {
line = resultString.substring(lineStart, lineEnd + 1);
} else {
line = resultString.substring(lineStart);
}
if (line.length() > 40) {
log.info(line.substring(0, 32));
} else {
log.info(line);
}
if (chunkSize + 2 != line.length()) {
log.error("Chunk wrong length. Was " + line.length() +
" Expected " + (chunkSize + 2));
byte[] resultBytes = resultString.getBytes();
// Find error
boolean found = false;
for (int i = totalBodyRead; i < (totalBodyRead + line.length()); i++) {
if (DATA[i] != resultBytes[lineStart + i - totalBodyRead]) {
int dataStart = i - 64;
if (dataStart < 0) {
dataStart = 0;
}
int dataEnd = i + 64;
if (dataEnd > DATA.length) {
dataEnd = DATA.length;
}
int resultStart = lineStart + i - totalBodyRead - 64;
if (resultStart < 0) {
resultStart = 0;
}
int resultEnd = lineStart + i - totalBodyRead + 64;
if (resultEnd > resultString.length()) {
resultEnd = resultString.length();
}
log.error("Mis-match tx: " + new String(
DATA, dataStart, dataEnd - dataStart));
log.error("Mis-match rx: " +
resultString.substring(resultStart, resultEnd));
found = true;
break;
}
}
if (!found) {
log.error("No mismatch. Data truncated");
}
}
Assert.assertTrue(line.endsWith("\r\n"));
Assert.assertEquals(chunkSize + 2, line.length());
totalBodyRead += chunkSize;
}
Assert.assertEquals(WRITE_SIZE, totalBodyRead);
}
@Test
public void testNonBlockingWriteError() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
TesterAccessLogValve alv = new TesterAccessLogValve();
ctx.getPipeline().addValve(alv);
NBWriteServlet servlet = new NBWriteServlet();
String servletName = NBWriteServlet.class.getName();
Tomcat.addServlet(ctx, servletName, servlet);
ctx.addServletMappingDecoded("/", servletName);
tomcat.getConnector().setProperty("socket.txBufSize", "1024");
tomcat.start();
SocketFactory factory = SocketFactory.getDefault();
Socket s = factory.createSocket("localhost", getPort());
ByteChunk result = new ByteChunk();
OutputStream os = s.getOutputStream();
os.write(("GET / HTTP/1.1\r\n" +
"Host: localhost:" + getPort() + "\r\n" +
"Connection: close\r\n" +
"\r\n").getBytes(StandardCharsets.ISO_8859_1));
os.flush();
InputStream is = s.getInputStream();
byte[] buffer = new byte[8192];
int read = 0;
int readSinceLastPause = 0;
int readTotal = 0;
while (read != -1 && readTotal < WRITE_SIZE / 32) {
long start = System.currentTimeMillis();
read = is.read(buffer);
long end = System.currentTimeMillis();
log.info("Client read [" + read + "] bytes in [" + (end - start) +
"] ms");
if (read > 0) {
result.append(buffer, 0, read);
}
readSinceLastPause += read;
readTotal += read;
if (readSinceLastPause > WRITE_SIZE / 64) {
readSinceLastPause = 0;
Thread.sleep(WRITE_PAUSE_MS);
}
}
os.close();
is.close();
s.close();
String resultString = result.toString();
log.info("Client read " + resultString.length() + " bytes");
int lineStart = 0;
int lineEnd = resultString.indexOf('\n', 0);
String line = resultString.substring(lineStart, lineEnd + 1);
Assert.assertEquals("HTTP/1.1 200 \r\n", line);
// Listeners are invoked and access valve entries created on a different
// thread so give that thread a chance to complete its work.
int count = 0;
while (count < 100 &&
!(servlet.wlistener.onErrorInvoked || servlet.rlistener.onErrorInvoked)) {
Thread.sleep(100);
count ++;
}
while (count < 100 && alv.getEntryCount() < 1) {
Thread.sleep(100);
count ++;
}
Assert.assertTrue("Error listener should have been invoked.",
servlet.wlistener.onErrorInvoked || servlet.rlistener.onErrorInvoked);
// TODO Figure out why non-blocking writes with the NIO connector appear
// to be slower on Linux
alv.validateAccessLog(1, 500, WRITE_PAUSE_MS,
WRITE_PAUSE_MS + 30 * 1000);
}
@Test
public void testBug55438NonBlockingReadWriteEmptyRead() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
Context ctx = tomcat.addContext("", null);
NBReadWriteServlet servlet = new NBReadWriteServlet();
String servletName = NBReadWriteServlet.class.getName();
Tomcat.addServlet(ctx, servletName, servlet);
ctx.addServletMappingDecoded("/", servletName);
tomcat.start();
Map<String, List<String>> resHeaders = new HashMap<>();
int rc = postUrl(false, new BytesStreamer() {
@Override
public byte[] next() {
return new byte[] {};
}
@Override
public int getLength() {
return 0;
}
@Override
public int available() {
return 0;
}
}, "http://localhost:" +
getPort() + "/", new ByteChunk(), resHeaders, null);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
}
public static class DataWriter implements BytesStreamer {
final int max = 5;
int count = 0;
long delay = 0;
byte[] b = "WANTMORE".getBytes();
byte[] f = "FINISHED".getBytes();
public DataWriter(long delay) {
this.delay = delay;
}
@Override
public int getLength() {
return b.length * max;
}
@Override
public int available() {
if (count < max) {
return b.length;
} else {
return 0;
}
}
@Override
public byte[] next() {
if (count < max) {
if (count > 0)
try {
if (delay > 0)
Thread.sleep(delay);
} catch (Exception x) {
}
count++;
if (count < max)
return b;
else
return f;
} else {
return null;
}
}
}
@WebServlet(asyncSupported = true)
public class NBReadServlet extends TesterServlet {
private static final long serialVersionUID = 1L;
private final boolean ignoreIsReady;
public volatile TestReadListener listener;
public NBReadServlet(boolean ignoreIsReady) {
this.ignoreIsReady = ignoreIsReady;
}
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
// step 1 - start async
AsyncContext actx = req.startAsync();
actx.setTimeout(Long.MAX_VALUE);
actx.addListener(new AsyncListener() {
@Override
public void onTimeout(AsyncEvent event) throws IOException {
log.info("onTimeout");
}
@Override
public void onStartAsync(AsyncEvent event) throws IOException {
log.info("onStartAsync");
}
@Override
public void onError(AsyncEvent event) throws IOException {
log.info("AsyncListener.onError");
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
log.info("onComplete");
}
});
// step 2 - notify on read
ServletInputStream in = req.getInputStream();
listener = new TestReadListener(actx, false, ignoreIsReady);
in.setReadListener(listener);
}
}
@WebServlet(asyncSupported = true)
public class NBWriteServlet extends TesterServlet {
private static final long serialVersionUID = 1L;
public volatile TestWriteListener wlistener;
public volatile TestReadListener rlistener;
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
// step 1 - start async
AsyncContext actx = req.startAsync();
actx.setTimeout(Long.MAX_VALUE);
actx.addListener(new AsyncListener() {
@Override
public void onTimeout(AsyncEvent event) throws IOException {
log.info("onTimeout");
}
@Override
public void onStartAsync(AsyncEvent event) throws IOException {
log.info("onStartAsync");
}
@Override
public void onError(AsyncEvent event) throws IOException {
log.info("AsyncListener.onError");
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
log.info("onComplete");
}
});
// step 2 - notify on read
ServletInputStream in = req.getInputStream();
rlistener = new TestReadListener(actx, true, false);
in.setReadListener(rlistener);
ServletOutputStream out = resp.getOutputStream();
resp.setBufferSize(200 * 1024);
wlistener = new TestWriteListener(actx);
out.setWriteListener(wlistener);
}
}
@WebServlet(asyncSupported = true)
public class NBReadWriteServlet extends TesterServlet {
private static final long serialVersionUID = 1L;
public volatile TestReadWriteListener rwlistener;
@Override
protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
// step 1 - start async
AsyncContext actx = req.startAsync();
actx.setTimeout(Long.MAX_VALUE);
// step 2 - notify on read
ServletInputStream in = req.getInputStream();
rwlistener = new TestReadWriteListener(actx);
in.setReadListener(rwlistener);
}
}
private class TestReadListener implements ReadListener {
private final AsyncContext ctx;
private final boolean usingNonBlockingWrite;
private final boolean ignoreIsReady;
private final StringBuilder body = new StringBuilder();
public volatile boolean onErrorInvoked = false;
public TestReadListener(AsyncContext ctx,
boolean usingNonBlockingWrite,
boolean ignoreIsReady) {
this.ctx = ctx;
this.usingNonBlockingWrite = usingNonBlockingWrite;
this.ignoreIsReady = ignoreIsReady;
}
@Override
public void onDataAvailable() throws IOException {
ServletInputStream in = ctx.getRequest().getInputStream();
String s = "";
byte[] b = new byte[8192];
int read = 0;
do {
read = in.read(b);
if (read == -1) {
break;
}
s += new String(b, 0, read);
} while (ignoreIsReady || in.isReady());
log.info(s);
body.append(s);
}
@Override
public void onAllDataRead() {
log.info("onAllDataRead");
// If non-blocking writes are being used, don't write here as it
// will inject unexpected data into the write output.
if (!usingNonBlockingWrite) {
String msg;
if (body.toString().endsWith("FINISHED")) {
msg = "OK";
} else {
msg = "FAILED";
}
try {
ctx.getResponse().getOutputStream().print(msg);
} catch (IOException ioe) {
// Ignore
}
ctx.complete();
}
}
@Override
public void onError(Throwable throwable) {
log.info("ReadListener.onError");
throwable.printStackTrace();
onErrorInvoked = true;
}
}
private class TestWriteListener implements WriteListener {
AsyncContext ctx;
int written = 0;
public volatile boolean onErrorInvoked = false;
public TestWriteListener(AsyncContext ctx) {
this.ctx = ctx;
}
@Override
public void onWritePossible() throws IOException {
long start = System.currentTimeMillis();
int before = written;
while (written < WRITE_SIZE &&
ctx.getResponse().getOutputStream().isReady()) {
ctx.getResponse().getOutputStream().write(
DATA, written, CHUNK_SIZE);
written += CHUNK_SIZE;
}
if (written == WRITE_SIZE) {
// Clear the output buffer else data may be lost when
// calling complete
ctx.getResponse().flushBuffer();
}
log.info("Write took: " + (System.currentTimeMillis() - start) +
" ms. Bytes before=" + before + " after=" + written);
// only call complete if we have emptied the buffer
if (ctx.getResponse().getOutputStream().isReady() &&
written == WRITE_SIZE) {
// it is illegal to call complete
// if there is a write in progress
ctx.complete();
}
}
@Override
public void onError(Throwable throwable) {
log.info("WriteListener.onError");
throwable.printStackTrace();
onErrorInvoked = true;
}
}
private class TestReadWriteListener implements ReadListener {
AsyncContext ctx;
private final StringBuilder body = new StringBuilder();
public TestReadWriteListener(AsyncContext ctx) {
this.ctx = ctx;
}
@Override
public void onDataAvailable() throws IOException {
ServletInputStream in = ctx.getRequest().getInputStream();
String s = "";
byte[] b = new byte[8192];
int read = 0;
do {
read = in.read(b);
if (read == -1) {
break;
}
s += new String(b, 0, read);
} while (in.isReady());
log.info("Read [" + s + "]");
body.append(s);
}
@Override
public void onAllDataRead() throws IOException {
log.info("onAllDataRead");
ServletOutputStream output = ctx.getResponse().getOutputStream();
output.setWriteListener(new WriteListener() {
@Override
public void onWritePossible() throws IOException {
ServletOutputStream output = ctx.getResponse().getOutputStream();
if (output.isReady()) {
log.info("Writing [" + body.toString() + "]");
output.write(body.toString().getBytes("utf-8"));
}
ctx.complete();
}
@Override
public void onError(Throwable throwable) {
log.info("ReadWriteListener.onError");
throwable.printStackTrace();
}
});
}
@Override
public void onError(Throwable throwable) {
log.info("ReadListener.onError");
throwable.printStackTrace();
}
}
public static int postUrlWithDisconnect(boolean stream, BytesStreamer streamer, String path,
Map<String, List<String>> reqHead, Map<String, List<String>> resHead) throws IOException {
URL url = new URL(path);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setDoOutput(true);
connection.setReadTimeout(1000000);
if (reqHead != null) {
for (Map.Entry<String, List<String>> entry : reqHead.entrySet()) {
StringBuilder valueList = new StringBuilder();
for (String value : entry.getValue()) {
if (valueList.length() > 0) {
valueList.append(',');
}
valueList.append(value);
}
connection.setRequestProperty(entry.getKey(), valueList.toString());
}
}
if (streamer != null && stream) {
if (streamer.getLength() > 0) {
connection.setFixedLengthStreamingMode(streamer.getLength());
} else {
connection.setChunkedStreamingMode(1024);
}
}
connection.connect();
// Write the request body
try (OutputStream os = connection.getOutputStream()) {
while (streamer != null && streamer.available() > 0) {
byte[] next = streamer.next();
os.write(next);
os.flush();
}
}
int rc = connection.getResponseCode();
if (resHead != null) {
Map<String, List<String>> head = connection.getHeaderFields();
resHead.putAll(head);
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
if (rc == HttpServletResponse.SC_OK) {
connection.getInputStream().close();
connection.disconnect();
}
return rc;
}
@Ignore
@Test
public void testDelayedNBWrite() throws Exception {
Tomcat tomcat = getTomcatInstance();
Context ctx = tomcat.addContext("", null);
CountDownLatch latch1 = new CountDownLatch(1);
DelayedNBWriteServlet servlet = new DelayedNBWriteServlet(latch1);
String servletName = DelayedNBWriteServlet.class.getName();
Tomcat.addServlet(ctx, servletName, servlet);
ctx.addServletMappingDecoded("/", servletName);
tomcat.start();
CountDownLatch latch2 = new CountDownLatch(2);
List<Throwable> exceptions = new ArrayList<>();
Thread t = new Thread(
new RequestExecutor("http://localhost:" + getPort() + "/", latch2, exceptions));
t.start();
latch1.await(3000, TimeUnit.MILLISECONDS);
Thread t1 = new Thread(new RequestExecutor(
"http://localhost:" + getPort() + "/?notify=true", latch2, exceptions));
t1.start();
latch2.await(3000, TimeUnit.MILLISECONDS);
if (exceptions.size() > 0) {
Assert.fail();
}
}
private static final class RequestExecutor implements Runnable {
private final String url;
private final CountDownLatch latch;
private final List<Throwable> exceptions;
public RequestExecutor(String url, CountDownLatch latch, List<Throwable> exceptions) {
this.url = url;
this.latch = latch;
this.exceptions = exceptions;
}
@Override
public void run() {
try {
ByteChunk result = new ByteChunk();
int rc = getUrl(url, result, null);
Assert.assertTrue(rc == HttpServletResponse.SC_OK);
Assert.assertTrue(result.toString().contains("OK"));
} catch (Throwable e) {
e.printStackTrace();
exceptions.add(e);
} finally {
latch.countDown();
}
}
}
@WebServlet(asyncSupported = true)
private static final class DelayedNBWriteServlet extends TesterServlet {
private static final long serialVersionUID = 1L;
private final Set<Emitter> emitters = new HashSet<>();
private final CountDownLatch latch;
public DelayedNBWriteServlet(CountDownLatch latch) {
this.latch = latch;
}
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
boolean notify = Boolean.parseBoolean(request.getParameter("notify"));
AsyncContext ctx = request.startAsync();
ctx.setTimeout(1000);
if (!notify) {
emitters.add(new Emitter(ctx));
latch.countDown();
} else {
for (Emitter e : emitters) {
e.emit();
}
response.getOutputStream().println("OK");
response.getOutputStream().flush();
ctx.complete();
}
}
}
private static final class Emitter {
private final AsyncContext ctx;
Emitter(AsyncContext ctx) {
this.ctx = ctx;
}
void emit() throws IOException {
ctx.getResponse().getOutputStream().setWriteListener(new WriteListener() {
private boolean written = false;
@Override
public void onWritePossible() throws IOException {
ServletOutputStream out = ctx.getResponse().getOutputStream();
if (out.isReady() && !written) {
out.println("OK");
written = true;
}
if (out.isReady() && written) {
out.flush();
if (out.isReady()) {
ctx.complete();
}
}
}
@Override
public void onError(Throwable t) {
t.printStackTrace();
}
});
}
}
}