/* * Copyright (c) 2012-2014 Spotify AB * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.spotify.netty4.handler.codec.zmtp.benchmarks; import com.google.common.collect.FluentIterable; import com.google.common.collect.Maps; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.SettableFuture; import com.spotify.netty4.handler.codec.zmtp.ZMTPCodec; import com.spotify.netty4.handler.codec.zmtp.ZMTPDecoder; import com.spotify.netty4.handler.codec.zmtp.ZMTPEncoder; import com.spotify.netty4.handler.codec.zmtp.ZMTPEstimator; import com.spotify.netty4.handler.codec.zmtp.ZMTPHandshakeSuccess; import com.spotify.netty4.handler.codec.zmtp.ZMTPWriter; import com.spotify.netty4.util.BatchFlusher; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.security.SecureRandom; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPromise; import io.netty.channel.DefaultChannelPromise; import io.netty.channel.MessageSizeEstimator; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.internal.chmv8.ForkJoinPool; import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.DEALER; import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER; import static io.netty.util.CharsetUtil.UTF_8; import static java.util.Arrays.asList; public class CustomReqRepBenchmark { private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private static final InetSocketAddress ANY_PORT = new InetSocketAddress("127.0.0.1", 0); private static final Thread.UncaughtExceptionHandler UNCAUGHT_EXCEPTION_HANDLER = new Thread.UncaughtExceptionHandler() { @Override public void uncaughtException(final Thread thread, final Throwable throwable) { throwable.printStackTrace(); } }; public static void main(final String... args) throws InterruptedException { final ProgressMeter meter = new ProgressMeter("requests", true); // Codecs final ZMTPCodec serverCodec = ZMTPCodec.builder() .socketType(ROUTER) .encoder(ReplyEncoder.class) .decoder(RequestDecoder.class) .build(); final ZMTPCodec clientCodec = ZMTPCodec.builder() .socketType(DEALER) .encoder(RequestEncoder.class) .decoder(ReplyDecoder.class) .build(); // Server final Executor serverExecutor = new ForkJoinPool( 1, ForkJoinPool.defaultForkJoinWorkerThreadFactory, UNCAUGHT_EXCEPTION_HANDLER, true); final ServerBootstrap serverBootstrap = new ServerBootstrap() .group(new NioEventLoopGroup(1), new NioEventLoopGroup()) .channel(NioServerSocketChannel.class) .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) .childOption(ChannelOption.MESSAGE_SIZE_ESTIMATOR, ByteBufSizeEstimator.INSTANCE) .childHandler(new ChannelInitializer<NioSocketChannel>() { @Override protected void initChannel(final NioSocketChannel ch) throws Exception { ch.pipeline().addLast(serverCodec); ch.pipeline().addLast(new ServerRequestTracker()); ch.pipeline().addLast(new ServerHandler(serverExecutor)); } }); final Channel server = serverBootstrap.bind(ANY_PORT).awaitUninterruptibly().channel(); // Client final Executor clientExecutor = new ForkJoinPool( 1, ForkJoinPool.defaultForkJoinWorkerThreadFactory, UNCAUGHT_EXCEPTION_HANDLER, true); final SocketAddress address = server.localAddress(); final Bootstrap clientBootstrap = new Bootstrap() .group(new NioEventLoopGroup()) .channel(NioSocketChannel.class) .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) .option(ChannelOption.MESSAGE_SIZE_ESTIMATOR, ByteBufSizeEstimator.INSTANCE) .handler(new ChannelInitializer<NioSocketChannel>() { @Override protected void initChannel(final NioSocketChannel ch) throws Exception { ch.pipeline().addLast(clientCodec); ch.pipeline().addLast(new ClientRequestTracker()); ch.pipeline().addLast(new ClientHandler(meter, clientExecutor)); } }); final Channel client = clientBootstrap.connect(address).awaitUninterruptibly().channel(); // Run until client is closed client.closeFuture().await(); } private static class ServerHandler extends ChannelInboundHandlerAdapter { public static final ByteBuffer REPLY_PAYLOAD = UTF_8.encode("hello world"); private final Executor executor; private BatchFlusher flusher; public ServerHandler(final Executor executor) { this.executor = executor; } @Override public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { super.channelRegistered(ctx); this.flusher = new BatchFlusher(ctx.channel()); } @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { executor.execute(new Runnable() { @Override public void run() { final Request request = (Request) msg; ctx.write(request.reply(200, REPLY_PAYLOAD)); flusher.flush(); } }); } } private static class ClientHandler extends ChannelInboundHandlerAdapter { private static final int CONCURRENCY = 1000; private final ProgressMeter meter; private final Executor executor; private BatchFlusher flusher; private ChannelHandlerContext ctx; private long seq = new SecureRandom().nextLong(); public ClientHandler(final ProgressMeter meter, final Executor executor) { this.meter = meter; this.executor = executor; } @Override public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; this.flusher = new BatchFlusher(ctx.channel()); } @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { if (evt instanceof ZMTPHandshakeSuccess) { for (int i = 0; i < CONCURRENCY; i++) { send(ctx); } } } private void send(final ChannelHandlerContext ctx) { final RequestPromise promise = new RequestPromise(this.ctx.channel()); ctx.write(req(), promise); flusher.flush(); Futures.addCallback(promise.replyFuture(), new FutureCallback<Reply>() { @Override public void onSuccess(final Reply reply) { final long latency = System.nanoTime() - reply.id().timestamp(); meter.inc(1, latency); send(ctx); } @Override public void onFailure(final Throwable t) { System.err.println("failure"); } }, executor); } public static long rand(long x) { x ^= (x << 21); x ^= (x >>> 35); x ^= (x << 4); return x; } private Request req() { seq = rand(seq); final MessageId id = new MessageId(seq, System.nanoTime()); return new Request(id, "foo://bar/some/resource", "GET", EMPTY_BUFFER); } } private static class RequestPromise extends DefaultChannelPromise { private final SettableFuture<Reply> replyFuture = SettableFuture.create(); public RequestPromise(final Channel channel) { super(channel); } public SettableFuture<Reply> replyFuture() { return replyFuture; } } private static class Message { private final MessageId id; private final CharSequence uri; private final CharSequence method; private final ByteBuffer payload; public Message(final MessageId id, final CharSequence method, final CharSequence uri, final ByteBuffer payload) { this.id = id; this.method = method; this.uri = uri; this.payload = payload; } public MessageId id() { return id; } public CharSequence uri() { return uri; } public CharSequence method() { return method; } public ByteBuffer payload() { return payload; } } private static class Request extends Message { public Request(final MessageId id, final CharSequence uri, final CharSequence method, final ByteBuffer payload) { super(id, method, uri, payload); } public Reply reply(final int code, final ByteBuffer payload) { return new Reply(id(), uri(), method(), code, payload); } } private static class Reply extends Message { private final int code; public Reply(final MessageId id, final CharSequence uri, final CharSequence method, final int code, final ByteBuffer payload) { super(id, method, uri, payload); this.code = code; } public int statusCode() { return code; } } private static class MessageId { private final long seq; private final long timestamp; public MessageId(final long seq, final long timestamp) { this.seq = seq; this.timestamp = timestamp; } public long seq() { return seq; } public long timestamp() { return timestamp; } public static MessageId from(final long seq, final long timestamp) { return new MessageId(seq, timestamp); } @Override public boolean equals(final Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } final MessageId messageId = (MessageId) o; if (seq != messageId.seq) { return false; } return timestamp == messageId.timestamp; } @Override public int hashCode() { int result = (int) (seq ^ (seq >>> 32)); result = 31 * result + (int) (timestamp ^ (timestamp >>> 32)); return result; } @Override public String toString() { return "MessageId{" + "seq=" + seq + ", timestamp=" + timestamp + '}'; } } private static class RequestEncoder implements ZMTPEncoder { @Override public void estimate(final Object message, final ZMTPEstimator estimator) { final Request request = (Request) message; estimator.frame(request.uri().length()); estimator.frame(request.method().length()); estimator.frame(16); estimator.frame(request.payload().remaining()); } @Override public void encode(final Object message, final ZMTPWriter writer) { final Request request = (Request) message; writeAscii(writer, request.uri()); writeAscii(writer, request.method()); writeId(writer, request.id()); writePayload(writer, request.payload()); } @Override public void close() { } } private static class ReplyEncoder implements ZMTPEncoder { @Override public void estimate(final Object message, final ZMTPEstimator estimator) { final Reply reply = (Reply) message; estimator.frame(reply.uri().length()); estimator.frame(reply.method().length()); estimator.frame(16); estimator.frame(4); estimator.frame(reply.payload().remaining()); } @Override public void encode(final Object message, final ZMTPWriter writer) { final Reply reply = (Reply) message; writeAscii(writer, reply.uri()); writeAscii(writer, reply.method()); writeId(writer, reply.id()); writer.frame(4, true).writeInt(reply.statusCode()); writePayload(writer, reply.payload()); } @Override public void close() { } } private static void writeAscii(final ZMTPWriter writer, final CharSequence s) { final ByteBuf frame = writer.frame(s.length(), true); if (s instanceof AsciiString) { ((AsciiString) s).write(frame); } else { for (int i = 0; i < s.length(); i++) { frame.writeByte(s.charAt(i)); } } } private static void writeId(final ZMTPWriter writer, final MessageId id) { writer.frame(16, true) .writeLong(id.seq()) .writeLong(id.timestamp()); } private static void writePayload(final ZMTPWriter writer, final ByteBuffer payload) { final ByteBuf buf = writer.frame(payload.remaining(), false); if (payload.hasArray()) { buf.writeBytes(payload.array(), payload.arrayOffset() + payload.position(), payload.remaining()); } else { final int pos = payload.position(); for (int i = 0; i < payload.remaining(); i++) { buf.writeByte(payload.get(pos + i)); } payload.position(pos); } } private static ByteBuffer readPayload(final ByteBuf data, final int size) { if (size == 0) { return EMPTY_BUFFER; } final ByteBuffer buffer = ByteBuffer.allocate(size); data.readBytes(buffer); buffer.flip(); return buffer; } private static int readStatusCode(final ByteBuf data, final int size) { if (size != 4) { throw new IllegalArgumentException(); } return data.readInt(); } private static CharSequence readAscii(final ByteBuf data, final int size) { final byte[] chars = new byte[size]; data.readBytes(chars); return new AsciiString(chars); } private static final AsciiString[] METHODS = FluentIterable .from(asList("GET", "POST", "PUT", "DELETE", "PATCH")) .transform(AsciiString.ASCII_STRING_FROM_STRING) .toArray(AsciiString.class); private static CharSequence readMethod(final ByteBuf data, final int size) { for (final AsciiString method : METHODS) { if (asciiEquals(method, data, size)) { data.skipBytes(size); return method; } } return readAscii(data, size); } private static boolean asciiEquals(final AsciiString s, final ByteBuf data, final int size) { final int ix = data.readerIndex(); if (size != s.length()) { return false; } for (int i = 0; i < size; i++) { char c = (char) data.getByte(ix + i); if (c != s.charAt(i)) { return false; } } return true; } private static MessageId readId(final ByteBuf data, final int size) { if (size != 16) { throw new IllegalArgumentException(); } final long seq = data.readLong(); final long timestamp = data.readLong(); return MessageId.from(seq, timestamp); } private static class RequestDecoder implements ZMTPDecoder { enum State { URI, METHOD, ID, PAYLOAD } private State state = State.URI; private CharSequence uri; private CharSequence method; private MessageId id; private ByteBuffer payload; private int frameLength; @Override public void header(final ChannelHandlerContext ctx, final long length, final boolean more, final List<Object> out) { if (length > Integer.MAX_VALUE) { throw new IllegalArgumentException("length"); } frameLength = (int) length; } @Override public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) { if (data.readableBytes() < frameLength) { return; } switch (state) { case URI: uri = readAscii(data, frameLength); state = State.METHOD; break; case METHOD: method = readMethod(data, frameLength); state = State.ID; break; case ID: id = readId(data, frameLength); state = State.PAYLOAD; break; case PAYLOAD: payload = readPayload(data, frameLength); state = State.URI; break; } } @Override public void finish(final ChannelHandlerContext ctx, final List<Object> out) { out.add(new Request(id, uri, method, payload)); } @Override public void close() { } } private static class ReplyDecoder implements ZMTPDecoder { enum State { URI, METHOD, ID, STATUSCODE, PAYLOAD } private State state = State.URI; private int frameLength; private CharSequence uri; private CharSequence method; private MessageId id; private int statusCode; private ByteBuffer payload; @Override public void header(final ChannelHandlerContext ctx, final long length, final boolean more, final List<Object> out) { if (length > Integer.MAX_VALUE) { throw new IllegalArgumentException("length"); } frameLength = (int) length; } @Override public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) { if (data.readableBytes() < frameLength) { return; } switch (state) { case URI: uri = readAscii(data, frameLength); state = State.METHOD; break; case METHOD: method = readMethod(data, frameLength); state = State.ID; break; case ID: id = readId(data, frameLength); state = State.STATUSCODE; break; case STATUSCODE: statusCode = readStatusCode(data, frameLength); state = State.PAYLOAD; break; case PAYLOAD: payload = readPayload(data, frameLength); state = State.URI; break; } } @Override public void finish(final ChannelHandlerContext ctx, final List<Object> out) { out.add(new Reply(id, uri, method, statusCode, payload)); } @Override public void close() { } } private static class ServerRequestTracker extends ChannelDuplexHandler { private final Map<MessageId, Request> outstanding = Maps.newHashMap(); @Override public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) throws Exception { final Reply reply = (Reply) msg; final Request request = outstanding.remove(reply.id()); if (request == null) { System.err.println("Unexpected reply: " + reply); } else { super.write(ctx, msg, promise); } } public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { final Request request = (Request) msg; outstanding.put(request.id(), request); super.channelRead(ctx, msg); } } private static class ClientRequestTracker extends ChannelDuplexHandler { private final Map<MessageId, SettableFuture<Reply>> outstanding = Maps.newHashMap(); @Override public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) throws Exception { final Request request = (Request) msg; final RequestPromise requestPromise = (RequestPromise) promise; outstanding.put(request.id(), requestPromise.replyFuture); super.write(ctx, msg, promise); } @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { final Reply reply = (Reply) msg; final SettableFuture<Reply> future = outstanding.remove(reply.id()); if (future == null) { System.err.println("unexpected reply: " + reply); } else { future.set(reply); } } } private static class ByteBufSizeEstimator implements MessageSizeEstimator, MessageSizeEstimator.Handle { public static final ByteBufSizeEstimator INSTANCE = new ByteBufSizeEstimator(); @Override public Handle newHandle() { return this; } @Override public int size(final Object msg) { if (msg instanceof ByteBuf) { return ((ByteBuf) msg).readableBytes(); } return 0; } } }