from paramiko import RSAKey
from paramiko.ssh_exception import SSHException
from scp import SCPClient
-from interruptingcow import timeout as icTimeout
from robot.api import logger
from robot.utils.asserts import assert_equal
# TODO: load priv key
+class SSHTimeout(Exception):
+ """This exception is raised when a timeout occurs."""
+ pass
+
+
class SSH(object):
"""Contains methods for managing and using SSH connections."""
return hash(frozenset([node['host'], node['port']]))
- def connect(self, node):
+ def connect(self, node, attempts=5):
"""Connect to node prior to running exec_command or scp.
If there already is a connection to the node, this method reuses it.
"""
- self._node = node
- node_hash = self._node_hash(node)
- if node_hash in SSH.__existing_connections:
- self._ssh = SSH.__existing_connections[node_hash]
- logger.debug('reusing ssh: {0}'.format(self._ssh))
- else:
- start = time()
- pkey = None
- if 'priv_key' in node:
- pkey = RSAKey.from_private_key(
- StringIO.StringIO(node['priv_key']))
-
- self._ssh = paramiko.SSHClient()
- self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-
- self._ssh.connect(node['host'], username=node['username'],
- password=node.get('password'), pkey=pkey,
- port=node['port'])
-
- self._ssh.get_transport().set_keepalive(10)
-
- SSH.__existing_connections[node_hash] = self._ssh
-
- logger.trace('connect took {} seconds'.format(time() - start))
- logger.debug('new ssh: {0}'.format(self._ssh))
-
- logger.debug('Connect peer: {0}'.
- format(self._ssh.get_transport().getpeername()))
- logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
+ try:
+ self._node = node
+ node_hash = self._node_hash(node)
+ if node_hash in SSH.__existing_connections:
+ self._ssh = SSH.__existing_connections[node_hash]
+ logger.debug('reusing ssh: {0}'.format(self._ssh))
+ else:
+ start = time()
+ pkey = None
+ if 'priv_key' in node:
+ pkey = RSAKey.from_private_key(
+ StringIO.StringIO(node['priv_key']))
+
+ self._ssh = paramiko.SSHClient()
+ self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+
+ self._ssh.connect(node['host'], username=node['username'],
+ password=node.get('password'), pkey=pkey,
+ port=node['port'])
+
+ self._ssh.get_transport().set_keepalive(10)
+
+ SSH.__existing_connections[node_hash] = self._ssh
+
+ logger.trace('connect took {} seconds'.format(time() - start))
+ logger.debug('new ssh: {0}'.format(self._ssh))
+
+ logger.debug('Connect peer: {0}'.
+ format(self._ssh.get_transport().getpeername()))
+ logger.debug('Connections: {0}'.
+ format(str(SSH.__existing_connections)))
+ except:
+ if attempts > 0:
+ self._reconnect(attempts-1)
+ else:
+ raise
def disconnect(self, node):
"""Close SSH connection to the node.
ssh = SSH.__existing_connections.pop(node_hash)
ssh.close()
- def _reconnect(self):
+ def _reconnect(self, attempts=0):
"""Close the SSH connection and open it again."""
node = self._node
self.disconnect(node)
- self.connect(node)
+ self.connect(node, attempts)
logger.debug('Reconnecting peer done: {}'.
format(self._ssh.get_transport().getpeername()))
:type timeout: int
:return return_code, stdout, stderr
:rtype: tuple(int, str, str)
- :raise socket.timeout: If command is not finished in timeout time.
+ :raise SSHTimeout: If command is not finished in timeout time.
"""
start = time()
stdout = StringIO.StringIO()
stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
if time() - start > timeout:
- raise socket.timeout(
- 'Timeout exception.\n'
- 'Current contents of stdout buffer: {0}\n'
- 'Current contents of stderr buffer: {1}\n'
- .format(stdout.getvalue(), stderr.getvalue())
+ raise SSHTimeout(
+ 'Timeout exception during execution of command: {0}\n'
+ 'Current contents of stdout buffer: {1}\n'
+ 'Current contents of stderr buffer: {2}\n'
+ .format(cmd, stdout.getvalue(), stderr.getvalue())
)
sleep(0.1)
command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
return self.exec_command(command, timeout)
- def interactive_terminal_open(self, time_out=10):
+ def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
+ timeout=30):
+ """Execute command in LXC on a new SSH channel on the connected Node.
+
+ :param lxc_cmd: Command to be executed.
+ :param lxc_name: LXC name.
+ :param lxc_params: Additional parameters for LXC attach.
+ :param sudo: Run in privileged LXC mode. Default: privileged
+ :param timeout: Timeout.
+ :type lxc_cmd: str
+ :type lxc_name: str
+ :type lxc_params: str
+ :type sudo: bool
+ :type timeout: int
+ :return: return_code, stdout, stderr
+ """
+ command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
+ .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
+
+ if sudo:
+ command = 'sudo -S {c}'.format(c=command)
+ return self.exec_command(command, timeout)
+
+ def interactive_terminal_open(self, time_out=30):
"""Open interactive terminal on a new channel on the connected Node.
:param time_out: Timeout in seconds.
chan.get_pty()
chan.invoke_shell()
chan.settimeout(int(time_out))
+ chan.set_combine_stderr(True)
buf = ''
- try:
- with icTimeout(time_out, exception=RuntimeError):
- while not buf.endswith(':~$ '):
- if chan.recv_ready():
- buf = chan.recv(4096)
- except RuntimeError:
- raise Exception('Open interactive terminal timeout.')
+ while not buf.endswith((":~$ ", "~]$ ")):
+ try:
+ chunk = chan.recv(self.__MAX_RECV_BUF)
+ if not chunk:
+ break
+ buf += chunk
+ if chan.exit_status_ready():
+ logger.error('Channel exit status ready')
+ break
+ except socket.timeout:
+ logger.error('Socket timeout: {0}'.format(buf))
+ raise Exception('Socket timeout: {0}'.format(buf))
return chan
- @staticmethod
- def interactive_terminal_exec_command(chan, cmd, prompt,
- time_out=30):
+ def interactive_terminal_exec_command(self, chan, cmd, prompt):
"""Execute command on interactive terminal.
interactive_terminal_open() method has to be called first!
:param cmd: Command to be executed.
:param prompt: Command prompt, sequence of characters used to
indicate readiness to accept commands.
- :param time_out: Timeout in seconds.
:return: Command output.
.. warning:: Interruptingcow is used here, and it uses
"""
chan.sendall('{c}\n'.format(c=cmd))
buf = ''
- try:
- with icTimeout(time_out, exception=RuntimeError):
- while not buf.endswith(prompt):
- if chan.recv_ready():
- buf += chan.recv(4096)
- except RuntimeError:
- raise Exception("Exec '{c}' timeout.".format(c=cmd))
+ while not buf.endswith(prompt):
+ try:
+ chunk = chan.recv(self.__MAX_RECV_BUF)
+ if not chunk:
+ break
+ buf += chunk
+ if chan.exit_status_ready():
+ logger.error('Channel exit status ready')
+ break
+ except socket.timeout:
+ logger.error('Socket timeout during execution of command: '
+ '{0}\nBuffer content:\n{1}'.format(cmd, buf))
+ raise Exception('Socket timeout during execution of command: '
+ '{0}\nBuffer content:\n{1}'.format(cmd, buf))
tmp = buf.replace(cmd.replace('\n', ''), '')
- return tmp.replace(prompt, '')
+ for item in prompt:
+ tmp.replace(item, '')
+ return tmp
@staticmethod
def interactive_terminal_close(chan):
"""
chan.close()
- def scp(self, local_path, remote_path):
- """Copy files from local_path to remote_path.
+ def scp(self, local_path, remote_path, get=False, timeout=10):
+ """Copy files from local_path to remote_path or vice versa.
connect() method has to be called first!
+
+ :param local_path: Path to local file that should be uploaded; or
+ path where to save remote file.
+ :param remote_path: Remote path where to place uploaded file; or
+ path to remote file which should be downloaded.
+ :param get: scp operation to perform. Default is put.
+ :param timeout: Timeout value in seconds.
+ :type local_path: str
+ :type remote_path: str
+ :type get: bool
+ :type timeout: int
"""
- logger.trace('SCP {0} to {1}:{2}'.format(
- local_path, self._ssh.get_transport().getpeername(), remote_path))
+ if not get:
+ logger.trace('SCP {0} to {1}:{2}'.format(
+ local_path, self._ssh.get_transport().getpeername(),
+ remote_path))
+ else:
+ logger.trace('SCP {0}:{1} to {2}'.format(
+ self._ssh.get_transport().getpeername(), remote_path,
+ local_path))
# SCPCLient takes a paramiko transport as its only argument
- scp = SCPClient(self._ssh.get_transport())
+ scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
start = time()
- scp.put(local_path, remote_path)
+ if not get:
+ scp.put(local_path, remote_path)
+ else:
+ scp.get(remote_path, local_path)
scp.close()
end = time()
logger.trace('SCP took {0} seconds'.format(end-start))
ssh = SSH()
try:
ssh.connect(node)
- except Exception as err:
+ except SSHException as err:
logger.error("Failed to connect to node" + str(err))
return None, None, None
else:
(ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
timeout=timeout)
- except Exception as err:
+ except SSHException as err:
logger.error(err)
return None, None, None