package org.rakam.server.http; import com.google.common.collect.ImmutableSet; import io.airlift.log.Logger; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.DecoderResult; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultHttpResponse; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.cookie.Cookie; import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.time.Duration; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Consumer; import static io.netty.handler.codec.http.HttpHeaders.Names.*; import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static io.netty.handler.codec.http.cookie.ServerCookieDecoder.STRICT; import static io.netty.util.CharsetUtil.UTF_8; public class RakamHttpRequest implements HttpRequest, Comparable { private final static Logger LOGGER = Logger.get(HttpServer.class); private final static InputStream REQUEST_DONE_STREAM = new InvalidInputStream(); private final ChannelHandlerContext ctx; private HttpRequest request; private FullHttpResponse response; @Override public boolean equals(Object o) { return o == this; } private Consumer<InputStream> bodyHandler; private Set<Cookie> cookies; private InputStream body; private QueryStringDecoder qs; private String remoteAddress; public RakamHttpRequest(ChannelHandlerContext ctx) { this.ctx = ctx; } void setRequest(HttpRequest request) { this.request = request; } public String getRemoteAddress() { return remoteAddress == null ? ((InetSocketAddress) context().channel().remoteAddress()).getHostString() : remoteAddress; } HttpRequest getRequest() { return request; } @Override public HttpMethod getMethod() { return request.getMethod(); } @Override public HttpRequest setMethod(HttpMethod method) { return request.setMethod(method); } @Override public String getUri() { return request.getUri(); } @Override public HttpRequest setUri(String uri) { return request.setUri(uri); } @Override public HttpVersion getProtocolVersion() { return request.getProtocolVersion(); } @Override public io.netty.handler.codec.http.HttpRequest setProtocolVersion(HttpVersion version) { return request.setProtocolVersion(version); } @Override public HttpHeaders headers() { return request.headers(); } public ChannelHandlerContext context() { return ctx; } @Override public DecoderResult getDecoderResult() { return request.getDecoderResult(); } protected Consumer<InputStream> getBodyHandler() { return bodyHandler; } @Override public void setDecoderResult(DecoderResult result) { request.setDecoderResult(result); } public void bodyHandler(Consumer<InputStream> function) { bodyHandler = function; } public RakamHttpRequest response(String content) { final ByteBuf byteBuf = Unpooled.wrappedBuffer(content.getBytes(UTF_8)); response = new DefaultFullHttpResponse(HTTP_1_1, OK, byteBuf); return this; } public Set<Cookie> cookies() { if (cookies == null) { String header = request.headers().get(COOKIE); cookies = header != null ? STRICT.decode(header) : ImmutableSet.of(); } return cookies; } public RakamHttpRequest response(byte[] content) { final ByteBuf byteBuf = Unpooled.copiedBuffer(content); response = new DefaultFullHttpResponse(HTTP_1_1, OK, byteBuf); return this; } public RakamHttpRequest response(byte[] content, HttpResponseStatus status) { final ByteBuf byteBuf = Unpooled.copiedBuffer(content); response = new DefaultFullHttpResponse(HTTP_1_1, status, byteBuf); return this; } public RakamHttpRequest response(String content, HttpResponseStatus status) { final ByteBuf byteBuf = Unpooled.wrappedBuffer(content.getBytes(UTF_8)); response = new DefaultFullHttpResponse(HTTP_1_1, status, byteBuf); return this; } public RakamHttpRequest response(FullHttpResponse response) { this.response = response; return this; } public Map<String, List<String>> params() { if (qs == null) { qs = new QueryStringDecoder(request.getUri()); } return qs.parameters(); } public String path() { if (qs == null) { qs = new QueryStringDecoder(request.getUri()); } return qs.path(); } public void end() { if (body != null) { try { body.close(); } catch (IOException e) { LOGGER.error(e); } } else { body = REQUEST_DONE_STREAM; } if (response == null) { response = new DefaultFullHttpResponse(HTTP_1_1, OK, Unpooled.wrappedBuffer(new byte[0])); } boolean keepAlive = HttpHeaders.isKeepAlive(request); String origin = request.headers().get(ORIGIN); if (origin != null) { response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, origin); } if (keepAlive) { response.headers().set(CONTENT_LENGTH, response.content().readableBytes()); response.headers().set(CONNECTION, HttpHeaders.Values.KEEP_ALIVE); ctx.writeAndFlush(response); } else { ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); } } void handleBody(InputStream body) { if(body == REQUEST_DONE_STREAM) { try { body.close(); } catch (IOException e) { LOGGER.error(e); } } this.body = body; if (bodyHandler != null) { bodyHandler.accept(body); } } InputStream getBody() { return body; } public StreamResponse streamResponse(Duration retryDuration) { StreamResponse streamResponse = streamResponse(); ByteBuf msg = Unpooled.wrappedBuffer(("retry:" + retryDuration.toMillis() + "\n\n").getBytes(UTF_8)); ctx.writeAndFlush(msg); return streamResponse; } public StreamResponse streamResponse() { HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK); response.headers().set(CONTENT_TYPE, "text/event-stream"); response.headers().set(ACCESS_CONTROL_ALLOW_ORIGIN, "*"); response.headers().set(CONNECTION, HttpHeaders.Values.KEEP_ALIVE); ctx.writeAndFlush(response); return new StreamResponse(ctx); } void setRemoteAddress(String remoteAddress) { this.remoteAddress = remoteAddress; } @Override public int compareTo(Object o) { return o == null ? -1 : (o == this ? 0 : 1); } public class StreamResponse { private final ChannelHandlerContext ctx; private ChannelFuture lastBufferData = null; public StreamResponse(ChannelHandlerContext ctx) { this.ctx = ctx; } public synchronized StreamResponse send(String event, String data) { if (ctx.isRemoved()) { throw new IllegalStateException(); } ByteBuf msg = Unpooled.wrappedBuffer(("event:" + event + "\ndata: " + data.replaceAll("\n", "\ndata: ") + "\n\n").getBytes(UTF_8)); lastBufferData = ctx.writeAndFlush(msg); return this; } public boolean isClosed() { return ctx.isRemoved(); } public void listenClose(Runnable runnable) { ctx.channel().closeFuture().addListener(future -> runnable.run()); } public synchronized void end() { if(ctx.isRemoved()) { return; } if (lastBufferData != null) { lastBufferData.addListener(ChannelFutureListener.CLOSE); } else { ctx.close(); } } } // if request.end() is called before Netty fetches the body data, handleBody() // will be called after request.end() and it needs to release the buffer immediately. // this class is used to identify if the request is ended. private static class InvalidInputStream extends InputStream { @Override public int read() throws IOException { throw new UnsupportedOperationException(); } } }