package org.mockserver.proxy.http;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.socks.SocksInitRequestDecoder;
import io.netty.handler.codec.socks.SocksMessageEncoder;
import org.apache.commons.codec.binary.Hex;
import org.junit.Test;
import org.mockserver.proxy.relay.RelayConnectHandler;
import org.mockserver.proxy.socks.SocksProxyHandler;
import org.mockserver.proxy.unification.PortUnificationHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetSocketAddress;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
public class HttpProxyUnificationHandlerSOCKSErrorTest {
@Test
public void shouldHandleErrorsDuringSOCKSConnection() throws IOException, InterruptedException {
// given - embedded channel
short localPort = 1234;
EmbeddedChannel embeddedChannel = new EmbeddedChannel(new HttpProxyUnificationHandler());
embeddedChannel.attr(HttpProxy.HTTP_CONNECT_SOCKET).set(new InetSocketAddress(localPort));
// and - mock logger
RelayConnectHandler.logger = mock(Logger.class);
// and - no SOCKS handlers
assertThat(embeddedChannel.pipeline().get(SocksProxyHandler.class), is(nullValue()));
assertThat(embeddedChannel.pipeline().get(SocksMessageEncoder.class), is(nullValue()));
assertThat(embeddedChannel.pipeline().get(SocksInitRequestDecoder.class), is(nullValue()));
// when - SOCKS INIT message
embeddedChannel.writeInbound(Unpooled.wrappedBuffer(new byte[]{
(byte) 0x05, // SOCKS5
(byte) 0x02, // 1 authentication method
(byte) 0x00, // NO_AUTH
(byte) 0x02, // AUTH_PASSWORD
}));
// then - INIT response
assertThat(ByteBufUtil.hexDump((ByteBuf) embeddedChannel.readOutbound()), is(Hex.encodeHexString(new byte[]{
(byte) 0x05, // SOCKS5
(byte) 0x00, // NO_AUTH
})));
// and then - should add SOCKS handlers first
if (LoggerFactory.getLogger(PortUnificationHandler.class).isTraceEnabled()) {
assertThat(embeddedChannel.pipeline().names(), contains(
"LoggingHandler#0",
"SocksCmdRequestDecoder#0",
"SocksMessageEncoder#0",
"SocksProxyHandler#0",
"HttpProxyUnificationHandler#0",
"DefaultChannelPipeline$TailContext#0"
));
} else {
assertThat(embeddedChannel.pipeline().names(), contains(
"SocksCmdRequestDecoder#0",
"SocksMessageEncoder#0",
"SocksProxyHandler#0",
"HttpProxyUnificationHandler#0",
"DefaultChannelPipeline$TailContext#0"
));
}
// and when - SOCKS CONNECT command
embeddedChannel.writeInbound(Unpooled.wrappedBuffer(new byte[]{
(byte) 0x05, // SOCKS5
(byte) 0x01, // command type CONNECT
(byte) 0x00, // reserved (must be 0x00)
(byte) 0x01, // address type IPv4
(byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address
(byte) (localPort & 0xFF00), (byte) localPort // port
}));
// then - CONNECT response
assertThat(ByteBufUtil.hexDump((ByteBuf) embeddedChannel.readOutbound()), is(Hex.encodeHexString(new byte[]{
(byte) 0x05, // SOCKS5
(byte) 0x01, // general failure (caused by connection failure)
(byte) 0x00, // reserved (must be 0x00)
(byte) 0x01, // address type IPv4
(byte) 0x7f, (byte) 0x00, (byte) 0x00, (byte) 0x01, // ip address
(byte) (localPort & 0xFF00), (byte) localPort // port
})));
// then - channel is closed after error
assertThat(embeddedChannel.isOpen(), is(false));
verify(RelayConnectHandler.logger).warn(eq("Connection failed to 0.0.0.0/0.0.0.0:1234"), any(IllegalStateException.class));
}
@Test
public void shouldSwitchToHttp() {
// given
EmbeddedChannel embeddedChannel = new EmbeddedChannel(new HttpProxyUnificationHandler());
// and - no HTTP handlers
assertThat(embeddedChannel.pipeline().get(HttpServerCodec.class), is(nullValue()));
assertThat(embeddedChannel.pipeline().get(HttpContentDecompressor.class), is(nullValue()));
assertThat(embeddedChannel.pipeline().get(HttpObjectAggregator.class), is(nullValue()));
// when - basic HTTP request
embeddedChannel.writeInbound(Unpooled.wrappedBuffer("GET /somePath HTTP/1.1\r\nHost: some.random.host\r\n\r\n".getBytes()));
// then - should add HTTP handlers last
if (LoggerFactory.getLogger(PortUnificationHandler.class).isTraceEnabled()) {
assertThat(embeddedChannel.pipeline().names(), contains(
"LoggingHandler#0",
"HttpServerCodec#0",
"HttpContentDecompressor#0",
"HttpObjectAggregator#0",
"MockServerServerCodec#0",
"HttpProxyHandler#0",
"DefaultChannelPipeline$TailContext#0"
));
} else {
assertThat(embeddedChannel.pipeline().names(), contains(
"HttpServerCodec#0",
"HttpContentDecompressor#0",
"HttpObjectAggregator#0",
"MockServerServerCodec#0",
"HttpProxyHandler#0",
"DefaultChannelPipeline$TailContext#0"
));
}
}
@Test
public void shouldSupportUnknownProtocol() {
// given
EmbeddedChannel embeddedChannel = new EmbeddedChannel(new HttpProxyUnificationHandler());
// and - channel open
assertThat(embeddedChannel.isOpen(), is(true));
// when - basic HTTP request
embeddedChannel.writeInbound(Unpooled.wrappedBuffer("UNKNOWN_PROTOCOL".getBytes()));
// then - should add no handlers
assertThat(embeddedChannel.pipeline().names(), contains(
"HttpProxyUnificationHandler#0",
"DefaultChannelPipeline$TailContext#0"
));
// and - close channel
assertThat(embeddedChannel.isOpen(), is(false));
}
}