/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.rpc; import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.*; import io.netty.channel.ChannelHandler; import io.netty.channel.EventLoopGroup; import org.apache.tajo.rpc.RpcProtos.RpcResponse; import java.lang.reflect.Method; import java.util.Properties; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.tajo.rpc.RpcConstants.*; public class AsyncRpcClient extends NettyClientBase<AsyncRpcClient.ResponseCallback> { private final Method stubMethod; private final ProxyRpcChannel rpcChannel; private final NettyChannelInboundHandler handler; /** * Intentionally make this method package-private, avoiding user directly * new an instance through this constructor. * * @param rpcConnectionKey RpcConnectionKey * @param eventLoopGroup Thread pool of netty's * @param rpcParams Rpc connection parameters (see RpcConstants) * * @throws ClassNotFoundException * @throws NoSuchMethodException */ AsyncRpcClient(EventLoopGroup eventLoopGroup, RpcConnectionKey rpcConnectionKey, Properties rpcParams) throws ClassNotFoundException, NoSuchMethodException { super(rpcConnectionKey, rpcParams); this.stubMethod = getServiceClass().getMethod("newStub", RpcChannel.class); this.rpcChannel = new ProxyRpcChannel(); this.handler = new ClientChannelInboundHandler(); final long socketTimeoutMills = Long.parseLong( rpcParams.getProperty(CLIENT_SOCKET_TIMEOUT, String.valueOf(CLIENT_SOCKET_TIMEOUT_DEFAULT))); // Enable proactive hang detection final boolean hangDetectionEnabled = Boolean.parseBoolean( rpcParams.getProperty(CLIENT_HANG_DETECTION, String.valueOf(CLIENT_HANG_DETECTION_DEFAULT))); init(new ProtoClientChannelInitializer(handler, RpcResponse.getDefaultInstance(), socketTimeoutMills, hangDetectionEnabled), eventLoopGroup); } @Override public <I> I getStub() { return getStub(stubMethod, rpcChannel); } @Override protected NettyChannelInboundHandler getHandler() { return handler; } private class ProxyRpcChannel implements RpcChannel { private final AtomicInteger sequence = new AtomicInteger(0); public void callMethod(final MethodDescriptor method, final RpcController controller, final Message param, final Message responseType, final RpcCallback<Message> done) { int nextSeqId = sequence.getAndIncrement(); RpcProtos.RpcRequest rpcRequest = buildRequest(nextSeqId, method, param); invoke(rpcRequest, new ResponseCallback(controller, responseType, done), 0); } } @ChannelHandler.Sharable private class ClientChannelInboundHandler extends NettyChannelInboundHandler { @Override protected void run(RpcResponse response, ResponseCallback callback) throws Exception { callback.run(response); } @Override protected void handleException(int requestId, ResponseCallback callback, String message) { RpcResponse.Builder responseBuilder = RpcResponse.newBuilder() .setErrorMessage(message + "") .setId(requestId); callback.run(responseBuilder.build()); } } static class ResponseCallback implements RpcCallback<RpcResponse> { private final RpcController controller; private final Message responsePrototype; private final RpcCallback<Message> callback; public ResponseCallback(RpcController controller, Message responsePrototype, RpcCallback<Message> callback) { this.controller = controller; this.responsePrototype = responsePrototype; this.callback = callback; } @Override public void run(RpcResponse rpcResponse) { // if hasErrorMessage is true, it means rpc-level errors. // it can be called the callback function with null response. if (rpcResponse.hasErrorMessage()) { if (controller != null) { this.controller.setFailed(rpcResponse.getErrorMessage()); } callback.run(null); } else { // if rpc call succeed Message responseMessage = null; if (rpcResponse.hasResponseMessage()) { try { responseMessage = responsePrototype.newBuilderForType().mergeFrom( rpcResponse.getResponseMessage()).build(); } catch (InvalidProtocolBufferException e) { if (controller != null) { this.controller.setFailed(e.getMessage()); } } } callback.run(responseMessage); } } } }