package com.yammer.telemetry.agent.handlers; import com.google.common.base.Optional; import com.yammer.telemetry.tracing.Span; import com.yammer.telemetry.tracing.SpanHelper; import java.math.BigInteger; import java.util.concurrent.*; @SuppressWarnings("UnusedDeclaration") public class InstrumentedThreadPoolExecutor extends ThreadPoolExecutor { private ThreadLocal<Span> local = new ThreadLocal<>(); public InstrumentedThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue); } public InstrumentedThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory); } public InstrumentedThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler); } public InstrumentedThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler); } @Override protected <T> RunnableFuture<T> newTaskFor(Runnable runnable, T value) { return taskFor(super.newTaskFor(runnable, value), runnable.getClass().getName() + ":" + value); } @Override protected <T> RunnableFuture<T> newTaskFor(Callable<T> callable) { return taskFor(super.newTaskFor(callable), callable.getClass().getName()); } @Override protected void beforeExecute(Thread t, Runnable r) { super.beforeExecute(t, r); if (r instanceof InstrumentedRunnableFuture) { InstrumentedRunnableFuture instrumentedRunnable = (InstrumentedRunnableFuture) r; BigInteger traceId = instrumentedRunnable.getTraceId(); BigInteger spanId = instrumentedRunnable.getSpanId(); if (traceId != null && spanId != null) { Span span = SpanHelper.attachSpan(traceId, spanId, instrumentedRunnable.getName()); local.set(span); span.addAnnotation("Before", instrumentedRunnable.getName()); } } } @Override protected void afterExecute(Runnable r, Throwable t) { try { super.afterExecute(r, t); if (r instanceof InstrumentedRunnableFuture) { Span span = local.get(); if (span != null) { span.addAnnotation("After", ((InstrumentedRunnableFuture) r).getName()); span.end(); } } } finally { local.remove(); } } private <T> RunnableFuture<T> taskFor(RunnableFuture<T> future, String name) { Optional<Span> currentSpan = SpanHelper.currentSpan(); if (currentSpan.isPresent()) { currentSpan.get().addAnnotation("Task", name); return new InstrumentedRunnableFuture<>(future, name, currentSpan.get().getTraceId(), currentSpan.get().getSpanId()); } return future; } }