/** * Copyright 2016-2017 Sixt GmbH & Co. Autovermietung KG * Licensed under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. You may obtain a * copy of the License at http://www.apache.org/licenses/LICENSE-2.0 * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.sixt.service.framework.rpc; import com.google.inject.Inject; import com.sixt.service.framework.FeatureFlags; import com.sixt.service.framework.ServiceProperties; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Marker; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.locks.ReentrantReadWriteLock; import static net.logstash.logback.marker.Markers.append; public class LoadBalancer { private static final Logger logger = LoggerFactory.getLogger(LoadBalancer.class); protected ServiceProperties serviceProps; protected HttpClientWrapper httpClientWrapper; protected String serviceName; //we don't expect more than 3, so hashmap doesn't necessarily make sense protected List<AvailabilityZone> availabilityZones = new ArrayList<>(); protected ReentrantReadWriteLock mutex = new ReentrantReadWriteLock(); protected Semaphore notificationSemaphore = new Semaphore(0); protected AtomicBoolean haveEndpoints = new AtomicBoolean(false); @Inject public LoadBalancer(ServiceProperties serviceProps, HttpClientWrapper wrapper) { this.serviceProps = serviceProps; this.httpClientWrapper = wrapper; httpClientWrapper.setLoadBalancer(this); } public HttpClientWrapper getHttpClientWrapper() { return httpClientWrapper; } public void setServiceName(String serviceName) { this.serviceName = serviceName; } public void updateServiceEndpoints(LoadBalancerUpdate updates) { mutex.writeLock().lock(); try { Marker logMarker = append("serviceName", this.serviceName); for (ServiceEndpoint ep : updates.getNewServices()) { logger.debug(logMarker, "Endpoint for {} became available: {}", this.serviceName, ep.getHostAndPort()); addServiceEndpoint(ep); } for (ServiceEndpoint ep : updates.getDeletedServices()) { logger.debug(logMarker, "Endpoint for {} became unavailable: {}", this.serviceName, ep.getHostAndPort()); updateEndpointHealth(ep, CircuitBreakerState.State.UNHEALTHY); } for (ServiceEndpoint ep : updates.getUpdatedServices()) { logger.debug(logMarker, "Health of endpoint {} of {} changed to {}", ep.getHostAndPort(), this.serviceName, ep.getCircuitBreakerState()); updateEndpointHealth(ep, ep.getCircuitBreakerState()); } } finally { mutex.writeLock().unlock(); } } /** * We rely on the object that populates us to order availability zones * with primary first, then going nearest to furthest. (implying priority) * Only to be used from this class or tests */ protected void addServiceEndpoint(ServiceEndpoint endpoint) { boolean found = false; for (AvailabilityZone az : availabilityZones) { if (az.getName().equals(endpoint.getAvailZone())) { az.addServiceEndpoint(endpoint); found = true; } } if (! found) { AvailabilityZone az = new AvailabilityZone(); az.addServiceEndpoint(endpoint); availabilityZones.add(az); } haveEndpoints.set(true); notificationSemaphore.release(); } protected void updateEndpointHealth(ServiceEndpoint ep, CircuitBreakerState.State state) { for (AvailabilityZone az : availabilityZones) { if (az.getName().equals(ep.getAvailZone())) { az.updateEndpointHealth(ep, state); return; } } logger.error("updateEndpointHealth: availZone {} not found", ep.getAvailZone()); } protected int getAvailabilityZoneCount() { return availabilityZones.size(); } /** * Try to find an endpoint in our primary AZ. If none found, try further AZs. * Modifies state */ public ServiceEndpoint getHealthyInstance() { if (! haveEndpoints.get()) { //wait for the first one to come in try { notificationSemaphore.tryAcquire(1, TimeUnit.SECONDS); } catch (InterruptedException e) { } } mutex.readLock().lock(); try { for (AvailabilityZone az : availabilityZones) { ServiceEndpoint next = az.nextEndpoint(); if (next != null) { logger.debug("Returning instance {} for {}", next.getHostAndPort(), serviceName); return next; } } return null; } finally { mutex.readLock().unlock(); } } //modifies state public ServiceEndpoint getHealthyInstanceExclude(List<ServiceEndpoint> triedEndpoints) { mutex.readLock().lock(); try { Set<ServiceEndpoint> set = new HashSet<>(triedEndpoints); Set<ServiceEndpoint> seenInstances = new HashSet<>(); while (true) { ServiceEndpoint retval = getHealthyInstance(); if (FeatureFlags.shouldDisableRpcInstanceRetry(serviceProps)) { if (seenInstances.contains(retval)) { //we've made a complete loop return null; } if (set.contains(retval)) { seenInstances.add(retval); continue; } } return retval; } } finally { mutex.readLock().unlock(); } } public String getServiceName() { return serviceName; } public void waitForServiceInstance() { while (true) { if (getHealthyInstance() != null) { break; } try { notificationSemaphore.tryAcquire(1, 100, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { logger.warn("Thread was interrupted", e); break; } } logger.debug("Found service instance of {}", serviceName); } //intended only for debugging public List<AvailabilityZone> getAvailabilityZones() { return availabilityZones; } } class AvailabilityZone { private String name = ""; private ServiceEndpointList serviceEndpoints = new ServiceEndpointList(); public String getName() { return name; } public void addServiceEndpoint(ServiceEndpoint endpoint) { if (serviceEndpoints.isEmpty()) { this.name = endpoint.getAvailZone(); } serviceEndpoints.add(endpoint); } //modifies state public ServiceEndpoint nextEndpoint() { return serviceEndpoints.nextAvailable(); } public void updateEndpointHealth(ServiceEndpoint ep, CircuitBreakerState.State state) { serviceEndpoints.updateEndpointHealth(ep, state); } //intended only for debugging public ServiceEndpointList getServiceEndpoints() { return serviceEndpoints; } }