/**
This file is part of Waarp Project.
Copyright 2009, Frederic Bregier, and individual contributors by the @author
tags. See the COPYRIGHT.txt in the distribution for a full listing of
individual contributors.
All Waarp Project is free software: you can redistribute it and/or
modify it under the terms of the GNU General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Waarp is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with Waarp . If not, see <http://www.gnu.org/licenses/>.
*/
package org.waarp.common.crypto.ssl;
import java.util.NoSuchElementException;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.DefaultEventExecutor;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.waarp.common.logging.WaarpLogger;
import org.waarp.common.logging.WaarpLoggerFactory;
import org.waarp.common.utility.WaarpThreadFactory;
/**
* Utilities for SSL support
*
* @author "Frederic Bregier"
*
*/
public class WaarpSslUtility {
/**
* Internal Logger
*/
private static final WaarpLogger logger = WaarpLoggerFactory.getLogger(WaarpSslUtility.class);
/**
* EventExecutor associated with Ssl utility
*/
private static final EventExecutor SSL_EVENT_EXECUTOR = new DefaultEventExecutor(new WaarpThreadFactory("SSLEVENT"));
/**
* ChannelGroup for SSL
*/
private static final ChannelGroup sslChannelGroup = new DefaultChannelGroup("SslChannelGroup", SSL_EVENT_EXECUTOR);
/**
* Add the Channel as SSL handshake will start soon
*
* @param channel
*/
public static void addSslOpenedChannel(Channel channel) {
sslChannelGroup.add(channel);
}
/**
* Add a SslHandler in a pipeline when the channel is already active
*
* @param future
* might be null, condition to start to add the handler to the pipeline
* @param pipeline
* @param sslHandler
* @param listener
* action once the handshake is done
*/
@SuppressWarnings({ "unchecked", "rawtypes" })
public static void addSslHandler(ChannelFuture future, final ChannelPipeline pipeline,
final ChannelHandler sslHandler,
final GenericFutureListener<? extends Future<? super Channel>> listener) {
if (future == null) {
logger.debug("Add SslHandler: " + pipeline.channel());
pipeline.addFirst("SSL", sslHandler);
((SslHandler) sslHandler).handshakeFuture().addListener(listener);
} else {
future.addListener(new GenericFutureListener() {
public void operationComplete(Future future) throws Exception {
logger.debug("Add SslHandler: " + pipeline.channel());
pipeline.addFirst("SSL", sslHandler);
((SslHandler) sslHandler).handshakeFuture().addListener(listener);
}
});
}
logger.debug("Checked Ssl Handler to be added: " + pipeline.channel());
}
/**
* Wait for the handshake on the given channel (better to use addSslHandler when handler is added after channel is active)
*
* @param channel
* @return True if the Handshake is done correctly
*/
public static boolean waitForHandshake(Channel channel) {
final ChannelHandler handler = channel.pipeline().first();
if (handler instanceof SslHandler) {
logger.debug("Start handshake SSL: " + channel);
final SslHandler sslHandler = (SslHandler) handler;
// Get the SslHandler and begin handshake ASAP.
// Get notified when SSL handshake is done.
Future<Channel> handshakeFuture = sslHandler.handshakeFuture();
try {
handshakeFuture.await(sslHandler.getHandshakeTimeoutMillis() + 100);
} catch (InterruptedException e1) {
}
logger.debug("Handshake: " + handshakeFuture.isSuccess() + ": " + channel, handshakeFuture.cause());
if (!handshakeFuture.isSuccess()) {
channel.close();
return false;
}
return true;
} else {
logger.error("SSL Not found but connected: " + handler.getClass().getName());
return true;
}
}
/**
* Waiting for the channel to be opened and ready (Client side) (blocking call)
*
* @param future
* a future on connect only
* @return the channel if correctly associated, else return null
*/
public static Channel waitforChannelReady(ChannelFuture future) {
// Wait until the connection attempt succeeds or fails.
try {
future.await(10000);
} catch (InterruptedException e1) {
}
if (!future.isSuccess()) {
logger.error("Channel not connected", future.cause());
return null;
}
Channel channel = future.channel();
if (waitForHandshake(channel)) {
return channel;
}
return null;
}
/**
* Utility to force all channels to be closed
*/
public static void forceCloseAllSslChannels() {
for (Channel channel : sslChannelGroup) {
closingSslChannel(channel);
}
sslChannelGroup.close();
SSL_EVENT_EXECUTOR.shutdownGracefully();
}
/**
* Utility method to close a channel in SSL mode correctly (if any)
*
* @param channel
*/
public static ChannelFuture closingSslChannel(Channel channel) {
if (channel.isActive()) {
removingSslHandler(null, channel, true);
logger.debug("Close the channel and returns the ChannelFuture: " + channel.toString());
return channel.closeFuture();
}
logger.debug("Already closed");
return channel.newSucceededFuture();
}
/**
* Remove the SslHandler (if any) cleanly
*
* @param future
* if not null, wait for this future to be done to removed the sslhandler
* @param channel
* @param close
* True to close the channel, else to only remove it
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public static void removingSslHandler(ChannelFuture future, final Channel channel, final boolean close) {
if (channel.isActive()) {
channel.config().setAutoRead(true);
ChannelHandler handler = channel.pipeline().first();
if (handler instanceof SslHandler) {
final SslHandler sslHandler = (SslHandler) handler;
if (future != null) {
future.addListener(new GenericFutureListener() {
public void operationComplete(Future future) throws Exception {
logger.debug("Found SslHandler and wait for Ssl.close()");
sslHandler.close().addListener(new GenericFutureListener<Future<? super Void>>() {
public void operationComplete(Future<? super Void> future) throws Exception {
logger.debug("Ssl closed");
if (!close) {
channel.pipeline().remove(sslHandler);
} else {
channel.close();
}
}
});
}
});
} else {
logger.debug("Found SslHandler and wait for Ssl.close() : " + channel);
sslHandler.close().addListener(new GenericFutureListener<Future<? super Void>>() {
public void operationComplete(Future<? super Void> future) throws Exception {
logger.debug("Ssl closed: " + channel);
if (!close) {
channel.pipeline().remove(sslHandler);
} else {
channel.close();
}
}
});
}
} else {
channel.close();
}
}
}
/**
* Thread used to ensure we are not in IO thread when waiting
*
* @author "Frederic Bregier"
*
*/
private static class SSLTHREAD extends Thread {
private final Channel channel;
/**
* @param channel
*/
private SSLTHREAD(Channel channel) {
this.channel = channel;
this.setDaemon(true);
this.setName("SSLTHREAD_" + this.getName());
}
@Override
public void run() {
closingSslChannel(channel);
}
}
/**
* Closing channel with SSL close at first step
*/
public static ChannelFutureListener SSLCLOSE = new ChannelFutureListener() {
public void operationComplete(ChannelFuture future) throws Exception {
if (future.channel().isActive()) {
SSLTHREAD thread = new SSLTHREAD(future.channel());
thread.start();
}
}
};
/**
* Wait for the channel with SSL to be closed
*
* @param channel
* @param delay
*/
public static boolean waitForClosingSslChannel(Channel channel, long delay) {
try {
if (!channel.closeFuture().await(delay)) {
try {
channel.pipeline().remove(WaarpSslHandler.class);
logger.debug("try to close anyway");
channel.close().await(delay);
return false;
} catch (NoSuchElementException e) {
// ignore;
channel.closeFuture().await(delay);
}
}
} catch (InterruptedException e) {
}
return true;
}
}