package org.rakam.server.http; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpContent; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObject; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCounted; import io.netty.util.internal.ConcurrentSet; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.List; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE; import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.netty.handler.codec.http.HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; public class HttpServerHandler extends ChannelInboundHandlerAdapter { private static InputStream EMPTY_BODY = new ByteArrayInputStream(new byte[] {}); private final HttpServer server; private final ConcurrentSet activeChannels; protected RakamHttpRequest request; private List<ByteBuf> body; public HttpServerHandler(ConcurrentSet activeChannels, HttpServer server) { this.server = server; this.activeChannels = activeChannels; } RakamHttpRequest createRequest(ChannelHandlerContext ctx) { return new RakamHttpRequest(ctx); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { activeChannels.add(ctx); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { activeChannels.remove(ctx); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (HttpHeaders.is100ContinueExpected(request)) { ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, CONTINUE)); } if (msg instanceof HttpRequest) { this.request = createRequest(ctx); this.request.setRequest((io.netty.handler.codec.http.HttpRequest) msg); if (msg instanceof HttpObject) { if (((HttpRequest) msg).getDecoderResult().isFailure()) { Throwable cause = ((HttpRequest) msg).getDecoderResult().cause(); if (request == null) { request = createRequest(ctx); } HttpServer.returnError(request, cause.getMessage(), BAD_REQUEST); } } server.markProcessing(request); server.routeMatcher.handle(request); server.unmarkProcessing(request); } else if (msg instanceof LastHttpContent) { HttpContent chunk = (HttpContent) msg; try { ByteBuf content = chunk.content(); if (content.isReadable()) { InputStream input; if (body == null || body.size() == 0) { input = new ReferenceCountedByteBufInputStream(content); } else { if (body == null) { body = new ArrayList<>(1); body.set(0, content); } else { body.add(content); } input = new ChainByteArrayInputStream(body); body = new ArrayList<>(2); } content.retain(); handleBody(input); } else { // even if body content is empty, call request.handleBody method. if (request.getBodyHandler() != null) { handleBody(EMPTY_BODY); } } } catch (HttpRequestException e) { HttpServer.returnError(request, e.getMessage(), e.getStatusCode()); } } else if (msg instanceof HttpContent) { HttpContent chunk = (HttpContent) msg; ByteBuf content = chunk.content(); if (content.isReadable()) { if (server.maximumBodySize > -1) { long value = content.capacity(); if (body != null) { for (ByteBuf byteBuf : body) { value += byteBuf.capacity(); } } if (value > server.maximumBodySize) { HttpServer.returnError(request, "Body is too large.", REQUEST_ENTITY_TOO_LARGE); ctx.close(); } } content.retain(); if (body == null) { body = new ArrayList<>(1); body.add(content); } else { body.add(content); } } } else if (msg instanceof WebSocketFrame) { server.markProcessing(request); server.routeMatcher.handle(ctx, (WebSocketFrame) msg); server.unmarkProcessing(request); } } private void handleBody(InputStream body) { server.markProcessing(request); request.handleBody(body); server.unmarkProcessing(request); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { server.uncaughtExceptionHandler.handle(request, cause); cause.printStackTrace(); HttpServer.returnError(request, "An error occurred", INTERNAL_SERVER_ERROR); ctx.close(); } private static class ReferenceCountedByteBufInputStream extends InputStream { private final ByteBuf buffer; public ReferenceCountedByteBufInputStream(ByteBuf buffer) { this.buffer = buffer; } @Override public int available() throws IOException { return buffer.readableBytes(); } @Override public int read() throws IOException { return buffer.readByte(); } @Override public int read(byte[] b, int off, int len) throws IOException { int available = available(); if (available == 0) { return -1; } len = Math.min(available, len); buffer.readBytes(b, off, len); return len; } @Override public void close() { buffer.release(); } } public static class ChainByteArrayInputStream extends InputStream { private final List<ByteBuf> arrays; private int position; private ByteBuf cursor; private int cursorPos; public ChainByteArrayInputStream(List<ByteBuf> arrays) { this.arrays = arrays; reset(); } @Override public int available() { int remaining = cursor.capacity() - position; for (int i = cursorPos; i < arrays.size(); i++) { remaining += arrays.get(0).capacity(); } return remaining; } @Override public int read() throws IOException { if (cursor.capacity() == position) { if (arrays.size() == cursorPos) { return -1; } cursor = arrays.get(cursorPos++); position = 1; return cursor.getByte(0); } return cursor.getByte(position++); } @Override public synchronized void reset() { cursor = arrays.get(0); position = 0; cursorPos = 1; } @Override public void close() throws IOException { arrays.forEach(ReferenceCounted::release); } } }