/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.twitter.distributedlog.client.routing; import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; import com.google.common.collect.MapDifference; import com.google.common.collect.Maps; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.twitter.common.zookeeper.ServerSet; import com.twitter.distributedlog.service.DLSocketAddress; import com.twitter.finagle.ChannelException; import com.twitter.finagle.NoBrokersAvailableException; import com.twitter.finagle.stats.Counter; import com.twitter.finagle.stats.Gauge; import com.twitter.finagle.stats.NullStatsReceiver; import com.twitter.finagle.stats.StatsReceiver; import com.twitter.util.Function0; import org.apache.commons.lang3.tuple.Pair; import org.jboss.netty.util.HashedWheelTimer; import org.jboss.netty.util.Timeout; import org.jboss.netty.util.TimerTask; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.collection.Seq; import java.net.SocketAddress; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; public class ConsistentHashRoutingService extends ServerSetRoutingService { private static final Logger logger = LoggerFactory.getLogger(ConsistentHashRoutingService.class); @Deprecated public static ConsistentHashRoutingService of(ServerSetWatcher serverSetWatcher, int numReplicas) { return new ConsistentHashRoutingService(serverSetWatcher, numReplicas, 300, NullStatsReceiver.get()); } public static Builder newBuilder() { return new Builder(); } public static class Builder implements RoutingService.Builder { private ServerSet _serverSet; private boolean _resolveFromName = false; private int _numReplicas; private int _blackoutSeconds = 300; private StatsReceiver _statsReceiver = NullStatsReceiver.get(); private Builder() {} public Builder serverSet(ServerSet serverSet) { this._serverSet = serverSet; return this; } public Builder resolveFromName(boolean enabled) { this._resolveFromName = enabled; return this; } public Builder numReplicas(int numReplicas) { this._numReplicas = numReplicas; return this; } public Builder blackoutSeconds(int seconds) { this._blackoutSeconds = seconds; return this; } public Builder statsReceiver(StatsReceiver statsReceiver) { this._statsReceiver = statsReceiver; return this; } @Override public RoutingService build() { Preconditions.checkNotNull(_serverSet, "No serverset provided."); Preconditions.checkNotNull(_statsReceiver, "No stats receiver provided."); Preconditions.checkArgument(_numReplicas > 0, "Invalid number of replicas : " + _numReplicas); return new ConsistentHashRoutingService(new TwitterServerSetWatcher(_serverSet, _resolveFromName), _numReplicas, _blackoutSeconds, _statsReceiver); } } static class ConsistentHash { private final HashFunction hashFunction; private final int numOfReplicas; private final SortedMap<Long, SocketAddress> circle; // Stats protected final Counter hostAddedCounter; protected final Counter hostRemovedCounter; ConsistentHash(HashFunction hashFunction, int numOfReplicas, StatsReceiver statsReceiver) { this.hashFunction = hashFunction; this.numOfReplicas = numOfReplicas; this.circle = new TreeMap<Long, SocketAddress>(); this.hostAddedCounter = statsReceiver.counter0("adds"); this.hostRemovedCounter = statsReceiver.counter0("removes"); } private String replicaName(int shardId, int replica, String address) { if (shardId < 0) { shardId = UNKNOWN_SHARD_ID; } StringBuilder sb = new StringBuilder(100); sb.append("shard-"); sb.append(shardId); sb.append('-'); sb.append(replica); sb.append('-'); sb.append(address); return sb.toString(); } private Long replicaHash(int shardId, int replica, String address) { return hashFunction.hashUnencodedChars(replicaName(shardId, replica, address)).asLong(); } private Long replicaHash(int shardId, int replica, SocketAddress address) { return replicaHash(shardId, replica, address.toString()); } public synchronized void add(int shardId, SocketAddress address) { String addressStr = address.toString(); for (int i = 0; i < numOfReplicas; i++) { Long hash = replicaHash(shardId, i, addressStr); circle.put(hash, address); } hostAddedCounter.incr(); } public synchronized void remove(int shardId, SocketAddress address) { for (int i = 0; i < numOfReplicas; i++) { long hash = replicaHash(shardId, i, address); SocketAddress oldAddress = circle.get(hash); if (null != oldAddress && oldAddress.equals(address)) { circle.remove(hash); } } hostRemovedCounter.incr(); } public SocketAddress get(String key, RoutingContext rContext) { long hash = hashFunction.hashUnencodedChars(key).asLong(); return find(hash, rContext); } private synchronized SocketAddress find(long hash, RoutingContext rContext) { if (circle.isEmpty()) { return null; } Iterator<Map.Entry<Long, SocketAddress>> iterator = circle.tailMap(hash).entrySet().iterator(); while (iterator.hasNext()) { Map.Entry<Long, SocketAddress> entry = iterator.next(); if (!rContext.isTriedHost(entry.getValue())) { return entry.getValue(); } } // the tail map has been checked iterator = circle.headMap(hash).entrySet().iterator(); while (iterator.hasNext()) { Map.Entry<Long, SocketAddress> entry = iterator.next(); if (!rContext.isTriedHost(entry.getValue())) { return entry.getValue(); } } return null; } private synchronized Pair<Long, SocketAddress> get(long hash) { if (circle.isEmpty()) { return null; } if (!circle.containsKey(hash)) { SortedMap<Long, SocketAddress> tailMap = circle.tailMap(hash); hash = tailMap.isEmpty() ? circle.firstKey() : tailMap.firstKey(); } return Pair.of(hash, circle.get(hash)); } synchronized void dumpHashRing() { for (Map.Entry<Long, SocketAddress> entry : circle.entrySet()) { System.out.println(entry.getKey() + " : " + entry.getValue()); } } } class BlackoutHost implements TimerTask { final int shardId; final SocketAddress address; BlackoutHost(int shardId, SocketAddress address) { this.shardId = shardId; this.address = address; numBlackoutHosts.incrementAndGet(); } @Override public void run(Timeout timeout) throws Exception { numBlackoutHosts.decrementAndGet(); if (!timeout.isExpired()) { return; } Set<SocketAddress> removedList = new HashSet<SocketAddress>(); boolean joined; // add the shard back synchronized (shardId2Address) { SocketAddress curHost = shardId2Address.get(shardId); if (null != curHost) { // there is already new shard joint, so drop the host. logger.info("Blackout Shard {} ({}) was already replaced by {} permanently.", new Object[] { shardId, address, curHost }); joined = false; } else { join(shardId, address, removedList); joined = true; } } if (joined) { for (RoutingListener listener : listeners) { listener.onServerJoin(address); } } else { for (RoutingListener listener : listeners) { listener.onServerLeft(address); } } } } protected final HashedWheelTimer hashedWheelTimer; protected final HashFunction hashFunction = Hashing.md5(); protected final ConsistentHash circle; protected final Map<Integer, SocketAddress> shardId2Address = new HashMap<Integer, SocketAddress>(); protected final Map<SocketAddress, Integer> address2ShardId = new HashMap<SocketAddress, Integer>(); // blackout period protected final int blackoutSeconds; // stats protected final StatsReceiver statsReceiver; protected final AtomicInteger numBlackoutHosts; protected final Gauge numBlackoutHostsGauge; protected final Gauge numHostsGauge; private static final int UNKNOWN_SHARD_ID = -1; ConsistentHashRoutingService(ServerSetWatcher serverSetWatcher, int numReplicas, int blackoutSeconds, StatsReceiver statsReceiver) { super(serverSetWatcher); this.circle = new ConsistentHash(hashFunction, numReplicas, statsReceiver.scope("ring")); this.hashedWheelTimer = new HashedWheelTimer(new ThreadFactoryBuilder() .setNameFormat("ConsistentHashRoutingService-Timer-%d").build()); this.blackoutSeconds = blackoutSeconds; // stats this.statsReceiver = statsReceiver; this.numBlackoutHosts = new AtomicInteger(0); this.numBlackoutHostsGauge = this.statsReceiver.addGauge(gaugeName("num_blackout_hosts"), new Function0<Object>() { @Override public Object apply() { return (float) numBlackoutHosts.get(); } }); this.numHostsGauge = this.statsReceiver.addGauge(gaugeName("num_hosts"), new Function0<Object>() { @Override public Object apply() { return (float) address2ShardId.size(); } }); } private static Seq<String> gaugeName(String name) { return scala.collection.JavaConversions.asScalaBuffer(Arrays.asList(name)).toList(); } @Override public void startService() { super.startService(); this.hashedWheelTimer.start(); } @Override public void stopService() { this.hashedWheelTimer.stop(); super.stopService(); } @Override public Set<SocketAddress> getHosts() { synchronized (shardId2Address) { return ImmutableSet.copyOf(address2ShardId.keySet()); } } @Override public SocketAddress getHost(String key, RoutingContext rContext) throws NoBrokersAvailableException { SocketAddress host = circle.get(key, rContext); if (null != host) { return host; } throw new NoBrokersAvailableException("No host found for " + key + ", routing context : " + rContext); } @Override public void removeHost(SocketAddress host, Throwable reason) { removeHostInternal(host, Optional.of(reason)); } private void removeHostInternal(SocketAddress host, Optional<Throwable> reason) { synchronized (shardId2Address) { Integer shardId = address2ShardId.remove(host); if (null != shardId) { SocketAddress curHost = shardId2Address.get(shardId); if (null != curHost && curHost.equals(host)) { shardId2Address.remove(shardId); } circle.remove(shardId, host); if (reason.isPresent()) { if (reason.get() instanceof ChannelException) { logger.info("Shard {} ({}) left due to ChannelException, black it out for {} seconds (message = {})", new Object[] { shardId, host, blackoutSeconds, reason.get().toString() }); BlackoutHost blackoutHost = new BlackoutHost(shardId, host); hashedWheelTimer.newTimeout(blackoutHost, blackoutSeconds, TimeUnit.SECONDS); } else { logger.info("Shard {} ({}) left due to exception {}", new Object[] { shardId, host, reason.get().toString() }); } } else { logger.info("Shard {} ({}) left after server set change", shardId, host); } } else if (reason.isPresent()) { logger.info("Node {} left due to exception {}", host, reason.get().toString()); } else { logger.info("Node {} left after server set change", host); } } } /** * The caller should synchronize on <i>shardId2Address</i>. * @param shardId * Shard id of new host joined. * @param newHost * New host joined. * @param removedList * Old hosts to remove */ private void join(int shardId, SocketAddress newHost, Set<SocketAddress> removedList) { SocketAddress oldHost = shardId2Address.put(shardId, newHost); if (null != oldHost) { // remove the old host only when a new shard is kicked in to replace it. address2ShardId.remove(oldHost); circle.remove(shardId, oldHost); removedList.add(oldHost); logger.info("Shard {} ({}) left permanently.", shardId, oldHost); } address2ShardId.put(newHost, shardId); circle.add(shardId, newHost); logger.info("Shard {} ({}) joined to replace ({}).", new Object[] { shardId, newHost, oldHost }); } @Override protected synchronized void performServerSetChange(ImmutableSet<DLSocketAddress> serviceInstances) { Set<SocketAddress> joinedList = new HashSet<SocketAddress>(); Set<SocketAddress> removedList = new HashSet<SocketAddress>(); Map<Integer, SocketAddress> newMap = new HashMap<Integer, SocketAddress>(); synchronized (shardId2Address) { for (DLSocketAddress serviceInstance : serviceInstances) { if (serviceInstance.getShard() >= 0) { newMap.put(serviceInstance.getShard(), serviceInstance.getSocketAddress()); } else { Integer shard = address2ShardId.get(serviceInstance.getSocketAddress()); if (null == shard) { // Assign a random negative shardId int shardId; do { shardId = Math.min(-1 , (int)(Math.random() * Integer.MIN_VALUE)); } while (null != shardId2Address.get(shardId)); shard = shardId; } newMap.put(shard, serviceInstance.getSocketAddress()); } } } Map<Integer, SocketAddress> left; synchronized (shardId2Address) { MapDifference<Integer, SocketAddress> difference = Maps.difference(shardId2Address, newMap); left = difference.entriesOnlyOnLeft(); for (Map.Entry<Integer, SocketAddress> shardEntry : left.entrySet()) { int shard = shardEntry.getKey(); if (shard >= 0) { SocketAddress host = shardId2Address.get(shard); if (null != host) { // we don't remove those hosts that just disappered on serverset proactively, // since it might be just because serverset become flaky // address2ShardId.remove(host); // circle.remove(shard, host); logger.info("Shard {} ({}) left temporarily.", shard, host); } } else { // shard id is negative - they are resolved from finagle name, which instances don't have shard id // in this case, if they are removed from serverset, we removed them directly SocketAddress host = shardEntry.getValue(); if (null != host) { removeHostInternal(host, Optional.<Throwable>absent()); removedList.add(host); } } } // we need to find if any shards are replacing old shards for (Map.Entry<Integer, SocketAddress> shard : newMap.entrySet()) { SocketAddress oldHost = shardId2Address.get(shard.getKey()); SocketAddress newHost = shard.getValue(); if (!newHost.equals(oldHost)) { join(shard.getKey(), newHost, removedList); joinedList.add(newHost); } } } for (SocketAddress addr : removedList) { for (RoutingListener listener : listeners) { listener.onServerLeft(addr); } } for (SocketAddress addr : joinedList) { for (RoutingListener listener : listeners) { listener.onServerJoin(addr); } } } }