/*
* Copyright 2016, Stuart Douglas, and individual contributors as indicated
* by the @authors tag.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.fakereplace.manip;
import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.fakereplace.api.environment.CurrentEnvironment;
import org.fakereplace.core.Agent;
import org.fakereplace.core.Constants;
import org.fakereplace.data.BaseClassData;
import org.fakereplace.data.ClassDataStore;
import org.fakereplace.data.MethodData;
import org.fakereplace.logging.Logger;
import org.fakereplace.manip.data.FakeMethodCallData;
import org.fakereplace.manip.util.Boxing;
import org.fakereplace.manip.util.ManipulationDataStore;
import org.fakereplace.manip.util.ManipulationUtils;
import org.fakereplace.runtime.MethodIdentifierStore;
import org.fakereplace.util.DescriptorUtils;
import javassist.bytecode.BadBytecode;
import javassist.bytecode.Bytecode;
import javassist.bytecode.ClassFile;
import javassist.bytecode.CodeIterator;
import javassist.bytecode.ConstPool;
import javassist.bytecode.MethodInfo;
import javassist.bytecode.Opcode;
/**
* Manipulator that handles fake method call invocations
*/
public class FakeMethodCallManipulator implements ClassManipulator {
private final ManipulationDataStore<FakeMethodCallData> data = new ManipulationDataStore<>();
private final Logger log = Logger.getLogger(FakeMethodCallManipulator.class);
public void clearRewrites(String className, ClassLoader loader) {
data.remove(className, loader);
}
public void addFakeMethodCall(FakeMethodCallData methodInfo) {
data.add(methodInfo.getClassName(), methodInfo);
}
public boolean transformClass(ClassFile file, ClassLoader loader, boolean modifiableClass, final Set<MethodInfo> modifiedMethods) {
if(!Agent.isRetransformationStarted()) {
return false;
}
final Map<String, Set<FakeMethodCallData>> virtualToStaticMethod = data.getManipulationData(loader);
final Map<Integer, FakeMethodCallData> methodCallLocations = new HashMap<>();
final Map<Integer, AddedMethodInfo> newMethodInfoMap = new HashMap<>();
// first we need to scan the constant pool looking for
// CONSTANT_method_info_ref structures
ConstPool pool = file.getConstPool();
for (int i = 1; i < pool.getSize(); ++i) {
// we have a method call
if (pool.getTag(i) == ConstPool.CONST_Methodref || pool.getTag(i) == ConstPool.CONST_InterfaceMethodref) {
String className, methodDesc, methodName;
if (pool.getTag(i) == ConstPool.CONST_Methodref) {
className = pool.getMethodrefClassName(i);
methodDesc = pool.getMethodrefType(i);
methodName = pool.getMethodrefName(i);
} else {
className = pool.getInterfaceMethodrefClassName(i);
methodDesc = pool.getInterfaceMethodrefType(i);
methodName = pool.getInterfaceMethodrefName(i);
}
if(methodName.equals("<clinit>") || methodName.equals("<init>")) {
continue;
}
boolean handled = false;
if (virtualToStaticMethod.containsKey(className)) {
for (FakeMethodCallData data : virtualToStaticMethod.get(className)) {
if (methodName.equals(data.getMethodName()) && methodDesc.equals(data.getMethodDesc())) {
// store the location in the const pool of the method ref
methodCallLocations.put(i, data);
// we have found a method call
// now lets replace it
handled = true;
break;
}
}
}
if (!handled && !className.equals(file.getName()) && CurrentEnvironment.getEnvironment().isClassReplaceable(className, loader)) {
//may be an added method
//if the field does not actually exist yet we just assume it is about to come into existence
//and rewrite it anyway
BaseClassData data = ClassDataStore.instance().getBaseClassData(loader, className);
if(data != null) {
boolean noClassData = false;
MethodData method = null;
try {
Class<?> mainClass = loader.loadClass(className);
Set<Class> allClasses = new HashSet<>();
addToAllClasses(mainClass, allClasses);
for(Class clazz : allClasses) {
data = ClassDataStore.instance().getBaseClassData(clazz.getClassLoader(), clazz.getName());
if(data == null) {
noClassData = true;
break;
}
method = data.getMethodOrConstructor(methodName, methodDesc);
if(method != null) {
break;
}
}
} catch (ClassNotFoundException e) {
noClassData = true;
}
if (!noClassData) {
if (method == null) {
//this is a new method
//lets deal with it
int methodNo = MethodIdentifierStore.instance().getMethodNumber(methodName, methodDesc);
newMethodInfoMap.put(i, new AddedMethodInfo(methodNo, className, methodName, methodDesc));
} else if (!Modifier.isPublic(method.getAccessFlags())) {
boolean requiresVisibilityUpgrade = false;
if (Modifier.isPrivate(method.getAccessFlags())) {
requiresVisibilityUpgrade = true;
} else if (!Modifier.isProtected(method.getAccessFlags())) {
//we can't handle protected properly, because we need to know the class heirachy
//this is package local, so we check the package names
boolean thisDefault = !file.getName().contains(".");
boolean thatDefault = !className.contains(".");
if (thisDefault && !thatDefault) {
requiresVisibilityUpgrade = true;
} else if (thatDefault && !thisDefault) {
requiresVisibilityUpgrade = true;
} else if (!thatDefault) {
String thatPackage = className.substring(0, className.lastIndexOf("."));
String thisPackage = file.getName().substring(0, file.getName().lastIndexOf("."));
if (!thisPackage.equals(thatPackage)) {
requiresVisibilityUpgrade = true;
}
}
}
if (requiresVisibilityUpgrade) {
int methodNo = MethodIdentifierStore.instance().getMethodNumber(methodName, methodDesc);
newMethodInfoMap.put(i, new AddedMethodInfo(methodNo, className, methodName, methodDesc));
}
}
}
}
}
}
}
// this means we found an instance of the call, now we have to iterate
// through the methods and replace instances of the call
if (!methodCallLocations.isEmpty() || !newMethodInfoMap.isEmpty()) {
List<MethodInfo> methods = file.getMethods();
for (MethodInfo m : methods) {
try {
// ignore abstract methods
if (m.getCodeAttribute() == null) {
continue;
}
CodeIterator it = m.getCodeAttribute().iterator();
while (it.hasNext()) {
// loop through the bytecode
int index = it.next();
int op = it.byteAt(index);
// if the bytecode is a method invocation
if (op == CodeIterator.INVOKEVIRTUAL || op == CodeIterator.INVOKESTATIC || op == CodeIterator.INVOKEINTERFACE || op == CodeIterator.INVOKESPECIAL) {
int val = it.s16bitAt(index + 1);
// if the method call is one of the methods we are
// replacing
if(newMethodInfoMap.containsKey(val)) {
AddedMethodInfo methodInfo = newMethodInfoMap.get(val);
FakeMethodCallData data = new FakeMethodCallData(methodInfo.className, methodInfo.name, methodInfo.desc, op == Opcode.INVOKESTATIC ? FakeMethodCallData.Type.STATIC : op == Opcode.INVOKEINTERFACE ? FakeMethodCallData.Type.INTERFACE : FakeMethodCallData.Type.VIRTUAL, loader, methodInfo.number);
handleFakeMethodCall(file, modifiedMethods, m, it, index, op, data);
} else if (methodCallLocations.containsKey(val)) {
FakeMethodCallData data = methodCallLocations.get(val);
handleFakeMethodCall(file, modifiedMethods, m, it, index, op, data);
}
}
}
modifiedMethods.add(m);
m.getCodeAttribute().computeMaxStack();
} catch (Exception e) {
log.error("Bad byte code transforming " + file.getName(), e);
e.printStackTrace();
}
}
return true;
} else {
return false;
}
}
private void addToAllClasses(Class<?> clazz, Set<Class> allClasses) {
while ( clazz != null) {
allClasses.add(clazz);
for(Class<?> iface : clazz.getInterfaces()) {
addToAllClasses(iface, allClasses);
}
clazz = clazz.getSuperclass();
}
}
private void handleFakeMethodCall(ClassFile file, Set<MethodInfo> modifiedMethods, MethodInfo m, CodeIterator it, int index, int op, FakeMethodCallData data) throws BadBytecode {
//NOP out the whole thing
it.writeByte(CodeIterator.NOP, index );
it.writeByte(CodeIterator.NOP, index + 1);
it.writeByte(CodeIterator.NOP, index + 2);
if (op == CodeIterator.INVOKEINTERFACE) {
// INVOKEINTERFACE has some extra parameters
it.writeByte(CodeIterator.NOP, index + 3);
it.writeByte(CodeIterator.NOP, index + 4);
}
//now we write some bytecode to invoke it directly
final boolean staticMethod = data.getType() == FakeMethodCallData.Type.STATIC;
Bytecode byteCode = new Bytecode(file.getConstPool());
// stick the method number in the const pool then load it onto the
// stack
ManipulationUtils.pushParametersIntoArray(byteCode, data.getMethodDesc());
int scind = file.getConstPool().addIntegerInfo(data.getMethodNumber());
byteCode.addLdc(scind);
byteCode.add(Opcode.SWAP);
// invoke the added method
if (staticMethod) {
byteCode.addInvokestatic(data.getClassName(), Constants.ADDED_STATIC_METHOD_NAME, "(I[Ljava/lang/Object;)Ljava/lang/Object;");
} else if (data.getType() == FakeMethodCallData.Type.INTERFACE) {
byteCode.addInvokeinterface(data.getClassName(), Constants.ADDED_METHOD_NAME, "(I[Ljava/lang/Object;)Ljava/lang/Object;", 3);
} else {
byteCode.addInvokevirtual(data.getClassName(), Constants.ADDED_METHOD_NAME, "(I[Ljava/lang/Object;)Ljava/lang/Object;");
}
// cast it to the appropriate type and return it
String returnType = DescriptorUtils.getReturnType(data.getMethodDesc());
if(returnType.length() == 1 && !returnType.equals("V")) {
Boxing.unbox(byteCode, returnType.charAt(0));
} else if(returnType.equals("V")) {
byteCode.add(Opcode.POP);
} else {
byteCode.addCheckcast(returnType.substring(1, returnType.length() - 1));
}
it.insertEx(byteCode.get());
modifiedMethods.add(m);
}
private static class AddedMethodInfo {
final int number;
final String className;
final String name;
final String desc;
private AddedMethodInfo(int number, String className, String name, String desc) {
this.number = number;
this.className = className;
this.name = name;
this.desc = desc;
}
}
}