/** * Copyright 2014 The CmRaft Project * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.chicm.cmraft.rpc; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousSocketChannel; import java.nio.channels.Channels; import java.nio.channels.SocketChannel; import java.util.concurrent.ExecutionException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.chicm.cmraft.protobuf.generated.RaftProtos.RequestHeader; import com.chicm.cmraft.protobuf.generated.RaftProtos.ResponseHeader; import com.google.protobuf.BlockingService; import com.google.protobuf.CodedInputStream; import com.google.protobuf.CodedOutputStream; import com.google.protobuf.Message; import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.Message.Builder; /** * Static utility methods dealing with protobuf RPC packets. * @author chicm * */ public class PacketUtils { static final Log LOG = LogFactory.getLog(PacketUtils.class); static final int DEFAULT_BYTEBUFFER_SIZE = 1000; static final int MESSAGE_LENGHT_FIELD_SIZE = 4; static final int DEFAULT_CHANNEL_READ_RETRIES = 5; public static byte[] int2Bytes(int val) { byte [] b = new byte[4]; for(int i = 3; i > 0; i--) { b[i] = (byte) val; val >>>= 8; } b[0] = (byte) val; return b; } public static int bytes2Int(byte[] bytes) { int n = 0; for(int i = 0; i < 4; i++) { n <<= 8; n ^= bytes[i] & 0xFF; } return n; } public static void writeIntToStream(int n, OutputStream os) throws IOException { byte[] b = int2Bytes(n); os.write(b); } public static int getTotalSizeofMessages(Message ... messages) { int totalSize = 0; for (Message m: messages) { if (m == null) continue; totalSize += m.getSerializedSize(); totalSize += CodedOutputStream.computeRawVarint32Size(m.getSerializedSize()); } return totalSize; } public static int writeRpc(AsynchronousSocketChannel channel, Message header, Message body) throws IOException, InterruptedException, ExecutionException { int totalSize = getTotalSizeofMessages(header, body); return writeRpc(channel, header, body, totalSize); } private static int writeRpc(AsynchronousSocketChannel channel, Message header, Message body, int totalSize) throws IOException, InterruptedException, ExecutionException { // writing total size so that server can read all request data in one read //LOG.debug("total size:" + totalSize); long t = System.currentTimeMillis(); ByteArrayOutputStream bos = new ByteArrayOutputStream(); writeIntToStream(totalSize, bos); header.writeDelimitedTo(bos); if (body != null) body.writeDelimitedTo(bos); bos.flush(); byte[] b = bos.toByteArray(); ByteBuffer buf = ByteBuffer.allocateDirect(totalSize + 4); buf.put(b); buf.flip(); channel.write(buf).get(); if(LOG.isTraceEnabled()) { LOG.trace("Write Rpc message to socket, takes " + (System.currentTimeMillis() -t) + " ms, size " + totalSize); LOG.trace("message:" + body); } return totalSize; } private static int writeRpc_backup(SocketChannel channel, Message header, Message body, int totalSize) throws IOException { // writing total size so that server can read all request data in one read LOG.debug("total size:" + totalSize); long t = System.currentTimeMillis(); OutputStream os = Channels.newOutputStream(channel); writeIntToStream(totalSize, os); header.writeDelimitedTo(os); if (body != null) body.writeDelimitedTo(os); os.flush(); LOG.debug("" + (System.currentTimeMillis() -t) + " ms"); LOG.debug("flushed:" + totalSize); return totalSize; } public static RpcCall parseRpcRequestFromChannel (AsynchronousSocketChannel channel, BlockingService service) throws InterruptedException, ExecutionException, IOException { RpcCall call = null; long t = System.currentTimeMillis(); InputStream in = Channels.newInputStream(channel); byte[] datasize = new byte[MESSAGE_LENGHT_FIELD_SIZE]; in.read(datasize); int nDataSize = bytes2Int(datasize); int len = 0; ByteBuffer buf = ByteBuffer.allocateDirect(nDataSize); for ( ;len < nDataSize; ) { len += channel.read(buf).get(); } if(len < nDataSize) { LOG.error("SOCKET READ FAILED, len:" + len); return call; } byte[] data = new byte[nDataSize]; buf.flip(); buf.get(data); int offset = 0; CodedInputStream cis = CodedInputStream.newInstance(data, offset, nDataSize - offset); int headerSize = cis.readRawVarint32(); offset += cis.getTotalBytesRead(); RequestHeader header = RequestHeader.newBuilder().mergeFrom(data, offset, headerSize ).build(); offset += headerSize; cis.skipRawBytes(headerSize); cis.resetSizeCounter(); int bodySize = cis.readRawVarint32(); offset += cis.getTotalBytesRead(); //LOG.debug("header parsed:" + header.toString()); MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getRequestName()); Builder builder = service.getRequestPrototype(md).newBuilderForType(); Message body = null; if (builder != null) { body = builder.mergeFrom(data, offset, bodySize).build(); //LOG.debug("server : request parsed:" + body.toString()); } call = new RpcCall(header.getId(), header, body, md); if(LOG.isTraceEnabled()) { LOG.trace("Parse Rpc request from socket: " + call.getCallId() + ", takes" + (System.currentTimeMillis() -t) + " ms"); } return call; } public static RpcCall parseRpcResponseFromChannel (AsynchronousSocketChannel channel, BlockingService service) throws InterruptedException, ExecutionException, IOException { RpcCall call = null; long t = System.currentTimeMillis(); InputStream in = Channels.newInputStream(channel); byte[] datasize = new byte[MESSAGE_LENGHT_FIELD_SIZE]; in.read(datasize); int nDataSize = bytes2Int(datasize); LOG.debug("message size: " + nDataSize); int len = 0; ByteBuffer buf = ByteBuffer.allocateDirect(nDataSize); for ( ;len < nDataSize; ) { len += channel.read(buf).get(); } if(len < nDataSize) { LOG.error("SOCKET READ FAILED, len:" + len); return call; } byte[] data = new byte[nDataSize]; buf.flip(); buf.get(data); int offset = 0; CodedInputStream cis = CodedInputStream.newInstance(data, offset, nDataSize - offset); int headerSize = cis.readRawVarint32(); offset += cis.getTotalBytesRead(); ResponseHeader header = ResponseHeader.newBuilder().mergeFrom(data, offset, headerSize ).build(); offset += headerSize; cis.skipRawBytes(headerSize); cis.resetSizeCounter(); int bodySize = cis.readRawVarint32(); offset += cis.getTotalBytesRead(); MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getResponseName()); Builder builder = service.getResponsePrototype(md).newBuilderForType(); Message body = null; if (builder != null) { body = builder.mergeFrom(data, offset, bodySize).build(); } call = new RpcCall(header.getId(), header, body, md); if(LOG.isTraceEnabled()) { LOG.trace("Parse Rpc response from socket: " + call.getCallId() + ", takes" + (System.currentTimeMillis() -t) + " ms"); } return call; } }