package de.twenty11.unitprofile.agent; import java.lang.instrument.ClassFileTransformer; import java.lang.instrument.IllegalClassFormatException; import java.security.ProtectionDomain; import java.util.ArrayList; import java.util.List; import javassist.CannotCompileException; import javassist.ClassPool; import javassist.CtClass; import javassist.CtMethod; import javassist.NotFoundException; import javassist.bytecode.CodeAttribute; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import de.twenty11.unitprofile.callback.ProfilerCallback; import de.twenty11.unitprofile.domain.MethodDescriptor; import de.twenty11.unitprofile.domain.Transformation; import de.twenty11.unitprofiler.annotations.Profile; /** * finds (for profiling) annotated methods and uses them as root for instrumentation. * */ public class ProfilingClassFileTransformer implements ClassFileTransformer { private static final String PROFILE_ANNOTATION = "@" + Profile.class.getName(); private static final Logger logger = LoggerFactory.getLogger(ProfilingClassFileTransformer.class); /** * list of all classes which have been transformed */ private List<Transformation> transformations = new ArrayList<Transformation>(); /** * list of all methods (identified by objectName/methodName) which have been instrumented. * * http://docs.oracle.com/javase/6/docs/api/java/lang/instrument/Instrumentation.html: * "Instrumentation is the addition of byte-codes to methods for the purpose of gathering data to be utilized by tools" */ private List<MethodDescriptor> instrumentations = new ArrayList<MethodDescriptor>(); private CtClass profilerCallbackCtClass; private java.lang.instrument.Instrumentation instrumentation; public ProfilingClassFileTransformer(java.lang.instrument.Instrumentation inst) { this.instrumentation = inst; } @Override public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException { if (shouldNotBeProfiled(className)) { return classfileBuffer; } Transformation transformation = trackTransformations(className, classfileBuffer); byte[] byteCode = classfileBuffer; ClassPool classPool = ClassPool.getDefault(); classPool.importPackage("de.twenty11.unitprofile.callback"); try { CtClass ctClass = classPool.get(className.replace("/", ".")); if (profilerCallbackCtClass == null) { profilerCallbackCtClass = classPool.get(ProfilerCallback.class.getName()); } List<CtMethod> annotatedMethodsToProfile = findMethodsToProfile(ctClass); if (annotatedMethodsToProfile.size() > 0) { logInfoAboutAnnotatedMethodsFound(annotatedMethodsToProfile); } for (CtMethod m : annotatedMethodsToProfile) { startProfiling(ctClass, profilerCallbackCtClass, m); } for (CtMethod m : ctClass.getMethods()) { profile(m, ctClass); } byteCode = ctClass.toBytecode(); transformation.update(byteCode.length); // logger.debug("transformation updated '{}'", transformation); ctClass.detach(); } catch (NotFoundException nfe) { logger.warn("{}", nfe.getMessage()); } catch (Exception ex) { logger.error(ex.getMessage(), ex); } return byteCode; } private boolean shouldNotBeProfiled(String className) { if (className.startsWith("java/") || className.startsWith("javax/") || className.startsWith("sun/")) { return true; } if (className.startsWith("org/junit") || className.startsWith("junit/framework")) { return true; } if (className.startsWith("de/twenty11/unitprofile/agent")) { return true; } if (className.startsWith("de/twenty11/unitprofile/domain")) { return true; } if (className.startsWith("de/twenty11/unitprofile/callback")) { return true; } if (className.startsWith("org/apache/maven/surefire")) { return true; } if (className.startsWith("org/apache/commons")) { return true; } if (className.startsWith("org/springframework")) { return true; } if (className.startsWith("org/jacoco/agent")) { return true; } return false; } private Transformation trackTransformations(String className, byte[] classfileBuffer) { Transformation transformation = new Transformation(className, classfileBuffer.length); if (transformations.contains(transformation)) { logger.warn("re-transforming '{}'", transformation); } else { transformations.add(transformation); } return transformation; } private void logInfoAboutAnnotatedMethodsFound(List<CtMethod> annotatedMethodsToProfile) { logger.info("found " + annotatedMethodsToProfile.size() + " method(s) annotated for profiling: "); for (CtMethod ctMethod : annotatedMethodsToProfile) { logger.info(" * {}", ctMethod.getDeclaringClass().getName() + "#" + ctMethod.getName()); } logger.info(""); } private final void startProfiling(CtClass classWithProfilingAnnotatedMethod, CtClass profilerClass, final CtMethod m) throws Exception { if (!instrument(m)) { return; } int lineNumber = m.getMethodInfo().getLineNumber(0); String code = "{ProfilerCallback.start(\"" + m.getDeclaringClass().getName() + "\", \"" + m.getName() + "\", " + lineNumber + ");}"; logger.info("insertBefore: '{}'", code); m.insertBefore(code); m.insertAfter("{ProfilerCallback.stop(\"" + m.getDeclaringClass().getName() + "\", \"" + m.getName() + "\");}"); m.instrument(new ProfilingExprEditor(this, classWithProfilingAnnotatedMethod)); classWithProfilingAnnotatedMethod.instrument(new ProfilingExprEditor(this, classWithProfilingAnnotatedMethod)); } protected final void profile(final CtMethod m, CtClass cc) throws CannotCompileException { if (!instrument(m)) { return; } MethodDescriptor md = new MethodDescriptor(m.getDeclaringClass().getName(), m.getName(), m.getMethodInfo() .getLineNumber(0)); String insertBeforeCode = md.getInsertBefore(); // logger.debug(insertBeforeCode); m.insertBefore(insertBeforeCode); m.insertAfter(md.getInsertAfter()); // m.instrument(new ProfilingExprEditor(this, cc)); } public boolean isAlreadyInstrumented(MethodDescriptor instrumentation) { return instrumentations.contains(instrumentation); } public void addInstrumentation(MethodDescriptor instrumentation) { instrumentations.add(instrumentation); } public List<MethodDescriptor> getInstrumentations() { return instrumentations; } public java.lang.instrument.Instrumentation getInstrumentation() { return this.instrumentation; } public Transformation getTransformation(String classname) { for (Transformation transformation : transformations) { if (transformation.getClassName().equals(classname)) { return transformation; } } return null; } private List<CtMethod> findMethodsToProfile(CtClass cc) { List<CtMethod> methodsToProfile = new ArrayList<CtMethod>(); CtMethod[] declaredMethods = cc.getDeclaredMethods(); for (int i = 0; i < declaredMethods.length; i++) { // System.out.println(declaredMethods[i].toString()); Object[] annotations; try { annotations = declaredMethods[i].getAnnotations(); if (annotations == null) { continue; } for (int j = 0; j < annotations.length; j++) { if (annotations[j].toString().equals(PROFILE_ANNOTATION)) { methodsToProfile.add(declaredMethods[i]); } } } catch (ClassNotFoundException e) { e.printStackTrace(); } } return methodsToProfile; } private boolean instrument(CtMethod method) { if (method.getDeclaringClass().isFrozen()) { logger.warn("'{}' is 'frozen'", method.getDeclaringClass().getName()); return false; } String objectName = method.getDeclaringClass().getName(); MethodDescriptor instrumentation = new MethodDescriptor(objectName, method.getName(), method.getMethodInfo() .getLineNumber(0)); if (instrumentations.contains(instrumentation)) { return false; } // logger.debug("added to instrumentations: " + objectName + "#" + method.getName() + "(line " // + method.getMethodInfo().getLineNumber(0) + ")"); instrumentations.add(instrumentation); CodeAttribute ca = method.getMethodInfo().getCodeAttribute(); if (ca == null) { return false; } return true; } }