/*
* Copyright 2014 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*/
package io.netty.handler.codec.http2;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.util.AsciiString;
import io.netty.util.CharsetUtil;
import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT;
import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyBoolean;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyShort;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
/**
* Test for data decompression in the HTTP/2 codec.
*/
public class DataCompressionHttp2Test {
private static final AsciiString GET = new AsciiString("GET");
private static final AsciiString POST = new AsciiString("POST");
private static final AsciiString PATH = new AsciiString("/some/path");
@Mock
private Http2FrameListener serverListener;
@Mock
private Http2FrameListener clientListener;
private Http2ConnectionEncoder clientEncoder;
private ServerBootstrap sb;
private Bootstrap cb;
private Channel serverChannel;
private Channel clientChannel;
private volatile Channel serverConnectedChannel;
private CountDownLatch serverLatch;
private Http2Connection serverConnection;
private Http2Connection clientConnection;
private Http2ConnectionHandler clientHandler;
private ByteArrayOutputStream serverOut;
@Before
public void setup() throws InterruptedException, Http2Exception {
MockitoAnnotations.initMocks(this);
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
if (invocation.getArgument(4)) {
serverConnection.stream((Integer) invocation.getArgument(1)).close();
}
return null;
}
}).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
anyInt(), anyBoolean());
doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
if (invocation.getArgument(7)) {
serverConnection.stream((Integer) invocation.getArgument(1)).close();
}
return null;
}
}).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class),
anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean());
}
@After
public void cleanup() throws IOException {
serverOut.close();
}
@After
public void teardown() throws InterruptedException {
if (clientChannel != null) {
clientChannel.close().sync();
clientChannel = null;
}
if (serverChannel != null) {
serverChannel.close().sync();
serverChannel = null;
}
final Channel serverConnectedChannel = this.serverConnectedChannel;
if (serverConnectedChannel != null) {
serverConnectedChannel.close().sync();
this.serverConnectedChannel = null;
}
Future<?> serverGroup = sb.config().group().shutdownGracefully(0, 0, MILLISECONDS);
Future<?> serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 0, MILLISECONDS);
Future<?> clientGroup = cb.config().group().shutdownGracefully(0, 0, MILLISECONDS);
serverGroup.sync();
serverChildGroup.sync();
clientGroup.sync();
}
@Test
public void justHeadersNoData() throws Exception {
bootstrapEnv(0);
final Http2Headers headers = new DefaultHttp2Headers().method(GET).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, true, newPromiseClient());
clientHandler.flush(ctxClient());
}
});
awaitServer();
verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0),
eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true));
}
@Test
public void gzipEncodingSingleEmptyMessage() throws Exception {
final String text = "";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
bootstrapEnv(data.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient());
clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient());
clientHandler.flush(ctxClient());
}
});
awaitServer();
assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
}
}
@Test
public void gzipEncodingSingleMessage() throws Exception {
final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
bootstrapEnv(data.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient());
clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient());
clientHandler.flush(ctxClient());
}
});
awaitServer();
assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
}
}
@Test
public void gzipEncodingMultipleMessages() throws Exception {
final String text1 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc";
final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff";
final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes());
final ByteBuf data2 = Unpooled.copiedBuffer(text2.getBytes());
bootstrapEnv(data1.readableBytes() + data2.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient());
clientEncoder.writeData(ctxClient(), 3, data1.retain(), 0, false, newPromiseClient());
clientEncoder.writeData(ctxClient(), 3, data2.retain(), 0, true, newPromiseClient());
clientHandler.flush(ctxClient());
}
});
awaitServer();
assertEquals(text1 + text2, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data1.release();
data2.release();
}
}
@Test
public void deflateEncodingWriteLargeMessage() throws Exception {
final int BUFFER_SIZE = 1 << 12;
final byte[] bytes = new byte[BUFFER_SIZE];
new Random().nextBytes(bytes);
bootstrapEnv(BUFFER_SIZE);
final ByteBuf data = Unpooled.wrappedBuffer(bytes);
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.DEFLATE);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient());
clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient());
clientHandler.flush(ctxClient());
}
});
awaitServer();
assertEquals(data.resetReaderIndex().toString(CharsetUtil.UTF_8),
serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
}
}
private void bootstrapEnv(int serverOutSize) throws Exception {
final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1);
serverOut = new ByteArrayOutputStream(serverOutSize);
serverLatch = new CountDownLatch(1);
sb = new ServerBootstrap();
cb = new Bootstrap();
// Streams are created before the normal flow for this test, so these connection must be initialized up front.
serverConnection = new DefaultHttp2Connection(true);
clientConnection = new DefaultHttp2Connection(false);
serverConnection.addListener(new Http2ConnectionAdapter() {
@Override
public void onStreamClosed(Http2Stream stream) {
serverLatch.countDown();
}
});
doAnswer(new Answer<Integer>() {
@Override
public Integer answer(InvocationOnMock in) throws Throwable {
ByteBuf buf = (ByteBuf) in.getArguments()[2];
int padding = (Integer) in.getArguments()[3];
int processedBytes = buf.readableBytes() + padding;
buf.readBytes(serverOut, buf.readableBytes());
if (in.getArgument(4)) {
serverConnection.stream((Integer) in.getArgument(1)).close();
}
return processedBytes;
}
}).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(),
any(ByteBuf.class), anyInt(), anyBoolean());
final CountDownLatch serverChannelLatch = new CountDownLatch(1);
sb.group(new NioEventLoopGroup(), new NioEventLoopGroup());
sb.channel(NioServerSocketChannel.class);
sb.childHandler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
serverConnectedChannel = ch;
ChannelPipeline p = ch.pipeline();
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
serverConnection.remote().flowController(
new DefaultHttp2RemoteFlowController(serverConnection));
serverConnection.local().flowController(
new DefaultHttp2LocalFlowController(serverConnection).frameWriter(frameWriter));
Http2ConnectionEncoder encoder = new CompressorHttp2ConnectionEncoder(
new DefaultHttp2ConnectionEncoder(serverConnection, frameWriter));
Http2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(serverConnection, encoder, new DefaultHttp2FrameReader());
Http2ConnectionHandler connectionHandler = new Http2ConnectionHandlerBuilder()
.frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener))
.codec(decoder, encoder).build();
p.addLast(connectionHandler);
serverChannelLatch.countDown();
}
});
cb.group(new NioEventLoopGroup());
cb.channel(NioSocketChannel.class);
cb.handler(new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
clientConnection.remote().flowController(
new DefaultHttp2RemoteFlowController(clientConnection));
clientConnection.local().flowController(
new DefaultHttp2LocalFlowController(clientConnection).frameWriter(frameWriter));
clientEncoder = new CompressorHttp2ConnectionEncoder(
new DefaultHttp2ConnectionEncoder(clientConnection, frameWriter));
Http2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(clientConnection, clientEncoder,
new DefaultHttp2FrameReader());
clientHandler = new Http2ConnectionHandlerBuilder()
.frameListener(new DelegatingDecompressorFrameListener(clientConnection, clientListener))
// By default tests don't wait for server to gracefully shutdown streams
.gracefulShutdownTimeoutMillis(0)
.codec(decoder, clientEncoder).build();
p.addLast(clientHandler);
p.addLast(new ChannelInboundHandlerAdapter() {
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof Http2ConnectionPrefaceWrittenEvent) {
prefaceWrittenLatch.countDown();
ctx.pipeline().remove(this);
}
}
});
}
});
serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel();
int port = ((InetSocketAddress) serverChannel.localAddress()).getPort();
ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port));
assertTrue(ccf.awaitUninterruptibly().isSuccess());
clientChannel = ccf.channel();
assertTrue(prefaceWrittenLatch.await(5, SECONDS));
assertTrue(serverChannelLatch.await(5, SECONDS));
}
private void awaitServer() throws Exception {
assertTrue(serverLatch.await(5, SECONDS));
serverOut.flush();
}
private ChannelHandlerContext ctxClient() {
return clientChannel.pipeline().firstContext();
}
private ChannelPromise newPromiseClient() {
return ctxClient().newPromise();
}
}