/*
* Copyright (c) 2015 Intellectual Reserve, Inc. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package cf.dropsonde.firehose;
import org.cloudfoundry.dropsonde.events.Envelope;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrameAggregator;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.CharsetUtil;
import rx.Subscriber;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManagerFactory;
import java.io.Closeable;
import java.net.URI;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
/**
* @author Mike Heath
*/
class NettyFirehoseOnSubscribe implements rx.Observable.OnSubscribe<Envelope>, Closeable {
public static final String HANDLER_NAME = "handler";
private Channel channel;
// Set this EventLoopGroup if an EventLoopGroup was not provided. This groups needs to be shutdown when the
// firehose client shuts down.
private final EventLoopGroup eventLoopGroup;
private final Bootstrap bootstrap;
public NettyFirehoseOnSubscribe(URI uri, String token, String subscriptionId, boolean skipTlsValidation, EventLoopGroup eventLoopGroup, Class<? extends SocketChannel> channelClass) {
try {
final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost();
final String scheme = uri.getScheme() == null ? "ws" : uri.getScheme();
final int port = getPort(scheme, uri.getPort());
final URI fullUri = uri.resolve("/firehose/" + subscriptionId);
final SslContext sslContext;
if ("wss".equalsIgnoreCase(scheme)) {
final SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
if (skipTlsValidation) {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
} else {
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init((KeyStore)null);
sslContextBuilder.trustManager(trustManagerFactory);
}
sslContext = sslContextBuilder.build();
} else {
sslContext = null;
}
bootstrap = new Bootstrap();
if (eventLoopGroup == null) {
this.eventLoopGroup = new NioEventLoopGroup();
bootstrap.group(this.eventLoopGroup);
} else {
this.eventLoopGroup = null;
bootstrap.group(eventLoopGroup);
}
bootstrap
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 15000)
.channel(channelClass == null ? NioSocketChannel.class : channelClass)
.remoteAddress(host, port)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel c) throws Exception {
final HttpHeaders headers = new DefaultHttpHeaders();
headers.add(HttpHeaders.Names.AUTHORIZATION, token);
final WebSocketClientHandler handler =
new WebSocketClientHandler(
WebSocketClientHandshakerFactory.newHandshaker(
fullUri, WebSocketVersion.V13, null, false, headers));
final ChannelPipeline pipeline = c.pipeline();
if (sslContext != null) {
pipeline.addLast(sslContext.newHandler(c.alloc(), host, port));
}
pipeline.addLast(new ReadTimeoutHandler(30));
pipeline.addLast(
new HttpClientCodec(),
new HttpObjectAggregator(8192));
pipeline.addLast(HANDLER_NAME, handler);
channel = c;
}
});
} catch (NoSuchAlgorithmException | SSLException | KeyStoreException e) {
throw new RuntimeException(e);
}
}
@Override
public void call(Subscriber<? super Envelope> subscriber) {
bootstrap.connect().addListener((ChannelFutureListener)future -> {
if (future.isSuccess()) {
future.channel().pipeline().get(WebSocketClientHandler.class).setSubscriber(subscriber);
} else {
subscriber.onError(future.cause());
}
});
}
@Override
public void close() {
if (channel != null) {
channel.close();
}
if (eventLoopGroup != null) {
eventLoopGroup.shutdownGracefully();
}
}
private int getPort(String scheme, int port) {
if (port == -1) {
if ("ws".equalsIgnoreCase(scheme)) {
return 80;
} else if ("wss".equalsIgnoreCase(scheme)) {
return 443;
} else {
return -1;
}
} else {
return port;
}
}
public boolean isConnected() {
return channel != null && channel.isActive();
}
private static class WebSocketClientHandler extends SimpleChannelInboundHandler<Object> {
private final WebSocketClientHandshaker handshaker;
private Subscriber<? super Envelope> subscriber;
public WebSocketClientHandler(WebSocketClientHandshaker handshaker) {
this.handshaker = handshaker;
}
@Override
public void channelActive(ChannelHandlerContext context) throws Exception {
handshaker.handshake(context.channel());
}
@Override
public void channelInactive(ChannelHandlerContext context) throws Exception {
subscriber.onCompleted();
}
@Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
context.close();
subscriber.onError(cause);
}
@Override
protected void channelRead0(ChannelHandlerContext context, Object message) throws Exception {
final Channel channel = context.channel();
if (!handshaker.isHandshakeComplete()) {
handshaker.finishHandshake(channel, (FullHttpResponse) message);
channel.pipeline().addBefore(HANDLER_NAME, "websocket-frame-aggregator", new WebSocketFrameAggregator(64 * 1024));
subscriber.onStart();
return;
}
if (message instanceof FullHttpResponse) {
final FullHttpResponse response = (FullHttpResponse) message;
throw new IllegalStateException(
"Unexpected FullHttpResponse (getStatus=" + response.getStatus() +
", content=" + response.content().toString(CharsetUtil.UTF_8) + ')');
}
final WebSocketFrame frame = (WebSocketFrame) message;
if (frame instanceof PingWebSocketFrame) {
context.writeAndFlush(new PongWebSocketFrame(((PingWebSocketFrame)frame).retain().content()));
} else if (frame instanceof BinaryWebSocketFrame) {
final ByteBufInputStream input = new ByteBufInputStream(((BinaryWebSocketFrame)message).content());
final Envelope envelope = Envelope.ADAPTER.decode(input);
subscriber.onNext(envelope);
}
}
public void setSubscriber(Subscriber<? super Envelope> subscriber) {
this.subscriber = subscriber;
}
}
}