package com.spotify.netty4.handler.codec.zmtp;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.runners.MockitoJUnitRunner;
import java.nio.ByteBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import static com.spotify.netty4.handler.codec.zmtp.Buffers.buf;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.PUB;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.REQ;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.SUB;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPVersion.ZMTP10;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPVersion.ZMTP20;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
/**
* Tests the handshake protocol
*/
@RunWith(MockitoJUnitRunner.class)
public class HandshakeTest {
private static final ByteBuffer FOO = UTF_8.encode("foo");
private static final ByteBuffer BAR = UTF_8.encode("bar");
@Mock ChannelHandlerContext ctx;
@Test
public void testGreeting() {
ZMTPHandshaker h = new ZMTP10Protocol.Handshaker(FOO);
assertThat(h.greeting(), is(buf(0x04, 0x00, 0x66, 0x6f, 0x6f)));
h = new ZMTP10Protocol.Handshaker(ByteBuffer.allocate(0));
assertThat(h.greeting(), is(buf(0x01, 0x00)));
h = new ZMTP20Protocol.Handshaker(SUB, FOO, true);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 4, 0x7f)));
h = new ZMTP20Protocol.Handshaker(REQ, FOO, false);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f,
0x01, 0x03, 0x00, 3, 0x66, 0x6f, 0x6f)));
}
@Test
public void test1to1Handshake() throws Exception {
final ZMTP10Protocol.Handshaker h = new ZMTP10Protocol.Handshaker(FOO);
assertThat(h.greeting(), is(buf(0x04, 0x00, 0x66, 0x6f, 0x6f)));
final ZMTPHandshake handshake = h.handshake(buf(0x04, 0x00, 0x62, 0x61, 0x72), ctx);
assertThat(handshake, is(notNullValue()));
verifyZeroInteractions(ctx);
assertEquals(ZMTPHandshake.of(ZMTP10, BAR, null), handshake);
}
@Test
public void test2InteropTo1Handshake() throws Exception {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(ROUTER, FOO, true);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x7f)));
ZMTPHandshake handshake = h.handshake(buf(0x04, 0x00, 0x62, 0x61, 0x72), ctx);
assertThat(handshake, is(notNullValue()));
verify(ctx).writeAndFlush(buf(0x66, 0x6f, 0x6f));
assertEquals(ZMTPHandshake.of(ZMTP10, BAR, null), handshake);
}
@Test
public void test2InteropTo2InteropHandshake() throws Exception {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(PUB, FOO, true);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x7f)));
ZMTPHandshake handshake;
handshake = h.handshake(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x7f), ctx);
assertThat(handshake, is(nullValue()));
verify(ctx).writeAndFlush(buf(0x01, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f));
handshake = h.handshake(buf(0x01, 0x01, 0x00, 0x03, 0x62, 0x61, 0x72), ctx);
assertThat(handshake, is(notNullValue()));
verifyNoMoreInteractions(ctx);
assertEquals(ZMTPHandshake.of(ZMTPVersion.ZMTP20, BAR, PUB), handshake);
}
@Test
public void test2InteropTo2Handshake() throws Exception {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(PUB, FOO, true);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x04, 0x7f)));
ByteBuf cb = buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x01, 0x01, 0x00, 0x03, 0x62, 0x61, 0x72);
ZMTPHandshake handshake;
handshake = h.handshake(cb, ctx);
assertThat(handshake, is(nullValue()));
verify(ctx).writeAndFlush(buf(0x01, 0x01, 0x00, 0x03, 0x66, 0x6f, 0x6f));
handshake = h.handshake(cb, ctx);
assertThat(handshake, is(notNullValue()));
verifyNoMoreInteractions(ctx);
assertEquals(ZMTPHandshake.of(ZMTPVersion.ZMTP20, BAR, PUB), handshake);
}
@Test
public void test2To2InteropHandshake() throws Exception {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(PUB, FOO, false);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x1, 0x1, 0, 0x3, 0x66, 0x6f, 0x6f)));
try {
h.handshake(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x4, 0x7f), ctx);
fail("not enough data in greeting (because compat mode) should have thrown exception");
} catch (IndexOutOfBoundsException e) {
// expected
}
ZMTPHandshake handshake = h.handshake(
buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0x4, 0x7f, 0x1, 0x1, 0, 0x03, 0x62, 0x61, 0x72), ctx);
assertThat(handshake, is(notNullValue()));
assertEquals(ZMTPHandshake.of(ZMTPVersion.ZMTP20, BAR, PUB), handshake);
}
@Test
public void test2To2Handshake() throws Exception {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(PUB, FOO, false);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x1, 0x1, 0, 0x3, 0x66, 0x6f, 0x6f)));
ZMTPHandshake handshake = h.handshake(buf(
0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x1, 0x1, 0, 0x03, 0x62, 0x61, 0x72), ctx);
assertThat(handshake, is(notNullValue()));
assertEquals(ZMTPHandshake.of(ZMTPVersion.ZMTP20, BAR, PUB), handshake);
}
@Test
public void test2To1Handshake() {
ZMTPHandshaker h = new ZMTP20Protocol.Handshaker(PUB, FOO, false);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x1, 0x1, 0, 0x3, 0x66, 0x6f, 0x6f)));
try {
assertThat(h.handshake(buf(0x04, 0, 0x62, 0x61, 0x72), ctx), is(nullValue()));
fail("An ZMTP/1 greeting is invalid in plain ZMTP/2. Should have thrown exception");
} catch (ZMTPException e) {
// pass
}
}
@Test
public void test2To2CompatTruncated() throws Exception {
final ByteBuffer identity = UTF_8.encode("identity");
ZMTP20Protocol.Handshaker h = new ZMTP20Protocol.Handshaker(PUB, identity, true);
assertThat(h.greeting(), is(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 9, 0x7f)));
ZMTPHandshake handshake = h.handshake(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 1, 0x7f, 1, 5), ctx);
assertThat(handshake, is(nullValue()));
verify(ctx).writeAndFlush(buf(1, 1, 0, 8, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79));
}
@Test
public void testReadZMTP2Greeting() throws Exception {
final ByteBuf in = buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x01, 0x02, 0x00, 0x01, 0x61);
final ZMTP20WireFormat.Greeting greeting = ZMTP20WireFormat.readGreeting(in);
assertThat(greeting.identity(), is(UTF_8.encode("a")));
}
@Test
public void testReadZMTP1RemoteIdentity() throws Exception {
ByteBuffer identity = ZMTP10WireFormat.readIdentity(buf(0x04, 0x00, 0x62, 0x61, 0x72));
assertThat(identity, is(notNullValue()));
assertEquals(BAR, identity);
// anonymous handshake
identity = ZMTP10WireFormat.readIdentity(buf(0x01, 0x00));
assertThat(identity, is(notNullValue()));
assert identity != null;
assertThat(identity.remaining(), is(0));
}
@Test
public void testTypeToConst() {
assertEquals(8, ZMTPSocketType.PUSH.ordinal());
}
@Test
public void testDetectProtocolVersion() {
try {
ZMTP20WireFormat.detectProtocolVersion(Unpooled.wrappedBuffer(new byte[0]));
fail("Should have thrown IndexOutOfBoundsException");
} catch (IndexOutOfBoundsException e) {
// ignore
}
try {
ZMTP20WireFormat.detectProtocolVersion(buf(0xff, 0, 0, 0));
fail("Should have thrown IndexOutOfBoundsException");
} catch (IndexOutOfBoundsException e) {
// ignore
}
assertEquals(ZMTP10, ZMTP20WireFormat.detectProtocolVersion(buf(0x07)));
assertEquals(ZMTP10, ZMTP20WireFormat
.detectProtocolVersion(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 1, 0)));
assertEquals(ZMTP20, ZMTP20WireFormat
.detectProtocolVersion(buf(0xff, 0, 0, 0, 0, 0, 0, 0, 1, 1)));
}
@Test
public void testReadZMTP2GreetingMalformed() {
try {
ByteBuf in = buf(0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0x7f, 0x01, 0x02, 0xf0, 0x01, 0x61);
ZMTP20WireFormat.readGreeting(in);
fail("13th byte is not 0x00, should throw exception");
} catch (ZMTPException e) {
// pass
}
}
}