package com.chicm.cmraft.rpc; import java.util.List; import java.util.concurrent.atomic.AtomicLong; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.chicm.cmraft.core.RaftRpcService; import com.chicm.cmraft.protobuf.generated.RaftProtos.RequestHeader; import com.chicm.cmraft.protobuf.generated.RaftProtos.ResponseHeader; import com.google.protobuf.BlockingService; import com.google.protobuf.Message; import com.google.protobuf.ServiceException; import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.Message.Builder; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.codec.LengthFieldPrepender; import io.netty.handler.codec.MessageToMessageDecoder; import io.netty.handler.codec.MessageToMessageEncoder; import io.netty.util.concurrent.DefaultEventExecutorGroup; import io.netty.util.concurrent.EventExecutorGroup; public class ServerChannelHandler extends ChannelInitializer<Channel> { static final Log LOG = LogFactory.getLog(ServerChannelHandler.class); private static final int MAX_PACKET_SIZE = 1024*1024*100; private static final int RPC_WORKER_THREADS = 100; static final EventExecutorGroup rpcGroup = new DefaultEventExecutorGroup(RPC_WORKER_THREADS); private BlockingService service; private AtomicLong callCounter; public ServerChannelHandler(BlockingService service, AtomicLong counter) { this.service = service; this.callCounter = counter; } @Override protected void initChannel(Channel ch) throws Exception { ch.pipeline().addLast("FrameDecoder", new LengthFieldBasedFrameDecoder(MAX_PACKET_SIZE,0,4,0,4)); ch.pipeline().addLast("FrameEncoder", new LengthFieldPrepender(4)); ch.pipeline().addLast("MessageDecoder", new RpcRequestDecoder() ); ch.pipeline().addLast("MessageEncoder", new RpcResponseEncoder()); ch.pipeline().addLast(rpcGroup, "RpcHandler", new RpcRequestHandler(service, callCounter)); LOG.debug("initChannel"); } class RpcRequestHandler extends ChannelInboundHandlerAdapter { private BlockingService service; private AtomicLong callCounter; RpcRequestHandler(BlockingService service, AtomicLong counter) { this.service = service; this.callCounter = counter; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { //System.out.println("channelRead"); RpcCall call = (RpcCall)msg; if(call == null) { return; } LOG.debug("RpcServer read, call ID: " + call.getCallId() + ", local server:" + ctx.channel().localAddress().toString()); try { Message response = service.callBlockingMethod(call.getMd(), null, call.getMessage()); if(response != null) { ResponseHeader.Builder builder = ResponseHeader.newBuilder(); builder.setId(call.getCallId()); builder.setResponseName(call.getMd().getName()); ResponseHeader header = builder.build(); call.setHeader(header); call.setMessage(response); ctx.writeAndFlush(call); callCounter.getAndIncrement(); } } catch(ServiceException e) { LOG.error("Rpc Server channelRead exception:" + e.getMessage(), e); } } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { LOG.info("Channel closed"); } @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { // (4) // Close the connection when an exception is raised. //cause.printStackTrace(System.out); LOG.info("Connection closed by:" + ctx.channel().remoteAddress().toString()); LOG.error(cause.getMessage(), cause); ctx.close(); } } class RpcRequestDecoder extends MessageToMessageDecoder<ByteBuf> { protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) throws Exception { ByteBufInputStream in = new ByteBufInputStream(msg); RequestHeader.Builder hbuilder = RequestHeader.newBuilder(); hbuilder.mergeDelimitedFrom(in); RequestHeader header = hbuilder.build(); BlockingService service = RaftRpcService.create().getService(); MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getRequestName()); Builder builder = service.getRequestPrototype(md).newBuilderForType(); Message body = null; if (builder != null) { if(builder.mergeDelimitedFrom(in)) { body = builder.build(); } else { LOG.error("Parsing packet failed!"); } } RpcCall call = new RpcCall(header.getId(), header, body, md); out.add(call); } } class RpcResponseEncoder extends MessageToMessageEncoder<RpcCall> { @Override protected void encode(ChannelHandlerContext ctx, RpcCall call, List<Object> out) throws Exception { int totalSize = PacketUtils.getTotalSizeofMessages(call.getHeader(), call.getMessage()); ByteBuf encoded = ctx.alloc().buffer(totalSize); ByteBufOutputStream os = new ByteBufOutputStream(encoded); try { call.getHeader().writeDelimitedTo(os); if (call.getMessage() != null) { call.getMessage().writeDelimitedTo(os); } out.add(encoded); LOG.debug("RpcServer encode response, call ID: " + call.getCallId()); } catch(Exception e) { LOG.error("Rpc Server encode exception:" + e.getMessage(), e); } } } }