package com.yammer.telemetry.test;
import com.yammer.telemetry.instrumentation.ClassInstrumentationHandler;
import com.yammer.telemetry.instrumentation.TelemetryTransformer;
import org.junit.Before;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.junit.Assert.fail;
public class TelemetryTestHelpers {
public static void runTransformed(Class<?> clazz, ClassInstrumentationHandler... handlers) throws Exception {
Method[] methods = clazz.getDeclaredMethods();
Map<Method, Throwable> testFailures = new HashMap<>();
int ran = 0;
for (Method method : methods) {
if (method.isAnnotationPresent(TransformedTest.class)) {
try {
ran++;
runTransformed(clazz, method.getName(), handlers);
} catch (Exception e) {
//noinspection ThrowableResultOfMethodCallIgnored
testFailures.put(method, unwrap(e));
}
}
}
if (!testFailures.isEmpty()) {
StringWriter builder = new StringWriter();
PrintWriter writer = new PrintWriter(builder);
writer.println("Transformed tests failed:");
for (Map.Entry<Method, Throwable> entry : testFailures.entrySet()) {
writer.printf("%s:%n%s%n%n", entry.getKey(), entry.getValue());
//noinspection ThrowableResultOfMethodCallIgnored
entry.getValue().printStackTrace(writer);
writer.println();
}
fail(builder.toString());
}
if (ran == 0) {
fail("No tests were found within '" + clazz.getName() + "' that were annotated as '" + TransformedTest.class.getName() + "'");
}
}
private static Throwable unwrap(Throwable e) {
if (e instanceof InvocationTargetException) {
return e.getCause();
}
return e;
}
public static void runTransformed(Class<?> clazz, String method, ClassInstrumentationHandler... handlers) throws Exception {
Set<String> befores = new HashSet<>();
for (Method beforeMethod : clazz.getDeclaredMethods()) {
if (beforeMethod.isAnnotationPresent(Before.class)) {
befores.add(beforeMethod.getName());
}
}
final TelemetryTransformer transformer = new TelemetryTransformer();
for (ClassInstrumentationHandler handler : handlers) {
transformer.addHandler(handler);
}
try (TransformingClassLoader loader = AccessController.doPrivileged(new PrivilegedAction<TransformingClassLoader>() {
@Override
public TransformingClassLoader run() {
return new TransformingClassLoader(transformer);
}
})) {
Class<?> aClass = loader.loadClass(clazz.getName());
Object instance = aClass.newInstance();
for (String beforeMethod : befores) {
Method before = aClass.getDeclaredMethod(beforeMethod);
before.invoke(instance);
}
Method declaredMethod = aClass.getDeclaredMethod(method);
declaredMethod.invoke(instance);
}
}
}