package com.mpush.netty.http; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.http.*; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.net.URL; import java.net.URLDecoder; @ChannelHandler.Sharable /*package*/ class HttpClientHandler extends ChannelInboundHandlerAdapter { private static final Logger LOGGER = LoggerFactory.getLogger(NettyHttpClient.class); private final NettyHttpClient client; public HttpClientHandler(NettyHttpClient client) { this.client = client; } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { RequestContext context = ctx.channel().attr(client.requestKey).getAndSet(null); try { if (context != null && context.tryDone()) { context.onException(cause); } } finally { client.pool.tryRelease(ctx.channel()); } LOGGER.error("http client caught an ex, info={}", context, cause); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { RequestContext context = ctx.channel().attr(client.requestKey).getAndSet(null); try { if (context != null && context.tryDone()) { LOGGER.debug("receive server response, request={}, response={}", context, msg); HttpResponse response = (HttpResponse) msg; if (isRedirect(response)) { if (context.onRedirect(response)) { String location = getRedirectLocation(context.request, response); if (location != null && location.length() > 0) { context.cancelled.set(false); context.request.setUri(location); client.request(context); return; } } } context.onResponse(response); } else { LOGGER.warn("receive server response but timeout, request={}, response={}", context, msg); } } finally { client.pool.tryRelease(ctx.channel()); ReferenceCountUtil.release(msg); } } private boolean isRedirect(HttpResponse response) { HttpResponseStatus status = response.status(); switch (status.code()) { case 300: case 301: case 302: case 303: case 305: case 307: return true; default: return false; } } private String getRedirectLocation(HttpRequest request, HttpResponse response) throws Exception { String hdr = URLDecoder.decode(response.headers().get(HttpHeaderNames.LOCATION), "UTF-8"); if (hdr != null) { if (hdr.toLowerCase().startsWith("http://") || hdr.toLowerCase().startsWith("https://")) { return hdr; } else { URL orig = new URL(request.uri()); String pth = orig.getPath() == null ? "/" : URLDecoder.decode(orig.getPath(), "UTF-8"); if (hdr.startsWith("/")) { pth = hdr; } else if (pth.endsWith("/")) { pth += hdr; } else { pth += "/" + hdr; } StringBuilder sb = new StringBuilder(orig.getProtocol()); sb.append("://").append(orig.getHost()); if (orig.getPort() > 0) { sb.append(":").append(orig.getPort()); } if (pth.charAt(0) != '/') { sb.append('/'); } sb.append(pth); return sb.toString(); } } return null; } @SuppressWarnings("unused") private HttpRequest copy(String uri, HttpRequest request) { HttpRequest nue = request; if (request instanceof DefaultFullHttpRequest) { DefaultFullHttpRequest dfr = (DefaultFullHttpRequest) request; FullHttpRequest rq; try { rq = dfr.copy(); } catch (IllegalReferenceCountException e) { // Empty byteBuf rq = dfr; } rq.setUri(uri); } else { DefaultHttpRequest dfr = new DefaultHttpRequest(request.protocolVersion(), request.method(), uri); dfr.headers().set(request.headers()); nue = dfr; } return nue; } }