e4ac93fb1b3356539b56c3782dbe409f7cfaf30e
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19 from time import time, sleep
20
21 from paramiko import RSAKey, SSHClient, AutoAddPolicy
22 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
23 from robot.api import logger
24 from scp import SCPClient, SCPException
25
26 from resources.libraries.python.OptionString import OptionString
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 def raise_from(raising, excepted):
34     """Function to be replaced by "raise from" in Python 3.
35
36     Neither "six" nor "future" offer good enough implementation right now.
37     chezsoi.org/lucas/blog/displaying-chained-exceptions-stacktraces-in-python-2
38
39     Current implementation just logs excepted error, and raises the new one.
40
41     :param raising: The exception to raise.
42     :param excepted: The exception we excepted and want to log.
43     :type raising: BaseException
44     :type excepted: BaseException
45     :raises: raising
46     """
47     logger.error("Excepted: {exc!r}\nRaising: {rai!r}".format(
48         exc=excepted, rai=raising))
49     raise raising
50
51
52 class SSHTimeout(Exception):
53     """This exception is raised when a timeout occurs."""
54     pass
55
56
57 class SSH(object):
58     """Contains methods for managing and using SSH connections."""
59
60     __MAX_RECV_BUF = 10*1024*1024
61     __existing_connections = {}
62
63     def __init__(self):
64         self._ssh = None
65         self._node = None
66
67     @staticmethod
68     def _node_hash(node):
69         """Get IP address and port hash from node dictionary.
70
71         :param node: Node in topology.
72         :type node: dict
73         :returns: IP address and port for the specified node.
74         :rtype: int
75         """
76
77         return hash(frozenset([node['host'], node['port']]))
78
79     def connect(self, node, attempts=5):
80         """Connect to node prior to running exec_command or scp.
81
82         If there already is a connection to the node, this method reuses it.
83
84         :param node: Node in topology.
85         :param attempts: Number of reconnect attempts.
86         :type node: dict
87         :type attempts: int
88         :raises IOError: If cannot connect to host.
89         """
90         self._node = node
91         node_hash = self._node_hash(node)
92         if node_hash in SSH.__existing_connections:
93             self._ssh = SSH.__existing_connections[node_hash]
94             if self._ssh.get_transport().is_active():
95                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
96             else:
97                 if attempts > 0:
98                     self._reconnect(attempts-1)
99                 else:
100                     raise IOError('Cannot connect to {host}'.
101                                   format(host=node['host']))
102         else:
103             try:
104                 start = time()
105                 pkey = None
106                 if 'priv_key' in node:
107                     pkey = RSAKey.from_private_key(
108                         StringIO.StringIO(node['priv_key']))
109
110                 self._ssh = SSHClient()
111                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
112
113                 self._ssh.connect(node['host'], username=node['username'],
114                                   password=node.get('password'), pkey=pkey,
115                                   port=node['port'])
116
117                 self._ssh.get_transport().set_keepalive(10)
118
119                 SSH.__existing_connections[node_hash] = self._ssh
120                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
121                              format(
122                                  peer=self._ssh.get_transport().getpeername(),
123                                  total=(time() - start),
124                                  ssh=self._ssh))
125             except SSHException as exc:
126                 raise_from(IOError('Cannot connect to {host}'.format(
127                     host=node['host'])), exc)
128             except NoValidConnectionsError as err:
129                 raise_from(IOError(
130                     'Unable to connect to port {port} on {host}'.format(
131                         port=node['port'], host=node['host'])), err)
132
133     def disconnect(self, node=None):
134         """Close SSH connection to the node.
135
136         :param node: The node to disconnect from. None means last connected.
137         :type node: dict or None
138         """
139         if node is None:
140             node = self._node
141         if node is None:
142             return
143         node_hash = self._node_hash(node)
144         if node_hash in SSH.__existing_connections:
145             logger.debug('Disconnecting peer: {host}, {port}'.
146                          format(host=node['host'], port=node['port']))
147             ssh = SSH.__existing_connections.pop(node_hash)
148             ssh.close()
149
150     def _reconnect(self, attempts=0):
151         """Close the SSH connection and open it again.
152
153         :param attempts: Number of reconnect attempts.
154         :type attempts: int
155         """
156         node = self._node
157         self.disconnect(node)
158         self.connect(node, attempts)
159         logger.debug('Reconnecting peer done: {host}, {port}'.
160                      format(host=node['host'], port=node['port']))
161
162     def exec_command(self, cmd, timeout=10, log_stdout_err=True):
163         """Execute SSH command on a new channel on the connected Node.
164
165         :param cmd: Command to run on the Node.
166         :param timeout: Maximal time in seconds to wait until the command is
167             done. If set to None then wait forever.
168         :param log_stdout_err: If True, stdout and stderr are logged. stdout
169             and stderr are logged also if the return code is not zero
170             independently of the value of log_stdout_err.
171         :type cmd: str or OptionString
172         :type timeout: int
173         :type log_stdout_err: bool
174         :returns: return_code, stdout, stderr
175         :rtype: tuple(int, str, str)
176         :raises SSHTimeout: If command is not finished in timeout time.
177         """
178         if isinstance(cmd, (list, tuple)):
179             cmd = OptionString(cmd)
180         cmd = str(cmd)
181         stdout = StringIO.StringIO()
182         stderr = StringIO.StringIO()
183         try:
184             chan = self._ssh.get_transport().open_session(timeout=5)
185             peer = self._ssh.get_transport().getpeername()
186         except (AttributeError, SSHException):
187             self._reconnect()
188             chan = self._ssh.get_transport().open_session(timeout=5)
189             peer = self._ssh.get_transport().getpeername()
190         chan.settimeout(timeout)
191
192         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
193                      .format(peer=peer, timeout=timeout, cmd=cmd))
194
195         start = time()
196         chan.exec_command(cmd)
197         while not chan.exit_status_ready() and timeout is not None:
198             if chan.recv_ready():
199                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
200
201             if chan.recv_stderr_ready():
202                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
203
204             if time() - start > timeout:
205                 raise SSHTimeout(
206                     'Timeout exception during execution of command: {cmd}\n'
207                     'Current contents of stdout buffer: {stdout}\n'
208                     'Current contents of stderr buffer: {stderr}\n'
209                     .format(cmd=cmd, stdout=stdout.getvalue(),
210                             stderr=stderr.getvalue())
211                 )
212
213             sleep(0.1)
214         return_code = chan.recv_exit_status()
215
216         while chan.recv_ready():
217             stdout.write(chan.recv(self.__MAX_RECV_BUF))
218
219         while chan.recv_stderr_ready():
220             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
221
222         end = time()
223         logger.trace('exec_command on {peer} took {total} seconds'.
224                      format(peer=peer, total=end-start))
225
226         logger.trace('return RC {rc}'.format(rc=return_code))
227         if log_stdout_err or int(return_code):
228             logger.trace('return STDOUT {stdout}'.
229                          format(stdout=stdout.getvalue()))
230             logger.trace('return STDERR {stderr}'.
231                          format(stderr=stderr.getvalue()))
232         return return_code, stdout.getvalue(), stderr.getvalue()
233
234     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30,
235                           log_stdout_err=True):
236         """Execute SSH command with sudo on a new channel on the connected Node.
237
238         :param cmd: Command to be executed.
239         :param cmd_input: Input redirected to the command.
240         :param timeout: Timeout.
241         :param log_stdout_err: If True, stdout and stderr are logged.
242         :type cmd: str
243         :type cmd_input: str
244         :type timeout: int
245         :type log_stdout_err: bool
246         :returns: return_code, stdout, stderr
247         :rtype: tuple(int, str, str)
248
249         :Example:
250
251         >>> from ssh import SSH
252         >>> ssh = SSH()
253         >>> ssh.connect(node)
254         >>> # Execute command without input (sudo -S cmd)
255         >>> ssh.exec_command_sudo("ifconfig eth0 down")
256         >>> # Execute command with input (sudo -S cmd <<< "input")
257         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
258         """
259         if isinstance(cmd, (list, tuple)):
260             cmd = OptionString(cmd)
261         if cmd_input is None:
262             command = 'sudo -S {c}'.format(c=cmd)
263         else:
264             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
265         return self.exec_command(command, timeout,
266                                  log_stdout_err=log_stdout_err)
267
268     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
269                          timeout=30):
270         """Execute command in LXC on a new SSH channel on the connected Node.
271
272         :param lxc_cmd: Command to be executed.
273         :param lxc_name: LXC name.
274         :param lxc_params: Additional parameters for LXC attach.
275         :param sudo: Run in privileged LXC mode. Default: privileged
276         :param timeout: Timeout.
277         :type lxc_cmd: str
278         :type lxc_name: str
279         :type lxc_params: str
280         :type sudo: bool
281         :type timeout: int
282         :returns: return_code, stdout, stderr
283         """
284         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
285             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
286
287         if sudo:
288             command = 'sudo -S {c}'.format(c=command)
289         return self.exec_command(command, timeout)
290
291     def interactive_terminal_open(self, time_out=45):
292         """Open interactive terminal on a new channel on the connected Node.
293
294         :param time_out: Timeout in seconds.
295         :returns: SSH channel with opened terminal.
296
297         .. warning:: Interruptingcow is used here, and it uses
298            signal(SIGALRM) to let the operating system interrupt program
299            execution. This has the following limitations: Python signal
300            handlers only apply to the main thread, so you cannot use this
301            from other threads. You must not use this in a program that
302            uses SIGALRM itself (this includes certain profilers)
303         """
304         chan = self._ssh.get_transport().open_session()
305         chan.get_pty()
306         chan.invoke_shell()
307         chan.settimeout(int(time_out))
308         chan.set_combine_stderr(True)
309
310         buf = ''
311         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
312             try:
313                 chunk = chan.recv(self.__MAX_RECV_BUF)
314                 if not chunk:
315                     break
316                 buf += chunk
317                 if chan.exit_status_ready():
318                     logger.error('Channel exit status ready')
319                     break
320             except socket.timeout as exc:
321                 raise_from(Exception('Socket timeout: {0}'.format(buf)), exc)
322         return chan
323
324     def interactive_terminal_exec_command(self, chan, cmd, prompt):
325         """Execute command on interactive terminal.
326
327         interactive_terminal_open() method has to be called first!
328
329         :param chan: SSH channel with opened terminal.
330         :param cmd: Command to be executed.
331         :param prompt: Command prompt, sequence of characters used to
332         indicate readiness to accept commands.
333         :returns: Command output.
334
335         .. warning:: Interruptingcow is used here, and it uses
336            signal(SIGALRM) to let the operating system interrupt program
337            execution. This has the following limitations: Python signal
338            handlers only apply to the main thread, so you cannot use this
339            from other threads. You must not use this in a program that
340            uses SIGALRM itself (this includes certain profilers)
341         """
342         chan.sendall('{c}\n'.format(c=cmd))
343         buf = ''
344         while not buf.endswith(prompt):
345             try:
346                 chunk = chan.recv(self.__MAX_RECV_BUF)
347                 if not chunk:
348                     break
349                 buf += chunk
350                 if chan.exit_status_ready():
351                     logger.error('Channel exit status ready')
352                     break
353             except socket.timeout as exc:
354                 raise_from(Exception(
355                     'Socket timeout during execution of command: '
356                     '{0}\nBuffer content:\n{1}'.format(cmd, buf)), exc)
357         tmp = buf.replace(cmd.replace('\n', ''), '')
358         for item in prompt:
359             tmp.replace(item, '')
360         return tmp
361
362     @staticmethod
363     def interactive_terminal_close(chan):
364         """Close interactive terminal SSH channel.
365
366         :param chan: SSH channel to be closed.
367         """
368         chan.close()
369
370     def scp(self, local_path, remote_path, get=False, timeout=30,
371             wildcard=False):
372         """Copy files from local_path to remote_path or vice versa.
373
374         connect() method has to be called first!
375
376         :param local_path: Path to local file that should be uploaded; or
377         path where to save remote file.
378         :param remote_path: Remote path where to place uploaded file; or
379         path to remote file which should be downloaded.
380         :param get: scp operation to perform. Default is put.
381         :param timeout: Timeout value in seconds.
382         :param wildcard: If path has wildcard characters. Default is false.
383         :type local_path: str
384         :type remote_path: str
385         :type get: bool
386         :type timeout: int
387         :type wildcard: bool
388         """
389         if not get:
390             logger.trace('SCP {0} to {1}:{2}'.format(
391                 local_path, self._ssh.get_transport().getpeername(),
392                 remote_path))
393         else:
394             logger.trace('SCP {0}:{1} to {2}'.format(
395                 self._ssh.get_transport().getpeername(), remote_path,
396                 local_path))
397         # SCPCLient takes a paramiko transport as its only argument
398         if not wildcard:
399             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
400         else:
401             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
402                             socket_timeout=timeout)
403         start = time()
404         if not get:
405             scp.put(local_path, remote_path)
406         else:
407             scp.get(remote_path, local_path)
408         scp.close()
409         end = time()
410         logger.trace('SCP took {0} seconds'.format(end-start))
411
412
413 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
414     """Convenience function to ssh/exec/return rc, out & err.
415
416     Returns (rc, stdout, stderr).
417
418     :param node: The node to execute command on.
419     :param cmd: Command to execute.
420     :param timeout: Timeout value in seconds. Default: 600.
421     :param sudo: Sudo privilege execution flag. Default: False.
422     :param disconnect: Close the opened SSH connection if True.
423     :type node: dict
424     :type cmd: str or OptionString
425     :type timeout: int
426     :type sudo: bool
427     :type disconnect: bool
428     :returns: RC, Stdout, Stderr.
429     :rtype: tuple(int, str, str)
430     """
431     if node is None:
432         raise TypeError('Node parameter is None')
433     if cmd is None:
434         raise TypeError('Command parameter is None')
435     if not cmd:
436         raise ValueError('Empty command parameter')
437
438     ssh = SSH()
439
440     if node.get('host_port') is not None:
441         ssh_node = dict()
442         ssh_node['host'] = '127.0.0.1'
443         ssh_node['port'] = node['port']
444         ssh_node['username'] = node['username']
445         ssh_node['password'] = node['password']
446         import pexpect
447         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
448         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
449         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
450             format(tnl=tnl, op=options, user=node['host_username'],
451                    host=node['host'], host_port=node['host_port'])
452         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
453                      format(ssh_cmd=ssh_cmd))
454         child = pexpect.spawn(ssh_cmd)
455         child.expect('.* password: ')
456         logger.trace(child.after)
457         child.sendline(node['host_password'])
458         child.expect('Welcome .*')
459         logger.trace(child.after)
460         logger.trace('Local port forwarding finished.')
461     else:
462         ssh_node = node
463
464     try:
465         ssh.connect(ssh_node)
466     except SSHException as err:
467         logger.error("Failed to connect to node" + repr(err))
468         return None, None, None
469
470     try:
471         if not sudo:
472             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
473         else:
474             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
475                                                                timeout=timeout)
476     except SSHException as err:
477         logger.error(repr(err))
478         return None, None, None
479     finally:
480         if disconnect:
481             ssh.disconnect()
482
483     return ret_code, stdout, stderr
484
485
486 def exec_cmd_no_error(
487         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
488         retries=0, include_reason=False):
489     """Convenience function to ssh/exec/return out & err.
490
491     Verifies that return code is zero.
492     Supports retries, timeout is related to each try separately then. There is
493     sleep(1) before each retry.
494     Disconnect (if enabled) is applied after each try.
495
496     :param node: DUT node.
497     :param cmd: Command to be executed.
498     :param timeout: Timeout value in seconds. Default: 600.
499     :param sudo: Sudo privilege execution flag. Default: False.
500     :param message: Error message in case of failure. Default: None.
501     :param disconnect: Close the opened SSH connection if True.
502     :param retries: How many times to retry on failure.
503     :param include_reason: Whether default info should be appended to message.
504     :type node: dict
505     :type cmd: str or OptionString
506     :type timeout: int
507     :type sudo: bool
508     :type message: str
509     :type disconnect: bool
510     :type retries: int
511     :type include_reason: bool
512     :returns: Stdout, Stderr.
513     :rtype: tuple(str, str)
514     :raises RuntimeError: If bash return code is not 0.
515     """
516     for _ in range(retries + 1):
517         ret_code, stdout, stderr = exec_cmd(
518             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
519         if ret_code == 0:
520             break
521         sleep(1)
522     else:
523         msg = 'Command execution failed: "{cmd}"\nRC: {rc}\n{stderr}'.format(
524             cmd=cmd, rc=ret_code, stderr=stderr)
525         logger.info(msg)
526         if message:
527             if include_reason:
528                 msg = message + '\n' + msg
529             else:
530                 msg = message
531         raise RuntimeError(msg)
532
533     return stdout, stderr
534
535
536 def scp_node(
537         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
538     """Copy files from local_path to remote_path or vice versa.
539
540     :param node: SUT node.
541     :param local_path: Path to local file that should be uploaded; or
542         path where to save remote file.
543     :param remote_path: Remote path where to place uploaded file; or
544         path to remote file which should be downloaded.
545     :param get: scp operation to perform. Default is put.
546     :param timeout: Timeout value in seconds.
547     :param disconnect: Close the opened SSH connection if True.
548     :type node: dict
549     :type local_path: str
550     :type remote_path: str
551     :type get: bool
552     :type timeout: int
553     :type disconnect: bool
554     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
555     """
556     ssh = SSH()
557
558     try:
559         ssh.connect(node)
560     except SSHException as exc:
561         raise_from(RuntimeError(
562             'Failed to connect to {host}!'.format(host=node['host'])), exc)
563     try:
564         ssh.scp(local_path, remote_path, get, timeout)
565     except SCPException as exc:
566         raise_from(RuntimeError(
567             'SCP execution failed on {host}!'.format(host=node['host'])), exc)
568     finally:
569         if disconnect:
570             ssh.disconnect()