/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.rakam.kume;
import org.rakam.kume.network.ClientChannelAdapter;
import org.rakam.kume.network.TCPServerHandler;
import org.rakam.kume.service.Service;
import org.rakam.kume.transport.PacketDecoder;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.RemovalCause;
import com.google.common.cache.RemovalNotification;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import org.rakam.kume.transport.Packet;
import org.rakam.kume.transport.PacketEncoder;
import org.rakam.kume.util.ThrowableNioEventLoopGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class NettyTransport implements Transport {
private final static Logger LOGGER = LoggerFactory.getLogger(NettyTransport.class);
private final Cache<Integer, CompletableFuture<Object>> messageHandlers = CacheBuilder.newBuilder()
.expireAfterWrite(200, TimeUnit.SECONDS)
.removalListener(this::removalListener).build();
// IO thread for TCP and UDP connections
final EventLoopGroup bossGroup = new NioEventLoopGroup(1);
// Processor thread pool that de-serializing/serializing incoming/outgoing packets
final EventLoopGroup workerGroup = new NioEventLoopGroup(4);
private TCPServerHandler server;
private final ThrowableNioEventLoopGroup requestExecutor;
private final List<Service> services;
private final Member localMember;
public NettyTransport(ThrowableNioEventLoopGroup requestExecutor, List<Service> services, Member localMember) {
this.requestExecutor = requestExecutor;
this.services = services;
this.localMember = localMember;
}
private void removalListener(RemovalNotification<Integer, CompletableFuture<Object>> notification) {
if (!notification.getCause().equals(RemovalCause.EXPLICIT))
notification.getValue().completeExceptionally(new TimeoutException());
}
@Override
public MemberChannel connect(Member member) throws InterruptedException {
Bootstrap b = new Bootstrap();
b.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
b.group(workerGroup)
.channel(NioSocketChannel.class)
.option(ChannelOption.TCP_NODELAY, true)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast("frameDecoder", new LengthFieldBasedFrameDecoder(1048576, 0, 4, 0, 4));
p.addLast("packetDecoder", new PacketDecoder());
p.addLast("frameEncoder", new LengthFieldPrepender(4));
p.addLast("packetEncoder", new PacketEncoder());
p.addLast("server", new ClientChannelAdapter(messageHandlers));
}
});
ChannelFuture f = b.connect(member.getAddress()).sync()
.addListener(future -> {
if (!future.isSuccess()) {
LOGGER.error("Failed to connect server {}", member.getAddress());
}
}).sync();
return new NettyChannel(f.channel());
}
@Override
public void close() {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
try {
server.close();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Override
public void initialize() {
try {
this.server = new TCPServerHandler(bossGroup, workerGroup, requestExecutor, services, localMember.getAddress());
} catch (InterruptedException e) {
throw new IllegalStateException("Failed to bind TCP " + localMember.getAddress());
}
server.setAutoRead(true);
}
// @Override
// public void pause() {
// server.setAutoRead(false);
// }
//
// @Override
// public void resume() {
// server.setAutoRead(true);
// }
public class NettyChannel implements MemberChannel {
private final Channel channel;
public NettyChannel(Channel channel) {
this.channel = channel;
}
public CompletableFuture ask(Packet message) {
CompletableFuture future = new CompletableFuture<>();
messageHandlers.put(message.sequence, future);
channel.writeAndFlush(message);
return future;
}
@Override
public void send(Packet message) {
channel.writeAndFlush(message);
}
@Override
public void close() throws InterruptedException {
channel.close().sync();
}
}
}