SSH connect use port specified in node dict
[csit.git] / resources / libraries / python / ssh.py
index 72e41c7..e121519 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import paramiko
+from paramiko import RSAKey
+import StringIO
 from scp import SCPClient
 from time import time
 from robot.api import logger
 from interruptingcow import timeout
 from robot.utils.asserts import assert_equal, assert_not_equal
+from socket import timeout as socket_timeout
 
 __all__ = ["exec_cmd", "exec_cmd_no_error"]
 
@@ -28,9 +31,7 @@ class SSH(object):
     __existing_connections = {}
 
     def __init__(self):
-        self._ssh = paramiko.SSHClient()
-        self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        self._hostname = None
+        pass
 
     def _node_hash(self, node):
         return hash(frozenset([node['host'], node['port']]))
@@ -40,50 +41,78 @@ class SSH(object):
 
         If there already is a connection to the node, this method reuses it.
         """
-        self._hostname = node['host']
         node_hash = self._node_hash(node)
-        if node_hash in self.__existing_connections:
-            self._ssh = self.__existing_connections[node_hash]
+        if node_hash in SSH.__existing_connections:
+            self._ssh = SSH.__existing_connections[node_hash]
+            logger.debug('reusing ssh: {0}'.format(self._ssh))
         else:
             start = time()
+            pkey = None
+            if 'priv_key' in node:
+                pkey = RSAKey.from_private_key(
+                        StringIO.StringIO(node['priv_key']))
+
+            self._ssh = paramiko.SSHClient()
+            self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
             self._ssh.connect(node['host'], username=node['username'],
-                              password=node['password'])
-            self.__existing_connections[node_hash] = self._ssh
+                              password=node.get('password'), pkey=pkey,
+                              port=node['port'])
+
+            SSH.__existing_connections[node_hash] = self._ssh
+
             logger.trace('connect took {} seconds'.format(time() - start))
+            logger.debug('new ssh: {0}'.format(self._ssh))
+
+        logger.debug('Connect peer: {0}'.
+                format(self._ssh.get_transport().getpeername()))
+        logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
 
     def exec_command(self, cmd, timeout=10):
         """Execute SSH command on a new channel on the connected Node.
 
         Returns (return_code, stdout, stderr).
         """
-        logger.trace('exec_command on {0}: {1}'.format(self._hostname, cmd))
+        logger.trace('exec_command on {0}: {1}'
+                     .format(self._ssh.get_transport().getpeername(), cmd))
         start = time()
         chan = self._ssh.get_transport().open_session()
         if timeout is not None:
             chan.settimeout(int(timeout))
         chan.exec_command(cmd)
         end = time()
-        logger.trace('exec_command "{0}" on {1} took {2} seconds'.format(
-            cmd, self._hostname, end-start))
+        logger.trace('exec_command on {0} took {1} seconds'.format(
+            self._ssh.get_transport().getpeername(), end-start))
 
         stdout = ""
         while True:
-            buf = chan.recv(self.__MAX_RECV_BUF)
-            stdout += buf
-            if not buf:
+            try:
+                buf = chan.recv(self.__MAX_RECV_BUF)
+                stdout += buf
+                if not buf:
+                    break
+            except socket_timeout:
+                logger.trace('Channels stdout timeout occurred')
                 break
 
         stderr = ""
         while True:
-            buf = chan.recv_stderr(self.__MAX_RECV_BUF)
-            stderr += buf
-            if not buf:
+            try:
+                buf = chan.recv_stderr(self.__MAX_RECV_BUF)
+                stderr += buf
+                if not buf:
+                    break
+            except socket_timeout:
+                logger.trace('Channels stderr timeout occurred')
                 break
 
         return_code = chan.recv_exit_status()
         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
 
-        return (return_code, stdout, stderr)
+        logger.trace('return RC {}'.format(return_code))
+        logger.trace('return STDOUT {}'.format(stdout))
+        logger.trace('return STDERR {}'.format(stderr))
+        return return_code, stdout, stderr
 
     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
         """Execute SSH command with sudo on a new channel on the connected Node.
@@ -99,9 +128,9 @@ class SSH(object):
             >>> ssh = SSH()
             >>> ssh.connect(node)
             >>> #Execute command without input (sudo -S cmd)
-            >>> ssh.exex_command_sudo("ifconfig eth0 down")
+            >>> ssh.exec_command_sudo("ifconfig eth0 down")
             >>> #Execute command with input (sudo -S cmd <<< "input")
-            >>> ssh.exex_command_sudo("vpe_api_test", "dump_interface_table")
+            >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
         """
         if cmd_input is None:
             command = 'sudo -S {c}'.format(c=cmd)
@@ -182,7 +211,7 @@ class SSH(object):
         connect() method has to be called first!
         """
         logger.trace('SCP {0} to {1}:{2}'.format(
-            local_path, self._hostname, remote_path))
+            local_path, self._ssh.get_transport().getpeername(), remote_path))
         # SCPCLient takes a paramiko transport as its only argument
         scp = SCPClient(self._ssh.get_transport())
         start = time()
@@ -229,7 +258,7 @@ def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
 
     Returns (stdout, stderr).
     """
-    (rc, stdout, stderr) = exec_cmd(node,cmd, timeout=timeout, sudo=sudo)
+    (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
                  format(cmd, stderr))
     return (stdout, stderr)