/**
* Copyright 2014 Ricardo Padilha
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.dsys.snio.impl.channel;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.NoConnectionPendingException;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.concurrent.Callable;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLSession;
import net.dsys.commons.api.exception.Bug;
import net.dsys.commons.impl.future.SettableCallbackFuture;
import net.dsys.snio.api.buffer.MessageBufferConsumer;
import net.dsys.snio.api.buffer.MessageBufferProducer;
import net.dsys.snio.api.buffer.MessageBufferProvider;
import net.dsys.snio.api.codec.MessageCodec;
import net.dsys.snio.api.limit.RateLimiter;
/**
* @author Ricardo Padilha
*/
final class SSLProcessor extends AbstractProcessor<ByteBuffer> {
private static final int NO_SEQUENCE = -1;
@Nonnull
private final MessageCodec codec;
@Nonnull
private final RateLimiter limiter;
@Nonnull
private final SSLEngine engine;
@Nonnegative
private final int sendSize;
@Nonnegative
private final int receiveSize;
private ByteBuffer receiveBuffer;
private ByteBuffer sendBuffer;
private ByteBuffer preSendBuffer;
private ByteBuffer postReceiveBuffer;
private long writeSequence;
private volatile SettableCallbackFuture<Void> closeFuture;
private volatile Callable<Void> closeTask;
private volatile boolean closedInternally;
private volatile boolean closed;
SSLProcessor(@Nonnull final MessageCodec codec, @Nonnull final RateLimiter limiter,
@Nonnull final MessageBufferProvider<ByteBuffer> provider,
@Nonnegative final int sendBufferSize, @Nonnegative final int receiveBufferSize,
@Nonnull final SSLEngine engine) {
super(provider);
if (codec == null) {
throw new NullPointerException("codec == null");
}
if (limiter == null) {
throw new NullPointerException("limiter == null");
}
if (engine == null) {
throw new NullPointerException("engine == null");
}
final int sendSize = nearestPowerOfTwo(Math.max(sendBufferSize, codec.getFrameLength()));
final int receiveSize = nearestPowerOfTwo(Math.max(receiveBufferSize, codec.getFrameLength()));
if (sendSize < 1) {
throw new IllegalArgumentException("sendSize < 1");
}
if (receiveSize < 1) {
throw new IllegalArgumentException("receiveSize < 1");
}
this.codec = codec;
this.limiter = limiter;
this.engine = engine;
this.sendSize = sendSize;
this.receiveSize = receiveSize;
this.writeSequence = NO_SEQUENCE;
}
/**
* {@inheritDoc}
*/
@Override
public void connect(final SelectionKey key) {
final SocketChannel client = (SocketChannel) key.channel();
try {
if (client.finishConnect()) {
key.interestOps(key.interestOps() & ~SelectionKey.OP_CONNECT | SelectionKey.OP_READ);
assert key.attachment() instanceof TCPChannel;
((TCPChannel<?>) key.attachment()).register();
}
} catch (final IOException | NoConnectionPendingException e) {
getConnectReadFuture().fail(e);
}
}
/**
* {@inheritDoc}
*/
@Override
protected void readRegistered(final SelectionKey key) {
if (key == null) {
throw new NullPointerException("key == null");
}
final SSLSession session = engine.getSession();
final int delta = session.getPacketBufferSize() - session.getApplicationBufferSize();
final int outSize = Math.max(receiveSize, session.getPacketBufferSize());
final int inSize = Math.max(receiveSize - delta, session.getApplicationBufferSize());
this.receiveBuffer = ByteBuffer.allocateDirect(outSize);
this.postReceiveBuffer = ByteBuffer.allocate(inSize);
}
/**
* {@inheritDoc}
*/
@Override
protected void writeRegistered(final SelectionKey key) {
if (key == null) {
throw new NullPointerException("key == null");
}
final SSLSession session = engine.getSession();
final int delta = session.getPacketBufferSize() - session.getApplicationBufferSize();
final int inSize = Math.max(sendSize - delta, session.getApplicationBufferSize());
final int outSize = Math.max(sendSize, session.getPacketBufferSize());
this.preSendBuffer = ByteBuffer.allocate(inSize);
this.sendBuffer = ByteBuffer.allocateDirect(outSize);
}
/**
* {@inheritDoc}
*/
@Override
public long read(final SelectionKey key) throws IOException {
if (closed) {
return 0;
}
final SocketChannel channel = (SocketChannel) key.channel();
final MessageBufferProducer<ByteBuffer> chnOut = getChannelOutput();
final MessageBufferProducer<ByteBuffer> appOut = getOutputBuffer();
final long n = channel.read(receiveBuffer);
if (n <= 0) {
// (n < 0) means channel closed from the other side
closedInternally = true;
return n;
}
limiter.receive(n);
receiveBuffer.flip();
// SSL handling
boolean closed = false;
SSLEngineResult result = engine.unwrap(receiveBuffer, postReceiveBuffer);
// state machine
if (result.getHandshakeStatus() != NOT_HANDSHAKING) {
// handshaking
SSLEngineResult.Status status = result.getStatus();
SSLEngineResult.HandshakeStatus hstatus = result.getHandshakeStatus();
while (status != BUFFER_OVERFLOW && hstatus != NEED_WRAP && hstatus != FINISHED) {
assert hstatus == NEED_UNWRAP || hstatus == NEED_TASK;
if (hstatus == NEED_TASK) {
Runnable task;
while ((task = engine.getDelegatedTask()) != null) {
task.run();
}
}
result = engine.unwrap(receiveBuffer, postReceiveBuffer);
status = result.getStatus();
hstatus = result.getHandshakeStatus();
assert result.bytesProduced() == 0;
}
if (hstatus == NEED_WRAP) {
wakeupWriter();
}
if (receiveBuffer.remaining() > 0) {
receiveBuffer.compact();
} else {
receiveBuffer.clear();
}
} else {
switch (result.getStatus()) {
case OK: {
assert result.bytesConsumed() > 0 && result.bytesProduced() > 0;
// normal encryption and send
if (receiveBuffer.remaining() > 0) {
receiveBuffer.compact();
} else {
receiveBuffer.clear();
}
break;
}
case BUFFER_UNDERFLOW:
case BUFFER_OVERFLOW: {
assert result.bytesConsumed() == 0 && result.bytesProduced() == 0;
// We can't decrypt more until some bytes are delivered.
// We have to "unflip" receiveBuffer, otherwise the flip above
// will "zero" the buffer the next time read() is called.
receiveBuffer.position(receiveBuffer.limit());
receiveBuffer.limit(receiveBuffer.capacity());
break;
}
case CLOSED: {
//assert result.bytesConsumed() > 0 && result.bytesProduced() == 0;
// SSLEngine close handshake was completed
closedInternally = true;
closed = true;
break;
}
default: {
// some status code that is not known
throw new Bug("Unsupported SSLEngineResult.Status: " + result.getStatus());
}
}
}
postReceiveBuffer.flip();
while (codec.hasNext(postReceiveBuffer)) {
try {
final long sequence = chnOut.acquire();
try {
final ByteBuffer msg = chnOut.get(sequence);
msg.clear();
codec.get(postReceiveBuffer, msg);
msg.flip();
chnOut.attach(sequence, appOut);
} finally {
chnOut.release(sequence);
}
} catch (final InterruptedException e) {
throw new IOException(e);
}
}
if (postReceiveBuffer.remaining() > 0) {
postReceiveBuffer.compact();
} else {
postReceiveBuffer.clear();
}
if (closed) {
return -1;
}
return n;
}
/**
* {@inheritDoc}
*/
@Override
public long write(final SelectionKey key) throws IOException {
if (closed) {
return 0;
}
final SocketChannel channel = (SocketChannel) key.channel();
final MessageBufferConsumer<ByteBuffer> chnIn = getChannelInput();
try {
long k = chnIn.remaining();
while (--k >= 0) {
if (writeSequence == NO_SEQUENCE) {
writeSequence = chnIn.acquire();
}
final ByteBuffer msg = chnIn.get(writeSequence);
final int msglen = codec.getEncodedLength(msg);
if (msglen > preSendBuffer.capacity()) {
// this message is too big for the current buffer
throw new IOException("codec.length(msg) > preSendBuffer.capacity()");
}
if (msglen > preSendBuffer.remaining()) {
break;
}
codec.put(msg, preSendBuffer);
msg.clear();
chnIn.release(writeSequence);
writeSequence = NO_SEQUENCE;
}
} catch (final InterruptedException e) {
throw new IOException(e);
}
preSendBuffer.flip(); // ready to send
// SSL handling
boolean closed = false;
SSLEngineResult result = engine.wrap(preSendBuffer, sendBuffer);
// state machine
if (result.getHandshakeStatus() != NOT_HANDSHAKING) {
// handshaking
SSLEngineResult.Status status = result.getStatus();
SSLEngineResult.HandshakeStatus hstatus = result.getHandshakeStatus();
while (status != BUFFER_OVERFLOW && hstatus != NEED_UNWRAP && hstatus != FINISHED) {
assert hstatus == NEED_WRAP || hstatus == NEED_TASK;
if (hstatus == NEED_TASK) {
Runnable task;
while ((task = engine.getDelegatedTask()) != null) {
task.run();
}
}
result = engine.wrap(preSendBuffer, sendBuffer);
status = result.getStatus();
hstatus = result.getHandshakeStatus();
assert result.bytesConsumed() == 0;
}
if (preSendBuffer.remaining() > 0) {
preSendBuffer.compact();
} else {
preSendBuffer.clear();
}
} else {
switch (result.getStatus()) {
case OK: {
//assert result.bytesConsumed() > 0 && result.bytesProduced() > 0;
// normal encryption and send
if (preSendBuffer.remaining() > 0) {
preSendBuffer.compact();
} else {
preSendBuffer.clear();
}
break;
}
case BUFFER_OVERFLOW: {
assert result.bytesConsumed() == 0 && result.bytesProduced() == 0;
// We can't encrypt more until some bytes are sent.
// We have to "unflip" preSendBuffer, otherwise the flip above
// will "zero" the buffer the next time write() is called.
preSendBuffer.position(preSendBuffer.limit());
preSendBuffer.limit(preSendBuffer.capacity());
break;
}
case CLOSED: {
//assert result.bytesConsumed() == 0 && result.bytesProduced() > 0;
// SSLEngine close handshake was completed
closedInternally = true;
closed = true;
break;
}
case BUFFER_UNDERFLOW:
default: {
// both cases are illegal here
throw new Bug("Unsupported SSLEngineResult.Status: " + result.getStatus());
}
}
}
sendBuffer.flip();
limiter.send(sendBuffer.remaining());
final int n = channel.write(sendBuffer);
if (sendBuffer.remaining() > 0) {
sendBuffer.compact();
return n;
}
sendBuffer.clear();
if (chnIn.remaining() == 0) {
disableWriter();
}
if (closed) {
return -1;
}
return n;
}
/**
* {@inheritDoc}
*/
@Override
protected void shutdown(final SettableCallbackFuture<Void> future, final Callable<Void> task) {
if (future == null) {
throw new NullPointerException("future == null");
}
if (task == null) {
throw new NullPointerException("task == null");
}
this.closeFuture = future;
this.closeTask = task;
engine.closeOutbound();
if (closedInternally || (receiveBuffer == null && sendBuffer == null)) {
// not yet connected or disconnected remotely
shutdown();
} else {
wakeupWriter();
}
}
private void shutdown() {
try {
codec.close();
closeTask.call();
closeFuture.success(null);
closed = true;
} catch (final Throwable t) {
closeFuture.fail(t);
}
}
/**
* @return nearest larger or equal power of two
*/
private static int nearestPowerOfTwo(final int num) {
int n = 0;
if (num > 0) {
n = num - 1;
}
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n++;
return n;
}
}