X-Git-Url: https://gerrit.fd.io/r/gitweb?a=blobdiff_plain;f=resources%2Flibraries%2Fpython%2Fssh.py;h=f59bd02e2542fed7e4853cd0d3556e69ae3bbc29;hb=7436d8bdf60bca9b80fb76781e1f709bbcd435da;hp=8acb54b3824f7dfed8c50e6d7c2d3abe832de272;hpb=145e2101dc39a701286bc51ddaa0dbd1a4cf022f;p=csit.git diff --git a/resources/libraries/python/ssh.py b/resources/libraries/python/ssh.py index 8acb54b382..f59bd02e25 100644 --- a/resources/libraries/python/ssh.py +++ b/resources/libraries/python/ssh.py @@ -29,6 +29,11 @@ __all__ = ["exec_cmd", "exec_cmd_no_error"] # TODO: load priv key +class SSHTimeout(Exception): + """This exception is raised when a timeout occurs.""" + pass + + class SSH(object): """Contains methods for managing and using SSH connections.""" @@ -51,40 +56,47 @@ class SSH(object): return hash(frozenset([node['host'], node['port']])) - def connect(self, node): + def connect(self, node, attempts=5): """Connect to node prior to running exec_command or scp. If there already is a connection to the node, this method reuses it. """ - self._node = node - node_hash = self._node_hash(node) - if node_hash in SSH.__existing_connections: - self._ssh = SSH.__existing_connections[node_hash] - logger.debug('reusing ssh: {0}'.format(self._ssh)) - else: - start = time() - pkey = None - if 'priv_key' in node: - pkey = RSAKey.from_private_key( - StringIO.StringIO(node['priv_key'])) - - self._ssh = paramiko.SSHClient() - self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - self._ssh.connect(node['host'], username=node['username'], - password=node.get('password'), pkey=pkey, - port=node['port']) - - self._ssh.get_transport().set_keepalive(10) - - SSH.__existing_connections[node_hash] = self._ssh - - logger.trace('connect took {} seconds'.format(time() - start)) - logger.debug('new ssh: {0}'.format(self._ssh)) - - logger.debug('Connect peer: {0}'. - format(self._ssh.get_transport().getpeername())) - logger.debug('Connections: {0}'.format(str(SSH.__existing_connections))) + try: + self._node = node + node_hash = self._node_hash(node) + if node_hash in SSH.__existing_connections: + self._ssh = SSH.__existing_connections[node_hash] + logger.debug('reusing ssh: {0}'.format(self._ssh)) + else: + start = time() + pkey = None + if 'priv_key' in node: + pkey = RSAKey.from_private_key( + StringIO.StringIO(node['priv_key'])) + + self._ssh = paramiko.SSHClient() + self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + self._ssh.connect(node['host'], username=node['username'], + password=node.get('password'), pkey=pkey, + port=node['port']) + + self._ssh.get_transport().set_keepalive(10) + + SSH.__existing_connections[node_hash] = self._ssh + + logger.trace('connect took {} seconds'.format(time() - start)) + logger.debug('new ssh: {0}'.format(self._ssh)) + + logger.debug('Connect peer: {0}'. + format(self._ssh.get_transport().getpeername())) + logger.debug('Connections: {0}'. + format(str(SSH.__existing_connections))) + except: + if attempts > 0: + self._reconnect(attempts-1) + else: + raise def disconnect(self, node): """Close SSH connection to the node. @@ -99,12 +111,12 @@ class SSH(object): ssh = SSH.__existing_connections.pop(node_hash) ssh.close() - def _reconnect(self): + def _reconnect(self, attempts=0): """Close the SSH connection and open it again.""" node = self._node self.disconnect(node) - self.connect(node) + self.connect(node, attempts) logger.debug('Reconnecting peer done: {}'. format(self._ssh.get_transport().getpeername())) @@ -118,7 +130,7 @@ class SSH(object): :type timeout: int :return return_code, stdout, stderr :rtype: tuple(int, str, str) - :raise socket.timeout: If command is not finished in timeout time. + :raise SSHTimeout: If command is not finished in timeout time. """ start = time() stdout = StringIO.StringIO() @@ -144,7 +156,7 @@ class SSH(object): stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF)) if time() - start > timeout: - raise socket.timeout( + raise SSHTimeout( 'Timeout exception.\n' 'Current contents of stdout buffer: {0}\n' 'Current contents of stderr buffer: {1}\n' @@ -195,7 +207,7 @@ class SSH(object): command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input) return self.exec_command(command, timeout) - def interactive_terminal_open(self, time_out=10): + def interactive_terminal_open(self, time_out=30): """Open interactive terminal on a new channel on the connected Node. :param time_out: Timeout in seconds. @@ -215,7 +227,7 @@ class SSH(object): chan.set_combine_stderr(True) buf = '' - while not buf.endswith(':~$ '): + while not buf.endswith((":~$ ", "~]$ ")): try: chunk = chan.recv(self.__MAX_RECV_BUF) if not chunk: @@ -228,8 +240,7 @@ class SSH(object): raise Exception('Socket timeout: {0}'.format(buf)) return chan - def interactive_terminal_exec_command(self, chan, cmd, prompt, - time_out=30): + def interactive_terminal_exec_command(self, chan, cmd, prompt): """Execute command on interactive terminal. interactive_terminal_open() method has to be called first! @@ -238,7 +249,6 @@ class SSH(object): :param cmd: Command to be executed. :param prompt: Command prompt, sequence of characters used to indicate readiness to accept commands. - :param time_out: Timeout in seconds. :return: Command output. .. warning:: Interruptingcow is used here, and it uses @@ -262,7 +272,9 @@ class SSH(object): except socket.timeout: raise Exception('Socket timeout: {0}'.format(buf)) tmp = buf.replace(cmd.replace('\n', ''), '') - return tmp.replace(prompt, '') + for item in prompt: + tmp.replace(item, '') + return tmp @staticmethod def interactive_terminal_close(chan): @@ -280,7 +292,7 @@ class SSH(object): logger.trace('SCP {0} to {1}:{2}'.format( local_path, self._ssh.get_transport().getpeername(), remote_path)) # SCPCLient takes a paramiko transport as its only argument - scp = SCPClient(self._ssh.get_transport()) + scp = SCPClient(self._ssh.get_transport(), socket_timeout=10) start = time() scp.put(local_path, remote_path) scp.close() @@ -303,7 +315,7 @@ def exec_cmd(node, cmd, timeout=600, sudo=False): ssh = SSH() try: ssh.connect(node) - except Exception as err: + except SSHException as err: logger.error("Failed to connect to node" + str(err)) return None, None, None @@ -313,7 +325,7 @@ def exec_cmd(node, cmd, timeout=600, sudo=False): else: (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd, timeout=timeout) - except Exception as err: + except SSHException as err: logger.error(err) return None, None, None