/* * Copyright 2014 Avanza Bank AB * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.avanza.astrix.modules; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.Objects; import java.util.Set; import java.util.Stack; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArrayList; import javax.annotation.PreDestroy; import org.slf4j.Logger; import org.slf4j.LoggerFactory; class ModuleManager implements Modules { private final Logger log = LoggerFactory.getLogger(ModuleManager.class); private final ConcurrentMap<Class<?>, List<ModuleInstance>> moduleByExportedType = new ConcurrentHashMap<>(); private final List<ModuleInstance> moduleInstances = new CopyOnWriteArrayList<>(); private final ModuleInstancePostProcessors globalModuleBeanPostProcessors = new ModuleInstancePostProcessors(); void register(Module module) { ModuleInstance moduleInstance = new ModuleInstance(module, globalModuleBeanPostProcessors); moduleInstances.add(moduleInstance); for (Class<?> exportedBean : moduleInstance.getExports()) { getExportingModules(exportedBean).add(moduleInstance); } } private List<ModuleInstance> getExportingModules(Class<?> exportedType) { List<ModuleInstance> modules = moduleByExportedType.get(exportedType); if (modules == null) { modules = new LinkedList<ModuleManager.ModuleInstance>(); moduleByExportedType.put(exportedType, modules); } return modules; } @Override public <T> T getInstance(Class<T> type) { return new CircularModuleDependenciesAwareCreation().get(type); } @Override public <T> Collection<T> getAll(Class<T> type) { return new CircularModuleDependenciesAwareCreation().getAll(type); } private List<String> getModulesNames(List<ModuleInstance> modules) { List<String> result = new ArrayList<String>(modules.size()); for (ModuleInstance instance : modules) { result.add(instance.getName()); } return result; } public static class ModuleInstance { private final ModuleInjector moduleInjector; private final String moduleName; public ModuleInstance(Module module, ModuleInstancePostProcessor moduleInstancePostProcessor) { this.moduleName = getModuleName(module); this.moduleInjector = new ModuleInjector(moduleName); this.moduleInjector.registerBeanPostProcessor(moduleInstancePostProcessor); module.prepare(new ModuleContext() { @Override public <T> void bind(Class<T> type, Class<? extends T> providerType) { moduleInjector.bind(type, providerType); } @Override public <T> void bind(Class<T> type, T provider) { moduleInjector.bind(type, provider); } @Override public void export(Class<?> type) { if (!type.isInterface()) { throw new IllegalArgumentException(String.format("Its only allowed to export interface types. module=%s exportedType=%s", moduleName, type)); } moduleInjector.addExport(type); } @Override public <T> void importType(final Class<T> type) { if (!type.isInterface()) { throw new IllegalArgumentException(String.format("Its only allowed to interface types. module=%s importedType=%s", moduleName, type)); } moduleInjector.addImport(type); } }); } private String getModuleName(Module module) { return module.name(); } public String getName() { return this.moduleName; } public Set<Class<?>> getExports() { return this.moduleInjector.getExports(); } public <T> T getInstance(final Class<T> type, ImportedDependencies importedDependencies) { return this.moduleInjector.getBean(type, importedDependencies); } public void destroy() { this.moduleInjector.destroy(); } } @Override @PreDestroy public void destroy() { for (ModuleInstance moduleInstance : this.moduleInstances) { moduleInstance.destroy(); } } public static class CreationFrame { private ModuleInstance module; private Class<?> type; public CreationFrame(ModuleInstance module, Class<?> type) { this.module = module; this.type = type; } @Override public int hashCode() { return Objects.hash(module, type); } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; CreationFrame other = (CreationFrame) obj; return Objects.equals(module, other.module) && Objects.equals(type, other.type); } } public class CircularModuleDependenciesAwareCreation { private final Stack<CreationFrame> creationStack = new Stack<>(); public <T> T get(final Class<T> type) { List<ModuleInstance> exportingModules = moduleByExportedType.get(type); if (exportingModules == null) { throw new MissingProvider(type); } if (exportingModules.size() > 1) { log.warn("Type exported by multiple modules. Using first registered provider. Ignoring export. type={} usedModule={} ignoredModules={}", type.getName(), exportingModules.get(0).getName(), getModulesNames(exportingModules.subList(1, exportingModules.size()))); } ModuleInstance exportingModule = exportingModules.get(0); return getInstance(type, exportingModule); } public <T> Collection<T> getAll(Class<T> type) { List<T> result = new ArrayList<>(); List<ModuleInstance> exportingModules = moduleByExportedType.get(type); if (exportingModules == null) { return result; } for (ModuleInstance exportingModule : exportingModules) { result.add(getInstance(type, exportingModule)); } return result; } private <T> T getInstance(final Class<T> type, ModuleInstance exportingModule) { CreationFrame creationFrame = new CreationFrame(exportingModule, type); if (creationStack.contains(creationFrame)) { CircularDependency circularDependency = new CircularDependency(); circularDependency.addToDependencyTrace(type, exportingModule.getName()); throw circularDependency; } creationStack.add(creationFrame); T result = exportingModule.getInstance(type, new ImportedDependencies() { @Override public <D> Collection<D> getAll(Class<D> type) { return CircularModuleDependenciesAwareCreation.this.getAll(type); } @Override public <D> D get(Class<D> type) { return CircularModuleDependenciesAwareCreation.this.get(type); } }); creationStack.pop(); return result; } } public void registerBeanPostProcessor(ModuleInstancePostProcessor beanPostProcessor) { this.globalModuleBeanPostProcessors.add(beanPostProcessor); } }