package org.rakam.kume.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.hash.Funnel;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import com.google.common.hash.PrimitiveSink;
import org.rakam.kume.Member;
import org.rakam.kume.transport.serialization.SinkSerializable;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class ConsistentHashRing {
private final int bucketPerNode;
private final Bucket[] buckets;
private static final HashFunction hashFunction = Hashing.murmur3_128();
private final int replicationFactor;
private static final Funnel keyFunnel = new Funnel<SinkSerializable>() {
@Override
public void funnel(SinkSerializable from, PrimitiveSink into) {
from.writeTo(into);
}
};
public ConsistentHashRing(Collection<Member> members, int bucketPerNode, int replicationFactor) {
this.bucketPerNode = bucketPerNode;
this.replicationFactor = replicationFactor;
// todo: find a way to construct buckets without adding elements one by one.
Bucket[] list = null;
for (Member member : members.stream().sorted(this::compareMembers).collect(Collectors.toList())) {
list = findBucketListForNewNode(member, list);
}
buckets = list;
}
protected ConsistentHashRing(Bucket[] buckets, int bucketPerNode, int replicationFactor) {
this.bucketPerNode = bucketPerNode;
this.buckets = buckets;
this.replicationFactor = replicationFactor;
}
public String toString() {
StringBuilder str = new StringBuilder();
getBuckets().forEach((range, members) -> {
str.append("[" + range.start + "-" + range.end + ", " + members.size() + " members]");
});
return str.toString();
}
private int compareMembers(Member o1, Member o2) {
String o1Str = o1.getAddress().getHostString() + o1.getAddress().getPort();
String o2Str = o2.getAddress().getHostString() + o2.getAddress().getPort();
return o1Str.compareTo(o2Str);
}
public static boolean isTokenBetween(long hash, long start, long end) {
if (start <= end) return hash >= start && hash <= end;
// we're in the start point of ring
else return hash > start || hash < end;
}
public static Hasher newHasher() {
return hashFunction.newHasher();
}
public Map<TokenRange, List<Member>> getBuckets() {
return getBuckets(buckets);
}
private Map<TokenRange, List<Member>> getBuckets(Bucket[] buckets) {
return IntStream.range(0, buckets.length)
.mapToObj(i -> {
Bucket bucket = buckets[i];
TokenRange tokenRange = new TokenRange(i, bucket.token, getBucketFromRing(buckets, i + 1).token);
return new Tuple<>(tokenRange, bucket.members);
}).collect(Collectors.toMap(Tuple::_1, Tuple::_2));
}
public Bucket getBucket(int i) {
int length = buckets.length;
if (i >= length)
i = i % length;
return buckets[i >= 0 ? i : i + length];
}
public int getBucketCount() {
return buckets.length;
}
public int getMemberCount() {
return buckets.length / bucketPerNode;
}
public TokenRange getBucketRange(int i) {
return new TokenRange(i, buckets[i].token, getBucketFromRing(buckets, i + 1).token);
}
public int findBucketIdFromToken(long l) {
int low = 0;
int high = buckets.length - 1;
while (low <= high) {
int mid = (low + high) >>> 1;
if (buckets[mid].token < l)
low = mid + 1;
else if (buckets[mid].token > l)
high = mid - 1;
else
return mid;
}
return high;
}
public Bucket findBucketFromToken(long l) {
return buckets[findBucketIdFromToken(l)];
}
public static long hash(String hash) {
return hashFunction.hashString(hash, Charset.forName("UTF-8")).asLong();
}
public static long hash(Object hash) {
// TODO: man, it's ugly. we should find a nice way to fix this problem. maybe strategy pattern?
if(hash instanceof String)
return hash((String) hash);
if(hash instanceof Long)
return hash((long) hash);
if(hash instanceof Integer)
return hash((int) hash);
if(hash instanceof SinkSerializable) {
return hashFunction.hashObject(hash, keyFunnel).asLong();
}
throw new IllegalArgumentException("map key should be one of [String, Long, Integer or com.google.common.hash.Funnel]");
}
public static long hash(byte[] hash) {
return hashFunction.hashBytes(hash).asLong();
}
public static long hash(long hash) {
return hashFunction.hashLong(hash).asLong();
}
public static long hash(int hash) {
return hashFunction.hashInt(hash).asLong();
}
private Bucket getBucketFromRing(Bucket[] buckets, int i) {
return buckets[(i % buckets.length) + (i < 0 ? buckets.length : 0)];
}
public double getTotalRingRange(Member member) {
double total = 0;
for (int i = 0; i < buckets.length; i++) {
Bucket current = buckets[i];
if (current.members.contains(member)) {
if (i == buckets.length - 1) {
total += (((Long.MAX_VALUE - current.token) + (buckets[0].token - Long.MIN_VALUE)) / 2) / (Long.MAX_VALUE / 100.0);
} else {
total += Math.abs((buckets[i + 1].token - current.token) / 2) / (Long.MAX_VALUE / 100.0);
}
}
}
return total;
}
private Bucket[] findBucketListForNewNode(Member member, Bucket[] buckets) {
if (buckets == null) {
long token = Long.MAX_VALUE / (bucketPerNode / 2);
return IntStream.range(0, bucketPerNode).mapToObj(i -> {
long t = Long.MIN_VALUE + (token * i);
return new Bucket(Sets.newHashSet(member), t);
}).toArray(Bucket[]::new);
}
Bucket[] newBucketList = Arrays.copyOf(buckets, buckets.length + bucketPerNode);
// find the members who owns less data than other to use them as replica of new buckets
Map<Member, Long> result = new HashMap<>();
getBuckets(buckets)
.entrySet().stream()
.filter(x -> !x.getValue().contains(member))
.forEach(x -> x.getValue().forEach(z -> result.merge(z, x.getKey().gap() / 2, Long::sum)));
Set<Map.Entry<Member, Long>> memberSet = result.entrySet();
// find the larger gaps to and divide them in order to create new buckets
TokenRange[] tokens = getBuckets(buckets).entrySet().stream()
.sorted((o1, o2) -> {
int compare = Long.compare(o2.getKey().gap(), o1.getKey().gap());
if (compare == 0) {
// we compare the nodes that own this bucket and choose the bucket
// that has members that owns minimum range on the ring.
long sum1 = o1.getValue().stream().mapToLong(x -> result.get(x)).sum();
long sum2 = o2.getValue().stream().mapToLong(x -> result.get(x)).sum();
compare = Long.compare(sum2, sum1);
if (compare == 0) {
// it's pointless but if such condition occurs we need
// a way that all nodes agree.
return Long.compare(o1.getKey().start, o2.getKey().start);
}
}
return compare;
}).limit(8).map(x -> x.getKey()).toArray(TokenRange[]::new);
for (int idx = 0; idx < tokens.length; idx++) {
TokenRange current = tokens[idx];
long gap = current.gap() / 2;
HashSet<Member> members = new HashSet<>();
members.add(member);
IntStream.range(0, replicationFactor - 1).forEach(i -> {
Map.Entry<Member, Long> m = memberSet.stream()
.sorted((x, y) -> Long.compare(x.getValue(), y.getValue())).findFirst().get();
members.add(m.getKey());
m.setValue(m.getValue() + gap);
});
Bucket element = new Bucket(members, current.start + gap);
newBucketList[idx + buckets.length] = element;
}
Arrays.sort(newBucketList, (o1, o2) -> Long.compare(o1.token, o2.token));
if (buckets.length / bucketPerNode < replicationFactor) {
for (int i = 0; i < newBucketList.length; i++) {
Bucket oldBucket = newBucketList[i];
if (oldBucket.members.size() < replicationFactor) {
HashSet members = new HashSet(oldBucket.members);
members.add(member);
newBucketList[i] = new Bucket(members, newBucketList[i].token);
}
}
}
return newBucketList;
}
public ConsistentHashRing addNode(Member member) {
if (getMembers().contains(member)) {
return new ConsistentHashRing(buckets, bucketPerNode, replicationFactor);
}
Bucket[] buckets = findBucketListForNewNode(member, this.buckets);
return new ConsistentHashRing(buckets, bucketPerNode, replicationFactor);
}
public ConsistentHashRing removeNode(Member member) {
if (!getMembers().contains(member)) {
return new ConsistentHashRing(buckets, bucketPerNode, replicationFactor);
}
if (getMemberCount() == 1)
throw new IllegalStateException("ring must contain at least one member");
List<Bucket> result = Lists.newArrayList(this.buckets);
int newMemberSize = (buckets.length / bucketPerNode) - 1;
// remove smallest buckets which is replicated by member
getBuckets().entrySet().stream()
.filter(x -> x.getValue().contains(member))
.sorted((x, y) -> Long.compare(x.getKey().gap(), y.getKey().gap()))
.limit(bucketPerNode)
.map(x -> x.getKey().id)
.sorted((x, y) -> Integer.compare(y, x)) // we need reverse order, otherwise the indexes change
.forEach(i -> {
// cause IntStream doesn't have a method sorted(Comparator)
// and you know, it sucks.
result.remove((int) i);
});
// replace this member to another in other buckets which is replicated by this member
Stream<Bucket> resultArr = result.stream()
.map(bucket -> {
if (!bucket.members.contains(member))
return bucket;
HashSet arrayList = new HashSet(bucket.members);
arrayList.remove(member);
if (newMemberSize >= replicationFactor) {
Map<Member, Long> memberTokenRange = new HashMap<>();
getBuckets().forEach((val, members) ->
members.forEach(m -> memberTokenRange.merge(m, val.gap(), Long::sum)));
Optional<Map.Entry<Member, Long>> first = memberTokenRange.entrySet().stream()
.filter(x -> !x.getKey().equals(member))
.sorted((o1, o2) -> Long.compare(o1.getValue(), o2.getValue()))
.findFirst();
first.ifPresent(entry -> arrayList.add(entry.getKey()));
}
return new Bucket(arrayList, bucket.token);
});
Bucket[] buckets1 = resultArr.toArray(Bucket[]::new);
return new ConsistentHashRing(buckets1, bucketPerNode, replicationFactor);
}
public int findBucketId(Object key) {
return findBucketIdFromToken(hash(key));
}
public Bucket findBucket(Object key) {
return buckets[findBucketIdFromToken(hash(key))];
}
public Set<Member> getMembers() {
HashSet<Member> members = new HashSet<>();
for (Bucket bucket : buckets) {
members.addAll(bucket.members);
}
return members;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof ConsistentHashRing)) return false;
ConsistentHashRing that = (ConsistentHashRing) o;
if (bucketPerNode != that.bucketPerNode) return false;
if (replicationFactor != that.replicationFactor) return false;
if (!Arrays.equals(buckets, that.buckets)) return false;
return true;
}
@Override
public int hashCode() {
int result = bucketPerNode;
result = 31 * result + Arrays.hashCode(buckets);
result = 31 * result + replicationFactor;
return result;
}
public static class Bucket {
public long token;
public ArrayList<Member> members;
public Bucket(Set<Member> members, long token) {
this.members = new ArrayList<>(members);
this.token = token;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Bucket)) return false;
Bucket bucket = (Bucket) o;
if (token != bucket.token) return false;
if (!members.equals(bucket.members)) return false;
return true;
}
@Override
public int hashCode() {
int result = (int) (token ^ (token >>> 32));
result = 31 * result + members.hashCode();
return result;
}
}
public static class TokenRange {
public final long start;
public final long end;
public final int id;
private TokenRange(int id, long start, long end) {
this.id = id;
this.start = start;
this.end = end;
}
@Override
public String toString() {
return "TokenRange{" +
"start=" + start +
", end=" + end +
", bucketId=" + id +
'}';
}
public long gap() {
return Math.abs(end - start);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof TokenRange)) return false;
TokenRange that = (TokenRange) o;
if (end != that.end) return false;
if (start != that.start) return false;
return true;
}
@Override
public int hashCode() {
int result = (int) (start ^ (start >>> 32));
result = 31 * result + (int) (end ^ (end >>> 32));
return result;
}
}
}