/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.utils.memory;
import java.nio.ByteBuffer;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import com.google.common.util.concurrent.Uninterruptibles;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.cassandra.utils.DynamicList;
import static org.junit.Assert.*;
public class LongBufferPoolTest
{
private static final Logger logger = LoggerFactory.getLogger(LongBufferPoolTest.class);
@Test
public void testAllocate() throws InterruptedException, ExecutionException
{
testAllocate(Runtime.getRuntime().availableProcessors() * 2, TimeUnit.MINUTES.toNanos(2L), 16 << 20);
}
private static final class BufferCheck
{
final ByteBuffer buffer;
final long val;
DynamicList.Node<BufferCheck> listnode;
private BufferCheck(ByteBuffer buffer, long val)
{
this.buffer = buffer;
this.val = val;
}
void validate()
{
ByteBuffer read = buffer.duplicate();
while (read.remaining() > 8)
assert read.getLong() == val;
}
void init()
{
ByteBuffer write = buffer.duplicate();
while (write.remaining() > 8)
write.putLong(val);
}
}
public void testAllocate(int threadCount, long duration, int poolSize) throws InterruptedException, ExecutionException
{
final int avgBufferSize = 16 << 10;
final int stdevBufferSize = 10 << 10; // picked to ensure exceeding buffer size is rare, but occurs
final DateFormat dateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss");
System.out.println(String.format("%s - testing %d threads for %dm",
dateFormat.format(new Date()),
threadCount,
TimeUnit.NANOSECONDS.toMinutes(duration)));
final long until = System.nanoTime() + duration;
final CountDownLatch latch = new CountDownLatch(threadCount);
final SPSCQueue<BufferCheck>[] sharedRecycle = new SPSCQueue[threadCount];
final AtomicBoolean[] makingProgress = new AtomicBoolean[threadCount];
for (int i = 0 ; i < sharedRecycle.length ; i++)
{
sharedRecycle[i] = new SPSCQueue<>();
makingProgress[i] = new AtomicBoolean(true);
}
ExecutorService executorService = Executors.newFixedThreadPool(threadCount + 2);
List<Future<Boolean>> ret = new ArrayList<>(threadCount);
long prevPoolSize = BufferPool.MEMORY_USAGE_THRESHOLD;
BufferPool.MEMORY_USAGE_THRESHOLD = poolSize;
BufferPool.DEBUG = true;
// sum(1..n) = n/2 * (n + 1); we set zero to CHUNK_SIZE, so have n=threadCount-1
int targetSizeQuanta = ((threadCount) * (threadCount - 1)) / 2;
// fix targetSizeQuanta at 1/64th our poolSize, so that we only consciously exceed our pool size limit
targetSizeQuanta = (targetSizeQuanta * poolSize) / 64;
{
// setup some high churn allocate/deallocate, without any checking
final SPSCQueue<ByteBuffer> burn = new SPSCQueue<>();
final CountDownLatch doneAdd = new CountDownLatch(1);
executorService.submit(new TestUntil(until)
{
int count = 0;
void testOne() throws Exception
{
if (count * BufferPool.CHUNK_SIZE >= poolSize / 10)
{
if (burn.exhausted)
count = 0;
else
Thread.yield();
return;
}
ByteBuffer buffer = BufferPool.tryGet(BufferPool.CHUNK_SIZE);
if (buffer == null)
{
Thread.yield();
return;
}
BufferPool.put(buffer);
burn.add(buffer);
count++;
}
void cleanup()
{
doneAdd.countDown();
}
});
executorService.submit(new TestUntil(until)
{
void testOne() throws Exception
{
ByteBuffer buffer = burn.poll();
if (buffer == null)
{
Thread.yield();
return;
}
BufferPool.put(buffer);
}
void cleanup()
{
Uninterruptibles.awaitUninterruptibly(doneAdd);
}
});
}
for (int t = 0; t < threadCount; t++)
{
final int threadIdx = t;
final int targetSize = t == 0 ? BufferPool.CHUNK_SIZE : targetSizeQuanta * t;
ret.add(executorService.submit(new TestUntil(until)
{
final SPSCQueue<BufferCheck> shareFrom = sharedRecycle[threadIdx];
final DynamicList<BufferCheck> checks = new DynamicList<>((int) Math.max(1, targetSize / (1 << 10)));
final SPSCQueue<BufferCheck> shareTo = sharedRecycle[(threadIdx + 1) % threadCount];
final ThreadLocalRandom rand = ThreadLocalRandom.current();
int totalSize = 0;
int freeingSize = 0;
int size = 0;
void checkpoint()
{
if (!makingProgress[threadIdx].get())
makingProgress[threadIdx].set(true);
}
void testOne() throws Exception
{
long currentTargetSize = rand.nextInt(poolSize / 1024) == 0 ? 0 : targetSize;
int spinCount = 0;
while (totalSize > currentTargetSize - freeingSize)
{
// free buffers until we're below our target size
if (checks.size() == 0)
{
// if we're out of buffers to free, we're waiting on our neighbour to free them;
// first check if the consuming neighbour has caught up, and if so mark that free
if (shareTo.exhausted)
{
totalSize -= freeingSize;
freeingSize = 0;
}
else if (!recycleFromNeighbour())
{
if (++spinCount > 1000 && System.nanoTime() > until)
return;
// otherwise, free one of our other neighbour's buffers if can; and otherwise yield
Thread.yield();
}
continue;
}
// pick a random buffer, with preference going to earlier ones
BufferCheck check = sample();
checks.remove(check.listnode);
check.validate();
size = BufferPool.roundUpNormal(check.buffer.capacity());
if (size > BufferPool.CHUNK_SIZE)
size = 0;
// either share to free, or free immediately
if (rand.nextBoolean())
{
shareTo.add(check);
freeingSize += size;
// interleave this with potentially messing with the other neighbour's stuff
recycleFromNeighbour();
}
else
{
check.validate();
BufferPool.put(check.buffer);
totalSize -= size;
}
}
// allocate a new buffer
size = (int) Math.max(1, avgBufferSize + (stdevBufferSize * rand.nextGaussian()));
if (size <= BufferPool.CHUNK_SIZE)
{
totalSize += BufferPool.roundUpNormal(size);
allocate(size);
}
else if (rand.nextBoolean())
{
allocate(size);
}
else
{
// perform a burst allocation to exhaust all available memory
while (totalSize < poolSize)
{
size = (int) Math.max(1, avgBufferSize + (stdevBufferSize * rand.nextGaussian()));
if (size <= BufferPool.CHUNK_SIZE)
{
allocate(size);
totalSize += BufferPool.roundUpNormal(size);
}
}
}
// validate a random buffer we have stashed
checks.get(rand.nextInt(checks.size())).validate();
// free all of our neighbour's remaining shared buffers
while (recycleFromNeighbour());
}
void cleanup()
{
while (checks.size() > 0)
{
BufferCheck check = checks.get(0);
BufferPool.put(check.buffer);
checks.remove(check.listnode);
}
latch.countDown();
}
boolean recycleFromNeighbour()
{
BufferCheck check = shareFrom.poll();
if (check == null)
return false;
check.validate();
BufferPool.put(check.buffer);
return true;
}
BufferCheck allocate(int size)
{
ByteBuffer buffer = BufferPool.get(size);
assertNotNull(buffer);
BufferCheck check = new BufferCheck(buffer, rand.nextLong());
assertEquals(size, buffer.capacity());
assertEquals(0, buffer.position());
check.init();
check.listnode = checks.append(check);
return check;
}
BufferCheck sample()
{
// sample with preference to first elements:
// element at index n will be selected with likelihood (size - n) / sum1ToN(size)
int size = checks.size();
// pick a random number between 1 and sum1toN(size)
int sampleRange = sum1toN(size);
int sampleIndex = rand.nextInt(sampleRange);
// then binary search for the N, such that [sum1ToN(N), sum1ToN(N+1)) contains this random number
int moveBy = Math.max(size / 4, 1);
int index = size / 2;
while (true)
{
int baseSampleIndex = sum1toN(index);
int endOfSampleIndex = sum1toN(index + 1);
if (sampleIndex >= baseSampleIndex)
{
if (sampleIndex < endOfSampleIndex)
break;
index += moveBy;
}
else index -= moveBy;
moveBy = Math.max(moveBy / 2, 1);
}
// this gives us the inverse of our desired value, so just subtract it from the last index
index = size - (index + 1);
return checks.get(index);
}
private int sum1toN(int n)
{
return (n * (n + 1)) / 2;
}
}));
}
boolean first = true;
while (!latch.await(10L, TimeUnit.SECONDS))
{
if (!first)
BufferPool.assertAllRecycled();
first = false;
for (AtomicBoolean progress : makingProgress)
{
assert progress.get();
progress.set(false);
}
}
for (SPSCQueue<BufferCheck> queue : sharedRecycle)
{
BufferCheck check;
while ( null != (check = queue.poll()) )
{
check.validate();
BufferPool.put(check.buffer);
}
}
assertEquals(0, executorService.shutdownNow().size());
BufferPool.MEMORY_USAGE_THRESHOLD = prevPoolSize;
for (Future<Boolean> r : ret)
assertTrue(r.get());
System.out.println(String.format("%s - finished.",
dateFormat.format(new Date())));
}
static abstract class TestUntil implements Callable<Boolean>
{
final long until;
protected TestUntil(long until)
{
this.until = until;
}
abstract void testOne() throws Exception;
void checkpoint() {}
void cleanup() {}
public Boolean call() throws Exception
{
try
{
while (System.nanoTime() < until)
{
checkpoint();
for (int i = 0 ; i < 100 ; i++)
testOne();
}
}
catch (Exception ex)
{
logger.error("Got exception {}, current chunk {}",
ex.getMessage(),
BufferPool.currentChunk());
ex.printStackTrace();
return false;
}
finally
{
cleanup();
}
return true;
}
}
public static void main(String[] args) throws InterruptedException, ExecutionException
{
new LongBufferPoolTest().testAllocate(Runtime.getRuntime().availableProcessors(), TimeUnit.HOURS.toNanos(2L), 16 << 20);
}
/**
* A single producer, single consumer queue.
*/
private static final class SPSCQueue<V>
{
static final class Node<V>
{
volatile Node<V> next;
final V value;
Node(V value)
{
this.value = value;
}
}
private volatile boolean exhausted = true;
Node<V> head = new Node<>(null);
Node<V> tail = head;
void add(V value)
{
exhausted = false;
tail = tail.next = new Node<>(value);
}
V poll()
{
Node<V> next = head.next;
if (next == null)
{
// this is racey, but good enough for our purposes
exhausted = true;
return null;
}
head = next;
return next.value;
}
}
}