/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.coyote.http11.upgrade; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.OutputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.Reader; import java.io.Writer; import java.net.Socket; import javax.net.SocketFactory; import javax.servlet.ReadListener; import javax.servlet.ServletException; import javax.servlet.ServletInputStream; import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpUpgradeHandler; import javax.servlet.http.WebConnection; import org.junit.Assert; import org.junit.Test; import static org.apache.catalina.startup.SimpleHttpClient.CRLF; import org.apache.catalina.Context; import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.TomcatBaseTest; public class TestUpgrade extends TomcatBaseTest { private static final String MESSAGE = "This is a test."; @Test public void testSimpleUpgradeBlocking() throws Exception { UpgradeConnection uc = doUpgrade(EchoBlocking.class); uc.shutdownInput(); uc.shutdownOutput(); } @Test public void testSimpleUpgradeNonBlocking() throws Exception { UpgradeConnection uc = doUpgrade(EchoNonBlocking.class); uc.shutdownInput(); uc.shutdownOutput(); } @Test public void testMessagesBlocking() throws Exception { doTestMessages(EchoBlocking.class); } @Test public void testMessagesNonBlocking() throws Exception { doTestMessages(EchoNonBlocking.class); } @Test public void testSetNullReadListener() throws Exception { doTestCheckClosed(SetNullReadListener.class); } @Test public void testSetNullWriteListener() throws Exception { doTestCheckClosed(SetNullWriteListener.class); } @Test public void testSetReadListenerTwice() throws Exception { doTestCheckClosed(SetReadListenerTwice.class); } @Test public void testSetWriteListenerTwice() throws Exception { doTestCheckClosed(SetWriteListenerTwice.class); } @Test public void testFirstCallToOnWritePossible() throws Exception { doTestFixedResponse(FixedResponseNonBlocking.class); } private void doTestCheckClosed( Class<? extends HttpUpgradeHandler> upgradeHandlerClass) throws Exception { UpgradeConnection conn = doUpgrade(upgradeHandlerClass); Reader r = conn.getReader(); int c = r.read(); Assert.assertEquals(-1, c); } private void doTestFixedResponse( Class<? extends HttpUpgradeHandler> upgradeHandlerClass) throws Exception { UpgradeConnection conn = doUpgrade(upgradeHandlerClass); Reader r = conn.getReader(); int c = r.read(); Assert.assertEquals(FixedResponseNonBlocking.FIXED_RESPONSE, c); } private void doTestMessages ( Class<? extends HttpUpgradeHandler> upgradeHandlerClass) throws Exception { UpgradeConnection uc = doUpgrade(upgradeHandlerClass); PrintWriter pw = new PrintWriter(uc.getWriter()); BufferedReader reader = uc.getReader(); pw.println(MESSAGE); pw.flush(); Thread.sleep(500); pw.println(MESSAGE); pw.flush(); uc.shutdownOutput(); // Note: BufferedReader.readLine() strips new lines // ServletInputStream.readLine() does not strip new lines String response = reader.readLine(); Assert.assertEquals(MESSAGE, response); response = reader.readLine(); Assert.assertEquals(MESSAGE, response); uc.shutdownInput(); } private UpgradeConnection doUpgrade( Class<? extends HttpUpgradeHandler> upgradeHandlerClass) throws Exception { // Setup Tomcat instance Tomcat tomcat = getTomcatInstance(); // No file system docBase required Context ctx = tomcat.addContext("", null); UpgradeServlet servlet = new UpgradeServlet(upgradeHandlerClass); Tomcat.addServlet(ctx, "servlet", servlet); ctx.addServletMappingDecoded("/", "servlet"); tomcat.start(); // Use raw socket so the necessary control is available after the HTTP // upgrade Socket socket = SocketFactory.getDefault().createSocket("localhost", getPort()); socket.setSoTimeout(5000); UpgradeConnection uc = new UpgradeConnection(socket); uc.getWriter().write("GET / HTTP/1.1" + CRLF); uc.getWriter().write("Host: whatever" + CRLF); uc.getWriter().write(CRLF); uc.getWriter().flush(); String status = uc.getReader().readLine(); Assert.assertNotNull(status); Assert.assertEquals("101", getStatusCode(status)); // Skip the remaining response headers String line = uc.getReader().readLine(); while (line != null && line.length() > 0) { // Skip line = uc.getReader().readLine(); } return uc; } private static class UpgradeServlet extends HttpServlet { private static final long serialVersionUID = 1L; private final Class<? extends HttpUpgradeHandler> upgradeHandlerClass; public UpgradeServlet(Class<? extends HttpUpgradeHandler> upgradeHandlerClass) { this.upgradeHandlerClass = upgradeHandlerClass; } @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { req.upgrade(upgradeHandlerClass); } } private static class UpgradeConnection { private final Socket socket; private final Writer writer; private final BufferedReader reader; public UpgradeConnection(Socket socket) { this.socket = socket; InputStream is; OutputStream os; try { is = socket.getInputStream(); os = socket.getOutputStream(); } catch (IOException ioe) { throw new IllegalArgumentException(ioe); } BufferedReader reader = new BufferedReader(new InputStreamReader(is)); Writer writer = new OutputStreamWriter(os); this.writer = writer; this.reader = reader; } public Writer getWriter() { return writer; } public BufferedReader getReader() { return reader; } public void shutdownOutput() throws IOException { writer.flush(); socket.shutdownOutput(); } public void shutdownInput() throws IOException { socket.shutdownInput(); } } public static class EchoBlocking implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { try (ServletInputStream sis = connection.getInputStream(); ServletOutputStream sos = connection.getOutputStream()){ byte[] buffer = new byte[8192]; int read; while ((read = sis.read(buffer)) >= 0) { sos.write(buffer, 0, read); sos.flush(); } } catch (IOException ioe) { throw new IllegalStateException(ioe); } } @Override public void destroy() { // NO-OP } } public static class EchoNonBlocking implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { ServletInputStream sis; ServletOutputStream sos; try { sis = connection.getInputStream(); sos = connection.getOutputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } EchoListener echoListener = new EchoListener(sis, sos); sis.setReadListener(echoListener); sos.setWriteListener(echoListener); } @Override public void destroy() { // NO-OP } private class EchoListener implements ReadListener, WriteListener { private final ServletInputStream sis; private final ServletOutputStream sos; private final byte[] buffer = new byte[8192]; public EchoListener(ServletInputStream sis, ServletOutputStream sos) { this.sis = sis; this.sos = sos; } @Override public void onWritePossible() throws IOException { if (sis.isFinished()) { sis.close(); sos.close(); } while (sis.isReady()) { int read = sis.read(buffer); if (read > 0) { sos.write(buffer, 0, read); if (!sos.isReady()) { break; } } } } @Override public void onDataAvailable() throws IOException { if (sos.isReady()) { onWritePossible(); } } @Override public void onAllDataRead() throws IOException { if (sos.isReady()) { onWritePossible(); } } @Override public void onError(Throwable throwable) { throwable.printStackTrace(); } } } public static class SetNullReadListener implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { ServletInputStream sis; try { sis = connection.getInputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } sis.setReadListener(null); } @Override public void destroy() { // NO-OP } } public static class SetNullWriteListener implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { ServletOutputStream sos; try { sos = connection.getOutputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } sos.setWriteListener(null); } @Override public void destroy() { // NO-OP } } public static class SetReadListenerTwice implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { ServletInputStream sis; ServletOutputStream sos; try { sis = connection.getInputStream(); sos = connection.getOutputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } sos.setWriteListener(new NoOpWriteListener()); ReadListener rl = new NoOpReadListener(); sis.setReadListener(rl); sis.setReadListener(rl); } @Override public void destroy() { // NO-OP } } public static class SetWriteListenerTwice implements HttpUpgradeHandler { @Override public void init(WebConnection connection) { ServletInputStream sis; ServletOutputStream sos; try { sis = connection.getInputStream(); sos = connection.getOutputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } sis.setReadListener(new NoOpReadListener()); WriteListener wl = new NoOpWriteListener(); sos.setWriteListener(wl); sos.setWriteListener(wl); } @Override public void destroy() { // NO-OP } } public static class FixedResponseNonBlocking implements HttpUpgradeHandler { public static final char FIXED_RESPONSE = 'F'; private ServletInputStream sis; private ServletOutputStream sos; @Override public void init(WebConnection connection) { try { sis = connection.getInputStream(); sos = connection.getOutputStream(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } sis.setReadListener(new NoOpReadListener()); sos.setWriteListener(new FixedResponseWriteListener()); } @Override public void destroy() { // NO-OP } private class FixedResponseWriteListener extends NoOpWriteListener { @Override public void onWritePossible() { try { sos.write(FIXED_RESPONSE); sos.flush(); } catch (IOException ioe) { throw new IllegalStateException(ioe); } } } } private static class NoOpReadListener implements ReadListener { @Override public void onDataAvailable() { // NO-OP } @Override public void onAllDataRead() { // Always NO-OP for HTTP Upgrade } @Override public void onError(Throwable throwable) { // NO-OP } } private static class NoOpWriteListener implements WriteListener { @Override public void onWritePossible() { // NO-OP } @Override public void onError(Throwable throwable) { // NO-OP } } }