package jane.core;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.security.KeyStore;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Locale;
import java.util.Map;
import java.util.TimeZone;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.filterchain.IoFilterAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.core.write.DefaultWriteRequest;
import org.apache.mina.core.write.WriteRequest;
import org.apache.mina.filter.ssl.SslFilter;
/**
* HTTP的mina协议编解码过滤器
* <p>
* 输入(解码): OctetsStream类型,包括一次完整请求原始的HTTP头和内容,position指向内容的起始,如果没有内容则指向结尾<br>
* 输出(编码): OctetsStream(从position到结尾的数据),或Octets,或byte[]<br>
* 输入处理: 获取HTTP头中的fields,method,url-path,url-param,content-charset,以及cookie,支持url编码的解码<br>
* 输出处理: 固定长度输出,chunked方式输出<br>
* 不直接支持: mime, Connection:close/timeout, Accept-Encoding, Set-Cookie, Multi-Part, encodeUrl
*/
public final class HttpCodec extends IoFilterAdapter
{
private static final byte[] HEAD_END_MARK = "\r\n\r\n".getBytes(Const.stringCharsetUTF8);
private static final byte[] CONT_LEN_MARK = "\r\nContent-Length: ".getBytes(Const.stringCharsetUTF8);
private static final byte[] CONT_TYPE_MARK = "\r\nContent-Type: ".getBytes(Const.stringCharsetUTF8);
private static final byte[] COOKIE_MARK = "\r\nCookie: ".getBytes(Const.stringCharsetUTF8);
private static final byte[] CHUNK_OVER_MARK = "\r\n".getBytes(Const.stringCharsetUTF8);
private static final byte[] CHUNK_END_MARK = "0\r\n\r\n".getBytes(Const.stringCharsetUTF8);
private static final String DEF_CONT_CHARSET = "utf-8";
private static final Pattern PATTERN_COOKIE = Pattern.compile("(\\w+)=(.*?)(; |$)");
private static final Pattern PATTERN_CHARSET = Pattern.compile("charset=([\\w-]+)");
private static final DateFormat _sdf = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
private static String _dateStr;
private static volatile long _lastSec;
private OctetsStream _buf = new OctetsStream(1024); // 用于解码器的数据缓存
private long _bodySize; // 当前请求所需的内容大小
/**
* 不带栈信息的解码错误异常
*/
public static final class DecodeException extends Exception
{
private static final long serialVersionUID = 1L;
public DecodeException(String cause)
{
super(cause);
}
@SuppressWarnings("sync-override")
@Override
public Throwable fillInStackTrace()
{
return this;
}
}
static
{
_sdf.setTimeZone(TimeZone.getTimeZone("GMT"));
}
private static String getDate()
{
long t = System.currentTimeMillis();
long sec = t / 1000;
if(sec != _lastSec)
{
synchronized(_sdf)
{
if(sec != _lastSec)
{
_dateStr = _sdf.format(new Date(t));
_lastSec = sec;
}
}
}
return _dateStr;
}
public static SslFilter getSslFilter(InputStream keyIs, char[] keyPw, InputStream trustIs, char[] trustPw) throws Exception
{
KeyStore ks = KeyStore.getInstance("JKS");
ks.load(keyIs, keyPw);
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(ks, keyPw);
KeyStore ts = KeyStore.getInstance("JKS");
ts.load(trustIs, trustPw);
TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
tmf.init(ts);
SSLContext ctx = SSLContext.getInstance("TLS");
ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
return new SslFilter(ctx);
}
public static SslFilter getSslFilter(String keyFile, String keyPw) throws Exception
{
byte[] key = Util.readFileData(keyFile);
char[] pw = keyPw.toCharArray();
return getSslFilter(new ByteArrayInputStream(key), pw, new ByteArrayInputStream(key), pw);
}
public static String decodeUrl(byte[] src, int srcPos, int srcLen)
{
if(srcPos < 0) srcPos = 0;
if(srcPos + srcLen > src.length) srcLen = src.length - srcPos;
if(srcLen <= 0) return "";
byte[] dst = new byte[srcLen];
int dstPos = 0;
for(int srcEnd = srcPos + srcLen; srcPos < srcEnd;)
{
int c = src[srcPos++];
switch(c)
{
case '+':
dst[dstPos++] = (byte)' ';
break;
case '%':
if(srcPos + 1 < srcEnd)
{
c = src[srcPos++];
int v = (c < 'A' ? c - '0' : c - 'A' + 10) << 4;
c = src[srcPos++];
v += (c < 'A' ? c - '0' : c - 'A' + 10);
dst[dstPos++] = (byte)v;
break;
}
//$FALL-THROUGH$
default:
dst[dstPos++] = (byte)c;
break;
}
}
return new String(dst, 0, dstPos, Const.stringCharsetUTF8);
}
public static String getHeadLine(OctetsStream head)
{
int p = head.find(0, head.position(), (byte)'\r');
return p < 0 ? "" : new String(head.array(), 0, p, Const.stringCharsetUTF8);
}
public static String getHeadVerb(OctetsStream head)
{
int p = head.find(0, head.position(), (byte)' ');
return p < 0 ? "" : new String(head.array(), 0, p, Const.stringCharsetUTF8);
}
// GET /path/name.html?k=v&a=b HTTP/1.1
public static String getHeadPath(OctetsStream head)
{
int e = head.position();
int p = head.find(0, e, (byte)' ');
if(p < 0) return "";
int q = head.find(++p, e, (byte)' ');
if(q < 0) return "";
int r = head.find(p, q, (byte)'?');
if(r >= p && r < q) q = r;
return decodeUrl(head.array(), p, q - p);
}
public static String getHeadPathParams(OctetsStream head)
{
int e = head.position();
int p = head.find(0, e, (byte)' ');
if(p < 0) return "";
int q = head.find(++p, e, (byte)' ');
if(q < 0) return "";
return decodeUrl(head.array(), p, q - p);
}
/**
* @return 获取的参数数量
*/
public static int getParams(Octets oct, int pos, int len, Map<String, String> params)
{
byte[] buf = oct.array();
if(pos < 0) pos = 0;
if(pos + len > buf.length) len = buf.length - pos;
if(len <= 0) return 0;
int end = pos + len;
int n = 0;
for(int p; pos < end; pos = p + 1, ++n)
{
p = oct.find(pos, end, (byte)'&');
if(p < 0) p = end;
int r = oct.find(pos, p, (byte)'=');
if(r >= pos)
{
String k = decodeUrl(buf, pos, r - pos);
String v = decodeUrl(buf, r + 1, p - r - 1);
params.put(k, v);
}
else
params.put(decodeUrl(buf, pos, p - pos), "");
}
return n;
}
/**
* @return 获取的参数数量
*/
public static int getHeadParams(Octets oct, int pos, int len, Map<String, String> params)
{
byte[] buf = oct.array();
if(pos < 0) pos = 0;
if(pos + len > buf.length) len = buf.length - pos;
if(len <= 0) return 0;
int e = pos + len;
int q = oct.find(0, e, (byte)'\r');
if(q < 0) return 0;
int p = oct.find(0, q, (byte)'?');
if(p < 0) return 0;
q = oct.find(++p, q, (byte)' ');
if(q < p) return 0;
return getParams(oct, p, q - p, params);
}
public static int getHeadParams(OctetsStream os, Map<String, String> params)
{
return getHeadParams(os, 0, os.position(), params);
}
public static int getBodyParams(OctetsStream os, Map<String, String> params)
{
return getParams(os, os.position(), os.remain(), params);
}
public static long getHeadLong(OctetsStream head, byte[] key)
{
int p = head.find(0, head.position(), key);
if(p < 0) return -1;
p += key.length;
int e = head.find(p, (byte)'\r');
if(e < 0) return -1;
long r = 0;
for(byte[] buf = head.array(); p < e; ++p)
r = r * 10 + (buf[p] - 0x30);
return r;
}
/**
* 获取HTTP请求头中的field
* <p>
* 注意: 重复的field-key只有第一个生效
* @param key 格式示例: "\r\nReferer: ".getBytes()
*/
public static String getHeadField(OctetsStream head, byte[] key)
{
int p = head.find(0, head.position(), key);
if(p < 0) return "";
p += key.length;
int e = head.find(p + key.length, (byte)'\r');
if(e < 0) return "";
return decodeUrl(head.array(), p, e - p);
}
public static String getHeadField(OctetsStream head, String key)
{
return getHeadField(head, ("\r\n" + key + ": ").getBytes(Const.stringCharsetUTF8));
}
public static String getHeadCharset(OctetsStream head)
{
String conttype = getHeadField(head, CONT_TYPE_MARK);
if(conttype.isEmpty()) return DEF_CONT_CHARSET; // default charset
Matcher mat = PATTERN_CHARSET.matcher(conttype);
return mat.find() ? mat.group(1) : DEF_CONT_CHARSET;
}
/**
* 获取HTTP请求头中的所有cookie键值
* <p>
* 注意: 不支持cookie值中含有"; "
* @return 获取的cookie数量
*/
public static int getHeadCookie(OctetsStream head, Map<String, String> cookies)
{
String cookie = getHeadField(head, COOKIE_MARK);
if(cookie.isEmpty()) return 0;
Matcher mat = PATTERN_COOKIE.matcher(cookie);
int n = 0;
for(; mat.find(); ++n)
cookies.put(mat.group(1), mat.group(2));
return n;
}
/**
* 发送HTTP的回复头
* @param code 回复的HTTP状态码字符串. 如"200 OK"表示正常
* @param len
* <li>len < 0: 使用chunked模式,后续发送若干个{@link #sendChunk},最后发送{@link #sendChunkEnd}
* <li>len > 0: 后续使用{@link #send}发送固定长度的数据
* @param heads 额外发送的HTTP头. 每个元素表示一行文字,没有做验证,所以小心使用,可传null表示无任何额外的头信息
*/
public static boolean sendHead(IoSession session, String code, long len, Iterable<String> heads)
{
if(session.isClosing()) return false;
StringBuilder sb = new StringBuilder(1024);
sb.append("HTTP/1.1 ").append(code).append('\r').append('\n');
sb.append("Date: ").append(getDate()).append('\r').append('\n');
if(len >= 0)
sb.append("Content-Length: ").append(len).append('\r').append('\n');
else
sb.append("Transfer-Encoding: chunked").append('\r').append('\n');
if(heads != null)
{
for(String head : heads)
sb.append(head).append('\r').append('\n');
}
sb.append('\r').append('\n');
int n = sb.length();
byte[] out = new byte[n];
for(int i = 0; i < n; ++i)
out[i] = (byte)sb.charAt(i);
return NetManager.write(session, out);
}
public static boolean send(IoSession session, byte[] data)
{
return NetManager.write(session, data);
}
public static boolean send(IoSession session, Octets data)
{
return NetManager.write(session, data);
}
public static boolean sendChunk(IoSession session, byte[] chunk)
{
int n = chunk.length;
return n <= 0 || NetManager.write(session, ByteBuffer.wrap(chunk, 0, n));
}
public static boolean sendChunk(IoSession session, Octets chunk)
{
int n = chunk.remain();
if(n <= 0) return true;
ByteBuffer buf = ByteBuffer.wrap(chunk.array(), chunk.position(), n);
return NetManager.write(session, buf);
}
public static boolean sendChunk(IoSession session, String chunk)
{
return sendChunk(session, Octets.wrap(chunk.getBytes(Const.stringCharsetUTF8)));
}
public static boolean sendChunkEnd(IoSession session)
{
return NetManager.write(session, CHUNK_END_MARK);
}
@Override
public void filterWrite(NextFilter next, IoSession session, WriteRequest writeRequest)
{
Object message = writeRequest.getMessage();
if(message instanceof byte[])
{
byte[] bytes = (byte[])message;
if(bytes.length > 0)
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(bytes), writeRequest.getFuture(), null));
}
else if(message instanceof ByteBuffer)
{
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(String.format("%x\r\n",
((ByteBuffer)message).remaining()).getBytes(Const.stringCharsetUTF8)), null, null));
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap((ByteBuffer)message), null, null));
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(CHUNK_OVER_MARK), writeRequest.getFuture(), null));
}
else if(message instanceof OctetsStream)
{
OctetsStream os = (OctetsStream)message;
int n = os.remain();
if(n > 0)
{
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(os.array(), os.position(), n),
writeRequest.getFuture(), null));
}
}
else if(message instanceof Octets)
{
Octets oct = (Octets)message;
int n = oct.size();
if(n > 0)
{
next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(oct.array(), 0, n),
writeRequest.getFuture(), null));
}
}
}
@Override
public void messageReceived(NextFilter next, IoSession session, Object message) throws Exception
{
IoBuffer in = (IoBuffer)message;
try
{
begin_: for(;;)
{
if(_bodySize <= 0)
{
int r = in.remaining();
if(r <= 0) return;
int s = (r < 1024 ? r : 1024); // 最多一次取1024字节来查找HTTP头
int p = _buf.size();
if(s > 0)
{
_buf.resize(p + s);
in.get(_buf.array(), p, s);
}
for(;;)
{
p = _buf.find(p - (HEAD_END_MARK.length - 1), HEAD_END_MARK);
if(p < 0)
{
if(_buf.size() > Const.maxHttpHeadSize)
throw new DecodeException("http head size overflow: bufsize=" + _buf.size() + ",maxsize=" + Const.maxHttpHeadSize);
if(!in.hasRemaining()) return;
continue begin_;
}
p += HEAD_END_MARK.length;
if(p < 18) // 最小的可能是"GET / HTTP/1.1\r\n\r\n"
throw new DecodeException("http head size too short: headsize=" + p);
_buf.setPosition(p);
_bodySize = getHeadLong(_buf, CONT_LEN_MARK); // 从HTTP头中找到内容长度(目前不支持请求的chunk)
if(_bodySize > 0) break; // 有内容则跳到下半部分的处理
OctetsStream os = new OctetsStream(_buf.array(), p, _buf.remain()); // 切割出尾部当作下次缓存(不会超过1024字节)
_buf.resize(p);
next.messageReceived(session, _buf);
_buf = os;
p = 0;
}
if(_bodySize > Const.maxHttpBodySize)
throw new DecodeException("http body size overflow: bodysize=" + _bodySize + ",maxsize=" + Const.maxHttpBodySize);
}
int r = in.remaining();
int s = (int)_bodySize - _buf.remain();
int p = _buf.size();
OctetsStream os;
if(s > r) s = r; // 只取能取到的大小
if(s >= 0) // 缓存数据不足或正好
{
if(s > 0) // 不足且有数据就尽量补足
{
_buf.resize(p + s);
in.get(_buf.array(), p, s);
}
if(_buf.remain() < _bodySize) return; // 再不足就等下次
os = new OctetsStream(1024); // 正好满足了,申请新的缓存
}
else
{
os = new OctetsStream(_buf.array(), p += s, -s); // 缓存数据过剩就切割出尾部当作下次缓存(不会超过1024字节)
_buf.resize(p);
}
next.messageReceived(session, _buf);
_buf = os;
_bodySize = 0; // 下次从HTTP头部开始匹配
}
}
finally
{
in.free();
}
}
}