package org.mockserver.proxy.unification;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.socks.SocksAuthScheme;
import io.netty.handler.codec.socks.SocksInitRequestDecoder;
import io.netty.handler.codec.socks.SocksMessageEncoder;
import io.netty.handler.codec.socks.SocksProtocolVersion;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.AttributeKey;
import org.mockserver.logging.LoggingHandler;
import org.mockserver.proxy.socks.SocksProxyHandler;
import org.mockserver.socket.SSLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.mockserver.proxy.error.Logging.shouldIgnoreException;
/**
* @author jamesdbloom
*/
@ChannelHandler.Sharable
public abstract class PortUnificationHandler extends SimpleChannelInboundHandler<ByteBuf> {
public static final AttributeKey<Boolean> SSL_ENABLED_UPSTREAM = AttributeKey.valueOf("PROXY_SSL_ENABLED_UPSTREAM");
public static final AttributeKey<Boolean> SSL_ENABLED_DOWNSTREAM = AttributeKey.valueOf("SSL_ENABLED_DOWNSTREAM");
@VisibleForTesting
public static Logger logger = LoggerFactory.getLogger(PortUnificationHandler.class);
public PortUnificationHandler() {
super(false);
}
public static void enabledSslUpstreamAndDownstream(Channel channel) {
channel.attr(PortUnificationHandler.SSL_ENABLED_UPSTREAM).set(Boolean.TRUE);
channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE);
}
public static boolean isSslEnabledUpstream(Channel channel) {
if (channel.attr(SSL_ENABLED_UPSTREAM).get() != null) {
return channel.attr(SSL_ENABLED_UPSTREAM).get();
} else {
return false;
}
}
public static void enabledSslDownstream(Channel channel) {
channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.TRUE);
}
public static void disableSslDownstream(Channel channel) {
channel.attr(PortUnificationHandler.SSL_ENABLED_DOWNSTREAM).set(Boolean.FALSE);
}
public static boolean isSslEnabledDownstream(Channel channel) {
if (channel.attr(SSL_ENABLED_DOWNSTREAM).get() != null) {
return channel.attr(SSL_ENABLED_DOWNSTREAM).get();
} else {
return false;
}
}
/**
* Closes the specified channel after all queued write requests are flushed.
*/
public static void closeOnFlush(Channel ch) {
if (ch != null && ch.isActive()) {
ch.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
// Will use the first five bytes to detect a protocol.
if (msg.readableBytes() < 3) {
return;
}
if (isSsl(msg)) {
enableSsl(ctx, msg);
} else if (isSocks(msg)) {
enableSocks(ctx, msg);
} else if (isHttp(msg)) {
switchToHttp(ctx, msg);
} else {
// Unknown protocol; discard everything and close the connection.
msg.clear();
ctx.close();
}
if (logger.isTraceEnabled()) {
if (ctx.pipeline().get(LoggingHandler.class) != null) {
ctx.pipeline().remove(LoggingHandler.class);
}
if (ctx.pipeline().get(SslHandler.class) != null) {
ctx.pipeline().addAfter("SslHandler#0", "LoggingHandler#0", new LoggingHandler(logger));
} else {
ctx.pipeline().addFirst(new LoggingHandler(logger));
}
}
}
private boolean isSsl(ByteBuf buf) {
return buf.readableBytes() >= 5 && SslHandler.isEncrypted(buf);
}
private boolean isSocks(ByteBuf msg) {
switch (SocksProtocolVersion.valueOf(msg.getByte(msg.readerIndex()))) {
case SOCKS5:
case SOCKS4a:
break;
default:
return false;
}
byte numberOfAuthenticationMethods = msg.getByte(msg.readerIndex() + 1);
for (int i = 0; i < numberOfAuthenticationMethods; i++) {
switch (SocksAuthScheme.valueOf(msg.getByte(msg.readerIndex() + 1 + i))) {
case NO_AUTH:
case AUTH_PASSWORD:
case AUTH_GSSAPI:
break;
default:
return false;
}
}
return true;
}
private boolean isHttp(ByteBuf msg) {
int letterOne = (int) msg.getUnsignedByte(msg.readerIndex());
int letterTwo = (int) msg.getUnsignedByte(msg.readerIndex() + 1);
int letterThree = (int) msg.getUnsignedByte(msg.readerIndex() + 2);
return letterOne == 'G' && letterTwo == 'E' && letterThree == 'T' || // GET
letterOne == 'P' && letterTwo == 'O' && letterThree == 'S' || // POST
letterOne == 'P' && letterTwo == 'U' && letterThree == 'T' || // PUT
letterOne == 'H' && letterTwo == 'E' && letterThree == 'A' || // HEAD
letterOne == 'O' && letterTwo == 'P' && letterThree == 'T' || // OPTIONS
letterOne == 'P' && letterTwo == 'A' && letterThree == 'T' || // PATCH
letterOne == 'D' && letterTwo == 'E' && letterThree == 'L' || // DELETE
letterOne == 'T' && letterTwo == 'R' && letterThree == 'A' || // TRACE
letterOne == 'C' && letterTwo == 'O' && letterThree == 'N'; // CONNECT
}
private void enableSsl(ChannelHandlerContext ctx, ByteBuf msg) {
ChannelPipeline pipeline = ctx.pipeline();
pipeline.addFirst(new SslHandler(SSLFactory.createServerSSLEngine()));
// re-unify (with SSL enabled)
PortUnificationHandler.enabledSslUpstreamAndDownstream(ctx.channel());
ctx.pipeline().fireChannelRead(msg);
}
private void enableSocks(ChannelHandlerContext ctx, ByteBuf msg) {
ChannelPipeline pipeline = ctx.pipeline();
pipeline.addFirst(new SocksProxyHandler());
pipeline.addFirst(new SocksMessageEncoder());
pipeline.addFirst(new SocksInitRequestDecoder());
// re-unify (with SOCKS enabled)
ctx.pipeline().fireChannelRead(msg);
}
private void switchToHttp(ChannelHandlerContext ctx, ByteBuf msg) {
ChannelPipeline pipeline = ctx.pipeline();
addLastIfNotPresent(pipeline, new HttpServerCodec());
addLastIfNotPresent(pipeline, new HttpContentDecompressor());
addLastIfNotPresent(pipeline, new HttpObjectAggregator(Integer.MAX_VALUE));
configurePipeline(ctx, pipeline);
pipeline.remove(this);
// pass message to next stage in pipeline
ctx.fireChannelRead(msg);
}
protected void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) {
if (pipeline.get(channelHandler.getClass()) == null) {
pipeline.addLast(channelHandler);
}
}
protected abstract void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline);
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (!shouldIgnoreException(cause)) {
logger.warn("Exception caught by port unification handler -> closing pipeline " + ctx.channel(), cause);
}
closeOnFlush(ctx.channel());
}
}