/****************************************************************************** * * * Copyright 2017 Subterranean Security * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * * *****************************************************************************/ package com.subterranean_security.crimson.core.net; import java.net.ConnectException; import java.util.HashMap; import java.util.Map; import java.util.Observable; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import javax.net.ssl.SSLException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.subterranean_security.crimson.core.proto.MSG; import com.subterranean_security.crimson.core.proto.MSG.Message; import com.subterranean_security.crimson.universal.Universal; import io.netty.bootstrap.Bootstrap; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.nio.NioDatagramChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.protobuf.ProtobufDecoder; import io.netty.handler.codec.protobuf.ProtobufEncoder; import io.netty.handler.codec.protobuf.ProtobufVarint32FrameDecoder; import io.netty.handler.codec.protobuf.ProtobufVarint32LengthFieldPrepender; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; public class Connector extends Observable { private static final Logger log = LoggerFactory.getLogger(Connector.class); private int cvid; private Universal.Instance instance; private ConnectionType type; private ConnectionState state; private EventLoopGroup workerGroup; private BasicExecutor executor; private BasicHandler handler; /** * All incoming messages are dropped into this queue and wait for processing */ public BlockingQueue<Message> msgQueue; /** * When a response message is desired, a MessageFuture is placed into this * map. If the BasicExecutor cannot execute a message and a corresponding * entry in this map exists, the MessageFuture is removed and notified. */ private Map<Integer, MessageFuture> responseMap; public Connector(BasicExecutor executor, BasicHandler handler) { // initialize state state = ConnectionState.NOT_CONNECTED; workerGroup = new NioEventLoopGroup(); // initialize message buffers msgQueue = new LinkedBlockingQueue<Message>(); responseMap = new HashMap<Integer, MessageFuture>(); // initialize executor and handler this.executor = executor; this.handler = handler; executor.setConnector(this); handler.setConnector(this); executor.start(); } public Connector(BasicExecutor executor) { this(executor, new BasicHandler()); } public void connect(ConnectionType type, String host, int port) throws InterruptedException, ConnectException { if (getState() == ConnectionState.NOT_CONNECTED) { this.type = type; Bootstrap b = new Bootstrap(); switch (type) { case DATAGRAM: b.channel(NioDatagramChannel.class); break; case SOCKET: b.channel(NioSocketChannel.class); break; default: break; } b.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000); b.group(workerGroup).handler(new InitiatorInitializer(host, port)); b.connect(host, port).sync(); setState(ConnectionState.CONNECTED); } } public void addNewResponse(Message m) { if (responseMap.containsKey(m.getId())) { responseMap.remove(m.getId()).setMessage(m); } else { // dropping this message because no thread is waiting for it } } public MessageFuture getResponse(int id) { if (!responseMap.containsKey(id)) { responseMap.put(id, new MessageFuture()); } return responseMap.get(id); } public MessageFuture writeAndGetResponse(Message m) { write(m); return getResponse(m.getId()); } public void write(Message m) { handler.write(m); } public String getRemoteIP() { return handler.getRemoteIP(); } public int getRemotePort() { return handler.getRemotePort(); } public BasicHandler getHandler() { return handler; } public void close() { setState(ConnectionState.NOT_CONNECTED); workerGroup.shutdownGracefully(); executor.stop(); deleteObservers(); } public ConnectionState getState() { return state; } public void setState(ConnectionState state) { if (this.state != state) { log.debug("Connector state changed: {}->{}", this.state, state); this.state = state; setChanged(); notifyObservers(state); } } public ConnectionType getType() { return type; } public Universal.Instance getInstance() { return instance; } public void setInstance(Universal.Instance instance) { this.instance = instance; } public int getCvid() { return cvid; } public void setCvid(int cvid) { this.cvid = cvid; } public enum ConnectionType { SOCKET, DATAGRAM; } public enum ConnectionState { // TODO remove auth stages NOT_CONNECTED, CONNECTED, AUTHENTICATED, AUTH_STAGE1, AUTH_STAGE2; } class InitiatorInitializer extends ChannelInitializer<SocketChannel> { private SslContext sslCtx; private final String host; private final int port; public InitiatorInitializer(String host, int port) { try { this.sslCtx = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); } catch (SSLException e) { // TODO Auto-generated catch block e.printStackTrace(); } this.host = host; this.port = port; } @Override public void initChannel(SocketChannel ch) { ChannelPipeline p = ch.pipeline(); if (sslCtx != null) { p.addLast(sslCtx.newHandler(ch.alloc(), host, port)); } if (Universal.isNetDebug) { p.addLast(new LoggingHandler(Connector.class)); } p.addLast(new ProtobufVarint32FrameDecoder()); p.addLast(new ProtobufDecoder(MSG.Message.getDefaultInstance())); p.addLast(new ProtobufVarint32LengthFieldPrepender()); p.addLast(new ProtobufEncoder()); p.addLast(handler); } } }