/*
* (C) Copyright 2015-2016 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Contributors:
* ohun@live.cn (夜色)
*/
package com.mpush.client.connect;
import com.google.common.collect.Maps;
import com.mpush.api.Constants;
import com.mpush.api.connection.Connection;
import com.mpush.api.event.ConnectionCloseEvent;
import com.mpush.api.protocol.Command;
import com.mpush.api.protocol.Packet;
import com.mpush.api.spi.common.CacheManager;
import com.mpush.api.spi.common.CacheManagerFactory;
import com.mpush.common.CacheKeys;
import com.mpush.common.message.*;
import com.mpush.common.security.AesCipher;
import com.mpush.common.security.CipherBox;
import com.mpush.netty.connection.NettyConnection;
import com.mpush.tools.event.EventBus;
import com.mpush.tools.thread.NamedPoolThreadFactory;
import com.mpush.tools.thread.ThreadNames;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* Created by ohun on 2015/12/19.
*
* @author ohun@live.cn
*/
public final class ConnClientChannelHandler extends ChannelInboundHandlerAdapter {
private static final Logger LOGGER = LoggerFactory.getLogger(ConnClientChannelHandler.class);
private static final Timer HASHED_WHEEL_TIMER = new HashedWheelTimer(new NamedPoolThreadFactory(ThreadNames.T_CONN_TIMER));
public static final AttributeKey<ClientConfig> CONFIG_KEY = AttributeKey.newInstance("clientConfig");
public static final TestStatistics STATISTICS = new TestStatistics();
private static CacheManager cacheManager = CacheManagerFactory.create();
private final Connection connection = new NettyConnection();
private ClientConfig clientConfig;
private boolean perfTest;
private int hbTimeoutTimes;
public ConnClientChannelHandler() {
perfTest = true;
}
public ConnClientChannelHandler(ClientConfig clientConfig) {
this.clientConfig = clientConfig;
}
public Connection getConnection() {
return connection;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
connection.updateLastReadTime();
if (msg instanceof Packet) {
Packet packet = (Packet) msg;
Command command = Command.toCMD(packet.cmd);
if (command == Command.HANDSHAKE) {
int connectedNum = STATISTICS.connectedNum.incrementAndGet();
connection.getSessionContext().changeCipher(new AesCipher(clientConfig.getClientKey(), clientConfig.getIv()));
HandshakeOkMessage message = new HandshakeOkMessage(packet, connection);
message.decodeBody();
byte[] sessionKey = CipherBox.I.mixKey(clientConfig.getClientKey(), message.serverKey);
connection.getSessionContext().changeCipher(new AesCipher(sessionKey, clientConfig.getIv()));
connection.getSessionContext().setHeartbeat(message.heartbeat);
startHeartBeat(message.heartbeat - 1000);
LOGGER.info("handshake success, clientConfig={}, connectedNum={}", clientConfig, connectedNum);
bindUser(clientConfig);
if (!perfTest) {
saveToRedisForFastConnection(clientConfig, message.sessionId, message.expireTime, sessionKey);
}
} else if (command == Command.FAST_CONNECT) {
int connectedNum = STATISTICS.connectedNum.incrementAndGet();
String cipherStr = clientConfig.getCipher();
String[] cs = cipherStr.split(",");
byte[] key = AesCipher.toArray(cs[0]);
byte[] iv = AesCipher.toArray(cs[1]);
connection.getSessionContext().changeCipher(new AesCipher(key, iv));
FastConnectOkMessage message = new FastConnectOkMessage(packet, connection);
message.decodeBody();
connection.getSessionContext().setHeartbeat(message.heartbeat);
startHeartBeat(message.heartbeat - 1000);
bindUser(clientConfig);
LOGGER.info("fast connect success, clientConfig={}, connectedNum={}", clientConfig, connectedNum);
} else if (command == Command.KICK) {
KickUserMessage message = new KickUserMessage(packet, connection);
LOGGER.error("receive kick user msg userId={}, deviceId={}, message={},", clientConfig.getUserId(), clientConfig.getDeviceId(), message);
ctx.close();
} else if (command == Command.ERROR) {
ErrorMessage message = new ErrorMessage(packet, connection);
message.decodeBody();
LOGGER.error("receive an error packet=" + message);
} else if (command == Command.PUSH) {
int receivePushNum = STATISTICS.receivePushNum.incrementAndGet();
PushMessage message = new PushMessage(packet, connection);
message.decodeBody();
LOGGER.info("receive push message, content={}, receivePushNum={}"
, new String(message.content, Constants.UTF_8), receivePushNum);
if (message.needAck()) {
AckMessage.from(message).sendRaw();
LOGGER.info("send ack success for sessionId={}", message.getSessionId());
}
} else if (command == Command.HEARTBEAT) {
LOGGER.info("receive heartbeat pong...");
} else if (command == Command.OK) {
OkMessage message = new OkMessage(packet, connection);
message.decodeBody();
int bindUserNum = STATISTICS.bindUserNum.get();
if (message.cmd == Command.BIND.cmd) {
bindUserNum = STATISTICS.bindUserNum.incrementAndGet();
}
LOGGER.info("receive {}, bindUserNum={}", message, bindUserNum);
} else if (command == Command.HTTP_PROXY) {
HttpResponseMessage message = new HttpResponseMessage(packet, connection);
message.decodeBody();
LOGGER.info("receive http response, message={}, body={}",
message, message.body == null ? null : new String(message.body, Constants.UTF_8));
}
}
LOGGER.debug("receive package={}, chanel={}", msg, ctx.channel());
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
connection.close();
LOGGER.error("caught an ex, channel={}", ctx.channel(), cause);
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
int clientNum = STATISTICS.clientNum.incrementAndGet();
LOGGER.info("client connect channel={}, clientNum={}", ctx.channel(), clientNum);
for (int i = 0; i < 3; i++) {
if (clientConfig != null) break;
clientConfig = ctx.channel().attr(CONFIG_KEY).getAndSet(null);
if (clientConfig == null) TimeUnit.SECONDS.sleep(1);
}
if (clientConfig == null) {
throw new NullPointerException("client config is null, channel=" + ctx.channel());
}
connection.init(ctx.channel(), true);
if (perfTest) {
handshake();
} else {
tryFastConnect();
}
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
int clientNum = STATISTICS.clientNum.decrementAndGet();
connection.close();
EventBus.I.post(new ConnectionCloseEvent(connection));
LOGGER.info("client disconnect channel={}, clientNum={}", connection, clientNum);
}
private void tryFastConnect() {
Map<String, String> sessionTickets = getFastConnectionInfo(clientConfig.getDeviceId());
if (sessionTickets == null) {
handshake();
return;
}
String sessionId = sessionTickets.get("sessionId");
if (sessionId == null) {
handshake();
return;
}
String expireTime = sessionTickets.get("expireTime");
if (expireTime != null) {
long exp = Long.parseLong(expireTime);
if (exp < System.currentTimeMillis()) {
handshake();
return;
}
}
final String cipher = sessionTickets.get("cipherStr");
FastConnectMessage message = new FastConnectMessage(connection);
message.deviceId = clientConfig.getDeviceId();
message.sessionId = sessionId;
message.sendRaw(channelFuture -> {
if (channelFuture.isSuccess()) {
clientConfig.setCipher(cipher);
} else {
handshake();
}
});
LOGGER.debug("send fast connect message={}", message);
}
private void bindUser(ClientConfig client) {
BindUserMessage message = new BindUserMessage(connection);
message.userId = client.getUserId();
message.tags = "test";
message.send();
connection.getSessionContext().setUserId(client.getUserId());
LOGGER.debug("send bind user message={}", message);
}
private void saveToRedisForFastConnection(ClientConfig client, String sessionId, Long expireTime, byte[] sessionKey) {
Map<String, String> map = Maps.newHashMap();
map.put("sessionId", sessionId);
map.put("expireTime", expireTime + "");
map.put("cipherStr", connection.getSessionContext().cipher.toString());
String key = CacheKeys.getDeviceIdKey(client.getDeviceId());
cacheManager.set(key, map, 60 * 5); //5分钟
}
@SuppressWarnings("unchecked")
private Map<String, String> getFastConnectionInfo(String deviceId) {
String key = CacheKeys.getDeviceIdKey(deviceId);
return cacheManager.get(key, Map.class);
}
private void handshake() {
HandshakeMessage message = new HandshakeMessage(connection);
message.clientKey = clientConfig.getClientKey();
message.iv = clientConfig.getIv();
message.clientVersion = clientConfig.getClientVersion();
message.deviceId = clientConfig.getDeviceId();
message.osName = clientConfig.getOsName();
message.osVersion = clientConfig.getOsVersion();
message.timestamp = System.currentTimeMillis();
message.send();
LOGGER.debug("send handshake message={}", message);
}
private void startHeartBeat(final int heartbeat) throws Exception {
HASHED_WHEEL_TIMER.newTimeout(new TimerTask() {
@Override
public void run(Timeout timeout) throws Exception {
if (connection.isConnected() && healthCheck()) {
HASHED_WHEEL_TIMER.newTimeout(this, heartbeat, TimeUnit.MILLISECONDS);
}
}
}, heartbeat, TimeUnit.MILLISECONDS);
}
private boolean healthCheck() {
if (connection.isReadTimeout()) {
hbTimeoutTimes++;
LOGGER.warn("heartbeat timeout times={}, client={}", hbTimeoutTimes, connection);
} else {
hbTimeoutTimes = 0;
}
if (hbTimeoutTimes >= 2) {
LOGGER.warn("heartbeat timeout times={} over limit={}, client={}", hbTimeoutTimes, 2, connection);
hbTimeoutTimes = 0;
connection.close();
return false;
}
if (connection.isWriteTimeout()) {
LOGGER.info("send heartbeat ping...");
connection.send(Packet.HB_PACKET);
}
return true;
}
}