package io.github.proxyhotswap;
import io.github.proxyhotswap.javassist.ClassPool;
import io.github.proxyhotswap.javassist.CtClass;
import io.github.proxyhotswap.javassist.CtField;
import io.github.proxyhotswap.javassist.CtMethod;
import io.github.proxyhotswap.javassist.Modifier;
import java.io.ByteArrayInputStream;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.security.ProtectionDomain;
import java.util.Map;
import java.util.UUID;
/**
* @author Erki Ehtla
*
*/
public abstract class AbstractProxyTransformer implements ClassFileTransformer {
protected static final String INIT_FIELD_PREFIX = "initCalled";
protected static final ClassPool classPool = TransformationUtils.getClassPool();
protected Instrumentation inst;
protected Map<Class<?>, TransformationState> transformationStates;
public AbstractProxyTransformer(Instrumentation inst, Map<Class<?>, TransformationState> transformationStates) {
this.inst = inst;
this.transformationStates = transformationStates;
}
@Override
public byte[] transform(ClassLoader loader, String className, final Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, final byte[] classfileBuffer) throws IllegalClassFormatException {
try {
if (classBeingRedefined == null || !isProxy(className, classBeingRedefined, classfileBuffer)) {
return null;
}
return transformRedefine(loader, className, classBeingRedefined, protectionDomain, classfileBuffer);
} catch (Exception e) {
removeClassState(classBeingRedefined);
TransformationUtils.logError(e);
throw new RuntimeException(e);
}
}
private boolean isTransformingNeeded(Class<?> classBeingRedefined) {
Class<?> superclass = classBeingRedefined.getSuperclass();
if (superclass != null && ClassfileBufferSigantureTransformer.hasClassChanged(superclass))
return true;
Class<?>[] interfaces = classBeingRedefined.getInterfaces();
for (Class<?> clazz : interfaces) {
if (ClassfileBufferSigantureTransformer.hasClassChanged(clazz))
return true;
}
return false;
}
protected abstract boolean isProxy(String className, Class<?> classBeingRedefined, byte[] classfileBuffer)
throws Exception;
protected byte[] transformRedefine(ClassLoader loader, String className, Class<?> classBeingRedefined,
ProtectionDomain protectionDomain, byte[] classfileBuffer) throws Exception {
switch (getTransformationstate(classBeingRedefined)) {
case NEW:
if (!isTransformingNeeded(classBeingRedefined)) {
return null;
}
setClassAsWaiting(classBeingRedefined);
// We can't do the transformation in this event, because we can't see the changes in the class
// definitons. Schedule a new redefinition event.
scheduleRedefinition(classBeingRedefined, classfileBuffer);
return null;
case WAITING:
removeClassState(classBeingRedefined);
return generateNewProxyClass(loader, className, classBeingRedefined);
default:
throw new RuntimeException("Unhandeled TransformationState!");
}
}
protected byte[] generateNewProxyClass(ClassLoader loader, String className, Class<?> classBeingRedefined)
throws Exception {
byte[] newByteCode = getNewByteCode(loader, className, classBeingRedefined);
CtClass cc = getCtClass(newByteCode, className);
String random = generateRandomString();
String initFieldName = INIT_FIELD_PREFIX + random;
addStaticInitStateField(cc, initFieldName);
String method = getInitCall(cc, random);
addInitCallToMethods(cc, initFieldName, method);
return cc.toBytecode();
}
protected CtClass getCtClass(byte[] newByteCode, String className) throws Exception {
return classPool.makeClass(new ByteArrayInputStream(newByteCode), false);
}
protected abstract String getInitCall(CtClass cc, String random) throws Exception;
protected abstract byte[] getNewByteCode(ClassLoader loader, String className, Class<?> classBeingRedefined)
throws Exception;
protected TransformationState getTransformationstate(final Class<?> classBeingRedefined) {
TransformationState transformationState = transformationStates.get(classBeingRedefined);
if (transformationState == null)
transformationState = TransformationState.NEW;
return transformationState;
}
protected String generateRandomString() {
return UUID.randomUUID().toString().replace("-", "");
}
protected void addInitCallToMethods(CtClass cc, String clinitFieldName, String initCall) throws Exception {
CtMethod[] methods = cc.getDeclaredMethods();
for (CtMethod ctMethod : methods) {
if (!ctMethod.isEmpty() && !Modifier.isStatic(ctMethod.getModifiers())) {
ctMethod.insertBefore("if(!" + clinitFieldName + "){" + initCall + "}");
}
}
}
protected void addStaticInitStateField(CtClass cc, String clinitFieldName) throws Exception {
CtField f = new CtField(CtClass.booleanType, clinitFieldName, cc);
f.setModifiers(Modifier.PRIVATE | Modifier.STATIC);
// init value "true" will be inside clinit, so the field wont actually be initialized to this
cc.addField(f, "true");
}
protected void scheduleRedefinition(final Class<?> classBeingRedefined, final byte[] classfileBuffer) {
new Thread() {
@Override
public void run() {
try {
inst.redefineClasses(new ClassDefinition(classBeingRedefined, classfileBuffer));
} catch (ClassNotFoundException | UnmodifiableClassException e) {
removeClassState(classBeingRedefined);
throw new RuntimeException(e);
}
}
}.start();
}
protected TransformationState setClassAsWaiting(Class<?> classBeingRedefined) {
return transformationStates.put(classBeingRedefined, TransformationState.WAITING);
}
protected TransformationState removeClassState(Class<?> classBeingRedefined) {
return transformationStates.remove(classBeingRedefined);
}
}