/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tajo.rpc; import com.google.protobuf.*; import com.google.protobuf.Descriptors.MethodDescriptor; import io.netty.channel.ChannelHandler; import io.netty.channel.EventLoopGroup; import org.apache.tajo.rpc.RpcProtos.RpcResponse; import java.lang.reflect.Method; import java.net.InetSocketAddress; import java.util.Properties; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.tajo.rpc.RpcConstants.*; public class BlockingRpcClient extends NettyClientBase<BlockingRpcClient.ProtoCallFuture> { private final Method stubMethod; private final ProxyRpcChannel rpcChannel; private final NettyChannelInboundHandler handler; /** * Intentionally make this method package-private, avoiding user directly * new an instance through this constructor. * * @param rpcConnectionKey RpcConnectionKey * @param eventLoopGroup Thread pool of netty's * @param rpcParams Rpc connection parameters (see RpcConstants) * * @throws ClassNotFoundException * @throws NoSuchMethodException * @see RpcConstants */ public BlockingRpcClient(EventLoopGroup eventLoopGroup, RpcConnectionKey rpcConnectionKey, Properties rpcParams) throws ClassNotFoundException, NoSuchMethodException { super(rpcConnectionKey, rpcParams); this.stubMethod = getServiceClass().getMethod("newBlockingStub", BlockingRpcChannel.class); this.rpcChannel = new ProxyRpcChannel(); this.handler = new ClientChannelInboundHandler(); long socketTimeoutMills = Long.parseLong( rpcParams.getProperty(CLIENT_SOCKET_TIMEOUT, String.valueOf(CLIENT_SOCKET_TIMEOUT_DEFAULT))); // Enable proactive hang detection final boolean hangDetectionEnabled = Boolean.parseBoolean( rpcParams.getProperty(CLIENT_HANG_DETECTION, String.valueOf(CLIENT_HANG_DETECTION_DEFAULT))); init(new ProtoClientChannelInitializer(handler, RpcResponse.getDefaultInstance(), socketTimeoutMills, hangDetectionEnabled), eventLoopGroup); } @Override public <I> I getStub() { return getStub(stubMethod, rpcChannel); } @Override protected NettyChannelInboundHandler getHandler() { return handler; } private class ProxyRpcChannel implements BlockingRpcChannel { private final AtomicInteger sequence = new AtomicInteger(0); @Override public Message callBlockingMethod(final MethodDescriptor method, final RpcController controller, final Message param, final Message responsePrototype) throws TajoServiceException { int nextSeqId = sequence.getAndIncrement(); RpcProtos.RpcRequest rpcRequest = buildRequest(nextSeqId, method, param); ProtoCallFuture callFuture = new ProtoCallFuture(controller, responsePrototype); invoke(rpcRequest, callFuture, 0); try { return callFuture.get(); } catch (Throwable t) { if (t instanceof ExecutionException) { Throwable cause = t.getCause(); if (cause != null && cause instanceof TajoServiceException) { throw (TajoServiceException) cause; } } throw new TajoServiceException(t.getMessage()); } } } private TajoServiceException makeTajoServiceException(RpcResponse response, Throwable cause) { if (getChannel() != null) { return new TajoServiceException(response.getErrorMessage(), cause, getKey().protocolClass.getName(), RpcUtils.normalizeInetSocketAddress((InetSocketAddress) getChannel().remoteAddress())); } else { return new TajoServiceException(response.getErrorMessage()); } } @ChannelHandler.Sharable public class ClientChannelInboundHandler extends NettyChannelInboundHandler { @Override protected void run(RpcResponse rpcResponse, ProtoCallFuture callback) throws Exception { if (rpcResponse.hasErrorMessage()) { callback.setFailed(rpcResponse.getErrorMessage(), makeTajoServiceException(rpcResponse, new ServiceException(rpcResponse.getErrorTrace()))); } else { Message responseMessage = null; if (rpcResponse.hasResponseMessage()) { try { responseMessage = callback.returnType.newBuilderForType().mergeFrom(rpcResponse.getResponseMessage()) .build(); } catch (InvalidProtocolBufferException e) { callback.setFailed(e.getMessage(), e); } } callback.setResponse(responseMessage); } } @Override protected void handleException(int requestId, ProtoCallFuture callback, String message) { callback.setFailed(message + "", new TajoServiceException(message)); } } static class ProtoCallFuture implements Future<Message> { private Semaphore sem = new Semaphore(0); private boolean done = false; private Message response = null; private Message returnType; private RpcController controller; private ExecutionException ee; public ProtoCallFuture(RpcController controller, Message message) { this.controller = controller; this.returnType = message; } @Override public boolean cancel(boolean arg0) { return false; } @Override public Message get() throws InterruptedException, ExecutionException { if(!isDone()) sem.acquire(); if (ee != null) { throw ee; } return response; } @Override public Message get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { if(!isDone()) { if (!sem.tryAcquire(timeout, unit)) { throw new TimeoutException(); } } if (ee != null) { throw ee; } return response; } @Override public boolean isCancelled() { return false; } @Override public boolean isDone() { return done; } public void setResponse(Message response) { this.response = response; done = true; sem.release(); } public void setFailed(String errorText, Throwable t) { if (controller != null) { this.controller.setFailed(errorText); } ee = new ExecutionException(errorText, t); done = true; sem.release(); } } }