package com.linkedin.parseq.batching; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import com.linkedin.parseq.internal.ArgumentUtil; import com.linkedin.parseq.promise.PromiseException; import com.linkedin.parseq.promise.PromiseListener; import com.linkedin.parseq.promise.PromiseResolvedException; import com.linkedin.parseq.promise.PromiseUnresolvedException; import com.linkedin.parseq.promise.Promises; import com.linkedin.parseq.promise.SettablePromise; import com.linkedin.parseq.trace.ShallowTraceBuilder; public class BatchImpl<K, T> implements Batch<K, T> { private final Map<K, BatchEntry<T>> _map; private final int _batchSize; private BatchImpl(Map<K, BatchEntry<T>> map, int batchSize) { _map = map; _batchSize = batchSize; } @Override public void done(K key, T value) throws PromiseResolvedException { _map.get(key).getPromise().done(value); } @Override public void fail(K key, Throwable error) throws PromiseResolvedException { _map.get(key).getPromise().fail(error); } @Override public int failAll(Throwable error) { int alreadyResolved = 0; for (Entry<K, BatchEntry<T>> entry: _map.entrySet()) { try { entry.getValue().getPromise().fail(error); } catch (PromiseResolvedException e) { alreadyResolved++; } } return alreadyResolved; } @Override public Set<K> keys() { return _map.keySet(); } @Override public void foreach(final BiConsumer<K, SettablePromise<T>> consumer) { _map.forEach((key, entry) -> consumer.accept(key, entry.getPromise())); } @Override public String toString() { return "BatchImpl [entries=" + _map + "]"; } /** * Internal Promise delegate that decouples setting value on internal Promise from * publishing result on external promise. Used in batching implementation to make sure * that (external) Promise is resolved after all bacth-internal promises (including * duplicates ) are resolved. */ public static class BatchPromise<T> implements SettablePromise<T> { private final SettablePromise<T> _internal = Promises.settable(); private final SettablePromise<T> _external = Promises.settable(); @Override public T get() throws PromiseException { return _internal.get(); } @Override public Throwable getError() throws PromiseUnresolvedException { return _internal.getError(); } @Override public T getOrDefault(T defaultValue) throws PromiseUnresolvedException { return _internal.getOrDefault(defaultValue); } @Override public void await() throws InterruptedException { _internal.await(); } @Override public boolean await(long time, TimeUnit unit) throws InterruptedException { return _internal.await(time, unit); } @Override public void addListener(PromiseListener<T> listener) { _external.addListener(listener); } @Override public boolean isDone() { return _internal.isDone(); } @Override public boolean isFailed() { return _internal.isFailed(); } @Override public void done(T value) throws PromiseResolvedException { _internal.done(value); } @Override public void fail(Throwable error) throws PromiseResolvedException { _internal.fail(error); } public void trigger() { Promises.propagateResult(_internal, _external); } public SettablePromise<T> getInternal() { return _internal; } } public static class BatchEntry<T> { private final BatchPromise<T> _promise; private final List<ShallowTraceBuilder> _shallowTraceBuilders = new ArrayList<>(); private final long _creationTimeNano = System.nanoTime(); public BatchEntry(ShallowTraceBuilder shallowTraceBuilder, BatchPromise<T> promise) { _promise = promise; _shallowTraceBuilders.add(shallowTraceBuilder); } public BatchPromise<T> getPromise() { return _promise; } List<ShallowTraceBuilder> getShallowTraceBuilders() { return _shallowTraceBuilders; } void addShallowTraceBuilder(final ShallowTraceBuilder shallowTraceBuilder) { _shallowTraceBuilders.add(shallowTraceBuilder); } void addShallowTraceBuilders(final List<ShallowTraceBuilder> shallowTraceBuilders) { _shallowTraceBuilders.addAll(shallowTraceBuilders); } } static class BatchBuilder<K, T> { private final Map<K, BatchEntry<T>> _map = new HashMap<>(); private Batch<K, T> _batch = null; private final int _maxSize; private final BatchAggregationTimeMetric _batchAggregationTimeMetric; private int _batchSize = 0; public BatchBuilder(int maxSize, BatchAggregationTimeMetric batchAggregationTimeMetric) { ArgumentUtil.requirePositive(maxSize, "max batch size"); _maxSize = maxSize; _batchAggregationTimeMetric = batchAggregationTimeMetric; } private static final boolean safeToAddWithoutOverflow(int left, int right) { if (right > 0 ? left > Integer.MAX_VALUE - right : left < Integer.MIN_VALUE - right) { return false; } return true; } /** * Adds a batch entry, returns true if adding was successful. Returns false if adding * was not successful. Adding will be successful if builder is currently empty or * the batch size after adding the entry not exceed max batch size. * Caller must check result of this operation. */ boolean add(K key, BatchEntry<T> entry, int size) { if (_batch != null) { throw new IllegalStateException("BatchBuilder has already been used to build a batch"); } if (_batchSize == 0 || (safeToAddWithoutOverflow(_batchSize, size) && _batchSize + size <= _maxSize)) { //de-duplication BatchEntry<T> duplicate = _map.get(key); if (duplicate != null) { Promises.propagateResult(duplicate.getPromise().getInternal(), entry.getPromise()); duplicate.getPromise().addListener(p -> entry.getPromise().trigger()); duplicate.addShallowTraceBuilders(entry.getShallowTraceBuilders()); } else { _map.put(key, entry); } //this will not overflow _batchSize += size; return true; } else { return false; } } /** * Adds a batch entry, returns true if adding was successful. Returns false if adding * was not successful. Adding will be successful if builder is currently empty or * the batch size after adding the entry not exceed max batch size. * Caller must check result of this operation. */ boolean add(K key, ShallowTraceBuilder traceBuilder, BatchPromise<T> promise, int size) { return add(key, new BatchEntry<>(traceBuilder, promise), size); } public boolean isFull() { return _batchSize >= _maxSize; } public Batch<K, T> build() { if (_batch == null) { final long _currentTimeNano = System.nanoTime(); _map.values().forEach(entry -> { final long time = _currentTimeNano - entry._creationTimeNano; _batchAggregationTimeMetric.record(time > 0 ? time : 0); }); _batch = new BatchImpl<>(_map, _batchSize); } return _batch; } public int size() { return _map.size(); } public int batchSize() { return _batchSize; } } @Override public Collection<BatchEntry<T>> values() { return _map.values(); } @Override public Set<Entry<K, BatchEntry<T>>> entries() { return _map.entrySet(); } @Override public int keySize() { return _map.size(); } @Override public int batchSize() { return _batchSize; } }