package com.yammer.telemetry.test; import com.google.common.io.ByteStreams; import com.yammer.telemetry.instrumentation.TelemetryTransformer; import javassist.ClassPool; import java.io.IOException; import java.io.InputStream; import java.lang.instrument.IllegalClassFormatException; import java.net.URL; import java.net.URLClassLoader; import static com.google.common.base.Preconditions.checkNotNull; public class TransformingClassLoader extends URLClassLoader { private final TelemetryTransformer transformer; private final ClassPool classPool; public TransformingClassLoader(TelemetryTransformer transformer) { super(new URL[] {}); this.transformer = checkNotNull(transformer); classPool = new ClassPool(null); classPool.appendSystemPath(); } @Override public Class<?> loadClass(String name) throws ClassNotFoundException { Class<?> loadedClass = super.findLoadedClass(name); if (loadedClass != null) return loadedClass; try (InputStream classStream = super.getResourceAsStream(name.replace('.', '/') + ".class")) { byte[] classfileBuffer = ByteStreams.toByteArray(classStream); byte[] transformedBytes = transformer.transform(this, name, classfileBuffer, classPool); if (transformedBytes == null) { if (name.startsWith("java")) { return super.loadClass(name); } return super.defineClass(name, classfileBuffer, 0, classfileBuffer.length); } else { return super.defineClass(name, transformedBytes, 0, transformedBytes.length); } } catch (IOException | IllegalClassFormatException | RuntimeException e) { throw new ClassNotFoundException(name, e); } } }