package com.ustcinfo.rpc.protocol; import java.util.ArrayList; import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.ustcinfo.rpc.RequestWrapper; import com.ustcinfo.rpc.ResponseWrapper; import com.ustcinfo.rpc.annotation.Codecs; /** * Common RPC Protocol * * Request Protocol * VERSION(1B): * TYPE(1B): request/response * CODECTYPE(1B): serialize/deserialize type * KEEPED(1B): * KEEPED(1B): * KEEPED(1B): * ID(4B): request id * TIMEOUT(4B): request timeout * TARGETINSTANCELEN(4B): target service name length * METHODNAMELEN(4B): method name length * ARGSCOUNT(4B): method args count * ARG1TYPELEN(4B): method arg1 type len * ARG2TYPELEN(4B): method arg2 type len * ... * ARG1LEN(4B): method arg1 len * ARG2LEN(4B): method arg2 len * ... * TARGETINSTANCENAME * METHODNAME * ARG1TYPENAME * ARG2TYPENAME * ... * ARG1 * ARG2 * ... * * Response Protocol * VERSION(1B): * TYPE(1B): request/response * DATATYPE(1B): serialize/deserialize type * KEEPED(1B): * KEEPED(1B): * KEEPED(1B): * ID(4B): request id * BodyClassNameLen(4B): body className Len * LENGTH(4B): body length * BodyClassName * BODY if need than set * */ public class RPCProtocol implements Protocol { private static final Log LOGGER = LogFactory.getLog(RPCProtocol.class); private static final int REQUEST_HEADER_LEN = 1 * 6 + 5 * 4; private static final int RESPONSE_HEADER_LEN = 1 * 6 + 2 * 4; private static final byte VERSION = (byte)1; private static final byte REQUEST = (byte)0; private static final byte RESPONSE = (byte)1; /** * 客户端请求响应编码、服务器端响应编码 * */ public ByteBufferWrapper encode(Object message,ByteBufferWrapper bytebufferWrapper) throws Exception{ if(!(message instanceof RequestWrapper) && !(message instanceof ResponseWrapper)){ throw new Exception("only support send RequestWrapper && ResponseWrapper"); } int id = 0; byte type = REQUEST; if(message instanceof RequestWrapper) { try{ int requestArgTypesLen = 0; int requestArgsLen = 0; List<byte[]> requestArgTypes = new ArrayList<byte[]>(); List<byte[]> requestArgs = new ArrayList<byte[]>(); RequestWrapper wrapper = (RequestWrapper) message; byte[][] requestArgTypeStrings = wrapper.getArgTypes(); for (byte[] requestArgType : requestArgTypeStrings) { requestArgTypes.add(requestArgType); requestArgTypesLen += requestArgType.length; } Object[] requestObjects = wrapper.getRequestObjects(); if(requestObjects!=null){ for (Object requestArg : requestObjects) { byte[] requestArgByte = Codecs.getEncoder(wrapper.getCodecType()).encode(requestArg); requestArgs.add(requestArgByte); requestArgsLen += requestArgByte.length; } } byte[] targetInstanceNameByte = wrapper.getTargetInstanceName(); byte[] methodNameByte = wrapper.getMethodName(); id = wrapper.getId(); int timeout = wrapper.getTimeout(); int capacity = REQUEST_HEADER_LEN + requestArgs.size() * 4 * 2 + targetInstanceNameByte.length + methodNameByte.length + requestArgTypesLen + requestArgsLen; ByteBufferWrapper byteBuffer = bytebufferWrapper.get(capacity); byteBuffer.writeByte(VERSION); byteBuffer.writeByte(type); byteBuffer.writeByte((byte)wrapper.getCodecType()); byteBuffer.writeByte((byte)0); byteBuffer.writeByte((byte)0); byteBuffer.writeByte((byte)0); byteBuffer.writeInt(id); byteBuffer.writeInt(timeout); byteBuffer.writeInt(targetInstanceNameByte.length); byteBuffer.writeInt(methodNameByte.length); byteBuffer.writeInt(requestArgs.size()); for (byte[] requestArgType : requestArgTypes) { byteBuffer.writeInt(requestArgType.length); } for (byte[] requestArg : requestArgs) { byteBuffer.writeInt(requestArg.length); } byteBuffer.writeBytes(targetInstanceNameByte); byteBuffer.writeBytes(methodNameByte); for (byte[] requestArgType : requestArgTypes) { byteBuffer.writeBytes(requestArgType); } for (byte[] requestArg : requestArgs) { byteBuffer.writeBytes(requestArg); } return byteBuffer; } catch(Exception e){ LOGGER.error("encode request object error",e); throw e; } } else{ ResponseWrapper wrapper = (ResponseWrapper) message; byte[] body = new byte[0]; byte[] className = new byte[0]; try{ // no return object if(wrapper.getResponse() != null){ className = wrapper.getResponse().getClass().getName().getBytes(); body = Codecs.getEncoder(wrapper.getCodecType()).encode(wrapper.getResponse()); } if(wrapper.isError()){ className = wrapper.getException().getClass().getName().getBytes(); body = Codecs.getEncoder(wrapper.getCodecType()).encode(wrapper.getException()); } id = wrapper.getRequestId(); } catch(Exception e){ LOGGER.error("encode response object error", e); // still create responsewrapper,so client can get exception wrapper.setResponse(new Exception("serialize response object error",e)); className = Exception.class.getName().getBytes(); body = Codecs.getEncoder(wrapper.getCodecType()).encode(wrapper.getResponse()); } type = RESPONSE; int capacity = RESPONSE_HEADER_LEN + body.length; ByteBufferWrapper byteBuffer = bytebufferWrapper.get(capacity); byteBuffer.writeByte(VERSION); byteBuffer.writeByte(type); byteBuffer.writeByte((byte)wrapper.getCodecType()); byteBuffer.writeByte((byte)0); byteBuffer.writeByte((byte)0); byteBuffer.writeByte((byte)0); byteBuffer.writeInt(id); byteBuffer.writeInt(className.length); byteBuffer.writeInt(body.length); byteBuffer.writeBytes(className); byteBuffer.writeBytes(body); return byteBuffer; } } /** * 服务器请求解码、客户端响应解码 * */ public Object decode(ByteBufferWrapper wrapper, Object errorObject) throws Exception{ int originPos = wrapper.readerIndex(); if(wrapper.readableBytes() < 2){ wrapper.setReaderIndex(originPos); return errorObject; } byte version = wrapper.readByte(); if(version == (byte)1){ byte type = wrapper.readByte(); if(type == REQUEST){ if(wrapper.readableBytes() < REQUEST_HEADER_LEN -2){ wrapper.setReaderIndex(originPos); return errorObject; } int codecType = wrapper.readByte(); wrapper.readByte(); wrapper.readByte(); wrapper.readByte(); int requestId = wrapper.readInt(); int timeout = wrapper.readInt(); int targetInstanceLen = wrapper.readInt(); int methodNameLen = wrapper.readInt(); int argsCount = wrapper.readInt(); int argInfosLen = argsCount * 4 * 2; int expectedLenInfoLen = argInfosLen + targetInstanceLen + methodNameLen; if(wrapper.readableBytes() < expectedLenInfoLen){ wrapper.setReaderIndex(originPos); return errorObject; } int expectedLen = 0; int[] argsTypeLen = new int[argsCount]; for (int i = 0; i < argsCount; i++) { argsTypeLen[i] = wrapper.readInt(); expectedLen += argsTypeLen[i]; } int[] argsLen = new int[argsCount]; for (int i = 0; i < argsCount; i++) { argsLen[i] = wrapper.readInt(); expectedLen += argsLen[i]; } byte[] targetInstanceByte = new byte[targetInstanceLen]; wrapper.readBytes(targetInstanceByte); byte[] methodNameByte = new byte[methodNameLen]; wrapper.readBytes(methodNameByte); if(wrapper.readableBytes() < expectedLen){ wrapper.setReaderIndex(originPos); return errorObject; } byte[][] argTypes = new byte[argsCount][]; for (int i = 0; i < argsCount; i++) { byte[] argTypeByte = new byte[argsTypeLen[i]]; wrapper.readBytes(argTypeByte); argTypes[i] = argTypeByte; } Object[] args = new Object[argsCount]; for (int i = 0; i < argsCount; i++) { byte[] argByte = new byte[argsLen[i]]; wrapper.readBytes(argByte); args[i] = argByte; } RequestWrapper requestWrapper = new RequestWrapper(targetInstanceByte, methodNameByte, argTypes, args, timeout, requestId, codecType); int messageLen = REQUEST_HEADER_LEN + expectedLenInfoLen + expectedLen; requestWrapper.setMessageLen(messageLen); return requestWrapper; } else if(type == RESPONSE){ if(wrapper.readableBytes() < RESPONSE_HEADER_LEN -2){ wrapper.setReaderIndex(originPos); return errorObject; } int codecType = wrapper.readByte(); wrapper.readByte(); wrapper.readByte(); wrapper.readByte(); int requestId = wrapper.readInt(); int classNameLen = wrapper.readInt(); int bodyLen = wrapper.readInt(); if(wrapper.readableBytes() < classNameLen + bodyLen){ wrapper.setReaderIndex(originPos); return errorObject; } byte[] classNameBytes = null; classNameBytes = new byte[classNameLen]; wrapper.readBytes(classNameBytes); byte[] bodyBytes = new byte[bodyLen]; wrapper.readBytes(bodyBytes); ResponseWrapper responseWrapper = new ResponseWrapper(requestId,codecType); responseWrapper.setResponse(bodyBytes); responseWrapper.setResponseClassName(classNameBytes); int messageLen = RESPONSE_HEADER_LEN + classNameLen + bodyLen; responseWrapper.setMessageLen(messageLen); return responseWrapper; } else{ throw new UnsupportedOperationException("protocol type : "+type+" is not supported!"); } } else{ throw new UnsupportedOperationException("protocol version :"+version+" is not supported!"); } } }