/* * Copyright (c) 2012-2013 Spotify AB * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package com.spotify.netty4.handler.codec.zmtp; import com.google.common.util.concurrent.SettableFuture; import org.junit.After; import org.junit.Before; import org.junit.experimental.theories.Theories; import org.junit.experimental.theories.Theory; import org.junit.experimental.theories.suppliers.TestedOn; import org.junit.runner.RunWith; import java.net.InetSocketAddress; import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.util.ReferenceCountUtil; import static com.spotify.netty4.handler.codec.zmtp.ZMTPProtocols.ZMTP20; import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER; import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertFalse; @RunWith(Theories.class) public class ProtocolViolationTests { private Channel serverChannel; private InetSocketAddress serverAddress; private final String identity = "identity"; private NioEventLoopGroup bossGroup; private NioEventLoopGroup group; @ChannelHandler.Sharable private static class MockHandler extends ChannelInboundHandlerAdapter { private SettableFuture<Void> active = SettableFuture.create(); private SettableFuture<Throwable> exception = SettableFuture.create(); private SettableFuture<Void> inactive = SettableFuture.create(); private volatile boolean handshaked; private volatile boolean read; @Override public void channelActive(final ChannelHandlerContext ctx) throws Exception { active.set(null); } @Override public void channelInactive(final ChannelHandlerContext ctx) throws Exception { inactive.set(null); } @Override public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { if (evt instanceof ZMTPHandshakeSuccess) { handshaked = true; } } @Override public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { ReferenceCountUtil.release(msg); read = true; } @Override public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception { exception.set(cause); ctx.close(); } } private final MockHandler mockHandler = new MockHandler(); @Before public void setup() { final ServerBootstrap serverBootstrap = new ServerBootstrap(); serverBootstrap.channel(NioServerSocketChannel.class); bossGroup = new NioEventLoopGroup(1); group = new NioEventLoopGroup(); serverBootstrap.group(bossGroup, group); serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() { @Override protected void initChannel(final NioSocketChannel ch) throws Exception { ch.pipeline().addLast( ZMTPCodec.builder() .protocol(ZMTP20) .socketType(ROUTER) .localIdentity(identity) .build(), mockHandler); } }); serverChannel = serverBootstrap.bind(new InetSocketAddress("localhost", 0)) .awaitUninterruptibly().channel(); serverAddress = (InetSocketAddress) serverChannel.localAddress(); } @After public void teardown() { if (serverChannel != null) { serverChannel.close(); } if (bossGroup != null) { bossGroup.shutdownGracefully(); } if (group != null) { group.shutdownGracefully(); } } @Theory public void protocolErrorsCauseException( @TestedOn(ints = {16, 17, 27, 32, 48, 53}) final int payloadSize) throws Exception { final Bootstrap b = new Bootstrap(); b.group(new NioEventLoopGroup()); b.channel(NioSocketChannel.class); b.handler(new ChannelInitializer<NioSocketChannel>() { @Override protected void initChannel(final NioSocketChannel ch) throws Exception { ch.pipeline().addLast(new MockHandler()); } }); final Channel channel = b.connect(serverAddress).awaitUninterruptibly().channel(); final ByteBuf payload = Unpooled.buffer(payloadSize); for (int i = 0; i < payloadSize; i++) { payload.writeByte(0); } channel.writeAndFlush(payload); mockHandler.active.get(5, SECONDS); mockHandler.exception.get(5, SECONDS); mockHandler.inactive.get(5, SECONDS); assertFalse(mockHandler.handshaked); assertFalse(mockHandler.read); } }