/*
* The MIT License (MIT)
*
* Copyright (c) 2016. Diorite (by Bartłomiej Mazur (aka GotoFinal))
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package org.diorite.inject.impl.controller;
import javax.annotation.Nullable;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Map.Entry;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.diorite.inject.InjectionController;
import org.diorite.inject.impl.controller.TransformerInjectTracker.PlaceholderType;
import org.diorite.inject.impl.data.WithMethods;
import org.diorite.inject.impl.utils.AsmUtils;
import org.diorite.inject.impl.utils.Constants;
import net.bytebuddy.description.ByteCodeElement;
class Transformer implements Opcodes
{
final ClassNode classNode;
final ControllerClassData classData;
final Map<MethodNode, TransformerInitMethodData> inits = new LinkedHashMap<>(3);
// private MethodNode clinit;
final Map<String, TransformerMethodPair> methods = new LinkedHashMap<>(5);
final Map<String, TransformerFieldPair> fields = new LinkedHashMap<>(5);
@SuppressWarnings("rawtypes") @Nullable
TransformerMemberPair firstStatic = null;
@SuppressWarnings("rawtypes") @Nullable
TransformerMemberPair lastStatic = null;
@SuppressWarnings("rawtypes") @Nullable
TransformerMemberPair firstObject = null;
@SuppressWarnings("rawtypes") @Nullable
TransformerMemberPair lastObject = null;
Transformer(byte[] bytecode, ControllerClassData classData)
{
this.classData = classData;
this.classNode = new ClassNode(Opcodes.ASM6);
ClassReader cr = new ClassReader(bytecode);
cr.accept(this.classNode, 0);
}
public ClassWriter getWriter()
{
ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
this.classNode.accept(cw);
return cw;
}
@SuppressWarnings({"rawtypes", "unchecked"})
private void createMappings()
{
// add all methods to object maps.
for (ControllerMemberData<?> memberData : this.classData.getMembers())
{
ByteCodeElement member = memberData.getMember();
String id = member.getDescriptor() + member.getName();
TransformerMemberPair memberPair;
if (memberData instanceof ControllerFieldData)
{
ControllerFieldData<?> fieldData = (ControllerFieldData<?>) memberData;
TransformerFieldPair fieldPair = new TransformerFieldPair(fieldData);
memberPair = fieldPair;
this.fields.put(id, fieldPair);
}
else
{
ControllerMethodData methodData = (ControllerMethodData) memberData;
TransformerMethodPair methodPair = new TransformerMethodPair(methodData);
memberPair = methodPair;
this.methods.put(id, methodPair);
}
memberPair.isStatic = Modifier.isStatic(member.getModifiers());
if (memberPair.isStatic)
{
if (this.lastStatic != null)
{
this.lastStatic.next = memberPair;
memberPair.prev = this.lastStatic;
this.lastStatic = memberPair;
}
else
{
this.lastStatic = memberPair;
}
if (this.firstStatic == null)
{
this.firstStatic = this.lastStatic;
}
}
else
{
if (this.lastObject != null)
{
this.lastObject.next = memberPair;
memberPair.prev = this.lastObject;
this.lastObject = memberPair;
}
else
{
this.lastObject = memberPair;
}
if (this.firstObject == null)
{
this.firstObject = this.lastObject;
}
}
}
List<MethodNode> methods = this.classNode.methods;
for (int i = 0, methodsSize = methods.size(); i < methodsSize; i++)
{
MethodNode method = methods.get(i);
// if (method.name.equals(InjectionController.STATIC_BLOCK_NAME))
// {
// this.clinit = method;
// }
if (method.name.equals(InjectionController.CONSTRUCTOR_NAME))
{
MethodInsnNode superNode = this.findSuperNode(method);
TransformerInitMethodData initPair = new TransformerInitMethodData(method, superNode);
this.findReturns(initPair);
this.inits.put(method, initPair);
}
String id = method.desc + method.name;
TransformerMethodPair methodPair = this.methods.computeIfAbsent(id, k -> new TransformerMethodPair(null));
methodPair.node = method;
methodPair.index = i;
}
List<FieldNode> fields = this.classNode.fields;
for (int i = 0, fieldsSize = fields.size(); i < fieldsSize; i++)
{
FieldNode field = fields.get(i);
String id = field.desc + field.name;
TransformerFieldPair fieldPair = this.fields.computeIfAbsent(id, k -> new TransformerFieldPair(null));
fieldPair.node = field;
fieldPair.index = i;
}
}
private void findReturns(TransformerInitMethodData initPair)
{
MethodNode init = initPair.node;
AbstractInsnNode node = init.instructions.getFirst();
while (node != null)
{
if ((node instanceof InsnNode) && AsmUtils.isReturnCode(node.getOpcode()))
{
initPair.returns.add((InsnNode) node);
}
node = node.getNext();
if (node == null)
{
break;
}
}
}
@SuppressWarnings("unchecked")
private MethodInsnNode findSuperNode(MethodNode init)
{
// we need to find super(...) invoke, it should be first super.<init> invoke, but it might be invoke to some new created object of super
// type, so we need to track created objects
Collection<String> objects = new LinkedList<>();
ListIterator<AbstractInsnNode> iterator = init.instructions.iterator();
while (iterator.hasNext())
{
AbstractInsnNode next_ = iterator.next();
if (next_.getOpcode() == Opcodes.NEW)
{
TypeInsnNode next = (TypeInsnNode) next_;
objects.add(next.desc);
}
else if (next_.getOpcode() == Opcodes.INVOKESPECIAL)
{
MethodInsnNode next = (MethodInsnNode) next_;
if (! next.name.equals(InjectionController.CONSTRUCTOR_NAME))
{
continue;
}
if (! objects.remove(next.owner) && next.owner.equals(this.classNode.superName))
{
return next;
}
}
}
throw new TransformerError("Can't find super() invoke for constructor!");
}
public <T> void run()
{
this.createMappings();
TransformerFieldInjector.run(this);
TransformerMethodInjector.run(this);
this.addGlobalInjectInvokes();
}
private void addGlobalInjectInvokes()
{
MethodNode codeBefore = new MethodNode();
MethodNode codeAfter = new MethodNode();
this.fillMethodInvokes(codeBefore, codeAfter, this.classData);
for (Entry<MethodNode, TransformerInitMethodData> initEntry : this.inits.entrySet())
{
MethodNode init = initEntry.getKey();
TransformerInitMethodData initPair = initEntry.getValue();
MethodInsnNode superInvoke = initPair.superInvoke;
if (codeAfter.instructions.size() > 0)
{
for (InsnNode node : initPair.returns)
{
init.instructions.insertBefore(node, codeAfter.instructions);
}
}
if (codeBefore.instructions.size() > 0)
{
init.instructions.insert(superInvoke, codeBefore.instructions);
}
}
}
void fillMethodInvokes(MethodNode codeBefore, MethodNode codeAfter, WithMethods member)
{
Collection<String> before = member.getBefore();
Collection<String> after = member.getAfter();
for (String s : before)
{
TransformerMethodPair methodPair = this.methods.get("()V" + s);
if ((methodPair == null) || (methodPair.node == null))
{
throw new TransformerError("Can't find method for invoke before: " + s + " in " + this.classNode.name);
}
boolean isStatic = Modifier.isStatic(methodPair.node.access);
TransformerInvokerGenerator.printMethod(codeBefore, this.classNode.name, s, isStatic, - 1);
}
for (String s : after)
{
TransformerMethodPair methodPair = this.methods.get("()V" + s);
if ((methodPair == null) || (methodPair.node == null))
{
throw new TransformerError("Can't find method for invoke after: " + s + " in " + this.classNode.name);
}
boolean isStatic = Modifier.isStatic(methodPair.node.access);
TransformerInvokerGenerator.printMethod(codeAfter, this.classNode.name, s, isStatic, - 1);
}
}
@Nullable
TransformerMethodPair getMethodPair(MethodInsnNode methodInsnNode)
{
return this.methods.get(methodInsnNode.desc + methodInsnNode.name);
}
@Nullable
TransformerFieldPair getFieldPair(FieldInsnNode fieldInsnNode)
{
return this.fields.get(fieldInsnNode.desc + fieldInsnNode.name);
}
@Nullable
MethodNode getMethod(MethodInsnNode methodInsnNode)
{
TransformerMethodPair methodPair = this.getMethodPair(methodInsnNode);
if (methodPair == null)
{
return null;
}
return methodPair.node;
}
@Nullable
FieldNode getField(FieldInsnNode fieldInsnNode)
{
TransformerFieldPair fieldPair = this.getFieldPair(fieldInsnNode);
if (fieldPair == null)
{
return null;
}
return fieldPair.node;
}
public static PlaceholderType isInjectPlaceholder(AbstractInsnNode node)
{
if (! (node instanceof MethodInsnNode))
{
return PlaceholderType.INVALID;
}
MethodInsnNode mNode = (MethodInsnNode) node;
return isInjectPlaceholder(mNode.getOpcode(), mNode.owner, mNode.name, mNode.desc);
}
public static PlaceholderType isInjectPlaceholder(int opcode, String owner, String name, String desc)
{
if ((opcode == INVOKESTATIC) && owner.equals(Constants.INJECTION_LIBRARY.getInternalName()) && (desc.equals("()Ljava/lang/Object;")))
{
switch (name)
{
case "injectNullable":
return PlaceholderType.NULLABLE;
case "inject":
return PlaceholderType.NONNULL;
default:
return PlaceholderType.INVALID;
}
}
return PlaceholderType.INVALID;
}
private static int printMethods(MethodNode mv, String clazz, Iterable<String> methods, Transformer transformer, int lineNumber)
{
for (String method : methods)
{
TransformerMethodPair methodPair = transformer.methods.get("()V" + method);
if (methodPair == null)
{
throw new TransformerError("Unknown method: " + method + " for " + clazz);
}
if (methodPair.node == null)
{
throw new TransformerError("Node not set yet.");
}
lineNumber = TransformerInvokerGenerator.printMethod(mv, clazz, method, Modifier.isStatic(methodPair.node.access), lineNumber);
}
return lineNumber;
}
}