package com.github.bingoohuang.springrestclient.generators;
import com.github.bingoohuang.springrestclient.annotations.*;
import com.github.bingoohuang.springrestclient.provider.*;
import com.github.bingoohuang.springrestclient.utils.Obj;
import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Maps;
import com.google.common.util.concurrent.UncheckedExecutionException;
import lombok.SneakyThrows;
import lombok.experimental.UtilityClass;
import lombok.val;
import org.apache.commons.lang3.StringUtils;
import org.springframework.context.ApplicationContext;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
@UtilityClass public class SpringRestClientFactory {
private Cache<Class, Object> restClientCache = CacheBuilder.newBuilder().build();
public <T> T getRestClient(final Class<T> restClientClass, final ApplicationContext appContext) {
Obj.ensureInterface(restClientClass);
try {
return (T) restClientCache.get(restClientClass, new Callable<Object>() {
@Override public Object call() throws Exception {
return load(restClientClass, appContext);
}
});
} catch (ExecutionException e) {
Throwable cause = e.getCause();
throw Throwables.propagate(cause);
} catch (UncheckedExecutionException e) {
Throwable cause = e.getCause();
throw Throwables.propagate(cause);
}
}
@SneakyThrows
public Object load(Class restClientClass, ApplicationContext appContext) {
val generator = new ClassGenerator(restClientClass);
val restClientImplClass = generator.generate();
val object = Obj.createObject(restClientImplClass);
setSignProvider(restClientImplClass, object, restClientClass, appContext);
setBaseUrlProvider(restClientImplClass, object, restClientClass, appContext);
setBasicAuthProvider(restClientImplClass, object, restClientClass, appContext);
setStatusMappings(restClientImplClass, object, restClientClass);
setFixedRequestParams(restClientImplClass, object, restClientClass);
setSuccInResponseJSONProperty(restClientImplClass, object, restClientClass);
setAppContext(restClientImplClass, object, restClientClass, appContext);
return object;
}
private void setAppContext(Class<?> restClientImplClass, Object object, Class restClientClass, ApplicationContext appContext) {
Obj.setField(restClientImplClass, object, MethodGenerator.appContext, appContext);
}
private void setSuccInResponseJSONProperty(Class<?> restClientImplClass, Object object, Class<?> restClientClass) {
for (Method method : restClientClass.getDeclaredMethods()) {
SuccInResponseJSONProperty property = method.getAnnotation(SuccInResponseJSONProperty.class);
if (property == null)
property = restClientClass.getAnnotation(SuccInResponseJSONProperty.class);
val fieldName = method.getName() + MethodGenerator.SuccInResponseJSONProperty;
Obj.setField(restClientImplClass, object, fieldName, property);
}
}
private void setFixedRequestParams(Class<?> restClientImplClass, Object object, Class restClientClass) {
for (Method method : restClientClass.getDeclaredMethods()) {
val mappings = createFixedRequestParams(method, restClientClass);
val fieldName = method.getName() + MethodGenerator.FixedRequestParams;
Obj.setField(restClientImplClass, object, fieldName, mappings);
}
}
private Map<String, Object> createFixedRequestParams(Method method, Class<?> restClientClass) {
Map<String, Object> map = Maps.newHashMap();
putRequestParams(map, restClientClass);
putRequestParams(map, method);
return Collections.unmodifiableMap(map);
}
private void putRequestParams(Map<String, Object> map, AnnotatedElement annotatedElement) {
// 按声明顺序来添加固定请求参数
for (Annotation annotation : annotatedElement.getAnnotations()) {
if (annotation instanceof FixedRequestParam) {
putFixedRequestParam(map, (FixedRequestParam) annotation);
} else if (annotation instanceof FixedRequestParams) {
val params = (FixedRequestParams) annotation;
for (FixedRequestParam paramValue : params.value()) {
putFixedRequestParam(map, paramValue);
}
}
}
}
private void putFixedRequestParam(Map<String, Object> map, FixedRequestParam fixedRequestParam) {
if (fixedRequestParam.clazz() != void.class) {
map.put(fixedRequestParam.name(), fixedRequestParam.clazz());
} else if (StringUtils.isNotEmpty(fixedRequestParam.value())) {
map.put(fixedRequestParam.name(), fixedRequestParam.value());
} else {
throw new RuntimeException("bad config for @FixedRequestParam" + fixedRequestParam + " value or clazz should be assigned");
}
}
private void setStatusMappings(Class<?> restClientImplClass, Object object, Class restClientClass) {
for (Method method : restClientClass.getDeclaredMethods()) {
val mappings = createStatusExceptionMappings(method, restClientClass);
val fieldName = method.getName() + MethodGenerator.StatusExceptionMappings;
Obj.setField(restClientImplClass, object, fieldName, mappings);
}
}
private Map<Integer, Class<? extends Throwable>> createStatusExceptionMappings(Method method, Class<?> restClientClass) {
Map<Integer, Class<? extends Throwable>> statusExceptionMappings = Maps.newHashMap();
addStatusExceptionMapppings(method, statusExceptionMappings, restClientClass.getAnnotation(RespStatusMappings.class));
addStatusExceptionMapppings(method, statusExceptionMappings, method.getAnnotation(RespStatusMappings.class));
return Collections.unmodifiableMap(statusExceptionMappings);
}
private void addStatusExceptionMapppings(Method method, Map<Integer, Class<? extends Throwable>> statusExceptionMappings, RespStatusMappings respStatusMappings) {
if (respStatusMappings == null) return;
for (RespStatusMapping respStatusMapping : respStatusMappings.value()) {
Class<? extends Throwable> exceptionClass = respStatusMapping.exception();
checkMethodException(method, exceptionClass);
statusExceptionMappings.put(respStatusMapping.status(), exceptionClass);
}
}
private void checkMethodException(Method method, Class<? extends Throwable> exceptionClass) {
if (RuntimeException.class.isAssignableFrom(exceptionClass)) return;
// checked exception should be declared
for (Class<?> declaredExceptionType : method.getExceptionTypes()) {
if (declaredExceptionType == exceptionClass) return;
}
throw new RuntimeException(exceptionClass + " is checked exception and should be declared on the method " + method);
}
private void setBaseUrlProvider(Class<?> restClientImplClass, Object object, Class restClientClass, ApplicationContext appContext) {
val provider = createBaseUrlProvider(restClientClass, appContext);
Obj.setField(restClientImplClass, object, MethodGenerator.baseUrlProvider, provider);
}
private void setBasicAuthProvider(Class<?> restClientImplClass, Object object, Class restClientClass, ApplicationContext appContext) {
val provider = createBasicAuthProvider(restClientClass, appContext);
Obj.setField(restClientImplClass, object, MethodGenerator.basicAuthProvider, provider);
}
private void setSignProvider(Class<?> restClientImplClass, Object object, Class restClientClass, ApplicationContext appContext) {
val provider = createSignProvider(restClientClass, appContext);
Obj.setField(restClientImplClass, object, MethodGenerator.signProvider, provider);
}
private SignProvider createSignProvider(Class<?> restClientClass, ApplicationContext appContext) {
val restClientEnabled = restClientClass.getAnnotation(SpringRestClientEnabled.class);
val signProviderClass = restClientEnabled.signProvider();
SignProvider bean = Obj.getBean(appContext, signProviderClass);
if (bean != null) return bean;
if (signProviderClass.isInterface()) return null;
try {
return signProviderClass.newInstance();
} catch (Exception e) {
throw new RuntimeException("signProvider configuration error for api " + restClientClass, e);
}
}
private BaseUrlProvider createBaseUrlProvider(Class<?> restClientClass, ApplicationContext appContext) {
val restClientEnabled = restClientClass.getAnnotation(SpringRestClientEnabled.class);
String baseUrl = restClientEnabled.baseUrl();
if (!Strings.isNullOrEmpty(baseUrl))
return new FixedBaseUrlProvider(baseUrl);
val providerClass = restClientEnabled.baseUrlProvider();
BaseUrlProvider bean = Obj.getBean(appContext, providerClass);
if (bean != null) return bean;
if (providerClass.isInterface()) {
throw new RuntimeException("base url should be configured for api " + restClientClass);
}
return Obj.createObject(providerClass, restClientClass);
}
private BasicAuthProvider createBasicAuthProvider(Class<?> restClientClass, ApplicationContext appContext) {
val basicAuth = restClientClass.getAnnotation(BasicAuth.class);
if (basicAuth == null) return null;
val providerClass = basicAuth.basicAuthProvider();
BasicAuthProvider bean;
if (providerClass == BasicAuthProvider.class) {
bean = new DefaultBasicAuthProvider(basicAuth.username(), basicAuth.password());
} else {
bean = Obj.getBean(appContext, providerClass);
}
if (bean != null) return bean;
if (providerClass.isInterface()) {
throw new RuntimeException("basicAuthProvider should be properly configured for api " + restClientClass);
}
return Obj.createObject(providerClass, restClientClass);
}
}