package jane.core; import java.util.Collection; import org.apache.mina.core.buffer.IoBuffer; import org.apache.mina.core.filterchain.IoFilterAdapter; import org.apache.mina.core.session.IoSession; import org.apache.mina.core.write.DefaultWriteRequest; import org.apache.mina.core.write.WriteRequest; import jane.core.map.IntHashMap; /** * bean的mina协议编解码过滤器 */ public class BeanCodec extends IoFilterAdapter { protected static final IntHashMap<Integer> _maxSize = new IntHashMap<>(65536, 0.5f); // 所有注册beans的最大空间限制 protected static final IntHashMap<Bean<?>> _stubMap = new IntHashMap<>(65536, 0.5f); // 所有注册beans的存根对象 protected final OctetsStream _os = new OctetsStream(); // 用于解码器的数据缓存 protected int _ptype; // 当前数据缓存中获得的协议类型 protected int _psize = -1; // 当前数据缓存中获得的协议大小. -1表示没获取到 /** * 不带栈信息的解码错误异常 */ public static final class DecodeException extends Exception { private static final long serialVersionUID = 1L; public DecodeException(String cause) { super(cause); } @SuppressWarnings("sync-override") @Override public Throwable fillInStackTrace() { return this; } } /** * 重新注册所有的beans * <p> * 参数中的所有beans会被保存下来当存根(通过调用create方法创建对象)<br> * 警告: 此方法<b>必须</b>在开启任何<b>网络连接前</b>调用 */ public static void registerAllBeans(Collection<Bean<?>> beans) { _maxSize.clear(); _stubMap.clear(); for(Bean<?> bean : beans) { int type = bean.type(); if(type > 0) { _maxSize.put(type, bean.maxSize()); _stubMap.put(type, bean); } } } /** * 获取某个类型bean的最大空间限制(字节) */ public static int beanMaxSize(int type) { Integer size = _maxSize.get(type); return size != null ? size : -1; } /** * 根据类型创建一个默认初始化的bean */ public static Bean<?> createBean(int type) { Bean<?> bean = _stubMap.get(type); return bean != null ? bean.create() : null; } @Override public void filterWrite(NextFilter next, IoSession session, WriteRequest writeRequest) { Bean<?> bean = (Bean<?>)writeRequest.getMessage(); int type = bean.type(); if(type == 0) { Octets rawdata = ((RawBean)bean).getData(); int n = rawdata.remain(); if(n > 0) { next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(rawdata.array(), rawdata.position(), n), writeRequest.getFuture(), null)); } } else { OctetsStream os = new OctetsStream(bean.initSize() + 10); os.resize(10); bean.marshalProtocol(os); int p = os.marshalUIntBack(10, os.size() - 10); p = 10 - (p + os.marshalUIntBack(10 - p, type)); next.filterWrite(session, new DefaultWriteRequest(IoBuffer.wrap(os.array(), p, os.size() - p), writeRequest.getFuture(), null)); } } protected boolean decodeProtocol(OctetsStream os, NextFilter next, IoSession session) throws Exception { if(_psize < 0) { int pos = os.position(); try { _ptype = os.unmarshalUInt(); _psize = os.unmarshalUInt(); } catch(MarshalException.EOF e) { os.setPosition(pos); return false; } int maxSize = beanMaxSize(_ptype); if(maxSize < 0) maxSize = Const.maxRawBeanSize; if(_psize > maxSize) throw new DecodeException("bean maxSize overflow: type=" + _ptype + ",size=" + _psize + ",maxSize=" + maxSize); } if(_psize > os.remain()) return false; Bean<?> bean = createBean(_ptype); if(bean != null) { int pos = os.position(); bean.unmarshalProtocol(os); int realSize = os.position() - pos; if(realSize > _psize) throw new DecodeException("bean realSize overflow: type=" + _ptype + ",size=" + _psize + ",realSize=" + realSize); os.setPosition(pos + _psize); } else bean = new RawBean(_ptype, os.unmarshalRaw(_psize)); _psize = -1; next.messageReceived(session, bean); return true; } @Override public void messageReceived(NextFilter next, IoSession session, Object message) throws Exception { IoBuffer in = (IoBuffer)message; try { if(!_os.empty()) { int r = in.remaining(); int s = _os.size(); int n = Math.min(_psize < 0 ? 10 - s : _psize, r); // 前者情况因两个unmarshalUInt不会超过10字节,所以s肯定是<10的 _os.resize(s + n); in.get(_os.array(), s, n); r -= n; s += n; if(!decodeProtocol(_os, next, session)) // 能正好解出一个协议,或者因之前无法解出头部或者in的数据还不够导致失败 { if(r <= 0) return; // 如果in已经无数据可取就直接等下次,之前无法解出头部的话,in也肯定无数据了 n = _psize - _os.remain(); if(r < n) // 如果数据不够多就先累积到缓存里 { _os.resize(s + r); in.get(_os.array(), s, r); return; } _os.resize(s + n); in.get(_os.array(), s, n); decodeProtocol(_os, next, session); // 应该能正好解出一个协议 _os.clear(); if(r <= n) return; } else { n = _os.remain(); _os.clear(); if(n > 0) // 有很小的可能因为之前无法解出头部,而补足10字节却过多的情况,可以调整in的位置 in.position(in.position() - n); else if(r <= 0) return; } } int n = in.limit(); OctetsStream os; if(in.hasArray()) { os = OctetsStream.wrap(in.array(), in.position(), n); in.position(n); } else { n = in.remaining(); byte[] buf = new byte[n]; in.get(buf, 0, n); os = OctetsStream.wrap(buf); } while(decodeProtocol(os, next, session)) if(os.remain() <= 0) return; if(os.remain() <= 0) return; // 正好只解出头部的情况 _os.replace(os.array(), os.position(), os.remain()); _os.setPosition(0); } finally { in.free(); } } }