package org.codehaus.groovy.gjit; import java.util.HashMap; import java.util.List; import java.util.ListIterator; import java.util.Map; import java.util.Set; import java.util.Map.Entry; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; import org.objectweb.asm.tree.AbstractInsnNode; import org.objectweb.asm.tree.FieldInsnNode; import org.objectweb.asm.tree.InsnList; import org.objectweb.asm.tree.InsnNode; import org.objectweb.asm.tree.JumpInsnNode; import org.objectweb.asm.tree.LdcInsnNode; import org.objectweb.asm.tree.LocalVariableNode; import org.objectweb.asm.tree.MethodInsnNode; import org.objectweb.asm.tree.MethodNode; import org.objectweb.asm.tree.TypeInsnNode; import org.objectweb.asm.tree.VarInsnNode; import org.objectweb.asm.tree.analysis.AnalyzerException; import org.objectweb.asm.tree.analysis.BasicValue; import org.objectweb.asm.tree.analysis.Value; import org.objectweb.asm.util.AbstractVisitor; public class Transformer extends Analyzer implements Opcodes { private static final String SCRIPT_BYTECODE_ADAPTER = "org/codehaus/groovy/runtime/ScriptBytecodeAdapter"; private static final String CALL_SITE_INTERFACE = "org/codehaus/groovy/runtime/callsite/CallSite"; private static final String DEFAULT_TYPE_TRANSFORMATION = "org/codehaus/groovy/runtime/typehandling/DefaultTypeTransformation"; private InsnList units; private MethodNode node; private enum Phase { PHASE_CALLSITE, PHASE_NEXT_1, PHASE_NEXT_2, PHASE_ERROR } private enum CallSiteState { START, FOUND_CALLSITE_INST, END, ERROR, }; private Phase phase = Phase.PHASE_CALLSITE; // private CallSiteState state = CallSiteState.START; private int callSiteVar = -1; private ConstantPack pack; private String[] siteNames; private Integer currentSiteIndex = -1; private String owner; private int[] localTypes; private Map<AbstractInsnNode, Type> markForLaterBox = new HashMap<AbstractInsnNode, Type>(); private Map<AbstractInsnNode, Type> markForLaterUnbox = new HashMap<AbstractInsnNode, Type>(); public Transformer(String owner, MethodNode mn, ConstantPack pack, String[] siteNames) { super(new MyBasicInterpreter()); this.owner = owner; this.node = mn; this.units = node.instructions; this.pack = pack; this.siteNames = siteNames; this.localTypes = new int[mn.maxLocals]; } public void transform() throws AnalyzerException { preTransform(); this.analyze(this.owner, this.node); postTransform(); // TraceMethodVisitor t = new TraceMethodVisitor(null); // node.accept(t); // DebugUtils.println(t.text); } private void postTransform() { Set<Entry<AbstractInsnNode, Type>> set = markForLaterBox.entrySet(); for (Entry<AbstractInsnNode, Type> entry : set) { AbstractInsnNode s = entry.getKey(); Type t = entry.getValue(); String boxType=null; String primType=null; switch(t.getSort()) { case Type.INT: boxType = "java/lang/Integer"; primType = "I"; break; case Type.LONG: boxType = "java/lang/Long"; primType = "J"; break; // TODO: other types } MethodInsnNode iv = new MethodInsnNode(INVOKESTATIC, boxType, "valueOf", "("+ primType +")L"+boxType+";"); if(s.getOpcode()==SWAP) s = s.getPrevious(); // work around for inserted SWAP,POP units.insert(s, iv); } Set<Entry<AbstractInsnNode, Type>> set2 = markForLaterUnbox.entrySet(); for (Entry<AbstractInsnNode, Type> entry : set2) { AbstractInsnNode s = entry.getKey(); Type t = entry.getValue(); String boxType=null; String primType = null; String primTypeName = null; switch(t.getSort()) { case Type.INT: boxType = "java/lang/Integer"; primType = "I"; primTypeName = "int"; break; case Type.LONG: boxType = "java/lang/Long"; primType = "J"; primTypeName = "long"; break; case Type.FLOAT: boxType = "java/lang/Float"; primType = "F"; primTypeName = "float"; break; case Type.DOUBLE: boxType = "java/lang/Double"; primType = "D"; primTypeName = "double"; break; } TypeInsnNode tnode = new TypeInsnNode(CHECKCAST, boxType); MethodInsnNode iv = new MethodInsnNode(INVOKEVIRTUAL, boxType, primTypeName + "Value", "()" + primType); units.insertBefore(s, tnode); units.insert(tnode, iv); } } @Override public Action process(AbstractInsnNode s, Map<AbstractInsnNode, Frame> frames) { Frame frame = frames.get(s); // DebugUtils.println(frame); // DebugUtils.println("index: " + units.indexOf(s)); if(extractCallSiteName(s)) return Action.NONE; if(eliminateBoxCastUnbox(s)) return Action.REMOVE; if(unwrapConst(s)) return Action.REPLACE; if(unwrapBoxOrUnbox(s)) return Action.REMOVE; if(unwrapBinaryPrimitiveCall(s, frame)) return Action.REPLACE; if(unwrapCompare(s,frame)) return Action.REMOVE; if(clearCast(s)) return Action.REMOVE; if(correctNormalCall(s)) return Action.NONE; if(correctLocalType(s)) return Action.REPLACE; if(correctUnbox(s)) return Action.NONE; return Action.NONE; } boolean flag = false; boolean flag2 = false; @Override protected void postprocess(AbstractInsnNode insnNode,Interpreter interpreter) { if(flag == true) { MyBasicInterpreter i = (MyBasicInterpreter)interpreter; MethodInsnNode m = (MethodInsnNode)insnNode; // DebugUtils.println(m.desc); Type[] types = Type.getArgumentTypes(m.desc); Value[] values = i.use.get(insnNode); if(m.getOpcode() == INVOKESTATIC) { for (int j = 0; j < values.length; j++) { Type t = types[j]; BasicValue bv = ((BasicValue)values[j]); //DebugUtils.print(j + ". "); if(t.equals(bv.getType())==false && bv.isReference()==false) { // DebugUtils.print("expected: " + t + ", found: "); // DebugUtils.print(values[j]+", "); markForLaterBox.put(((DefValue)bv).source, bv.getType()); } // DebugUtils.println(AbstractVisitor.OPCODES[((DefValue)values[j]).source.getOpcode()]); } } else if(m.getOpcode() == INVOKEINTERFACE){ for (int j = 1; j < values.length; j++) { Type t = types[j-1]; BasicValue bv = ((BasicValue)values[j]); //DebugUtils.print(j + ". "); if(t.equals(bv.getType())==false && bv.isReference()==false) { // DebugUtils.print("expected: " + t + ", found: "); // DebugUtils.print(values[j]+", "); try { markForLaterBox.put(((DefValue)bv).source, bv.getType()); } catch(ClassCastException e) { DebugUtils.print(m.owner+"."); DebugUtils.print(m.name); DebugUtils.println(m.desc); DebugUtils.println(bv); // e.printStackTrace(); //throw new RuntimeException(e); } } // DebugUtils.println(AbstractVisitor.OPCODES[((DefValue)values[j]).source.getOpcode()]); } } // DebugUtils.println("================="); flag = false; } if(flag2 == true) { MyBasicInterpreter i = (MyBasicInterpreter)interpreter; // MethodInsnNode m = (MethodInsnNode)insnNode; // Type[] types = Type.getArgumentTypes(m.desc); Value[] values = i.use.get(insnNode); DebugUtils.println(AbstractVisitor.OPCODES[insnNode.getOpcode()]); try { DefValue defValue = (DefValue)values[0]; if(defValue.isReference()) { markForLaterUnbox.put(insnNode, flag2Type); } } catch(ClassCastException e) { // TODO: what's happening? } flag2 = false; } } private Type flag2Type; private boolean correctUnbox(AbstractInsnNode s) { flag2Type = null; switch(s.getOpcode()) { case ISTORE: case IRETURN: flag2Type = Type.INT_TYPE; break; case LRETURN: case LSTORE: flag2Type = Type.LONG_TYPE; break; case FRETURN: case FSTORE: flag2Type = Type.FLOAT_TYPE; break; case DRETURN: case DSTORE: flag2Type = Type.DOUBLE_TYPE; break; } if(flag2Type != null) { flag2 = true; return true; } return false; } private boolean correctNormalCall(AbstractInsnNode s) { if(s.getOpcode() != INVOKEINTERFACE && s.getOpcode() != INVOKESTATIC) return false; MethodInsnNode iv = (MethodInsnNode)s; if(iv.owner.equals(CALL_SITE_INTERFACE) && iv.name.startsWith("call")) { //DebugUtils.println(iv.name); flag = true; return true; } else if (iv.owner.equals(SCRIPT_BYTECODE_ADAPTER) && !iv.name.equals("unwrap")) { //DebugUtils.println(iv.name); flag = true; return true; } return false; } private boolean correctLocalType(AbstractInsnNode s) { if(s.getOpcode() != ALOAD) return false; VarInsnNode v = (VarInsnNode)s; int type = localTypes[v.var]; if(type==0) return false; switch(type) { case 'I': case 'B': case 'S': case 'Z': case 'C': units.set(s, new VarInsnNode(ILOAD, v.var)); break; case 'J': units.set(s, new VarInsnNode(LLOAD, v.var)); break; case 'F': units.set(s, new VarInsnNode(FLOAD, v.var)); break; case 'D': units.set(s, new VarInsnNode(DLOAD, v.var)); break; } return true; } private boolean eliminateBoxCastUnbox(AbstractInsnNode s) { // INVOKESTATIC org/codehaus/groovy/runtime/typehandling/DefaultTypeTransformation.box(I)Ljava/lang/Object; // INVOKESTATIC TreeNode.$get$$class$java$lang$Integer()Ljava/lang/Class; // INVOKESTATIC org/codehaus/groovy/runtime/ScriptBytecodeAdapter.castToType(Ljava/lang/Object;Ljava/lang/Class;)Ljava/lang/Object; // CHECKCAST java/lang/Integer // INVOKESTATIC org/codehaus/groovy/runtime/typehandling/DefaultTypeTransformation.intUnbox(Ljava/lang/Object;)I if(s.getOpcode() != INVOKESTATIC) return false; AbstractInsnNode s1 = s.getNext(); if(s1 == null) return false; if(s1.getOpcode() != INVOKESTATIC) return false; AbstractInsnNode s2 = s1.getNext(); if(s2 == null) return false; if(s2.getOpcode() != INVOKESTATIC) return false; AbstractInsnNode s3 = s2.getNext(); if(s3 == null) return false; if(s3.getOpcode() != CHECKCAST) return false; AbstractInsnNode s4 = s3.getNext(); if(s4 == null) return false; if(s4.getOpcode() != INVOKESTATIC) return false; MethodInsnNode m = (MethodInsnNode)s; MethodInsnNode m1 = (MethodInsnNode)s1; MethodInsnNode m2 = (MethodInsnNode)s2; MethodInsnNode m4 = (MethodInsnNode)s4; // if(m.owner.equals(DEFAULT_TYPE_TRANSFORMATION)==false) return false; if(m.name.equals("box") == false) return false; if(m1.name.startsWith("$get$$class$") == false) return false; if(m2.name.startsWith("castToType") == false) return false; if(m4.name.endsWith("Unbox")== false) return false; units.remove(s); units.remove(s1); units.remove(s2); units.remove(s3); units.remove(s4); return true; } private int getPrimitive(String className) { if(className.charAt(0)=='L' && className.charAt(className.length()-1)==';') { className = className.substring(1,className.length()-1); } // DebugUtils.println("getPrimitive: " + className); if(className.equals("java/lang/Integer")) return 'I'; if(className.equals("java/lang/Long")) return 'J'; if(className.equals("java/lang/Boolean")) return 'Z'; if(className.equals("java/lang/Byte")) return 'B'; if(className.equals("java/lang/Character")) return 'C'; if(className.equals("java/lang/Short")) return 'S'; if(className.equals("java/lang/Float")) return 'F'; if(className.equals("java/lang/Double")) return 'D'; return 0; } private boolean clearCast(AbstractInsnNode s) { // INVOKESTATIC TreeNode.$get$$class$java$lang$Integer()Ljava/lang/Class; // INVOKESTATIC org/codehaus/groovy/runtime/ScriptBytecodeAdapter.castToType(Ljava/lang/Object;Ljava/lang/Class;)Ljava/lang/Object; // CHECKCAST java/lang/Integer if(s.getOpcode() != INVOKESTATIC) return false; MethodInsnNode m = (MethodInsnNode)s; if(m.name.startsWith("$get$$class$java$lang$")==false) return false; AbstractInsnNode s1 = s.getNext(); if(s1 ==null ) return false; if(s1.getOpcode() != INVOKESTATIC) return false; MethodInsnNode m1 = (MethodInsnNode)s1; if(m1.name.equals("castToType")==false) return false; AbstractInsnNode s2 = s1.getNext(); if(s2 == null) return false; if(s2.getOpcode() != CHECKCAST) return false; TypeInsnNode t2 = (TypeInsnNode)s2; if(t2.desc.startsWith("java/lang")==false) return false; int type = getPrimitive(t2.desc); AbstractInsnNode s3 = s2.getNext(); fixASTORE(type, s3); units.remove(s); units.remove(s1); units.remove(s2); // TODO: change the next instruction to deal with PRIMITIVE return true; } private void fixASTORE(int type, AbstractInsnNode nextS) { if(nextS.getOpcode()==ASTORE) { VarInsnNode v3 = (VarInsnNode)nextS; switch(type) { case 'I': case 'B': case 'S': case 'Z': case 'C': units.set(nextS, new VarInsnNode(ISTORE, v3.var)); break; case 'J': units.set(nextS, new VarInsnNode(LSTORE, v3.var)); break; case 'F': units.set(nextS, new VarInsnNode(FSTORE, v3.var)); break; case 'D': units.set(nextS, new VarInsnNode(DSTORE, v3.var)); break; } localTypes[v3.var] = type; correctLocalVarInfo(type, v3); } } private void correctLocalVarInfo(int type, VarInsnNode v3) { List<?> vars = node.localVariables; if(vars != null) { for(int i=0;i<vars.size();i++) { LocalVariableNode l = (LocalVariableNode)vars.get(i); if(l.index==v3.var) { l.desc = String.valueOf((char)type); break; } } } } private enum ComparingMethod { compareLessThan, compareGreaterThan, compareLessThanEqual, compareGreaterThanEqual }; private boolean unwrapCompare(AbstractInsnNode s, Frame frame) { if(s.getOpcode() != Opcodes.INVOKESTATIC) return false; MethodInsnNode m = (MethodInsnNode)s; if(m.owner.equals(SCRIPT_BYTECODE_ADAPTER)==false) return false; if(m.name.startsWith("compare")==false) return false; if(m.desc.equals("(Ljava/lang/Object;Ljava/lang/Object;)Z")==false) return false; JumpInsnNode s1 = (JumpInsnNode)s.getNext(); ComparingMethod compare; try { compare = ComparingMethod.valueOf(m.name); } catch(IllegalArgumentException e) { return false; } // DebugUtils.println(">>>>> did unwrapping compare"); switch(compare) { case compareGreaterThan: units.set(s1, new JumpInsnNode(IF_ICMPLE, s1.label)); break; case compareGreaterThanEqual: units.set(s1, new JumpInsnNode(IF_ICMPLT, s1.label)); break; case compareLessThan: units.set(s1, new JumpInsnNode(IF_ICMPGE, s1.label)); break; case compareLessThanEqual: units.set(s1, new JumpInsnNode(IF_ICMPGT, s1.label)); break; } units.remove(s); return true; } private void preTransform() { ListIterator<?> stmts = units.iterator(); while(stmts.hasNext()) { AbstractInsnNode s = (AbstractInsnNode)stmts.next(); switch(phase) { case PHASE_CALLSITE: phase = processPhaseCallSite(s); break; } } //node.localVariables.add(null); } private Phase processPhaseCallSite(AbstractInsnNode s0) { CallSiteState state = CallSiteState.START; AbstractInsnNode s = s0; while(true) { switch(state) { case START: state = detectCallSiteInst(state, s); break; case FOUND_CALLSITE_INST: state = detectCallSiteVar(state, s); break; case END: return Phase.PHASE_NEXT_1; case ERROR: return Phase.PHASE_ERROR; } s = s.getNext(); if(s == null) state = CallSiteState.ERROR; } } private CallSiteState detectCallSiteVar(CallSiteState state, AbstractInsnNode s) { if(s.getOpcode() != ASTORE) return state; VarInsnNode v = (VarInsnNode)s; callSiteVar = v.var; return CallSiteState.END; } private CallSiteState detectCallSiteInst(CallSiteState state, AbstractInsnNode s) { if(s.getOpcode() != INVOKESTATIC) return state; MethodInsnNode m = (MethodInsnNode)s; if(m.name.equals("$getCallSiteArray")) return CallSiteState.FOUND_CALLSITE_INST; return state; } private enum BinOp { minus, plus, multiply, div, leftShift, rightShift } private boolean unwrapBinaryPrimitiveCall(AbstractInsnNode s, Frame frame) { if(s.getOpcode() != INVOKEINTERFACE) return false; MethodInsnNode iv = (MethodInsnNode)s; if(iv.owner.equals(CALL_SITE_INTERFACE) == false) return false; if(iv.name.equals("call") == false) return false; if(iv.desc.equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") == false) return false; String name = siteNames[currentSiteIndex]; // DebugUtils.println("frame: " + frame); BinOp op=null; try {op = BinOp.valueOf(name);} catch(IllegalArgumentException e){} if(op == null) return false; // TODO check type from "frame" DebugUtils.println("frame: " + frame); int oldIndex = units.indexOf(s); if(s.getPrevious().getOpcode()==LLOAD) DebugUtils.println(">> Found it !!"); Value v2 = frame.getStack(frame.getStackSize()-1); // peek Value v1 = frame.getStack(frame.getStackSize()-2); // peek // TODO if(v1.sort != v2.sort) do something int offset = 0; DebugUtils.println("v1:" +v1); DebugUtils.println("v2:" +v2); if(((BasicValue)v1).getType().equals(Type.LONG_TYPE)) offset = 1; else if(((BasicValue)v1).getType().equals(Type.FLOAT_TYPE)) offset = 2; else if(((BasicValue)v1).getType().equals(Type.DOUBLE_TYPE)) offset = 3; switch(op) { case minus: units.set(s, new InsnNode(ISUB + offset)); break; case plus: units.set(s, new InsnNode(IADD + offset)); break; case multiply: units.set(s, new InsnNode(IMUL + offset)); break; case div: units.set(s, new InsnNode(IDIV + offset)); break; case leftShift: units.set(s, new InsnNode(ISHL + offset)); break; case rightShift: units.set(s, new InsnNode(ISHR + offset)); break; } s = units.get(oldIndex); if(v1.getSize()==1) { // SWAP, // POP units.insert(s, new InsnNode(POP)); units.insert(s, new InsnNode(SWAP)); } else if(v1.getSize()==2){ units.insert(s, new InsnNode(POP)); units.insert(s, new InsnNode(POP2)); units.insert(s, new InsnNode(DUP2_X1)); } return true; } private boolean extractCallSiteName(AbstractInsnNode s) { if(s.getOpcode() != ALOAD) return false; VarInsnNode v = (VarInsnNode)s; if(v.var != callSiteVar) return false; AbstractInsnNode s1 = s.getNext(); if(s1.getOpcode() != LDC) return false; LdcInsnNode l = (LdcInsnNode)s1; currentSiteIndex = (Integer)l.cst; return true; } private boolean unwrapConst(AbstractInsnNode s) { if(s.getOpcode() != GETSTATIC) return false; FieldInsnNode f = (FieldInsnNode)s; if(f.name.startsWith("$const$")) { LdcInsnNode newS = new LdcInsnNode(pack.get(f.name)); AbstractInsnNode s1 = s.getNext(); fixASTORE(getPrimitive(f.desc), s1); units.set(s, newS); return true; } return false; } private boolean unwrapBoxOrUnbox(AbstractInsnNode s) { if(s.getOpcode() != INVOKESTATIC) return false; MethodInsnNode m = (MethodInsnNode)s; if(m.owner.equals(DEFAULT_TYPE_TRANSFORMATION)==false) return false; if(m.name.equals("box")) { units.remove(s); return true; } else if(m.name.endsWith("Unbox")) { units.remove(s); return true; } return false; } }