package org.rakam.kume; import org.rakam.kume.service.ServiceListBuilder; import org.rakam.kume.util.Tuple; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Iterators; import org.rakam.kume.service.Service; import org.rakam.kume.service.ServiceConstructor; import org.rakam.kume.transport.Packet; import org.rakam.kume.transport.LocalOperationContext; import org.rakam.kume.transport.Request; import org.rakam.kume.util.ThrowableNioEventLoopGroup; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import java.util.stream.IntStream; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; public class Cluster { final static Logger LOGGER = LoggerFactory.getLogger(Cluster.class); // Thread pool for handling requests and messages final ThrowableNioEventLoopGroup requestExecutor = new ThrowableNioEventLoopGroup("request-executor", (t1, e) -> LOGGER.error("error while executing request", e)); // Event loop for running cluster events. final protected ThrowableNioEventLoopGroup eventLoop = new ThrowableNioEventLoopGroup("event-executor", (t1, e) -> LOGGER.error("error while executing operation", e)); final protected List<Service> services; final private AtomicInteger messageSequence = new AtomicInteger(); final protected ServiceContext<InternalService> internalBus; final ConcurrentHashMap<Member, MemberChannel> clusterConnection = new ConcurrentHashMap<>(); final private Member localMember; final protected Map<String, Service> serviceNameMap; private final JoinerService joinerService; private final Transport transport; private Member master; private Set<Member> members; long lastContactedTimeMaster; private AtomicInteger currentTerm; final private List<MembershipListener> membershipListeners = Collections.synchronizedList(new ArrayList<>()); final private Map<Member, Long> heartbeatMap = new ConcurrentHashMap<>(); final private long clusterStartTime; private ScheduledFuture<?> heartbeatTask; private ConcurrentMap<InetSocketAddress, Integer> pendingUserVotes = CacheBuilder.newBuilder().expireAfterWrite(100, TimeUnit.SECONDS).<InetSocketAddress, Integer>build().asMap(); private MemberState memberState; private Map<Long, Request> pendingConsensusMessages = new ConcurrentHashMap<>(); private AtomicLong lastCommitIndex = new AtomicLong(); public Cluster(Collection<Member> members, ImmutableList<ServiceListBuilder.Constructor> services, TransportConstructor transportConstructor, InetSocketAddress serverAddress, JoinerService joinerService, boolean mustJoinCluster, boolean client) { clusterStartTime = System.currentTimeMillis(); this.members = new HashSet<>(members); Runtime.getRuntime().addShutdownHook(new Thread() { public void run() { // try { // server.waitForClose(); // } catch (InterruptedException e) { eventLoop.shutdownGracefully(); requestExecutor.shutdownGracefully(); transport.close(); // } } }); this.services = new ArrayList<>(services.size() + 16); InternalService internalService = new InternalService(new ServiceContext<>(this, 0, "internal"), this); this.services.add(internalService); localMember = new Member(serverAddress, client); this.transport = transportConstructor.newInstance(requestExecutor, this.services, localMember); master = localMember; if (mustJoinCluster) { joinCluster(); } internalBus = internalService.getContext(); IntStream.range(0, services.size()) .mapToObj(idx -> { ServiceListBuilder.Constructor c = services.get(idx); ServiceContext bus = new ServiceContext(this, idx + 1, c.name); return c.constructor.newInstance(bus); }).forEach(this.services::add); serviceNameMap = IntStream.range(0, services.size()) .mapToObj(idx -> new Tuple<>(services.get(idx).name, this.services.get(idx+1))) .collect(Collectors.toConcurrentMap(x -> x._1, x -> x._2)); scheduleClusteringTask(); if(!client) { transport.initialize(); } this.joinerService = joinerService; if (joinerService != null) { joinerService.onStart(new ClusterMembership() { @Override public void addMember(Member member) { addMemberInternal(member); } @Override public void removeMember(Member member) { throw new UnsupportedOperationException("not implemented"); } }); } members.stream().forEach(this::getConnection); } private void joinCluster() { CompletableFuture<Boolean> latch = new CompletableFuture<>(); AtomicInteger count = new AtomicInteger(); eventLoop.scheduleAtFixedRate(() -> { if (getMembers().size() > 0) { latch.complete(true); // this is a trick that stops this task. the exception will be swallowed. throw new RuntimeException("found cluster"); } if (count.incrementAndGet() >= 20) latch.complete(false); }, 0, 1, TimeUnit.SECONDS); memberState = memberState.FOLLOWER; if (!latch.join()) { throw new IllegalStateException("Could not found a cluster. You may disable mustJoinCluster.set(false) for creating new cluster."); } } private void addMemberInternal(Member member) { if (!members.contains(member) && !member.equals(localMember)) { LOGGER.info("Discovered new member {}", member); // we may create the connection before executing this method. if (!clusterConnection.containsKey(member)) { MemberChannel channel; try { channel = transport.connect(member); } catch (InterruptedException e) { LOGGER.error("Couldn't connect new server", e); return; } clusterConnection.put(member, channel); } members.add(member); if (isMaster()) heartbeatMap.put(member, System.currentTimeMillis()); if (!member.isClient()) membershipListeners.forEach(x -> eventLoop.execute(() -> x.memberAdded(member))); } } private synchronized void addMembersInternal(Set<Member> newMembers) { if (!members.containsAll(newMembers)) { LOGGER.info("Discovered another cluster of {} members", members.size()); for (Member member : newMembers) { if (member.equals(localMember)) continue; // we may create the connection before executing this method. if (!clusterConnection.containsKey(member)) { MemberChannel channel; try { channel = transport.connect(member); } catch (InterruptedException e) { LOGGER.error("Couldn't connect new server", e); return; } clusterConnection.put(member, channel); } members.add(member); if (isMaster()) heartbeatMap.put(member, System.currentTimeMillis()); } membershipListeners.forEach(x -> eventLoop.execute(() -> x.clusterMerged(newMembers))); } } public Transport getTransport() { return transport; } private void scheduleClusteringTask() { heartbeatTask = eventLoop.scheduleAtFixedRate(() -> { long time = System.currentTimeMillis(); if (!localMember.isClient() && isMaster()) { // heartbeatMap.forEach((member, lastResponse) -> { // if (time - lastResponse > 20000) { // removeMemberAsMaster(member, true); // } // }); members.forEach(member -> internalBus.send(member, new HeartbeatRequest(localMember))); } else { if (time - lastContactedTimeMaster > 500) { eventLoop.schedule(() -> { if (time - lastContactedTimeMaster > 500) { memberState = MemberState.CANDIDATE; voteElection(); } }, 150 + new Random().nextInt(150), TimeUnit.MILLISECONDS); } else { Member localMember = getLocalMember(); internalBus.send(master, (masterCluster, ctx) -> masterCluster.cluster.heartbeatMap.put(localMember, System.currentTimeMillis())); } } }, 200, 200, TimeUnit.MILLISECONDS); } public long startTime() { return clusterStartTime; } public void voteElection() { Collection<Member> clusterMembers = getMembers(); Map<Member, Boolean> map = new ConcurrentHashMap<>(); CompletableFuture<Boolean> future = new CompletableFuture<>(); int cursor = currentTerm.incrementAndGet(); Map<Member, CompletableFuture<Boolean>> m = internalBus.askAllMembers((service, ctx) -> { ctx.reply(service.cluster.currentTerm.incrementAndGet() == cursor - 1); }); m.forEach((member, resultFuture) -> resultFuture.thenAccept(result -> { map.put(member, result); Map<Boolean, Long> stream = map.entrySet().stream() .collect(Collectors.groupingBy(Map.Entry::getValue, Collectors.counting())); if (stream.getOrDefault(true, 0l) > clusterMembers.size() / 2) { future.complete(true); } else if (stream.getOrDefault(false, 0l) > clusterMembers.size() / 2) { future.complete(false); } })); if (future.join()) { memberState = MemberState.MASTER; Member localMember = this.localMember; internalBus.sendAllMembers((service, ctx) -> service.cluster.changeMaster(localMember)); } else { memberState = MemberState.FOLLOWER; } } public MemberState memberState() { return memberState; } private synchronized void changeMaster(Member masterMember) { master = masterMember; memberState = masterMember.equals(localMember) ? MemberState.MASTER : MemberState.FOLLOWER; } public synchronized void removeMemberAsMaster(Member member, boolean replicate) { if (!isMaster()) throw new IllegalStateException(); heartbeatMap.remove(member); members.remove(member); // if(replicate) { // internalBus.sendAllMembers((cluster, ctx) -> { // cluster.clusterConnection.remove(member); // Cluster.LOGGER.info("Member removed {}", member); // cluster.membershipListeners.forEach(l -> Throwables.propagate(() -> l.memberRemoved(member))); // }, true); // } } public void addMembershipListener(MembershipListener listener) { membershipListeners.add(listener); } public Set<Member> getMembers() { return ImmutableSet.copyOf(Iterables.concat(members, () -> Iterators.forArray(localMember))); } public <T extends Service> T getService(String serviceName) { checkNotNull(serviceName, "null is not allowed for service name"); return (T) serviceNameMap.get(serviceName); } public <T extends Service> T getService(String serviceName, Class<T> clazz) { return getService(serviceName); } public <T extends Service> T createOrGetService(String name, ServiceConstructor<T> ser) { checkNotNull(ser, "null is not allowed for service constructor"); Service existingService = serviceNameMap.get(name); if (existingService != null) return (T) existingService; int maxSize = Short.MAX_VALUE * 2; checkState(services.size() < maxSize, "Maximum number of allowed services is %s", maxSize); String finalName = name == null ? UUID.randomUUID().toString() : name; Boolean result = internalBus.replicateSafely(new AddServiceRequest(finalName, name, ser)).join(); if (!result) throw new IllegalArgumentException("there is already another service with same name"); Service service = serviceNameMap.get(finalName); if (service == null) throw new IllegalStateException("service couldn't created"); return (T) service; } /** * It uses Raft log replication protocol for consensus. * Most of the requests that Kume is planned to execute don't need consensus, * so unlike Raft implementation which waits the quorum it waits all nodes to execute the request. * Because there's no way to find out consistency issues without consensus methods like this one. * In Raft algorithm, since each log replication request uses consensus algorithm, it's easy to recover * from inconsistent states. * Since consensus is expensive compared to fire-and-forget fashion, use this method when you really need. * * @param request */ protected CompletableFuture<Boolean> replicateSafelyInternal(Request<?, Boolean> request, int serviceId) { AppendLogEntryRequest requestFromMaster = new AppendLogEntryRequest(request, serviceId); return askInternal(getMaster(), requestFromMaster, 0); } protected Map<Long, Request> pendingConsensusMessages() { return pendingConsensusMessages; } public <T extends Service> T createService(ServiceConstructor<T> ser) { return createOrGetService(null, ser); } public boolean destroyService(String serviceName) { checkNotNull(serviceName, "null is not allowed for service name"); Service service = serviceNameMap.remove(serviceName); if (service == null) return false; service.onClose(); int serviceId = services.indexOf(service); // we do not shift the array because if the indexes change, we have to ensure consensus among nodes. services.set(serviceId, null); return true; } public Member getLocalMember() { return localMember; } private void send(Member server, Object bytes, int service) { sendInternal(server, bytes, service); } public void sendAllMembersInternal(Object bytes, boolean includeThisMember, int service) { System.out.println(this + "internal -> " + members.size()); members.stream().filter(member -> !member.equals(localMember) && !member.isClient()) .forEach(member -> sendInternal(clusterConnection.get(member), bytes, service)); if (includeThisMember) { if (localMember.isClient()) { throw new IllegalArgumentException(); } Service s = services.get(service); LocalOperationContext ctx = new LocalOperationContext(null, service, localMember); s.handle(requestExecutor, ctx, bytes); } } public <R> Map<Member, CompletableFuture<R>> askAllMembersInternal(Object bytes, boolean includeThisMember, int service) { Map<Member, CompletableFuture<R>> map = new ConcurrentHashMap<>(); clusterConnection.forEach((member, conn) -> { if (!member.equals(localMember)) { map.put(member, askInternal(conn, bytes, service)); } }); if (includeThisMember) { CompletableFuture<R> f = new CompletableFuture<>(); Service s = services.get(service); LocalOperationContext ctx = new LocalOperationContext(f, service, localMember); s.handle(requestExecutor, ctx, bytes); map.put(localMember, f); } return map; } public void close() throws InterruptedException { for (MemberChannel entry : clusterConnection.values()) { entry.close(); } transport.close(); heartbeatTask.cancel(true); services.forEach(s -> s.onClose()); joinerService.onStart(new ClusterMembership() { @Override public void addMember(Member member) { addMemberInternal(member); } @Override public void removeMember(Member member) { throw new UnsupportedOperationException("not implemented"); } }); } public void sendInternal(MemberChannel channel, Object obj, int service) { Packet message = new Packet(obj, service); channel.ask(message); } public void sendInternal(Member member, Object obj, int service) { if (member.equals(localMember)) { LocalOperationContext ctx1 = new LocalOperationContext(null, service, localMember); services.get(service).handle(requestExecutor, ctx1, obj); } else { Packet message = new Packet(obj, service); getConnection(member).send(message); } } public void sendInternal(Member member, Request request, int service) { if (member.equals(localMember)) { LocalOperationContext ctx1 = new LocalOperationContext(null, service, localMember); services.get(service).handle(requestExecutor, ctx1, request); } else { Packet message = new Packet(request, service); getConnection(member).ask(message); } } public <R> void tryAskUntilDoneInternal(Member member, Request req, int numberOfTimes, int service, CompletableFuture future) { CompletableFuture<R> ask = askInternal(member, req, service); ask.whenComplete((val, ex) -> { if (ex != null) if (ex instanceof TimeoutException) { if (numberOfTimes == 0) { future.completeExceptionally(new TimeoutException()); } else { tryAskUntilDoneInternal(member, req, numberOfTimes, service, future); } } else { future.completeExceptionally(ex); } else future.complete(val); }); } public <R> CompletableFuture<R> askInternal(MemberChannel channel, Object obj, int service) { int andIncrement = messageSequence.getAndIncrement(); Packet message = new Packet(andIncrement, obj, service); return channel.ask(message); } public <R> CompletableFuture<R> askInternal(Member member, Object obj, int service) { if (member.equals(localMember)) { CompletableFuture<R> future = new CompletableFuture<>(); LocalOperationContext ctx1 = new LocalOperationContext(future, service, localMember); services.get(service).handle(requestExecutor, ctx1, obj); return future; } else { return askInternal(getConnection(member), obj, service); } } public <R> CompletableFuture<R> askInternal(Member member, Request request, int service) { if (member.equals(localMember)) { CompletableFuture<R> future = new CompletableFuture<>(); LocalOperationContext ctx1 = new LocalOperationContext(future, service, localMember); services.get(service).handle(requestExecutor, ctx1, request); return future; } else { return askInternal(getConnection(member), request, service); } } private MemberChannel getConnection(Member member) { MemberChannel channel = clusterConnection.get(member); if (channel == null) { if (!members.contains(member)) throw new IllegalArgumentException("the member doesn't exist in the cluster"); MemberChannel created; try { created = transport.connect(member); } catch (InterruptedException e) { throw Throwables.propagate(e); } synchronized (this) { clusterConnection.put(member, created); } return created; } return channel; } public boolean isMaster() { return localMember.equals(master); } public Member getMaster() { return master; } public List<Service> getServices() { return Collections.unmodifiableList(services); } public AtomicLong getLastCommitIndex() { return lastCommitIndex; } // protected synchronized void changeCluster(Set<Member> newClusterMembers, Member masterMember, boolean isNew) { // try { // pause(); // clusterConnection.clear(); // master = masterMember; // members = newClusterMembers; // messageHandlers.cleanUp(); // LOGGER.info("Joined a cluster of {} nodes.", members.size()); // if (!isNew) // membershipListeners.forEach(x -> eventLoop.execute(() -> x.clusterChanged())); // } finally { // resume(); // } // } }