-# Copyright (c) 2018 Cisco and/or its affiliates.
+# Copyright (c) 2019 Cisco and/or its affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
"""Library for SSH connection management."""
+
+import socket
import StringIO
+
from time import time, sleep
-import socket
-import paramiko
-from paramiko import RSAKey
-from paramiko.ssh_exception import SSHException
-from scp import SCPClient
+from paramiko import RSAKey, SSHClient, AutoAddPolicy
+from paramiko.ssh_exception import SSHException, NoValidConnectionsError
from robot.api import logger
-from robot.utils.asserts import assert_equal
+from scp import SCPClient, SCPException
+
__all__ = ["exec_cmd", "exec_cmd_no_error"]
"""Connect to node prior to running exec_command or scp.
If there already is a connection to the node, this method reuses it.
+
+ :param node: Node in topology.
+ :param attempts: Number of reconnect attempts.
+ :type node: dict
+ :type attempts: int
+ :raises IOError: If cannot connect to host.
"""
- try:
- self._node = node
- node_hash = self._node_hash(node)
- if node_hash in SSH.__existing_connections:
- self._ssh = SSH.__existing_connections[node_hash]
- logger.debug('reusing ssh: {0}'.format(self._ssh))
+ self._node = node
+ node_hash = self._node_hash(node)
+ if node_hash in SSH.__existing_connections:
+ self._ssh = SSH.__existing_connections[node_hash]
+ if self._ssh.get_transport().is_active():
+ logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
else:
+ if attempts > 0:
+ self._reconnect(attempts-1)
+ else:
+ raise IOError('Cannot connect to {host}'.
+ format(host=node['host']))
+ else:
+ try:
start = time()
pkey = None
if 'priv_key' in node:
pkey = RSAKey.from_private_key(
StringIO.StringIO(node['priv_key']))
- self._ssh = paramiko.SSHClient()
- self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ self._ssh = SSHClient()
+ self._ssh.set_missing_host_key_policy(AutoAddPolicy())
self._ssh.connect(node['host'], username=node['username'],
password=node.get('password'), pkey=pkey,
self._ssh.get_transport().set_keepalive(10)
SSH.__existing_connections[node_hash] = self._ssh
-
- logger.trace('connect took {} seconds'.format(time() - start))
- logger.debug('new ssh: {0}'.format(self._ssh))
-
- logger.debug('Connect peer: {0}'.
- format(self._ssh.get_transport().getpeername()))
- logger.debug('Connections: {0}'.
- format(str(SSH.__existing_connections)))
- except RuntimeError as exc:
- if attempts > 0:
- self._reconnect(attempts-1)
- else:
- raise exc
-
- def disconnect(self, node):
+ logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
+ format(
+ peer=self._ssh.get_transport().getpeername(),
+ total=(time() - start),
+ ssh=self._ssh))
+ except SSHException:
+ raise IOError('Cannot connect to {host}'.
+ format(host=node['host']))
+ except NoValidConnectionsError as err:
+ logger.error(repr(err))
+ raise IOError('Unable to connect to port {port} on {host}'.
+ format(port=node['port'], host=node['host']))
+
+ def disconnect(self, node=None):
"""Close SSH connection to the node.
- :param node: The node to disconnect from.
- :type node: dict
+ :param node: The node to disconnect from. None means last connected.
+ :type node: dict or None
"""
+ if node is None:
+ node = self._node
+ if node is None:
+ return
node_hash = self._node_hash(node)
if node_hash in SSH.__existing_connections:
- logger.debug('Disconnecting peer: {}, {}'.
- format(node['host'], node['port']))
+ logger.debug('Disconnecting peer: {host}, {port}'.
+ format(host=node['host'], port=node['port']))
ssh = SSH.__existing_connections.pop(node_hash)
ssh.close()
def _reconnect(self, attempts=0):
- """Close the SSH connection and open it again."""
+ """Close the SSH connection and open it again.
+ :param attempts: Number of reconnect attempts.
+ :type attempts: int
+ """
node = self._node
self.disconnect(node)
self.connect(node, attempts)
- logger.debug('Reconnecting peer done: {}'.
- format(self._ssh.get_transport().getpeername()))
+ logger.debug('Reconnecting peer done: {host}, {port}'.
+ format(host=node['host'], port=node['port']))
def exec_command(self, cmd, timeout=10):
"""Execute SSH command on a new channel on the connected Node.
:param cmd: Command to run on the Node.
:param timeout: Maximal time in seconds to wait until the command is
done. If set to None then wait forever.
- :type cmd: str
+ :type cmd: str or OptionString
:type timeout: int
:return return_code, stdout, stderr
:rtype: tuple(int, str, str)
:raise SSHTimeout: If command is not finished in timeout time.
"""
+ cmd = str(cmd)
stdout = StringIO.StringIO()
stderr = StringIO.StringIO()
try:
command = 'sudo -S {c}'.format(c=command)
return self.exec_command(command, timeout)
- def interactive_terminal_open(self, time_out=30):
+ def interactive_terminal_open(self, time_out=45):
"""Open interactive terminal on a new channel on the connected Node.
- FIXME: Convert or document other possible exceptions, such as
- socket.error or SSHException.
+ :param time_out: Timeout in seconds.
+ :returns: SSH channel with opened terminal.
.. warning:: Interruptingcow is used here, and it uses
signal(SIGALRM) to let the operating system interrupt program
handlers only apply to the main thread, so you cannot use this
from other threads. You must not use this in a program that
uses SIGALRM itself (this includes certain profilers)
-
- :param time_out: Timeout in seconds.
- :returns: SSH channel with opened terminal.
- :raise IOError: If receive attempt results in socket.timeout.
"""
chan = self._ssh.get_transport().open_session()
chan.get_pty()
chan.set_combine_stderr(True)
buf = ''
- while not buf.endswith((":~$ ", "~]$ ", "~]# ")):
+ while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
try:
chunk = chan.recv(self.__MAX_RECV_BUF)
if not chunk:
break
except socket.timeout:
logger.error('Socket timeout: {0}'.format(buf))
- # TODO: Find out which exception would callers appreciate here.
- raise IOError('Socket timeout: {0}'.format(buf))
+ raise Exception('Socket timeout: {0}'.format(buf))
return chan
def interactive_terminal_exec_command(self, chan, cmd, prompt):
interactive_terminal_open() method has to be called first!
+ :param chan: SSH channel with opened terminal.
+ :param cmd: Command to be executed.
+ :param prompt: Command prompt, sequence of characters used to
+ indicate readiness to accept commands.
+ :returns: Command output.
+
.. warning:: Interruptingcow is used here, and it uses
signal(SIGALRM) to let the operating system interrupt program
execution. This has the following limitations: Python signal
handlers only apply to the main thread, so you cannot use this
from other threads. You must not use this in a program that
uses SIGALRM itself (this includes certain profilers)
-
- :param chan: SSH channel with opened terminal.
- :param cmd: Command to be executed.
- :param prompt: Command prompt, sequence of characters used to
- indicate readiness to accept commands.
- :returns: Command output.
- :raise IOError: If receive attempt results in socket.timeout.
"""
chan.sendall('{c}\n'.format(c=cmd))
buf = ''
except socket.timeout:
logger.error('Socket timeout during execution of command: '
'{0}\nBuffer content:\n{1}'.format(cmd, buf))
- # TODO: Find out which exception would callers appreciate here.
- raise IOError('Socket timeout during execution of command: '
- '{0}\nBuffer content:\n{1}'.format(cmd, buf))
+ raise Exception('Socket timeout during execution of command: '
+ '{0}\nBuffer content:\n{1}'.format(cmd, buf))
tmp = buf.replace(cmd.replace('\n', ''), '')
for item in prompt:
tmp.replace(item, '')
def interactive_terminal_close(chan):
"""Close interactive terminal SSH channel.
- :param: chan: SSH channel to be closed.
+ :param chan: SSH channel to be closed.
"""
chan.close()
- def scp(self, local_path, remote_path, get=False, timeout=10):
+ def scp(self, local_path, remote_path, get=False, timeout=30,
+ wildcard=False):
"""Copy files from local_path to remote_path or vice versa.
connect() method has to be called first!
path to remote file which should be downloaded.
:param get: scp operation to perform. Default is put.
:param timeout: Timeout value in seconds.
+ :param wildcard: If path has wildcard characters. Default is false.
:type local_path: str
:type remote_path: str
:type get: bool
:type timeout: int
+ :type wildcard: bool
"""
if not get:
logger.trace('SCP {0} to {1}:{2}'.format(
self._ssh.get_transport().getpeername(), remote_path,
local_path))
# SCPCLient takes a paramiko transport as its only argument
- scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
+ if not wildcard:
+ scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
+ else:
+ scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
+ socket_timeout=timeout)
start = time()
if not get:
scp.put(local_path, remote_path)
logger.trace('SCP took {0} seconds'.format(end-start))
-def exec_cmd(node, cmd, timeout=600, sudo=False):
+def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
"""Convenience function to ssh/exec/return rc, out & err.
- FIXME: Document :param, :type, :raise and similar.
Returns (rc, stdout, stderr).
+
+ :param node: The node to execute command on.
+ :param cmd: Command to execute.
+ :param timeout: Timeout value in seconds. Default: 600.
+ :param sudo: Sudo privilege execution flag. Default: False.
+ :param disconnect: Close the opened SSH connection if True.
+ :type node: dict
+ :type cmd: str or OptionString
+ :type timeout: int
+ :type sudo: bool
+ :type disconnect: bool
+ :returns: RC, Stdout, Stderr.
+ :rtype: tuple(int, str, str)
"""
if node is None:
raise TypeError('Node parameter is None')
if cmd is None:
raise TypeError('Command parameter is None')
- if len(cmd) == 0:
+ if not cmd:
raise ValueError('Empty command parameter')
ssh = SSH()
+
+ if node.get('host_port') is not None:
+ ssh_node = dict()
+ ssh_node['host'] = '127.0.0.1'
+ ssh_node['port'] = node['port']
+ ssh_node['username'] = node['username']
+ ssh_node['password'] = node['password']
+ import pexpect
+ options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
+ tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
+ ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
+ format(tnl=tnl, op=options, user=node['host_username'],
+ host=node['host'], host_port=node['host_port'])
+ logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
+ format(ssh_cmd=ssh_cmd))
+ child = pexpect.spawn(ssh_cmd)
+ child.expect('.* password: ')
+ logger.trace(child.after)
+ child.sendline(node['host_password'])
+ child.expect('Welcome .*')
+ logger.trace(child.after)
+ logger.trace('Local port forwarding finished.')
+ else:
+ ssh_node = node
+
try:
- ssh.connect(node)
+ ssh.connect(ssh_node)
except SSHException as err:
- logger.error("Failed to connect to node" + str(err))
+ logger.error("Failed to connect to node" + repr(err))
return None, None, None
try:
(ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
timeout=timeout)
except SSHException as err:
- logger.error(err)
+ logger.error(repr(err))
return None, None, None
+ finally:
+ if disconnect:
+ ssh.disconnect()
return ret_code, stdout, stderr
-def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
+def exec_cmd_no_error(
+ node, cmd, timeout=600, sudo=False, message=None, disconnect=False):
"""Convenience function to ssh/exec/return out & err.
Verifies that return code is zero.
- Returns (stdout, stderr).
+ :param node: DUT node.
+ :param cmd: Command to be executed.
+ :param timeout: Timeout value in seconds. Default: 600.
+ :param sudo: Sudo privilege execution flag. Default: False.
+ :param message: Error message in case of failure. Default: None.
+ :param disconnect: Close the opened SSH connection if True.
+ :type node: dict
+ :type cmd: str or OptionString
+ :type timeout: int
+ :type sudo: bool
+ :type message: str
+ :type disconnect: bool
+ :returns: Stdout, Stderr.
+ :rtype: tuple(str, str)
+ :raises RuntimeError: If bash return code is not 0.
"""
- (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
- assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
- format(cmd, stderr))
+ ret_code, stdout, stderr = exec_cmd(
+ node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
+ msg = ('Command execution failed: "{cmd}"\n{stderr}'.
+ format(cmd=cmd, stderr=stderr) if message is None else message)
+ if ret_code != 0:
+ raise RuntimeError(msg)
+
return stdout, stderr
+
+def scp_node(
+ node, local_path, remote_path, get=False, timeout=30, disconnect=False):
+ """Copy files from local_path to remote_path or vice versa.
+
+ :param node: SUT node.
+ :param local_path: Path to local file that should be uploaded; or
+ path where to save remote file.
+ :param remote_path: Remote path where to place uploaded file; or
+ path to remote file which should be downloaded.
+ :param get: scp operation to perform. Default is put.
+ :param timeout: Timeout value in seconds.
+ :param disconnect: Close the opened SSH connection if True.
+ :type node: dict
+ :type local_path: str
+ :type remote_path: str
+ :type get: bool
+ :type timeout: int
+ :type disconnect: bool
+ :raises RuntimeError: If SSH connection failed or SCP transfer failed.
+ """
+ ssh = SSH()
+
+ try:
+ ssh.connect(node)
+ except SSHException:
+ raise RuntimeError('Failed to connect to {host}!'
+ .format(host=node['host']))
+ try:
+ ssh.scp(local_path, remote_path, get, timeout)
+ except SCPException:
+ raise RuntimeError('SCP execution failed on {host}!'
+ .format(host=node['host']))
+ finally:
+ if disconnect:
+ ssh.disconnect()