/*
* Copyright 2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ratpack.server.internal;
import io.netty.buffer.ByteBuf;
import io.netty.channel.*;
import io.netty.handler.codec.http.*;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.stream.ChunkedNioStream;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ratpack.api.Nullable;
import ratpack.exec.Blocking;
import ratpack.exec.Execution;
import ratpack.exec.Promise;
import ratpack.file.internal.ResponseTransmitter;
import ratpack.func.Action;
import ratpack.handling.RequestOutcome;
import ratpack.handling.internal.DefaultRequestOutcome;
import ratpack.handling.internal.DoubleTransmissionException;
import ratpack.http.Request;
import ratpack.http.RequestBodyTooLargeException;
import ratpack.http.SentResponse;
import ratpack.http.internal.*;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
public class DefaultResponseTransmitter implements ResponseTransmitter {
static final AttributeKey<DefaultResponseTransmitter> ATTRIBUTE_KEY = AttributeKey.valueOf(DefaultResponseTransmitter.class.getName());
private final static Logger LOGGER = LoggerFactory.getLogger(ResponseTransmitter.class);
private static final Runnable NOOP_RUNNABLE = () -> {
};
private final AtomicBoolean transmitted;
private final Channel channel;
private final Request ratpackRequest;
private final HttpHeaders responseHeaders;
private final RequestBody requestBody;
private final boolean isSsl;
private List<Action<? super RequestOutcome>> outcomeListeners;
private boolean isKeepAlive;
private Instant stopTime;
private Runnable onWritabilityChanged = NOOP_RUNNABLE;
public DefaultResponseTransmitter(
AtomicBoolean transmitted,
Channel channel,
HttpRequest nettyRequest,
Request ratpackRequest,
HttpHeaders responseHeaders,
@Nullable RequestBody requestBody
) {
this.transmitted = transmitted;
this.channel = channel;
this.ratpackRequest = ratpackRequest;
this.responseHeaders = responseHeaders;
this.requestBody = requestBody;
this.isKeepAlive = HttpUtil.isKeepAlive(nettyRequest);
this.isSsl = channel.pipeline().get(SslHandler.class) != null;
}
private void drainRequestBody(Consumer<Throwable> next) {
if (requestBody == null || !requestBody.isUnread()) {
next.accept(null);
} else {
if (Execution.isActive()) {
Promise.async(down ->
requestBody.drain(e -> {
if (e == null) {
down.success(null);
} else {
down.error(e);
}
})
)
.onError(next::accept)
.then(n -> next.accept(null));
} else {
requestBody.drain(next);
}
}
}
private ChannelFuture pre(HttpResponseStatus responseStatus) {
if (transmitted.compareAndSet(false, true)) {
stopTime = Instant.now();
try {
if (responseHeaders.contains(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE, true)) {
isKeepAlive = false;
} else if (!isKeepAlive) {
forceCloseConnection();
}
HttpResponse headersResponse = new CustomHttpResponse(responseStatus, responseHeaders);
if (mustHaveBody(responseStatus) && isKeepAlive && HttpUtil.getContentLength(headersResponse, -1) == -1 && !HttpUtil.isTransferEncodingChunked(headersResponse)) {
HttpUtil.setTransferEncodingChunked(headersResponse, true);
}
if (channel.isOpen()) {
return channel.writeAndFlush(headersResponse);
} else {
return null;
}
} catch (Exception e) {
LOGGER.warn("Error finalizing response", e);
return null;
}
} else {
String msg = "attempt at double transmission for: " + ratpackRequest.getRawUri();
LOGGER.warn(msg, new DoubleTransmissionException(msg));
return null;
}
}
private boolean mustHaveBody(HttpResponseStatus responseStatus) {
int code = responseStatus.code();
return (code < 100 || code >= 200) && code != 204 && code != 304;
}
@Override
public void transmit(HttpResponseStatus responseStatus, ByteBuf body) {
if (body.readableBytes() == 0) {
body.release();
transmit(responseStatus, LastHttpContent.EMPTY_LAST_CONTENT, false);
} else {
transmit(responseStatus, new DefaultLastHttpContent(body), false);
}
}
private void transmit(final HttpResponseStatus responseStatus, Object body, boolean sendLastHttpContent) {
ChannelFuture channelFuture = pre(responseStatus);
if (channelFuture == null) {
ReferenceCountUtil.release(body);
isKeepAlive = false;
post(responseStatus);
return;
}
channelFuture.addListener(future -> {
if (future.isSuccess() && channel.isOpen()) {
if (sendLastHttpContent) {
channel.write(body);
post(responseStatus);
} else {
post(responseStatus, channel.writeAndFlush(body));
}
} else {
ReferenceCountUtil.release(body);
isKeepAlive = false;
post(responseStatus);
}
});
}
@Override
public void transmit(HttpResponseStatus status, Path file) {
String sizeString = responseHeaders.getAsString(HttpHeaderConstants.CONTENT_LENGTH);
long size = sizeString == null ? 0 : Long.parseLong(sizeString);
boolean compress = !responseHeaders.contains(HttpHeaderConstants.CONTENT_ENCODING, HttpHeaderConstants.IDENTITY, true);
if (!isSsl && !compress && file.getFileSystem().equals(FileSystems.getDefault())) {
FileRegion defaultFileRegion = new DefaultFileRegion(file.toFile(), 0, size);
transmit(status, defaultFileRegion, true);
} else {
Blocking.get(() ->
Files.newByteChannel(file)
).then(fileChannel ->
transmit(status, new HttpChunkedInput(new ChunkedNioStream(fileChannel)), false)
);
}
}
@Override
public Subscriber<ByteBuf> transmitter(HttpResponseStatus responseStatus) {
return new Subscriber<ByteBuf>() {
private Subscription subscription;
private final AtomicBoolean done = new AtomicBoolean();
private final ChannelFutureListener cancelOnFailure = future -> {
if (!future.isSuccess()) {
cancel();
}
};
private final GenericFutureListener<Future<? super Void>> cancelOnCloseListener =
c -> cancel();
private void cancel() {
channel.closeFuture().removeListener(cancelOnCloseListener);
if (done.compareAndSet(false, true)) {
subscription.cancel();
post(responseStatus);
}
}
@Override
public void onSubscribe(Subscription subscription) {
if (subscription == null) {
throw new NullPointerException("'subscription' is null");
}
if (this.subscription != null) {
subscription.cancel();
return;
}
this.subscription = subscription;
ChannelFuture channelFuture = pre(responseStatus);
if (channelFuture == null) {
subscription.cancel();
isKeepAlive = false;
notifyListeners(responseStatus);
} else {
channelFuture.addListener(f -> {
if (f.isSuccess() && channel.isOpen()) {
channel.closeFuture().addListener(cancelOnCloseListener);
if (channel.isWritable()) {
this.subscription.request(1);
}
onWritabilityChanged = () -> {
if (channel.isWritable() && !done.get()) {
this.subscription.request(1);
}
};
} else {
cancel();
}
});
}
}
@Override
public void onNext(ByteBuf o) {
o.touch();
if (channel.isOpen()) {
channel.writeAndFlush(new DefaultHttpContent(o)).addListener(cancelOnFailure);
if (channel.isWritable()) {
subscription.request(1);
}
} else {
o.release();
cancel();
}
}
@Override
public void onError(Throwable t) {
if (t == null) {
throw new NullPointerException("error is null");
}
LOGGER.warn("Exception thrown transmitting stream", t);
if (done.compareAndSet(false, true)) {
channel.closeFuture().removeListener(cancelOnCloseListener);
post(responseStatus);
}
}
@Override
public void onComplete() {
if (done.compareAndSet(false, true)) {
channel.closeFuture().removeListener(cancelOnCloseListener);
post(responseStatus);
}
}
};
}
private void post(HttpResponseStatus responseStatus) {
post(responseStatus, channel.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT));
}
private void post(HttpResponseStatus responseStatus, ChannelFuture lastContentFuture) {
lastContentFuture.addListener(v ->
drainRequestBody(e -> {
if (LOGGER.isDebugEnabled()) {
if (e instanceof RequestBodyTooLargeException) {
LOGGER.debug("Unread request body was too large to drain, will close connection (maxContentLength: {})", ((RequestBodyTooLargeException) e).getMaxContentLength());
} else {
LOGGER.debug("An error occurred draining the unread request body. The connection will be closed", e);
}
}
if (channel.isOpen()) {
if (isKeepAlive && e == null) {
lastContentFuture.channel().read();
ConnectionIdleTimeout.of(channel).reset();
} else {
lastContentFuture.channel().close();
}
}
notifyListeners(responseStatus);
})
);
}
private void notifyListeners(final HttpResponseStatus responseStatus) {
if (outcomeListeners != null) {
channel.attr(ATTRIBUTE_KEY).set(null);
SentResponse sentResponse = new DefaultSentResponse(new NettyHeadersBackedHeaders(responseHeaders), new DefaultStatus(responseStatus));
RequestOutcome requestOutcome = new DefaultRequestOutcome(ratpackRequest, sentResponse, stopTime);
for (Action<? super RequestOutcome> outcomeListener : outcomeListeners) {
try {
outcomeListener.execute(requestOutcome);
} catch (Exception e) {
LOGGER.warn("request outcome listener " + outcomeListener + " threw exception", e);
}
}
}
}
public void writabilityChanged() {
onWritabilityChanged.run();
}
@Override
public void addOutcomeListener(Action<? super RequestOutcome> action) {
if (outcomeListeners == null) {
outcomeListeners = new ArrayList<>(1);
}
outcomeListeners.add(action);
}
@Override
public void forceCloseConnection() {
responseHeaders.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}
}