package eu.dnetlib.iis.common; import java.io.IOException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.direct.Session; import net.schmizz.sshj.connection.channel.direct.Session.Command; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.verification.PromiscuousVerifier; /** * Class used for executing commands on remote machine through ssh protocol. * * @author madryk * */ public class SshSimpleConnection { private final static Logger log = LoggerFactory.getLogger(SshSimpleConnection.class); private final static int SSH_EXEC_TIMEOUT_IN_SEC = 5; private final static int SSH_MAX_RETRY_COUNT = 5; private final static int SSH_RETRY_COOLDOWN_IN_SEC = 10; private final SSHClient sshClient = new SSHClient(); //------------------------ LOGIC -------------------------- /** * Executes command on remote machine through ssh.<br/> * Internally uses {@link #execute(String, boolean)} with * parameter throwExceptionOnCommandError set to true */ public Command execute(String command) { return execute(command, true); } /** * Executes command on remote machine through ssh. * * @param command - command that will be executed on remote machine * @param throwExceptionOnCommandError - if true then any error in executing * command on remote machine will cause this method to throw exception * @return command execution results */ public Command execute(String command, boolean throwExceptionOnCommandError) { Command cmd = null; try { cmd = executeCommandWithRetries(command); if (throwExceptionOnCommandError && cmd.getExitStatus() != 0) { throw new RuntimeException("Error executing command: " + command + "\n" + SshExecUtils.readCommandError(cmd)); } } catch (IOException e) { throw new RuntimeException("Error in communication with remote machine: " + e.getMessage(), e); } return cmd; } /** * Downloads file (or directory) from remote to local machine * * @param remotePath - path on remote machine to file to download * @param localPath - path on local machine where file should be downloaded */ public void download(String remotePath, String localPath) { try { sshClient.newSCPFileTransfer().download(remotePath, localPath); } catch (IOException e) { throw new RuntimeException("Error in downloading file", e); } } //------------------------ PACKAGE-PRIVATE -------------------------- /** * Opens ssh connection. * Method must be executed before any attempt to read from remote host. */ void openConnection(String remoteHost, int sshPort, String remoteUser) { try { sshClient.addHostKeyVerifier(new PromiscuousVerifier()); sshClient.connect(remoteHost, sshPort); sshClient.authPublickey(remoteUser); } catch (IOException e) { throw new RuntimeException("Error in opening ssh connection", e); } } /** * Closes ssh connection * After executing this method any attempt to read from remote host * will fail until connection will be open again (see {@link #openConnection()}). */ void closeConnection() { try { sshClient.close(); } catch (IOException e) { throw new RuntimeException("Error in closing ssh connection", e); } } //------------------------ PRIVATE -------------------------- private Command executeCommandWithRetries(String command) throws TransportException, ConnectionException { int currentRetryCount = 0; Command cmd = null; while(true) { cmd = executeCommand(command); if (cmd == null) { if (currentRetryCount >= SSH_MAX_RETRY_COUNT) { throw new RuntimeException("Retry limit exceeded when trying to execute ssh command."); } ++currentRetryCount; log.debug("Timeout when trying to execute ssh command. Will try again in " + SSH_RETRY_COOLDOWN_IN_SEC + " seconds"); try { Thread.sleep(1000 * SSH_RETRY_COOLDOWN_IN_SEC); } catch (InterruptedException e1) { throw new RuntimeException(e1); } continue; } break; } return cmd; } private Command executeCommand(String command) throws TransportException, ConnectionException { Session sshSession = null; Command cmd = null; try { sshSession = sshClient.startSession(); cmd = sshSession.exec(command); cmd.join(SSH_EXEC_TIMEOUT_IN_SEC, TimeUnit.SECONDS); } catch (ConnectionException e) { if (e.getCause() instanceof TimeoutException) { return null; } throw e; } finally { if (sshSession != null) { sshSession.close(); } } return cmd; } }