1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
6 # http://www.apache.org/licenses/LICENSE-2.0
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
14 """Library for SSH connection management."""
20 from time import time, sleep
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
33 class SSHTimeout(Exception):
34 """This exception is raised when a timeout occurs."""
39 """Contains methods for managing and using SSH connections."""
41 __MAX_RECV_BUF = 10*1024*1024
42 __existing_connections = {}
50 """Get IP address and port hash from node dictionary.
52 :param node: Node in topology.
54 :returns: IP address and port for the specified node.
58 return hash(frozenset([node['host'], node['port']]))
60 def connect(self, node, attempts=5):
61 """Connect to node prior to running exec_command or scp.
63 If there already is a connection to the node, this method reuses it.
65 :param node: Node in topology.
66 :param attempts: Number of reconnect attempts.
69 :raises IOError: If cannot connect to host.
72 node_hash = self._node_hash(node)
73 if node_hash in SSH.__existing_connections:
74 self._ssh = SSH.__existing_connections[node_hash]
75 if self._ssh.get_transport().is_active():
76 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
79 self._reconnect(attempts-1)
81 raise IOError('Cannot connect to {host}'.
82 format(host=node['host']))
87 if 'priv_key' in node:
88 pkey = RSAKey.from_private_key(
89 StringIO.StringIO(node['priv_key']))
91 self._ssh = SSHClient()
92 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
94 self._ssh.connect(node['host'], username=node['username'],
95 password=node.get('password'), pkey=pkey,
98 self._ssh.get_transport().set_keepalive(10)
100 SSH.__existing_connections[node_hash] = self._ssh
101 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
103 peer=self._ssh.get_transport().getpeername(),
104 total=(time() - start),
107 raise IOError('Cannot connect to {host}'.
108 format(host=node['host']))
109 except NoValidConnectionsError as err:
110 logger.error(repr(err))
111 raise IOError('Unable to connect to port {port} on {host}'.
112 format(port=node['port'], host=node['host']))
114 def disconnect(self, node=None):
115 """Close SSH connection to the node.
117 :param node: The node to disconnect from. None means last connected.
118 :type node: dict or None
124 node_hash = self._node_hash(node)
125 if node_hash in SSH.__existing_connections:
126 logger.debug('Disconnecting peer: {host}, {port}'.
127 format(host=node['host'], port=node['port']))
128 ssh = SSH.__existing_connections.pop(node_hash)
131 def _reconnect(self, attempts=0):
132 """Close the SSH connection and open it again.
134 :param attempts: Number of reconnect attempts.
138 self.disconnect(node)
139 self.connect(node, attempts)
140 logger.debug('Reconnecting peer done: {host}, {port}'.
141 format(host=node['host'], port=node['port']))
143 def exec_command(self, cmd, timeout=10, log_stdout_err=True):
144 """Execute SSH command on a new channel on the connected Node.
146 :param cmd: Command to run on the Node.
147 :param timeout: Maximal time in seconds to wait until the command is
148 done. If set to None then wait forever.
149 :param log_stdout_err: If True, stdout and stderr are logged. stdout
150 and stderr are logged also if the return code is not zero
151 independently of the value of log_stdout_err.
152 :type cmd: str or OptionString
154 :type log_stdout_err: bool
155 :returns: return_code, stdout, stderr
156 :rtype: tuple(int, str, str)
157 :raises SSHTimeout: If command is not finished in timeout time.
160 stdout = StringIO.StringIO()
161 stderr = StringIO.StringIO()
163 chan = self._ssh.get_transport().open_session(timeout=5)
164 peer = self._ssh.get_transport().getpeername()
165 except AttributeError:
167 chan = self._ssh.get_transport().open_session(timeout=5)
168 peer = self._ssh.get_transport().getpeername()
171 chan = self._ssh.get_transport().open_session(timeout=5)
172 peer = self._ssh.get_transport().getpeername()
173 chan.settimeout(timeout)
175 logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
176 .format(peer=peer, timeout=timeout, cmd=cmd))
179 chan.exec_command(cmd)
180 while not chan.exit_status_ready() and timeout is not None:
181 if chan.recv_ready():
182 stdout.write(chan.recv(self.__MAX_RECV_BUF))
184 if chan.recv_stderr_ready():
185 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
187 if time() - start > timeout:
189 'Timeout exception during execution of command: {cmd}\n'
190 'Current contents of stdout buffer: {stdout}\n'
191 'Current contents of stderr buffer: {stderr}\n'
192 .format(cmd=cmd, stdout=stdout.getvalue(),
193 stderr=stderr.getvalue())
197 return_code = chan.recv_exit_status()
199 while chan.recv_ready():
200 stdout.write(chan.recv(self.__MAX_RECV_BUF))
202 while chan.recv_stderr_ready():
203 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
206 logger.trace('exec_command on {peer} took {total} seconds'.
207 format(peer=peer, total=end-start))
209 logger.trace('return RC {rc}'.format(rc=return_code))
210 if log_stdout_err or int(return_code):
211 logger.trace('return STDOUT {stdout}'.
212 format(stdout=stdout.getvalue()))
213 logger.trace('return STDERR {stderr}'.
214 format(stderr=stderr.getvalue()))
215 return return_code, stdout.getvalue(), stderr.getvalue()
217 def exec_command_sudo(self, cmd, cmd_input=None, timeout=30,
218 log_stdout_err=True):
219 """Execute SSH command with sudo on a new channel on the connected Node.
221 :param cmd: Command to be executed.
222 :param cmd_input: Input redirected to the command.
223 :param timeout: Timeout.
224 :param log_stdout_err: If True, stdout and stderr are logged.
228 :type log_stdout_err: bool
229 :returns: return_code, stdout, stderr
230 :rtype: tuple(int, str, str)
234 >>> from ssh import SSH
236 >>> ssh.connect(node)
237 >>> # Execute command without input (sudo -S cmd)
238 >>> ssh.exec_command_sudo("ifconfig eth0 down")
239 >>> # Execute command with input (sudo -S cmd <<< "input")
240 >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
242 if cmd_input is None:
243 command = 'sudo -S {c}'.format(c=cmd)
245 command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
246 return self.exec_command(command, timeout,
247 log_stdout_err=log_stdout_err)
249 def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
251 """Execute command in LXC on a new SSH channel on the connected Node.
253 :param lxc_cmd: Command to be executed.
254 :param lxc_name: LXC name.
255 :param lxc_params: Additional parameters for LXC attach.
256 :param sudo: Run in privileged LXC mode. Default: privileged
257 :param timeout: Timeout.
260 :type lxc_params: str
263 :returns: return_code, stdout, stderr
265 command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
266 .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
269 command = 'sudo -S {c}'.format(c=command)
270 return self.exec_command(command, timeout)
272 def interactive_terminal_open(self, time_out=45):
273 """Open interactive terminal on a new channel on the connected Node.
275 :param time_out: Timeout in seconds.
276 :returns: SSH channel with opened terminal.
278 .. warning:: Interruptingcow is used here, and it uses
279 signal(SIGALRM) to let the operating system interrupt program
280 execution. This has the following limitations: Python signal
281 handlers only apply to the main thread, so you cannot use this
282 from other threads. You must not use this in a program that
283 uses SIGALRM itself (this includes certain profilers)
285 chan = self._ssh.get_transport().open_session()
288 chan.settimeout(int(time_out))
289 chan.set_combine_stderr(True)
292 while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
294 chunk = chan.recv(self.__MAX_RECV_BUF)
298 if chan.exit_status_ready():
299 logger.error('Channel exit status ready')
301 except socket.timeout:
302 logger.error('Socket timeout: {0}'.format(buf))
303 raise Exception('Socket timeout: {0}'.format(buf))
306 def interactive_terminal_exec_command(self, chan, cmd, prompt):
307 """Execute command on interactive terminal.
309 interactive_terminal_open() method has to be called first!
311 :param chan: SSH channel with opened terminal.
312 :param cmd: Command to be executed.
313 :param prompt: Command prompt, sequence of characters used to
314 indicate readiness to accept commands.
315 :returns: Command output.
317 .. warning:: Interruptingcow is used here, and it uses
318 signal(SIGALRM) to let the operating system interrupt program
319 execution. This has the following limitations: Python signal
320 handlers only apply to the main thread, so you cannot use this
321 from other threads. You must not use this in a program that
322 uses SIGALRM itself (this includes certain profilers)
324 chan.sendall('{c}\n'.format(c=cmd))
326 while not buf.endswith(prompt):
328 chunk = chan.recv(self.__MAX_RECV_BUF)
332 if chan.exit_status_ready():
333 logger.error('Channel exit status ready')
335 except socket.timeout:
336 logger.error('Socket timeout during execution of command: '
337 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
338 raise Exception('Socket timeout during execution of command: '
339 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
340 tmp = buf.replace(cmd.replace('\n', ''), '')
342 tmp.replace(item, '')
346 def interactive_terminal_close(chan):
347 """Close interactive terminal SSH channel.
349 :param chan: SSH channel to be closed.
353 def scp(self, local_path, remote_path, get=False, timeout=30,
355 """Copy files from local_path to remote_path or vice versa.
357 connect() method has to be called first!
359 :param local_path: Path to local file that should be uploaded; or
360 path where to save remote file.
361 :param remote_path: Remote path where to place uploaded file; or
362 path to remote file which should be downloaded.
363 :param get: scp operation to perform. Default is put.
364 :param timeout: Timeout value in seconds.
365 :param wildcard: If path has wildcard characters. Default is false.
366 :type local_path: str
367 :type remote_path: str
373 logger.trace('SCP {0} to {1}:{2}'.format(
374 local_path, self._ssh.get_transport().getpeername(),
377 logger.trace('SCP {0}:{1} to {2}'.format(
378 self._ssh.get_transport().getpeername(), remote_path,
380 # SCPCLient takes a paramiko transport as its only argument
382 scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
384 scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
385 socket_timeout=timeout)
388 scp.put(local_path, remote_path)
390 scp.get(remote_path, local_path)
393 logger.trace('SCP took {0} seconds'.format(end-start))
396 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
397 """Convenience function to ssh/exec/return rc, out & err.
399 Returns (rc, stdout, stderr).
401 :param node: The node to execute command on.
402 :param cmd: Command to execute.
403 :param timeout: Timeout value in seconds. Default: 600.
404 :param sudo: Sudo privilege execution flag. Default: False.
405 :param disconnect: Close the opened SSH connection if True.
407 :type cmd: str or OptionString
410 :type disconnect: bool
411 :returns: RC, Stdout, Stderr.
412 :rtype: tuple(int, str, str)
415 raise TypeError('Node parameter is None')
417 raise TypeError('Command parameter is None')
419 raise ValueError('Empty command parameter')
423 if node.get('host_port') is not None:
425 ssh_node['host'] = '127.0.0.1'
426 ssh_node['port'] = node['port']
427 ssh_node['username'] = node['username']
428 ssh_node['password'] = node['password']
430 options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
431 tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
432 ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
433 format(tnl=tnl, op=options, user=node['host_username'],
434 host=node['host'], host_port=node['host_port'])
435 logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
436 format(ssh_cmd=ssh_cmd))
437 child = pexpect.spawn(ssh_cmd)
438 child.expect('.* password: ')
439 logger.trace(child.after)
440 child.sendline(node['host_password'])
441 child.expect('Welcome .*')
442 logger.trace(child.after)
443 logger.trace('Local port forwarding finished.')
448 ssh.connect(ssh_node)
449 except SSHException as err:
450 logger.error("Failed to connect to node" + repr(err))
451 return None, None, None
455 (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
457 (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
459 except SSHException as err:
460 logger.error(repr(err))
461 return None, None, None
466 return ret_code, stdout, stderr
469 def exec_cmd_no_error(
470 node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
472 """Convenience function to ssh/exec/return out & err.
474 Verifies that return code is zero.
475 Supports retries, timeout is related to each try separately then. There is
476 sleep(1) before each retry.
477 Disconnect (if enabled) is applied after each try.
479 :param node: DUT node.
480 :param cmd: Command to be executed.
481 :param timeout: Timeout value in seconds. Default: 600.
482 :param sudo: Sudo privilege execution flag. Default: False.
483 :param message: Error message in case of failure. Default: None.
484 :param disconnect: Close the opened SSH connection if True.
485 :param retries: How many times to retry on failure.
487 :type cmd: str or OptionString
491 :type disconnect: bool
493 :returns: Stdout, Stderr.
494 :rtype: tuple(str, str)
495 :raises RuntimeError: If bash return code is not 0.
497 for _ in range(retries + 1):
498 ret_code, stdout, stderr = exec_cmd(
499 node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
504 msg = ('Command execution failed: "{cmd}"\n{stderr}'.
505 format(cmd=cmd, stderr=stderr) if message is None else message)
506 raise RuntimeError(msg)
508 return stdout, stderr
512 node, local_path, remote_path, get=False, timeout=30, disconnect=False):
513 """Copy files from local_path to remote_path or vice versa.
515 :param node: SUT node.
516 :param local_path: Path to local file that should be uploaded; or
517 path where to save remote file.
518 :param remote_path: Remote path where to place uploaded file; or
519 path to remote file which should be downloaded.
520 :param get: scp operation to perform. Default is put.
521 :param timeout: Timeout value in seconds.
522 :param disconnect: Close the opened SSH connection if True.
524 :type local_path: str
525 :type remote_path: str
528 :type disconnect: bool
529 :raises RuntimeError: If SSH connection failed or SCP transfer failed.
536 raise RuntimeError('Failed to connect to {host}!'
537 .format(host=node['host']))
539 ssh.scp(local_path, remote_path, get, timeout)
541 raise RuntimeError('SCP execution failed on {host}!'
542 .format(host=node['host']))