package org.approvaltests.namer; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; import com.spun.util.ObjectUtils; import com.spun.util.io.StackElementSelector; public class AttributeStackSelector implements StackElementSelector { private List<Class<? extends Annotation>> attributes; public AttributeStackSelector() { attributes = getAvailableAttributes(); } private List<Class<? extends Annotation>> getAvailableAttributes() { String classNames[] = {"org.testng.annotations.Test", "org.junit.Test"}; ArrayList<Class<? extends Annotation>> attributes = new ArrayList<Class<? extends Annotation>>(); for (String className : classNames) { Class<? extends Annotation> clazz = loadClass(className); if (clazz != null) { attributes.add(clazz); } } return attributes; } private Class<? extends Annotation> loadClass(String className) { Class<? extends Annotation> clazz = null; try { clazz = (Class<? extends Annotation>) ObjectUtils.loadClass(className); } catch (ClassNotFoundException e) { clazz = null; } return clazz; } @Override public StackTraceElement selectElement(StackTraceElement[] trace) throws Exception { boolean inTestCase = false; for (int i = 0; i < trace.length; i++) { if (isTestCase(trace[i])) { inTestCase = true; } else if (inTestCase) { return trace[i - 1]; } } throw new RuntimeException("Could not find Junit/TestNg TestCase you are running"); } private boolean isTestCase(StackTraceElement element) throws ClassNotFoundException { String fullClassName = element.getClassName(); Class<?> clazz = loadClass(fullClassName); if (clazz == null) { return false; } if (isJunit3Test(clazz)) { return true; } return isTestAttribute(clazz, element.getMethodName()); } private boolean isJunit3Test(Class<?> clazz) { Class<?> testcase = loadClass("junit.framework.TestCase"); return testcase != null && ObjectUtils.isThisInstanceOfThat(clazz, testcase); } private boolean isTestAttribute(Class<?> clazz, String methodName) throws ClassNotFoundException, SecurityException { Method method = getMethodByName(clazz, methodName); if (method == null) { return false; } for (Class<? extends Annotation> attribute : attributes) { if (method.isAnnotationPresent(attribute)) { return true; } } return false; } public Method getMethodByName(Class<?> clazz, String methodName) { Method method = null; try { Method[] declaredMethods = clazz.getDeclaredMethods(); for (Method m : declaredMethods) { if (m.getName().equals(methodName)) { method = m; } } } catch (Throwable e) { } return method; } @Override public void increment() { //ignore } }