/*
* (C) Copyright 2015-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributors:
* ohun@live.cn (夜色)
*/
package com.mpush.client.gateway.connection;
import com.google.common.collect.Maps;
import com.google.common.eventbus.Subscribe;
import com.google.common.net.HostAndPort;
import com.mpush.api.connection.Connection;
import com.mpush.api.event.ConnectionConnectEvent;
import com.mpush.api.service.Listener;
import com.mpush.api.spi.common.ServiceDiscoveryFactory;
import com.mpush.api.srd.ServiceDiscovery;
import com.mpush.api.srd.ServiceNode;
import com.mpush.client.gateway.GatewayClient;
import com.mpush.common.message.BaseMessage;
import com.mpush.tools.event.EventBus;
import io.netty.channel.ChannelFuture;
import io.netty.util.AttributeKey;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import static com.mpush.api.srd.ServiceNames.GATEWAY_SERVER;
import static com.mpush.tools.config.CC.mp.net.gateway_client_num;
/**
* Created by yxx on 2016/5/17.
*
* @author ohun@live.cn
*/
public class GatewayTCPConnectionFactory extends GatewayConnectionFactory {
private final AttributeKey<String> attrKey = AttributeKey.valueOf("host_port");
private final Map<String, List<Connection>> connections = Maps.newConcurrentMap();
private GatewayClient client;
@Override
protected void doStart(Listener listener) throws Throwable {
EventBus.I.register(this);
client = new GatewayClient();
client.start().join();
ServiceDiscovery discovery = ServiceDiscoveryFactory.create();
discovery.subscribe(GATEWAY_SERVER, this);
discovery.lookup(GATEWAY_SERVER).forEach(this::addConnection);
listener.onSuccess();
}
@Override
public void onServiceAdded(String path, ServiceNode node) {
asyncAddConnection(node);
}
@Override
public void onServiceUpdated(String path, ServiceNode node) {
removeClient(node);
asyncAddConnection(node);
}
@Override
public void onServiceRemoved(String path, ServiceNode node) {
removeClient(node);
logger.warn("Gateway Server zkNode={} was removed.", node);
}
@Override
public void doStop(Listener listener) throws Throwable {
connections.values().forEach(l -> l.forEach(Connection::close));
if (client != null) {
client.stop().join();
}
ServiceDiscoveryFactory.create().unsubscribe(GATEWAY_SERVER, this);
}
@Override
public Connection getConnection(String hostAndPort) {
List<Connection> connections = this.connections.get(hostAndPort);
if (connections == null || connections.isEmpty()) {
return null;//TODO create client
}
int L = connections.size();
Connection connection;
if (L == 1) {
connection = connections.get(0);
} else {
connection = connections.get((int) (Math.random() * L % L));
}
if (connection.isConnected()) {
return connection;
}
reconnect(connection, hostAndPort);
return getConnection(hostAndPort);
}
@Override
public <M extends BaseMessage> boolean send(String hostAndPort, Function<Connection, M> creator, Consumer<M> sender) {
Connection connection = getConnection(hostAndPort);
if (connection == null) return false;// gateway server 找不到,直接返回推送失败
sender.accept(creator.apply(connection));
return true;
}
@Override
public <M extends BaseMessage> boolean broadcast(Function<Connection, M> creator, Consumer<M> sender) {
if (connections.isEmpty()) return false;
connections
.values()
.stream()
.filter(connections -> connections.size() > 0)
.forEach(connections -> sender.accept(creator.apply(connections.get(0))));
return true;
}
private void reconnect(Connection connection, String hostAndPort) {
HostAndPort h_p = HostAndPort.fromString(hostAndPort);
connections.get(hostAndPort).remove(connection);
connection.close();
addConnection(h_p.getHost(), h_p.getPort(), false);
}
private void removeClient(ServiceNode node) {
if (node != null) {
List<Connection> clients = connections.remove(getHostAndPort(node.getHost(), node.getPort()));
if (clients != null) {
clients.forEach(Connection::close);
}
}
}
private void asyncAddConnection(ServiceNode node) {
for (int i = 0; i < gateway_client_num; i++) {
addConnection(node.getHost(), node.getPort(), false);
}
}
private void addConnection(ServiceNode node) {
for (int i = 0; i < gateway_client_num; i++) {
addConnection(node.getHost(), node.getPort(), true);
}
}
private void addConnection(String host, int port, boolean sync) {
ChannelFuture future = client.connect(host, port);
future.channel().attr(attrKey).set(getHostAndPort(host, port));
future.addListener(f -> {
if (!f.isSuccess()) {
logger.error("create gateway connection ex, host={}, port={}", host, port, f.cause());
}
});
if (sync) future.awaitUninterruptibly();
}
@Subscribe
void on(ConnectionConnectEvent event) {
Connection connection = event.connection;
String hostAndPort = connection.getChannel().attr(attrKey).getAndSet(null);
if (hostAndPort == null) {
InetSocketAddress address = (InetSocketAddress) connection.getChannel().remoteAddress();
hostAndPort = getHostAndPort(address.getAddress().getHostAddress(), address.getPort());
}
connections.computeIfAbsent(hostAndPort, key -> new ArrayList<>(gateway_client_num)).add(connection);
logger.info("one gateway client connect success, hostAndPort={}, conn={}", hostAndPort, connection);
}
private static String getHostAndPort(String host, int port) {
return host + ":" + port;
}
}