/** * Copyright 2014 The CmRaft Project * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.chicm.cmraft.rpc; import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioSocketChannel; import java.io.IOException; import java.util.Random; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import com.chicm.cmraft.common.CmRaftConfiguration; import com.chicm.cmraft.common.Configuration; import com.chicm.cmraft.common.ServerInfo; import com.chicm.cmraft.protobuf.generated.RaftProtos.RaftService; import com.chicm.cmraft.protobuf.generated.RaftProtos.RequestHeader; import com.chicm.cmraft.protobuf.generated.RaftProtos.RaftService.BlockingInterface; import com.chicm.cmraft.protobuf.generated.RaftProtos.TestRpcRequest; import com.chicm.cmraft.util.BlockingHashMap; import com.google.common.base.Preconditions; import com.google.protobuf.BlockingRpcChannel; import com.google.protobuf.ByteString; import com.google.protobuf.Message; import com.google.protobuf.RpcController; import com.google.protobuf.ServiceException; import com.google.protobuf.Descriptors.MethodDescriptor; /** * RpcClient implements the BlockingRpcChannel interface with inner class. It translate RPC method calls to * RPC request packets and send them to RPC server. Then translate RPC response packets from * RPC server to returned objects for RPC method calls. * * @author chicm * */ public class RpcClient { static final Log LOG = LogFactory.getLog(RpcClient.class); private final static String RPC_TIMEOUT_KEY = "raft.rpc.timeout"; private final static int DEFAULT_RPC_TIMEOUT = 3000; private static volatile AtomicInteger client_call_id = new AtomicInteger(0); private BlockingInterface stub = null; private ChannelHandlerContext ctx = null; private BlockingHashMap<Integer, RpcCall> responsesMap = new BlockingHashMap<>(); private RpcClientEventListener listener = new RpcClientEventListenerImpl(); private volatile boolean connected = false; private int rpcTimeout; private ServerInfo remoteServer = null; public RpcClient(Configuration conf, ServerInfo remoteServer) { rpcTimeout = conf.getInt(RPC_TIMEOUT_KEY, DEFAULT_RPC_TIMEOUT); this.remoteServer = remoteServer; //todo: to change call id init value Random r = new Random(); client_call_id.set(r.nextInt(1000) * 100); } public boolean isConnected() { if(ctx == null) return false; if(!ctx.channel().isActive()) return false; return connected; } public synchronized boolean connect() throws IOException, InterruptedException, ExecutionException { if(isConnected()) return true; try { ctx = connectRemoteServer(); } catch(Exception e) { LOG.error("Failed connecting to:" + getRemoteServer() + " : " + e.getMessage()); try { if(ctx != null && ctx.channel().isOpen()) { ctx.close().sync(); } } catch(Exception e2) { LOG.error("Failed closing ctx, " + e2.getMessage()); } throw e; } BlockingRpcChannel c = createBlockingRpcChannel(); stub = RaftService.newBlockingStub(c); connected = true; return connected; } public ServerInfo getRemoteServer() { return remoteServer; } public synchronized void close() { try { LOG.info("Closing connection"); ctx.close().sync(); connected = false; } catch(Exception e) { LOG.error("Closing failed", e); } } public BlockingInterface getStub() throws Exception { if(!isConnected()) { if(!connect()) { return null; } } return stub; } private ChannelHandlerContext connectRemoteServer() throws InterruptedException { EventLoopGroup workerGroup = new NioEventLoopGroup(); try { ClientChannelHandler channelHandler = new ClientChannelHandler(listener); Bootstrap b = new Bootstrap(); b.group(workerGroup); b.channel(NioSocketChannel.class); b.option(ChannelOption.SO_KEEPALIVE, true); b.handler(channelHandler); ChannelFuture f = b.connect(getRemoteServer().getHost(), getRemoteServer().getPort()).sync(); LOG.debug("connected to: " + this.getRemoteServer() ); return channelHandler.getCtx(); // Wait until the connection is closed. //f.channel().closeFuture().sync(); } finally { //workerGroup.shutdownGracefully(); } } public static int generateCallId() { return client_call_id.incrementAndGet(); } public static int getCallId() { return client_call_id.get(); } private BlockingRpcChannel createBlockingRpcChannel() { return new BlockingRpcChannelImplementation(); } class BlockingRpcChannelImplementation implements BlockingRpcChannel { @Override public Message callBlockingMethod(MethodDescriptor md, RpcController controller, Message request, Message returnType) throws ServiceException { Message response = null; int callId = generateCallId(); try { RequestHeader.Builder builder = RequestHeader.newBuilder(); builder.setId(callId); builder.setRequestName(md.getName()); RequestHeader header = builder.build(); LOG.debug("SENDING RPC, CALLID:" + header.getId()); RpcCall call = new RpcCall(callId, header, request, md); long tm = System.currentTimeMillis(); ctx.writeAndFlush(call); RpcCall result = responsesMap.take(callId, rpcTimeout); response = result != null? result.getMessage() : null; if(response != null) { LOG.debug("response taken: " + callId); LOG.debug(String.format("RPC[%d] round trip takes %d ms", header.getId(), (System.currentTimeMillis() - tm))); } } catch(RpcTimeoutException e) { LOG.error("Rpc Timeout, call ID:" + callId + ", remote server:" + getRemoteServer()); LOG.error("Rpc Timeout, call:" + request); LOG.error("Rpc Timeout", e); ServiceException se = new ServiceException(e.getMessage(), e); throw se; } catch(Exception e) { LOG.error("ctx:" + ctx); LOG.error("callBlockingMethod exception", e); throw e; } return response; } } class RpcClientEventListenerImpl implements RpcClientEventListener { @Override public void channelClosed() { ctx.close(); connected = false; } @Override public void onRpcResponse(RpcCall call) { Preconditions.checkNotNull(call); responsesMap.put(call.getCallId(), call); } } /* * For testing purpose * @param args * @throws Exception */ public static void main(String[] args) throws Exception { if(args.length < 3) { System.out.println("usage: RpcClient <server host> <server port> <clients number> <threads number> <packetsize>"); return; } String host = args[0]; int port = Integer.parseInt(args[1]); int nclients = Integer.parseInt(args[2]); int nThreads = Integer.parseInt(args[3]); int nPacketSize = 1024; if(args.length >= 5) { nPacketSize = Integer.parseInt(args[4]); } for(int j =0; j < nclients; j++ ) { RpcClient client = new RpcClient(CmRaftConfiguration.create(), new ServerInfo(host, port)); for(int i = 0; i < nThreads; i++) { new Thread(new TestRpcWorker(client, nPacketSize)).start(); } } } static class TestRpcWorker implements Runnable{ private RpcClient client; private int packetSize; public TestRpcWorker(RpcClient client, int size) { this.client = client; this.packetSize = size; } @Override public void run() { client.sendRequest(packetSize); } } public void testRpc(int packetSize) throws Exception { TestRpcRequest.Builder builder = TestRpcRequest.newBuilder(); byte[] bytes = new byte[packetSize]; builder.setData(ByteString.copyFrom(bytes)); stub.testRpc(null, builder.build()); } private ThreadLocal<Long> startTime = new ThreadLocal<>(); /* * For testing purpose */ public void sendRequest(int packetSize) { if(!this.isConnected()) { try { if(!connect()) { LOG.error("INIT error"); return; } } catch(Exception e) { LOG.error("RpcClient init exception", e); return; } } LOG.info("client thread started"); long starttime = System.currentTimeMillis(); try { for(int i = 0; i < 5000000 ;i++) { startTime.set(System.currentTimeMillis()); testRpc(packetSize); if(i != 0 && i %1000 == 0 ) { long ms = System.currentTimeMillis() - starttime; LOG.debug("RPC CALL[ " + i + "] round trip time: " + ms); long curtm = System.currentTimeMillis(); long elipsetm = (curtm - starttime) /1000; if(elipsetm == 0) elipsetm =1; long tps = i / elipsetm; LOG.info("response id: " + i + " time: " + elipsetm + " TPS: " + tps); } } } catch (Exception e) { e.printStackTrace(System.out); } } }