/* * Copyright (c) 2012-2015 Spotify AB * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.spotify.netty4.handler.codec.zmtp; import com.google.common.base.Function; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.AsyncFunction; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.spotify.netty4.util.BatchFlusher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.URI; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.group.ChannelGroup; import io.netty.channel.group.DefaultChannelGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.concurrent.GlobalEventExecutor; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.util.concurrent.Futures.immediateFailedFuture; import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.common.util.concurrent.Futures.transform; import static com.spotify.netty4.handler.codec.zmtp.ListenableFutureAdapter.listenable; /** * A simple ZMTP socket implementation for testing purposes. */ public class ZMTPSocket implements Closeable { private static final Logger log = LoggerFactory.getLogger(ZMTPSocket.class); private static final ListeningExecutorService EXECUTOR = MoreExecutors.listeningDecorator(GlobalEventExecutor.INSTANCE); /** * Represents a connected peer. */ public interface ZMTPPeer { /** * Get the ZMTP session for this peer. */ ZMTPSession session(); /** * Send a message to this peer. */ ListenableFuture<Void> send(ZMTPMessage message); } /** * Handles incoming messages and connection events. */ public interface Handler { /** * A peer connected. */ void connected(ZMTPSocket socket, ZMTPPeer peer); /** * A peer disconnected. */ void disconnected(ZMTPSocket socket, ZMTPPeer peer); /** * A message was received from a peer. */ void message(ZMTPSocket socket, ZMTPPeer peer, ZMTPMessage message); } private interface Sender { ListenableFuture<Void> send(ZMTPMessage message); } private interface Receiver { void receive(final ZMTPPeer peer, ZMTPMessage message); } private static final ThreadFactory DAEMON = new ThreadFactoryBuilder().setDaemon(true).build(); private static final NioEventLoopGroup GROUP = new NioEventLoopGroup(1, DAEMON); private final ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); private volatile List<ZMTPPeer> peers = ImmutableList.of(); private volatile Map<ByteBuf, ZMTPPeer> routing = ImmutableMap.of(); private final Object lock = new Object(); private final Sender sender; private final Receiver receiver; private final Handler handler; private final ZMTPConfig config; private volatile boolean closed; /** * Create a new socket. */ private ZMTPSocket(final Handler handler, final ZMTPConfig config) { this.handler = checkNotNull(handler, "handler"); this.config = config.toBuilder() .identityGenerator(new IdentityGenerator()) .decoder(decoder(config.socketType())) .encoder(encoder(config.socketType())) .build(); this.sender = sender(config.socketType()); this.receiver = receiver(config.socketType()); } /** * Bind this socket to an endpoint. */ public ListenableFuture<InetSocketAddress> bind(final String endpoint) { return transform(address(endpoint), new AsyncFunction<InetSocketAddress, InetSocketAddress>() { @Override public ListenableFuture<InetSocketAddress> apply( @SuppressWarnings("NullableProblems") final InetSocketAddress input) throws Exception { return bind(input); } }); } /** * Bind this socket to an address. */ public ListenableFuture<InetSocketAddress> bind(final InetSocketAddress address) { final ServerBootstrap b = new ServerBootstrap() .channel(NioServerSocketChannel.class) .group(GROUP) .childHandler(new ChannelInitializer()); final ChannelFuture f = b.bind(address); channelGroup.add(f.channel()); if (closed) { f.channel().close(); return immediateFailedFuture(new ClosedChannelException()); } return transform(listenable(f), new Function<Void, InetSocketAddress>() { @Override public InetSocketAddress apply(final Void input) { return (InetSocketAddress) f.channel().localAddress(); } }); } /** * Connect this socket to an endpoint. */ public ListenableFuture<Void> connect(final String endpoint) { return transform(address(endpoint), new AsyncFunction<InetSocketAddress, Void>() { @Override public ListenableFuture<Void> apply( @SuppressWarnings("NullableProblems") final InetSocketAddress address) throws Exception { return connect(address); } }); } /** * Connect this socket to an address. */ public ListenableFuture<Void> connect(final InetSocketAddress address) { final Bootstrap b = new Bootstrap() .group(GROUP) .channel(NioSocketChannel.class) .handler(new ChannelInitializer()); final ChannelFuture f = b.connect(address); if (closed) { f.channel().close(); return immediateFailedFuture(new ClosedChannelException()); } return listenable(f); } /** * Send a message on this socket. */ public ListenableFuture<Void> send(final ZMTPMessage message) { return sender.send(message); } /** * Close this socket. */ @Override public void close() { closed = true; channelGroup.close().awaitUninterruptibly(); } /** * Get a list of all connected peers. */ public List<ZMTPPeer> peers() { return peers; } /** * Get a sender for a socket type. */ private ZMTPEncoder.Factory encoder(final ZMTPSocketType socketType) { switch (socketType) { case ROUTER: return RoutingEncoder.FACTORY; default: return ZMTPMessageEncoder.FACTORY; } } /** * Get a decoder for a socket type. */ private ZMTPDecoder.Factory decoder(final ZMTPSocketType socketType) { switch (socketType) { case ROUTER: return RoutingDecoder.FACTORY; default: return ZMTPMessageDecoder.FACTORY; } } /** * Get a receiver that implements the appropriate socket type behavior. */ private Receiver receiver(final ZMTPSocketType socketType) { switch (socketType) { case PUSH: return new DropReceiver(); case DEALER: case ROUTER: case SUB: case PUB: case PULL: return new PassReceiver(); default: throw new IllegalArgumentException("Unsupported socket type: " + socketType); } } /** * Get a sender that implements the appropriate socket type behavior. */ private Sender sender(final ZMTPSocketType socketType) { switch (socketType) { case PUSH: case DEALER: return new RoundRobinSender(); case ROUTER: return new RoutingSender(); case SUB: case PUB: return new BroadcastSender(); case PULL: return new UnsupportedOperationSender(); default: throw new IllegalArgumentException("Unsupported socket type: " + socketType); } } /** * Resolve an endpoint into an address. Async to avoid blocking on DNS resolution. */ private static ListenableFuture<InetSocketAddress> address(final String endpoint) { return EXECUTOR.submit(new Callable<InetSocketAddress>() { @Override public InetSocketAddress call() throws Exception { final URI uri = URI.create(endpoint); checkArgument("tcp".equals(uri.getScheme()), "Unsupported endpoint type: %s", uri.getScheme()); final List<String> parts = Splitter.on(':').splitToList(uri.getAuthority()); final String hostString = parts.get(0); final InetAddress host = hostString.equals("*") ? null : InetAddress.getByName(hostString); final String portString = parts.get(1); final int port = portString.equals("*") ? 0 : Integer.valueOf(portString); return new InetSocketAddress(host, port); } }); } /** * Handles a single connected peer. */ private class Peer extends ChannelInboundHandlerAdapter implements ZMTPPeer { private final Channel ch; private final BatchFlusher flusher; private final ZMTPSession session; public Peer(final Channel ch, final ZMTPSession session) { this.ch = ch; this.session = session; this.flusher = new BatchFlusher(ch); } @Override public ZMTPSession session() { return session; } @Override public ListenableFuture<Void> send(final ZMTPMessage message) { final ChannelFuture f = ch.write(message); flusher.flush(); return listenable(f); } @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { if (evt instanceof ZMTPHandshakeSuccess) { register(session.peerIdentity()); try { handler.connected(ZMTPSocket.this, this); } catch (Exception e) { log.error("handler threw exception", e); } } } @Override public void channelInactive(final ChannelHandlerContext ctx) throws Exception { deregister(session.peerIdentity()); try { handler.disconnected(ZMTPSocket.this, this); } catch (Exception e) { log.error("handler threw exception", e); } } private void register(final ByteBuffer identity) { synchronized (lock) { peers = ImmutableList.<ZMTPPeer>builder() .addAll(peers) .add(this) .build(); routing = ImmutableMap.<ByteBuf, ZMTPPeer>builder() .putAll(routing) .put(Unpooled.wrappedBuffer(identity), this) .build(); } } private void deregister(final ByteBuffer identity) { synchronized (lock) { final ImmutableList.Builder<ZMTPPeer> newPeers = ImmutableList.builder(); for (final ZMTPPeer handler : peers) { if (handler != this) { newPeers.add(handler); } } peers = newPeers.build(); final ImmutableMap.Builder<ByteBuf, ZMTPPeer> newRouting = ImmutableMap.builder(); final ByteBuf id = Unpooled.wrappedBuffer(identity); for (final Map.Entry<ByteBuf, ZMTPPeer> entry : routing.entrySet()) { if (!entry.getKey().equals(id)) { newRouting.put(entry.getKey(), entry.getValue()); } } routing = newRouting.build(); } } @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { if (msg instanceof ZMTPMessage) { try { receiver.receive(this, (ZMTPMessage) msg); } catch (Exception e) { log.error("handler threw exception", e); } } } @Override public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception { log.warn("exception", cause); ctx.close(); } } private class ChannelInitializer extends io.netty.channel.ChannelInitializer { @Override protected void initChannel(final Channel ch) throws Exception { channelGroup.add(ch); final ZMTPCodec codec = ZMTPCodec.from(config); final Peer peer = new Peer(ch, codec.session()); ch.pipeline().addLast(codec, peer); } } /** * A sender that round robin load balances messages over all connected peers. */ private class RoundRobinSender implements Sender { private final AtomicInteger i = new AtomicInteger(); @Override public ListenableFuture<Void> send(final ZMTPMessage message) { final List<ZMTPPeer> channels = peers(); if (channels.size() == 0) { return immediateFailedFuture(new ClosedChannelException()); } final ZMTPPeer handler = next(channels); return handler.send(message); } private ZMTPPeer next(final List<ZMTPPeer> channels) { assert !channels.isEmpty(); int next; int prev; do { prev = i.get(); next = prev + 1; if (next >= channels.size()) { next = 0; } } while (!i.compareAndSet(prev, next)); return channels.get(next); } } /** * A sender that routes message to the appropriate peer using the front identity frame. */ private class RoutingSender implements Sender { @Override public ListenableFuture<Void> send(final ZMTPMessage message) { if (message.size() == 0) { return immediateFailedFuture(new IllegalArgumentException("empty message")); } final ByteBuf identity = message.frame(0); final ZMTPPeer peer = routing.get(identity); if (peer == null) { message.release(); return immediateFailedFuture(new ClosedChannelException()); } return peer.send(message); } } /** * A sender that broadcasts messages to all connected peers. */ private class BroadcastSender implements Sender { @Override public ListenableFuture<Void> send(final ZMTPMessage message) { final List<ZMTPPeer> channels = peers(); for (final ZMTPPeer handler : channels) { handler.send(message); } return immediateFuture(null); } } /** * A sender that immediately fails all send operations. */ private class UnsupportedOperationSender implements Sender { @Override public ListenableFuture<Void> send(final ZMTPMessage message) { return immediateFailedFuture(new UnsupportedOperationException()); } } /** * A receiver that drops all incoming messages. */ private class DropReceiver implements Receiver { @Override public void receive(final ZMTPPeer peer, final ZMTPMessage message) { message.release(); } } /** * A receive that passes on all incoming messages to the {@link Handler}. */ private class PassReceiver implements Receiver { @Override public void receive(final ZMTPPeer peer, final ZMTPMessage message) { try { handler.message(ZMTPSocket.this, peer, message); } catch (Exception e) { log.error("handler threw exception", e); } } } /** * A {@link ZMTPMessage} decoder that pushes the peer identity onto the front of the message. */ private static class RoutingDecoder implements ZMTPDecoder { private static final ZMTPDecoder.Factory FACTORY = new Factory() { @Override public ZMTPDecoder decoder(final ZMTPSession session) { return new RoutingDecoder(session.peerIdentity()); } }; private final ByteBuf DELIMITER = Unpooled.EMPTY_BUFFER; private final ByteBuffer identity; private final List<ByteBuf> frames = new ArrayList<ByteBuf>(); private int frameLength; RoutingDecoder(final ByteBuffer identity) { this.identity = identity; reset(); } private void reset() { frames.clear(); frames.add(Unpooled.wrappedBuffer(identity)); frameLength = 0; } @Override public void header(final ChannelHandlerContext ctx, final long length, final boolean more, final List<Object> out) { frameLength = (int) length; } @Override public void content(final ChannelHandlerContext ctx, final ByteBuf data, final List<Object> out) { if (data.readableBytes() < frameLength) { return; } if (frameLength == 0) { frames.add(DELIMITER); return; } final ByteBuf frame = data.readSlice(frameLength); frame.retain(); frames.add(frame); } @Override public void finish(final ChannelHandlerContext ctx, final List<Object> out) { final ZMTPMessage message = ZMTPMessage.from(frames); reset(); out.add(message); } @Override public void close() { for (final ByteBuf frame : frames) { frame.release(); } frames.clear(); } } /** * A {@link ZMTPMessage} encoder that pops the peer identity from the front of the message. */ private static class RoutingEncoder implements ZMTPEncoder { private static final ZMTPEncoder.Factory FACTORY = new Factory() { @Override public ZMTPEncoder encoder(final ZMTPSession session) { return new RoutingEncoder(); } }; @Override public void estimate(final Object msg, final ZMTPEstimator estimator) { final ZMTPMessage message = (ZMTPMessage) msg; for (int i = 1; i < message.size(); i++) { final ByteBuf frame = message.frame(i); estimator.frame(frame.readableBytes()); } } @Override public void encode(final Object msg, final ZMTPWriter writer) { final ZMTPMessage message = (ZMTPMessage) msg; for (int i = 1; i < message.size(); i++) { final ByteBuf frame = message.frame(i); final boolean more = i < message.size() - 1; final ByteBuf dst = writer.frame(frame.readableBytes(), more); dst.writeBytes(frame, frame.readerIndex(), frame.readableBytes()); } } @Override public void close() { } } public static Builder builder() { return new Builder(); } public static class Builder { private final ZMTPConfig.Builder config = ZMTPConfig.builder(); private Handler handler; public Builder handler(final Handler handler) { this.handler = handler; return this; } public Builder protocol(final ZMTPProtocol protocol) { config.protocol(protocol); return this; } public Builder interop(final boolean interop) { config.interop(interop); return this; } public Builder type(final ZMTPSocketType socketType) { config.socketType(socketType); return this; } public Builder identity(final CharSequence identity) { config.localIdentity(identity); return this; } public Builder identity(final byte[] identity) { config.localIdentity(identity); return this; } public Builder identity(final ByteBuffer localIdentity) { config.localIdentity(localIdentity); return this; } public ZMTPSocket build() { return new ZMTPSocket(handler, config.build()); } } /** * An identity generator that keeps an integer counter per {@link ZMTPSocket}. */ private static class IdentityGenerator implements ZMTPIdentityGenerator { private final AtomicInteger peerIdCounter = new AtomicInteger(new SecureRandom().nextInt()); @Override public ByteBuffer generateIdentity(final ZMTPSession session) { final ByteBuffer generated = ByteBuffer.allocate(5); generated.put((byte) 0); generated.putInt(peerIdCounter.incrementAndGet()); generated.flip(); return generated; } } }