/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tez.dag.app; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Set; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class PluginWrapperTestHelpers { private static final Logger LOG = LoggerFactory.getLogger(PluginWrapperTestHelpers.class); public static void testDelegation(Class<?> delegateClass, Class<?> rawClass, Set<String> skipMethods) throws Exception { TrackingAnswer answer = new TrackingAnswer(); Object mock = mock(rawClass, answer); Constructor ctor = delegateClass.getConstructor(rawClass); Object wrapper = ctor.newInstance(mock); // Run through all the methods on the wrapper, and invoke the methods. Constructs // arguments and return types for each of them. Method[] methods = delegateClass.getMethods(); for (Method method : methods) { if (method.getDeclaringClass().equals(delegateClass) && !skipMethods.contains(method.getName())) { assertTrue(method.getExceptionTypes().length == 1); assertEquals(Exception.class, method.getExceptionTypes()[0]); LOG.info("Checking method [{}] with parameterTypes [{}]", method.getName(), Arrays.toString(method.getParameterTypes())); Object[] params = constructMethodArgs(method); Object result = method.invoke(wrapper, params); // Validate the correct arguments are forwarded, and the real instance is invoked. assertEquals(method.getName(), answer.lastMethodName); assertArrayEquals(params, answer.lastArgs); // Validate the results. // Handle auto-boxing if (answer.compareAsPrimitive) { assertEquals(answer.lastRetValue, result); } else { assertTrue("Expected: " + System.identityHashCode(answer.lastRetValue) + ", actual=" + System.identityHashCode(result), answer.lastRetValue == result); } } } } public static Object[] constructMethodArgs(Method method) throws IllegalAccessException, InstantiationException { Class<?>[] paramTypes = method.getParameterTypes(); Object[] params = new Object[paramTypes.length]; for (int i = 0; i < paramTypes.length; i++) { params[i] = constructSingleArg(paramTypes[i]); } return params; } private static Object constructSingleArg(Class<?> clazz) { if (clazz.isPrimitive() || clazz.equals(String.class)) { return getValueForPrimitiveOrString(clazz); } else if (clazz.isEnum()) { if (clazz.getEnumConstants().length == 0) { return null; } else { return clazz.getEnumConstants()[0]; } } else if (clazz.isArray() && (clazz.getComponentType().isPrimitive() || clazz.getComponentType().equals(String.class))) { // Cannot mock. For now using null. Also does not handle deeply nested arrays. return null; } else { return mock(clazz); } } private static Object getValueForPrimitiveOrString(Class<?> clazz) { if (clazz.equals(String.class)) { return "teststring"; } else if (clazz.equals(byte.class)) { return 'b'; } else if (clazz.equals(short.class)) { return 2; } else if (clazz.equals(int.class)) { return 224; } else if (clazz.equals(long.class)) { return 445l; } else if (clazz.equals(float.class)) { return 2.24f; } else if (clazz.equals(double.class)) { return 4.57d; } else if (clazz.equals(boolean.class)) { return true; } else if (clazz.equals(char.class)) { return 'c'; } else if (clazz.equals(void.class)) { return null; } else { throw new RuntimeException("Unrecognized type: " + clazz.getName()); } } public static class TrackingAnswer implements Answer { public String lastMethodName; public Object[] lastArgs; public Object lastRetValue; boolean compareAsPrimitive; @Override public Object answer(InvocationOnMock invocation) throws Throwable { lastArgs = invocation.getArguments(); lastMethodName = invocation.getMethod().getName(); Class<?> retType = invocation.getMethod().getReturnType(); lastRetValue = constructSingleArg(retType); compareAsPrimitive = retType.isPrimitive() || retType.isEnum() || retType.equals(String.class); return lastRetValue; } } }