package codechicken.lib.asm; import com.google.common.base.Function; import com.google.common.collect.Maps; import org.objectweb.asm.tree.*; import java.util.*; import static org.objectweb.asm.tree.AbstractInsnNode.*; public class InsnComparator { public static boolean varInsnEqual(VarInsnNode insn1, VarInsnNode insn2) { return insn1.var == -1 || insn2.var == -1 || insn1.var == insn2.var; } public static boolean methodInsnEqual(MethodInsnNode insn1, MethodInsnNode insn2) { return insn1.owner.equals(insn2.owner) && insn1.name.equals(insn2.name) && insn1.desc.equals(insn2.desc); } public static boolean fieldInsnEqual(FieldInsnNode insn1, FieldInsnNode insn2) { return insn1.owner.equals(insn2.owner) && insn1.name.equals(insn2.name) && insn1.desc.equals(insn2.desc); } public static boolean ldcInsnEqual(LdcInsnNode insn1, LdcInsnNode insn2) { return insn1.cst == null || insn2.cst == null || insn1.cst.equals(insn2.cst); } public static boolean typeInsnEqual(TypeInsnNode insn1, TypeInsnNode insn2) { return insn1.desc.equals("*") || insn2.desc.equals("*") || insn1.desc.equals(insn2.desc); } public static boolean iincInsnEqual(IincInsnNode node1, IincInsnNode node2) { return node1.var == node2.var && node1.incr == node2.incr; } public static boolean intInsnEqual(IntInsnNode node1, IntInsnNode node2) { return node1.operand == -1 || node2.operand == -1 || node1.operand == node2.operand; } public static boolean insnEqual(AbstractInsnNode node1, AbstractInsnNode node2) { if (node1.getOpcode() != node2.getOpcode()) { return false; } switch (node2.getType()) { case VAR_INSN: return varInsnEqual((VarInsnNode) node1, (VarInsnNode) node2); case TYPE_INSN: return typeInsnEqual((TypeInsnNode) node1, (TypeInsnNode) node2); case FIELD_INSN: return fieldInsnEqual((FieldInsnNode) node1, (FieldInsnNode) node2); case METHOD_INSN: return methodInsnEqual((MethodInsnNode) node1, (MethodInsnNode) node2); case LDC_INSN: return ldcInsnEqual((LdcInsnNode) node1, (LdcInsnNode) node2); case IINC_INSN: return iincInsnEqual((IincInsnNode) node1, (IincInsnNode) node2); case INT_INSN: return intInsnEqual((IntInsnNode) node1, (IntInsnNode) node2); default: return true; } } public static boolean insnImportant(AbstractInsnNode insn, Set<LabelNode> controlFlowLabels) { switch (insn.getType()) { case LINE: case FRAME: return false; case LABEL: return controlFlowLabels.contains(insn); default: return true; } } public static Set<LabelNode> getControlFlowLabels(InsnListSection list) { return getControlFlowLabels(list.list); } public static Set<LabelNode> getControlFlowLabels(InsnList list) { HashSet<LabelNode> controlFlowLabels = new HashSet<LabelNode>(); for (AbstractInsnNode insn = list.getFirst(); insn != null; insn = insn.getNext()) { switch (insn.getType()) { case JUMP_INSN: JumpInsnNode jinsn = (JumpInsnNode) insn; controlFlowLabels.add(jinsn.label); break; case TABLESWITCH_INSN: TableSwitchInsnNode tsinsn = (TableSwitchInsnNode) insn; controlFlowLabels.add(tsinsn.dflt); for (LabelNode label : tsinsn.labels) { controlFlowLabels.add(label); } break; case LOOKUPSWITCH_INSN: LookupSwitchInsnNode lsinsn = (LookupSwitchInsnNode) insn; controlFlowLabels.add(lsinsn.dflt); for (LabelNode label : lsinsn.labels) { controlFlowLabels.add(label); } break; } } return controlFlowLabels; } public static InsnList getImportantList(InsnList list) { return getImportantList(new InsnListSection(list)).list; } public static InsnListSection getImportantList(InsnListSection list) { if (list.size() == 0) { return list; } Set<LabelNode> controlFlowLabels = getControlFlowLabels(list); Map<LabelNode, LabelNode> labelMap = Maps.asMap(controlFlowLabels, new Function<LabelNode, LabelNode>() { @Override public LabelNode apply(LabelNode input) { return input; } }); InsnListSection importantNodeList = new InsnListSection(); for (AbstractInsnNode insn : list) { if (insnImportant(insn, controlFlowLabels)) { importantNodeList.add(insn.clone(labelMap)); } } return importantNodeList; } public static List<InsnListSection> find(InsnListSection haystack, InsnListSection needle) { Set<LabelNode> controlFlowLabels = getControlFlowLabels(haystack); LinkedList<InsnListSection> list = new LinkedList<InsnListSection>(); for (int start = 0; start <= haystack.size() - needle.size(); start++) { InsnListSection section = matches(haystack.drop(start), needle, controlFlowLabels); if (section != null) { list.add(section); start = section.end - 1; } } return list; } public static List<InsnListSection> find(InsnList haystack, InsnListSection needle) { return find(new InsnListSection(haystack), needle); } public static InsnListSection matches(InsnListSection haystack, InsnListSection needle, Set<LabelNode> controlFlowLabels) { int h = 0, n = 0; for (; h < haystack.size() && n < needle.size(); h++) { AbstractInsnNode insn = haystack.get(h); if (!insnImportant(insn, controlFlowLabels)) { continue; } if (!insnEqual(haystack.get(h), needle.get(n))) { return null; } n++; } if (n != needle.size()) { return null; } return haystack.take(h); } public static InsnListSection findOnce(InsnListSection haystack, InsnListSection needle) { List<InsnListSection> list = find(haystack, needle); if (list.size() != 1) { throw new RuntimeException("Needle found " + list.size() + " times in Haystack:\n" + haystack + "\n\n" + needle); } return list.get(0); } public static InsnListSection findOnce(InsnList haystack, InsnListSection needle) { return findOnce(new InsnListSection(haystack), needle); } public static List<InsnListSection> findN(InsnListSection haystack, InsnListSection needle) { List<InsnListSection> list = find(haystack, needle); if (list.isEmpty()) { throw new RuntimeException("Needle not found in Haystack:\n" + haystack + "\n\n" + needle); } return list; } public static List<InsnListSection> findN(InsnList haystack, InsnListSection needle) { return findN(new InsnListSection(haystack), needle); } }