package com.chicm.cmraft.rpc;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import com.chicm.cmraft.core.RaftRpcService;
import com.chicm.cmraft.protobuf.generated.RaftProtos.ResponseHeader;
import com.chicm.cmraft.util.BlockingHashMap;
import com.google.protobuf.BlockingService;
import com.google.protobuf.Message;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Message.Builder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
public class ClientChannelHandler extends ChannelInitializer<Channel> {
static final Log LOG = LogFactory.getLog(ClientChannelHandler.class);
private static final int MAX_PACKET_SIZE = 1024*1024*100;
private ChannelHandlerContext activeCtx;
private RpcClientEventListener listener;
//private long startTime = System.currentTimeMillis();
public ClientChannelHandler(RpcClientEventListener listener) {
this.listener = listener;
}
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline().addLast("FrameDecoder", new LengthFieldBasedFrameDecoder(MAX_PACKET_SIZE,0,4,0,4));
ch.pipeline().addLast("FrameEncoder", new LengthFieldPrepender(4));
ch.pipeline().addLast("MessageDecoder", new RpcResponseDecoder() );
ch.pipeline().addLast("MessageEncoder", new RpcRequestEncoder());
ch.pipeline().addLast("ClientHandler", new RpcResponseHandler());
LOG.debug("initChannel");
}
public ChannelHandlerContext getCtx() {
return activeCtx;
}
class RpcResponseHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
RpcCall call = (RpcCall) msg;
LOG.debug("client channel read, callid: " + call.getCallId());
listener.onRpcResponse(call);
}
/* (non-Javadoc)
* @see io.netty.channel.ChannelStateHandlerAdapter#channelActive(io.netty.channel.ChannelHandlerContext)
*/
@Override
public void channelActive(final ChannelHandlerContext ctx) { // (1)
activeCtx = ctx;
LOG.debug("Client Channel Active");
}
/* (non-Javadoc)
* @see io.netty.channel.ChannelStateHandlerAdapter#channelInactive(io.netty.channel.ChannelHandlerContext)
*/
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
}
/* (non-Javadoc)
* @see io.netty.channel.ChannelStateHandlerAdapter#exceptionCaught(io.netty.channel.ChannelHandlerContext)
*/
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
LOG.error("Socket Exception: " + cause.getMessage(), cause);
listener.channelClosed();
}
}
class RpcResponseDecoder extends MessageToMessageDecoder<ByteBuf> {
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
throws Exception {
ByteBufInputStream in = new ByteBufInputStream(msg);
ResponseHeader.Builder hbuilder = ResponseHeader.newBuilder();
hbuilder.mergeDelimitedFrom(in);
ResponseHeader header = hbuilder.build();
BlockingService service = RaftRpcService.create().getService();
MethodDescriptor md = service.getDescriptorForType().findMethodByName(header.getResponseName());
Builder builder = service.getResponsePrototype(md).newBuilderForType();
Message body = null;
if (builder != null) {
if(builder.mergeDelimitedFrom(in)) {
body = builder.build();
} else {
LOG.error("Parse packet failed!!");
}
}
RpcCall call = new RpcCall(header.getId(), header, body, md);
out.add(call);
}
}
class RpcRequestEncoder extends MessageToMessageEncoder<RpcCall> {
@Override
protected void encode(ChannelHandlerContext ctx, RpcCall call, List<Object> out) throws Exception {
//System.out.println("RpcMessageEncoder");
int totalSize = PacketUtils.getTotalSizeofMessages(call.getHeader(), call.getMessage());
ByteBuf encoded = ctx.alloc().buffer(totalSize);
ByteBufOutputStream os = new ByteBufOutputStream(encoded);
try {
call.getHeader().writeDelimitedTo(os);
if (call.getMessage() != null) {
call.getMessage().writeDelimitedTo(os);
}
out.add(encoded);
LOG.debug("Rpc encode: " + call.getCallId());
} catch(Exception e) {
LOG.error("Rpc Encoder exception:" + e.getMessage(), e);
}
}
}
}