SSH timeout problem
[csit.git] / resources / libraries / python / ssh.py
index a96a55f..f35b925 100644 (file)
 # limitations under the License.
 
 import StringIO
-from time import time
+from time import time, sleep
 
 import socket
 import paramiko
 from paramiko import RSAKey
+from paramiko.ssh_exception import SSHException
 from scp import SCPClient
 from interruptingcow import timeout
 from robot.api import logger
@@ -34,6 +35,7 @@ class SSH(object):
 
     def __init__(self):
         self._ssh = None
+        self._node = None
 
     @staticmethod
     def _node_hash(node):
@@ -44,6 +46,7 @@ class SSH(object):
 
         If there already is a connection to the node, this method reuses it.
         """
+        self._node = node
         node_hash = self._node_hash(node)
         if node_hash in SSH.__existing_connections:
             self._ssh = SSH.__existing_connections[node_hash]
@@ -79,53 +82,82 @@ class SSH(object):
         """
         node_hash = self._node_hash(node)
         if node_hash in SSH.__existing_connections:
+            logger.debug('Disconnecting peer: {}, {}'.
+                         format(node['host'], node['port']))
             ssh = SSH.__existing_connections.pop(node_hash)
             ssh.close()
 
+    def _reconnect(self):
+        node = self._node
+        self.disconnect(node)
+        self.connect(node)
+        logger.debug('Reconnecting peer done: {}'.
+                     format(self._ssh.get_transport().getpeername()))
+
     def exec_command(self, cmd, timeout=10):
         """Execute SSH command on a new channel on the connected Node.
 
-        Returns (return_code, stdout, stderr).
+        :param cmd: Command to run on the Node.
+        :param timeout: Maximal time in seconds to wait while the command is
+        done. If is None then wait forever.
+        :type cmd: str
+        :type timeout: int
+        :return return_code, stdout, stderr
+        :rtype: tuple(int, str, str)
+        :raise socket.timeout: If command is not finished in timeout time.
         """
+        start = time()
+        stdout = StringIO.StringIO()
+        stderr = StringIO.StringIO()
+        try:
+            chan = self._ssh.get_transport().open_session(timeout=5)
+        except AttributeError:
+            self._reconnect()
+            chan = self._ssh.get_transport().open_session(timeout=5)
+        except SSHException:
+            self._reconnect()
+            chan = self._ssh.get_transport().open_session(timeout=5)
+        chan.settimeout(timeout)
         logger.trace('exec_command on {0}: {1}'
                      .format(self._ssh.get_transport().getpeername(), cmd))
-        start = time()
-        chan = self._ssh.get_transport().open_session()
-        if timeout is not None:
-            chan.settimeout(int(timeout))
+
         chan.exec_command(cmd)
+        while not chan.exit_status_ready() and timeout is not None:
+            if chan.recv_ready():
+                stdout.write(chan.recv(self.__MAX_RECV_BUF))
+
+            if chan.recv_stderr_ready():
+                stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
+
+            if time() - start > timeout:
+                raise socket.timeout(
+                    'Timeout exception.\n'
+                    'Current contents of stdout buffer: {0}\n'
+                    'Current contents of stderr buffer: {1}\n'
+                    .format(stdout.getvalue(), stderr.getvalue())
+                )
+
+            sleep(0.1)
+        return_code = chan.recv_exit_status()
+
+        while chan.recv_ready():
+            stdout.write(chan.recv(self.__MAX_RECV_BUF))
+
+        while chan.recv_stderr_ready():
+            stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
+
         end = time()
         logger.trace('exec_command on {0} took {1} seconds'.format(
             self._ssh.get_transport().getpeername(), end-start))
 
-        stdout = ""
-        try:
-            while True:
-                buf = chan.recv(self.__MAX_RECV_BUF)
-                stdout += buf
-                if not buf:
-                    break
-        except socket.timeout:
-            logger.error('Caught timeout exception, current contents '
-                         'of buffer: {0}'.format(stdout))
-            raise
-
-        stderr = ""
-        while True:
-            buf = chan.recv_stderr(self.__MAX_RECV_BUF)
-            stderr += buf
-            if not buf:
-                break
-
-        return_code = chan.recv_exit_status()
         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
 
         logger.trace('return RC {}'.format(return_code))
-        logger.trace('return STDOUT {}'.format(stdout))
-        logger.trace('return STDERR {}'.format(stderr))
-        return return_code, stdout, stderr
+        logger.trace('return STDOUT {}'.format(stdout.getvalue()))
+        logger.trace('return STDERR {}'.format(stderr.getvalue()))
+        return return_code, stdout.getvalue(), stderr.getvalue()
 
-    def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
+    def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
         """Execute SSH command with sudo on a new channel on the connected Node.
 
         :param cmd: Command to be executed.
@@ -234,7 +266,7 @@ class SSH(object):
         logger.trace('SCP took {0} seconds'.format(end-start))
 
 
-def exec_cmd(node, cmd, timeout=None, sudo=False):
+def exec_cmd(node, cmd, timeout=600, sudo=False):
     """Convenience function to ssh/exec/return rc, out & err.
 
     Returns (rc, stdout, stderr).
@@ -251,7 +283,7 @@ def exec_cmd(node, cmd, timeout=None, sudo=False):
         ssh.connect(node)
     except Exception, e:
         logger.error("Failed to connect to node" + str(e))
-        return None
+        return None, None, None
 
     try:
         if not sudo:
@@ -261,12 +293,12 @@ def exec_cmd(node, cmd, timeout=None, sudo=False):
                                                                timeout=timeout)
     except Exception, e:
         logger.error(e)
-        return None
+        return None, None, None
 
     return ret_code, stdout, stderr
 
 
-def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
+def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
     """Convenience function to ssh/exec/return out & err.
 
     Verifies that return code is zero.