/*
* Copyright (C) 2015 Noorq, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.noorq.casser.support;
import java.io.File;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.net.URLConnection;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class PackageUtil {
private static final Logger log = LoggerFactory.getLogger(PackageUtil.class);
public static final String JAR_URL_SEPARATOR = "!/";
private static void doFetchInPath(Set<Class<?>> classes, File directory,
String packageName, ClassLoader classLoader)
throws ClassNotFoundException {
File[] dirContents = directory.listFiles();
if (dirContents == null) {
throw new ClassNotFoundException("invalid directory "
+ directory.getAbsolutePath());
}
for (File file : dirContents) {
String fileName = file.getName();
if (file.isDirectory()) {
doFetchInPath(classes, file, packageName + "." + fileName,
classLoader);
} else if (fileName.endsWith(".class")) {
classes.add(classLoader.loadClass(packageName + '.'
+ fileName.substring(0, fileName.length() - 6)));
}
}
}
public static Set<Class<?>> getClasses(String packagePath)
throws ClassNotFoundException {
Set<Class<?>> classes = new HashSet<Class<?>>();
ClassLoader classLoader = Thread.currentThread()
.getContextClassLoader();
if (classLoader == null) {
throw new ClassNotFoundException(
"class loader not found for current thread");
}
Enumeration<URL> resources = null;
try {
resources = classLoader.getResources(packagePath.replace('.', '/'));
} catch (IOException e) {
throw new ClassNotFoundException("invalid package " + packagePath,
e);
}
while (resources.hasMoreElements()) {
URL url = resources.nextElement();
if (url == null) {
throw new ClassNotFoundException(packagePath
+ " - package not found");
}
String dirPath = fastReplace(url.getFile(), "%20", " ");
int jarSeparator = dirPath.indexOf(JAR_URL_SEPARATOR);
if (jarSeparator == -1) {
File directory = new File(dirPath);
if (!directory.exists()) {
throw new ClassNotFoundException(packagePath
+ " - invalid package");
}
doFetchInPath(classes, directory, packagePath, classLoader);
} else {
String rootEntry = dirPath.substring(jarSeparator
+ JAR_URL_SEPARATOR.length());
if (!"".equals(rootEntry) && !rootEntry.endsWith("/")) {
rootEntry = rootEntry + "/";
}
try {
JarFile jarFile = null;
URLConnection con = url.openConnection();
if (con instanceof JarURLConnection) {
JarURLConnection jarCon = (JarURLConnection) con;
jarCon.setUseCaches(false);
jarFile = jarCon.getJarFile();
} else {
String jarName = dirPath.substring(0, jarSeparator);
jarName = fastReplace(jarName, " ", "%20");
jarFile = new JarFile(jarName);
}
for (Enumeration<JarEntry> entries = jarFile.entries(); entries
.hasMoreElements();) {
JarEntry entry = entries.nextElement();
String fileName = entry.getName();
if (fileName.startsWith(rootEntry)
&& fileName.endsWith(".class")) {
fileName = fileName.replace('/', '.');
try {
classes.add(classLoader.loadClass(fileName
.substring(0, fileName.length() - 6)));
} catch (ClassNotFoundException e) {
log.error("class load fail", e);
}
}
}
} catch (IOException e) {
throw new ClassNotFoundException("jar fail", e);
}
}
}
return classes;
}
public static String fastReplace(String inString, String oldPattern,
String newPattern) {
if (inString == null) {
return null;
}
if (oldPattern == null || newPattern == null) {
return inString;
}
StringBuilder sbuf = new StringBuilder();
int pos = 0;
int index = inString.indexOf(oldPattern);
int patLen = oldPattern.length();
while (index >= 0) {
sbuf.append(inString.substring(pos, index));
sbuf.append(newPattern);
pos = index + patLen;
index = inString.indexOf(oldPattern, pos);
}
sbuf.append(inString.substring(pos));
return sbuf.toString();
}
}