package org.batfish.common.plugin; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.FileInputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Modifier; import java.net.URL; import java.net.URLClassLoader; import java.nio.file.FileVisitResult; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.SimpleFileVisitor; import java.nio.file.attribute.BasicFileAttributes; import java.util.ArrayList; import java.util.Arrays; import java.util.Enumeration; import java.util.List; import java.util.jar.JarEntry; import java.util.jar.JarFile; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import org.apache.commons.io.IOUtils; import org.batfish.common.BatfishException; import org.batfish.common.BfConsts; import org.batfish.common.util.BatfishObjectInputStream; import com.thoughtworks.xstream.XStream; import com.thoughtworks.xstream.io.xml.DomDriver; public abstract class PluginConsumer implements IPluginConsumer { private static final String CLASS_EXTENSION = ".class"; /** * A byte-array containing the first 4 bytes of the header for a file that is * the output of java serialization */ private static final byte[] JAVA_SERIALIZED_OBJECT_HEADER = { (byte) 0xac, (byte) 0xed, (byte) 0x00, (byte) 0x05 }; private ClassLoader _currentClassLoader; private final List<Path> _pluginDirs; private final boolean _serializeToText; public PluginConsumer(boolean serializeToText, List<Path> pluginDirs) { _currentClassLoader = getClass().getClassLoader(); _serializeToText = serializeToText; _pluginDirs = new ArrayList<>(pluginDirs); String questionPluginDirStr = System .getProperty(BfConsts.PROP_QUESTION_PLUGIN_DIR); // try to place question plugin first if system property is defined if (questionPluginDirStr != null) { Path questionPluginDir = Paths.get(questionPluginDirStr); if (_pluginDirs.isEmpty() || !_pluginDirs.get(0).equals(questionPluginDir)) { _pluginDirs.add(0, questionPluginDir); } } return; } protected <S extends Serializable> S deserializeObject(byte[] data, Class<S> outputClass) { try { boolean isJavaSerializationData = isJavaSerializationData(data); ByteArrayInputStream bais = new ByteArrayInputStream(data); ObjectInputStream ois; if (!isJavaSerializationData) { XStream xstream = new XStream(new DomDriver("UTF-8")); xstream.setClassLoader(_currentClassLoader); ois = xstream.createObjectInputStream(bais); } else { ois = new BatfishObjectInputStream(bais, _currentClassLoader); } Object o = ois.readObject(); ois.close(); return outputClass.cast(o); } catch (IOException | ClassNotFoundException | ClassCastException e) { throw new BatfishException("Failed to deserialize object of type '" + outputClass.getCanonicalName() + "' from data", e); } } public <S extends Serializable> S deserializeObject(Path inputFile, Class<S> outputClass) { byte[] data = fromGzipFile(inputFile); return deserializeObject(data, outputClass); } protected byte[] fromGzipFile(Path inputFile) { try { FileInputStream fis = new FileInputStream(inputFile.toFile()); GZIPInputStream gis = new GZIPInputStream(fis); byte[] data = IOUtils.toByteArray(gis); return data; } catch (IOException e) { throw new BatfishException( "Failed to gunzip file: " + inputFile.toString(), e); } } public ClassLoader getCurrentClassLoader() { return _currentClassLoader; } public abstract PluginClientType getType(); private boolean isJavaSerializationData(byte[] fileBytes) { int headerLength = JAVA_SERIALIZED_OBJECT_HEADER.length; byte[] headerBytes = new byte[headerLength]; for (int i = 0; i < headerLength; i++) { headerBytes[i] = fileBytes[i]; } return Arrays.equals(headerBytes, JAVA_SERIALIZED_OBJECT_HEADER); } private void loadPluginJar(Path path) { /* * Adapted from * http://stackoverflow.com/questions/11016092/how-to-load-classes-at- * runtime-from-a-folder-or-jar Retrieved: 2016-08-31 Original Authors: * Kevin Crain http://stackoverflow.com/users/2688755/kevin-crain * Apfelsaft http://stackoverflow.com/users/1447641/apfelsaft License: * https://creativecommons.org/licenses/by-sa/3.0/ */ String pathString = path.toString(); if (pathString.endsWith(".jar")) { try { URL[] urls = { new URL("jar:file:" + pathString + "!/") }; URLClassLoader cl = URLClassLoader.newInstance(urls, _currentClassLoader); _currentClassLoader = cl; Thread.currentThread().setContextClassLoader(cl); JarFile jar = new JarFile(path.toFile()); Enumeration<JarEntry> entries = jar.entries(); while (entries.hasMoreElements()) { JarEntry element = entries.nextElement(); String name = element.getName(); if (element.isDirectory() || !name.endsWith(CLASS_EXTENSION)) { continue; } String className = name .substring(0, name.length() - CLASS_EXTENSION.length()) .replace("/", "."); try { cl.loadClass(className); Class<?> pluginClass = Class.forName(className, true, cl); if (!Plugin.class.isAssignableFrom(pluginClass) || Modifier.isAbstract(pluginClass.getModifiers())) { continue; } Constructor<?> pluginConstructor; try { pluginConstructor = pluginClass.getConstructor(); } catch (NoSuchMethodException | SecurityException e) { throw new BatfishException( "Could not find default constructor in plugin: '" + className + "'", e); } Object pluginObj; try { pluginObj = pluginConstructor.newInstance(); } catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException e) { throw new BatfishException("Could not instantiate plugin '" + className + "' from constructor", e); } Plugin plugin = (Plugin) pluginObj; plugin.initialize(this); } catch (ClassNotFoundException e) { jar.close(); throw new BatfishException( "Unexpected error loading classes from jar", e); } } jar.close(); } catch (IOException e) { throw new BatfishException( "Error loading plugin jar: '" + path.toString() + "'", e); } } } protected final void loadPlugins() { for (Path pluginDir : _pluginDirs) { if (Files.exists(pluginDir)) { try { Files.walkFileTree(pluginDir, new SimpleFileVisitor<Path>() { @Override public FileVisitResult visitFile(Path path, BasicFileAttributes attrs) throws IOException { loadPluginJar(path); return FileVisitResult.CONTINUE; } }); } catch (IOException e) { throw new BatfishException("Error walking through plugin dir: '" + pluginDir.toString() + "'", e); } } } } public void serializeObject(Serializable object, Path outputFile) { try { byte[] data = toGzipData(object); Files.write(outputFile, data); } catch (IOException e) { throw new BatfishException( "Failed to serialize object to gzip output file: " + outputFile.toString(), e); } } protected byte[] toGzipData(Serializable object) { try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); GZIPOutputStream gos = new GZIPOutputStream(baos); ObjectOutputStream oos; if (_serializeToText) { XStream xstream = new XStream(new DomDriver("UTF-8")); oos = xstream.createObjectOutputStream(gos); } else { oos = new ObjectOutputStream(gos); } oos.writeObject(object); oos.close(); byte[] data = baos.toByteArray(); return data; } catch (IOException e) { throw new BatfishException("Failed to convert object to gzip data", e); } } }