/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tajo.rpc;
import com.google.common.base.Preconditions;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import com.google.protobuf.ServiceException;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import io.netty.util.concurrent.GenericFutureListener;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.tajo.rpc.RpcProtos.RpcResponse;
import java.io.Closeable;
import java.lang.reflect.Method;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.channels.UnresolvedAddressException;
import java.util.Collection;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import static org.apache.tajo.rpc.RpcConstants.*;
public abstract class NettyClientBase<T> implements ProtoDeclaration, Closeable {
public final static Log LOG = LogFactory.getLog(NettyClientBase.class);
private final RpcConnectionKey key;
/** Number to retry for connection and RPC invocation */
private final int maxRetryNum;
/** Connection Timeout */
private final long connTimeoutMillis;
private boolean enableMonitor;
private final ConcurrentMap<RpcConnectionKey, ChannelEventListener> channelEventListeners = new ConcurrentHashMap<>();
private final ConcurrentMap<Integer, T> requests = new ConcurrentHashMap<>();
private Bootstrap bootstrap;
private volatile ChannelFuture channelFuture;
/**
* Constructor of NettyClientBase
*
* @param rpcConnectionKey RpcConnectionKey
* @param rpcParams Rpc connection parameters (see RpcConstants)
*
* @throws ClassNotFoundException
* @throws NoSuchMethodException
* @see RpcConstants
*/
public NettyClientBase(RpcConnectionKey rpcConnectionKey, Properties rpcParams)
throws ClassNotFoundException, NoSuchMethodException {
this.key = rpcConnectionKey;
this.maxRetryNum = Integer.parseInt(
rpcParams.getProperty(CLIENT_RETRY_NUM, String.valueOf(CLIENT_RETRY_NUM_DEFAULT)));
this.connTimeoutMillis = Integer.parseInt(
rpcParams.getProperty(CLIENT_CONNECTION_TIMEOUT, String.valueOf(CLIENT_CONNECTION_TIMEOUT_DEFAULT)));
// Netty only takes integer value range and this is to avoid integer overflow.
Preconditions.checkArgument(this.connTimeoutMillis <= Integer.MAX_VALUE, "Too long connection timeout");
}
// should be called from sub class
protected void init(ChannelInitializer<Channel> initializer, EventLoopGroup eventLoopGroup) {
this.bootstrap = new Bootstrap();
this.bootstrap
.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.handler(initializer)
.option(ChannelOption.ALLOCATOR, NettyUtils.ALLOCATOR)
.option(ChannelOption.SO_REUSEADDR, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connTimeoutMillis)
.option(ChannelOption.SO_RCVBUF, 1048576 * 10)
.option(ChannelOption.TCP_NODELAY, true);
}
public RpcConnectionKey getKey() {
return key;
}
protected final Class<?> getServiceClass() throws ClassNotFoundException {
String serviceClassName = getKey().protocolClass.getName() + "$" +
getKey().protocolClass.getSimpleName() + "Service";
return Class.forName(serviceClassName);
}
@SuppressWarnings("unchecked")
protected final <I> I getStub(Method stubMethod, Object rpcChannel) {
try {
return (I) stubMethod.invoke(null, rpcChannel);
} catch (Exception e) {
throw new RemoteException(e.getMessage(), e);
}
}
protected static RpcProtos.RpcRequest buildRequest(int seqId,
Descriptors.MethodDescriptor method,
Message param) {
RpcProtos.RpcRequest.Builder requestBuilder = RpcProtos.RpcRequest.newBuilder()
.setId(seqId)
.setMethodName(method.getName());
if (param != null) {
requestBuilder.setRequestMessage(param.toByteString());
}
return requestBuilder.build();
}
/**
* Repeat invoke rpc request until the connection attempt succeeds or exceeded retries
*/
protected void invoke(final RpcProtos.RpcRequest rpcRequest, final T callback, final int retry) {
if(getChannel().eventLoop().isShuttingDown()) {
LOG.warn("RPC is shutting down");
return;
}
ChannelPromise promise = getChannel().newPromise();
promise.addListener(new GenericFutureListener<ChannelFuture>() {
@Override
public void operationComplete(final ChannelFuture future) throws Exception {
if (future.isSuccess()) {
getHandler().registerCallback(rpcRequest.getId(), callback);
} else {
if (!future.channel().isActive() && retry < maxRetryNum) {
/* schedule the current request for the retry */
LOG.warn(future.cause() + " Try to reconnect :" + getKey().addr);
final EventLoop loop = future.channel().eventLoop();
loop.schedule(new Runnable() {
@Override
public void run() {
doConnect(getKey().addr).addListener(new GenericFutureListener<ChannelFuture>() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
invoke(rpcRequest, callback, retry + 1);
}
});
}
}, RpcConstants.DEFAULT_PAUSE, TimeUnit.MILLISECONDS);
} else {
/* Max retry count has been exceeded or internal failure */
getHandler().registerCallback(rpcRequest.getId(), callback);
getHandler().exceptionCaught(getChannel().pipeline().lastContext(),
new RecoverableException(rpcRequest.getId(), future.cause()));
}
}
}
});
getChannel().writeAndFlush(rpcRequest, promise);
}
private static InetSocketAddress resolveAddress(InetSocketAddress address) {
if (address.isUnresolved()) {
return RpcUtils.createSocketAddr(address.getHostName(), address.getPort());
}
return address;
}
private ChannelFuture doConnect(SocketAddress address) {
return this.channelFuture = bootstrap.clone().connect(address);
}
private ConnectException makeConnectException(InetSocketAddress address, ChannelFuture future) {
if (future.cause() instanceof UnresolvedAddressException) {
return new ConnectException("Can't resolve host name: " + address.toString());
} else {
return new ConnectTimeoutException(future.cause().getMessage());
}
}
public synchronized void connect() throws ConnectException {
if (isConnected()) return;
int retries = 0;
InetSocketAddress address = key.addr;
if (address.isUnresolved()) {
address = resolveAddress(address);
}
/* do not call await() inside handler */
ChannelFuture f = doConnect(address).awaitUninterruptibly();
if (!f.isSuccess()) {
if (maxRetryNum > 0) {
doReconnect(address, f, ++retries);
} else {
throw makeConnectException(address, f);
}
}
}
private void doReconnect(final InetSocketAddress address, ChannelFuture future, int retries)
throws ConnectException {
for (; ; ) {
if (maxRetryNum > retries) {
retries++;
if(getChannel().eventLoop().isShuttingDown()) {
LOG.warn("RPC is shutting down");
return;
}
LOG.warn(getErrorMessage(ExceptionUtils.getMessage(future.cause())) + "\nTry to reconnect : " + getKey().addr);
try {
Thread.sleep(RpcConstants.DEFAULT_PAUSE);
} catch (InterruptedException e) {
}
this.channelFuture = doConnect(address).awaitUninterruptibly();
if (this.channelFuture.isDone() && this.channelFuture.isSuccess()) {
break;
}
} else {
LOG.error("Max retry count has been exceeded. attempts=" + retries + " caused by: " + future.cause());
throw makeConnectException(address, future);
}
}
}
protected abstract NettyChannelInboundHandler getHandler();
public Channel getChannel() {
return channelFuture == null ? null : channelFuture.channel();
}
public boolean isConnected() {
Channel channel = getChannel();
return channel != null && channel.isActive();
}
public SocketAddress getRemoteAddress() {
Channel channel = getChannel();
return channel == null ? null : channel.remoteAddress();
}
public int getActiveRequests() {
return requests.size();
}
public boolean subscribeEvent(RpcConnectionKey key, ChannelEventListener listener) {
return channelEventListeners.putIfAbsent(key, listener) == null;
}
public void removeSubscribers() {
channelEventListeners.clear();
}
public Collection<ChannelEventListener> getSubscribers() {
return channelEventListeners.values();
}
private String getErrorMessage(String message) {
return "Exception [" + getKey().protocolClass.getCanonicalName() +
"(" + getKey().addr + ")]: " + message;
}
@Override
public void close() {
Channel channel = getChannel();
if (channel != null && channel.isOpen()) {
LOG.debug("Proxy will be disconnected from remote " + channel.remoteAddress());
/* channelInactive receives event and then client terminates all the requests */
channel.close().syncUninterruptibly();
}
}
protected abstract class NettyChannelInboundHandler extends SimpleChannelInboundHandler<RpcResponse> {
protected void registerCallback(int seqId, T callback) {
if (requests.putIfAbsent(seqId, callback) != null) {
throw new RemoteException(
getErrorMessage("Duplicate Sequence Id " + seqId));
}
}
@Override
public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
MonitorClientHandler handler = ctx.pipeline().get(MonitorClientHandler.class);
if (handler != null) {
enableMonitor = true;
}
for (ChannelEventListener listener : getSubscribers()) {
listener.channelRegistered(ctx);
}
super.channelRegistered(ctx);
}
@Override
public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
for (ChannelEventListener listener : getSubscribers()) {
listener.channelUnregistered(ctx);
}
super.channelUnregistered(ctx);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
LOG.debug("Connection established successfully : " + ctx.channel());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
super.channelInactive(ctx);
sendExceptions("Connection lost :" + getKey().addr);
}
@Override
protected final void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception {
T callback = requests.remove(response.getId());
if (callback == null)
LOG.warn("Dangling rpc call");
else run(response, callback);
}
/**
* A {@link #channelRead0} received a message.
* @param response response proto of type {@link RpcResponse}.
* @param callback callback of type {@link T}.
* @throws Exception
*/
protected abstract void run(RpcResponse response, T callback) throws Exception;
/**
* Calls from exceptionCaught
* @param requestId sequence id of request.
* @param callback callback of type {@link T}.
* @param message the error message to handle
*/
protected abstract void handleException(int requestId, T callback, String message);
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
throws Exception {
Throwable rootCause = ExceptionUtils.getRootCause(cause);
LOG.error(getErrorMessage(ExceptionUtils.getMessage(rootCause)), rootCause);
if (cause instanceof RecoverableException) {
sendException((RecoverableException) cause);
} else {
/* unrecoverable fatal error*/
sendExceptions(ExceptionUtils.getMessage(rootCause));
if (ctx.channel().isOpen()) {
ctx.close();
}
}
}
/**
* Send an error to all callback
*/
private void sendExceptions(String message) {
for (int requestId : requests.keySet()) {
handleException(requestId, requests.remove(requestId), message);
}
}
/**
* Send an error to callback
*/
private void sendException(RecoverableException e) {
T callback = requests.remove(e.getSeqId());
if (callback != null) {
handleException(e.getSeqId(), callback, ExceptionUtils.getRootCauseMessage(e));
}
}
/**
* Trigger timeout event
*/
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (!enableMonitor && evt instanceof IdleStateEvent) {
IdleStateEvent e = (IdleStateEvent) evt;
/* If all requests is done and event is triggered, idle channel close. */
if (e.state() == IdleState.READER_IDLE && requests.isEmpty()) {
ctx.close();
LOG.info("Idle connection closed successfully :" + ctx.channel());
}
} else if (evt instanceof MonitorStateEvent) {
MonitorStateEvent e = (MonitorStateEvent) evt;
if (e.state() == MonitorStateEvent.MonitorState.PING_EXPIRED) {
exceptionCaught(ctx, new ServiceException("Server has not respond: " + ctx.channel()));
}
}
super.userEventTriggered(ctx, evt);
}
}
}