/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.rpc; import com.google.common.base.Preconditions; import com.google.protobuf.Descriptors; import com.google.protobuf.Message; import com.google.protobuf.ServiceException; import io.netty.bootstrap.Bootstrap; import io.netty.channel.*; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import io.netty.util.concurrent.GenericFutureListener; import org.apache.commons.lang.exception.ExceptionUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.tajo.rpc.RpcProtos.RpcResponse; import java.io.Closeable; import java.lang.reflect.Method; import java.net.ConnectException; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.channels.UnresolvedAddressException; import java.util.Collection; import java.util.Properties; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import static org.apache.tajo.rpc.RpcConstants.*; public abstract class NettyClientBase<T> implements ProtoDeclaration, Closeable { public final static Log LOG = LogFactory.getLog(NettyClientBase.class); private final RpcConnectionKey key; /** Number to retry for connection and RPC invocation */ private final int maxRetryNum; /** Connection Timeout */ private final long connTimeoutMillis; private boolean enableMonitor; private final ConcurrentMap<RpcConnectionKey, ChannelEventListener> channelEventListeners = new ConcurrentHashMap<>(); private final ConcurrentMap<Integer, T> requests = new ConcurrentHashMap<>(); private Bootstrap bootstrap; private volatile ChannelFuture channelFuture; /** * Constructor of NettyClientBase * * @param rpcConnectionKey RpcConnectionKey * @param rpcParams Rpc connection parameters (see RpcConstants) * * @throws ClassNotFoundException * @throws NoSuchMethodException * @see RpcConstants */ public NettyClientBase(RpcConnectionKey rpcConnectionKey, Properties rpcParams) throws ClassNotFoundException, NoSuchMethodException { this.key = rpcConnectionKey; this.maxRetryNum = Integer.parseInt( rpcParams.getProperty(CLIENT_RETRY_NUM, String.valueOf(CLIENT_RETRY_NUM_DEFAULT))); this.connTimeoutMillis = Integer.parseInt( rpcParams.getProperty(CLIENT_CONNECTION_TIMEOUT, String.valueOf(CLIENT_CONNECTION_TIMEOUT_DEFAULT))); // Netty only takes integer value range and this is to avoid integer overflow. Preconditions.checkArgument(this.connTimeoutMillis <= Integer.MAX_VALUE, "Too long connection timeout"); } // should be called from sub class protected void init(ChannelInitializer<Channel> initializer, EventLoopGroup eventLoopGroup) { this.bootstrap = new Bootstrap(); this.bootstrap .group(eventLoopGroup) .channel(NioSocketChannel.class) .handler(initializer) .option(ChannelOption.ALLOCATOR, NettyUtils.ALLOCATOR) .option(ChannelOption.SO_REUSEADDR, true) .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connTimeoutMillis) .option(ChannelOption.SO_RCVBUF, 1048576 * 10) .option(ChannelOption.TCP_NODELAY, true); } public RpcConnectionKey getKey() { return key; } protected final Class<?> getServiceClass() throws ClassNotFoundException { String serviceClassName = getKey().protocolClass.getName() + "$" + getKey().protocolClass.getSimpleName() + "Service"; return Class.forName(serviceClassName); } @SuppressWarnings("unchecked") protected final <I> I getStub(Method stubMethod, Object rpcChannel) { try { return (I) stubMethod.invoke(null, rpcChannel); } catch (Exception e) { throw new RemoteException(e.getMessage(), e); } } protected static RpcProtos.RpcRequest buildRequest(int seqId, Descriptors.MethodDescriptor method, Message param) { RpcProtos.RpcRequest.Builder requestBuilder = RpcProtos.RpcRequest.newBuilder() .setId(seqId) .setMethodName(method.getName()); if (param != null) { requestBuilder.setRequestMessage(param.toByteString()); } return requestBuilder.build(); } /** * Repeat invoke rpc request until the connection attempt succeeds or exceeded retries */ protected void invoke(final RpcProtos.RpcRequest rpcRequest, final T callback, final int retry) { if(getChannel().eventLoop().isShuttingDown()) { LOG.warn("RPC is shutting down"); return; } ChannelPromise promise = getChannel().newPromise(); promise.addListener(new GenericFutureListener<ChannelFuture>() { @Override public void operationComplete(final ChannelFuture future) throws Exception { if (future.isSuccess()) { getHandler().registerCallback(rpcRequest.getId(), callback); } else { if (!future.channel().isActive() && retry < maxRetryNum) { /* schedule the current request for the retry */ LOG.warn(future.cause() + " Try to reconnect :" + getKey().addr); final EventLoop loop = future.channel().eventLoop(); loop.schedule(new Runnable() { @Override public void run() { doConnect(getKey().addr).addListener(new GenericFutureListener<ChannelFuture>() { @Override public void operationComplete(ChannelFuture future) throws Exception { invoke(rpcRequest, callback, retry + 1); } }); } }, RpcConstants.DEFAULT_PAUSE, TimeUnit.MILLISECONDS); } else { /* Max retry count has been exceeded or internal failure */ getHandler().registerCallback(rpcRequest.getId(), callback); getHandler().exceptionCaught(getChannel().pipeline().lastContext(), new RecoverableException(rpcRequest.getId(), future.cause())); } } } }); getChannel().writeAndFlush(rpcRequest, promise); } private static InetSocketAddress resolveAddress(InetSocketAddress address) { if (address.isUnresolved()) { return RpcUtils.createSocketAddr(address.getHostName(), address.getPort()); } return address; } private ChannelFuture doConnect(SocketAddress address) { return this.channelFuture = bootstrap.clone().connect(address); } private ConnectException makeConnectException(InetSocketAddress address, ChannelFuture future) { if (future.cause() instanceof UnresolvedAddressException) { return new ConnectException("Can't resolve host name: " + address.toString()); } else { return new ConnectTimeoutException(future.cause().getMessage()); } } public synchronized void connect() throws ConnectException { if (isConnected()) return; int retries = 0; InetSocketAddress address = key.addr; if (address.isUnresolved()) { address = resolveAddress(address); } /* do not call await() inside handler */ ChannelFuture f = doConnect(address).awaitUninterruptibly(); if (!f.isSuccess()) { if (maxRetryNum > 0) { doReconnect(address, f, ++retries); } else { throw makeConnectException(address, f); } } } private void doReconnect(final InetSocketAddress address, ChannelFuture future, int retries) throws ConnectException { for (; ; ) { if (maxRetryNum > retries) { retries++; if(getChannel().eventLoop().isShuttingDown()) { LOG.warn("RPC is shutting down"); return; } LOG.warn(getErrorMessage(ExceptionUtils.getMessage(future.cause())) + "\nTry to reconnect : " + getKey().addr); try { Thread.sleep(RpcConstants.DEFAULT_PAUSE); } catch (InterruptedException e) { } this.channelFuture = doConnect(address).awaitUninterruptibly(); if (this.channelFuture.isDone() && this.channelFuture.isSuccess()) { break; } } else { LOG.error("Max retry count has been exceeded. attempts=" + retries + " caused by: " + future.cause()); throw makeConnectException(address, future); } } } protected abstract NettyChannelInboundHandler getHandler(); public Channel getChannel() { return channelFuture == null ? null : channelFuture.channel(); } public boolean isConnected() { Channel channel = getChannel(); return channel != null && channel.isActive(); } public SocketAddress getRemoteAddress() { Channel channel = getChannel(); return channel == null ? null : channel.remoteAddress(); } public int getActiveRequests() { return requests.size(); } public boolean subscribeEvent(RpcConnectionKey key, ChannelEventListener listener) { return channelEventListeners.putIfAbsent(key, listener) == null; } public void removeSubscribers() { channelEventListeners.clear(); } public Collection<ChannelEventListener> getSubscribers() { return channelEventListeners.values(); } private String getErrorMessage(String message) { return "Exception [" + getKey().protocolClass.getCanonicalName() + "(" + getKey().addr + ")]: " + message; } @Override public void close() { Channel channel = getChannel(); if (channel != null && channel.isOpen()) { LOG.debug("Proxy will be disconnected from remote " + channel.remoteAddress()); /* channelInactive receives event and then client terminates all the requests */ channel.close().syncUninterruptibly(); } } protected abstract class NettyChannelInboundHandler extends SimpleChannelInboundHandler<RpcResponse> { protected void registerCallback(int seqId, T callback) { if (requests.putIfAbsent(seqId, callback) != null) { throw new RemoteException( getErrorMessage("Duplicate Sequence Id " + seqId)); } } @Override public void channelRegistered(ChannelHandlerContext ctx) throws Exception { MonitorClientHandler handler = ctx.pipeline().get(MonitorClientHandler.class); if (handler != null) { enableMonitor = true; } for (ChannelEventListener listener : getSubscribers()) { listener.channelRegistered(ctx); } super.channelRegistered(ctx); } @Override public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { for (ChannelEventListener listener : getSubscribers()) { listener.channelUnregistered(ctx); } super.channelUnregistered(ctx); } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { super.channelActive(ctx); LOG.debug("Connection established successfully : " + ctx.channel()); } @Override public void channelInactive(ChannelHandlerContext ctx) throws Exception { super.channelInactive(ctx); sendExceptions("Connection lost :" + getKey().addr); } @Override protected final void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception { T callback = requests.remove(response.getId()); if (callback == null) LOG.warn("Dangling rpc call"); else run(response, callback); } /** * A {@link #channelRead0} received a message. * @param response response proto of type {@link RpcResponse}. * @param callback callback of type {@link T}. * @throws Exception */ protected abstract void run(RpcResponse response, T callback) throws Exception; /** * Calls from exceptionCaught * @param requestId sequence id of request. * @param callback callback of type {@link T}. * @param message the error message to handle */ protected abstract void handleException(int requestId, T callback, String message); @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { Throwable rootCause = ExceptionUtils.getRootCause(cause); LOG.error(getErrorMessage(ExceptionUtils.getMessage(rootCause)), rootCause); if (cause instanceof RecoverableException) { sendException((RecoverableException) cause); } else { /* unrecoverable fatal error*/ sendExceptions(ExceptionUtils.getMessage(rootCause)); if (ctx.channel().isOpen()) { ctx.close(); } } } /** * Send an error to all callback */ private void sendExceptions(String message) { for (int requestId : requests.keySet()) { handleException(requestId, requests.remove(requestId), message); } } /** * Send an error to callback */ private void sendException(RecoverableException e) { T callback = requests.remove(e.getSeqId()); if (callback != null) { handleException(e.getSeqId(), callback, ExceptionUtils.getRootCauseMessage(e)); } } /** * Trigger timeout event */ @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (!enableMonitor && evt instanceof IdleStateEvent) { IdleStateEvent e = (IdleStateEvent) evt; /* If all requests is done and event is triggered, idle channel close. */ if (e.state() == IdleState.READER_IDLE && requests.isEmpty()) { ctx.close(); LOG.info("Idle connection closed successfully :" + ctx.channel()); } } else if (evt instanceof MonitorStateEvent) { MonitorStateEvent e = (MonitorStateEvent) evt; if (e.state() == MonitorStateEvent.MonitorState.PING_EXPIRED) { exceptionCaught(ctx, new ServiceException("Server has not respond: " + ctx.channel())); } } super.userEventTriggered(ctx, evt); } } }