package org.simpleflatmapper.test.junit; import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.regex.Pattern; public class LibrarySetsClassLoader extends URLClassLoader { private final ClassLoader classLoader; private final String[] libraries; private final Pattern[] excludes; public LibrarySetsClassLoader(ClassLoader classLoader, String[] libraries, Class<?>[] includes, Pattern[] excludes) throws IOException { super(getUrls(libraries, includes), Integer.class.getClassLoader()); this.classLoader = classLoader; this.libraries = libraries; this.excludes = excludes; } private static URL[] getUrls(String[] libraries, Class<?>[] includes) throws IOException { List<URL> urls = new ArrayList<URL>(); for(int i = 0; i < libraries.length; i++) { urls.add(new URL(libraries[i])); } for(Class<?> includeClass : includes) { URL url = findUrl(includeClass, includeClass.getClassLoader()); if (!urls.contains(url)) { urls.add(url); } } return urls.toArray(new URL[0]); } public static URL findUrl(Class<?> includeClass, ClassLoader classLoader) throws MalformedURLException { String classResource = includeClass.getName().replace(".", "/") + ".class"; URL urlClass = classLoader.getResource(classResource); if (urlClass != null) { String url = urlClass.toString(); if (url.startsWith("jar:")) { int bang = url.indexOf('!'); if (bang != -1) { String jarUrl = url.substring("jar:".length(), bang); System.out.println(includeClass + " => " + urlClass + " => jarUrl = " + jarUrl); return new URL(jarUrl); } } else if (url.startsWith("file:")) { if (url.endsWith(classResource)) { String directoryUrl = url.substring(0, url.length() - classResource.length()); System.out.println(includeClass + " => directoryUrl = " + directoryUrl); return new URL(directoryUrl); } } } throw new IllegalArgumentException("Could not find url for " + includeClass + " " + urlClass); } @Override public Class<?> loadClass(String name) throws ClassNotFoundException { if (isExcluded(name)) { return classLoader.loadClass(name); } try { return super.loadClass(name); } catch (ClassNotFoundException e) { return classLoader.loadClass(name); } } private boolean isExcluded(String name) { for(Pattern p : excludes) { if(p.matcher(name).find()) { return true; } } return false; } @Override public String toString() { return "LibrarySetsClassLoader{" + "libraries=" + Arrays.toString(libraries) + '}'; } }