cee35868e4f0c389169ad4f6f3f8ddb2d993f982
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19 from time import time, sleep
20
21 from paramiko import RSAKey, SSHClient, AutoAddPolicy
22 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
23 from robot.api import logger
24 from scp import SCPClient, SCPException
25
26 from resources.libraries.python.OptionString import OptionString
27 from resources.libraries.python.PythonThree import raise_from
28
29 __all__ = ["exec_cmd", "exec_cmd_no_error"]
30
31 # TODO: load priv key
32
33
34 class SSHTimeout(Exception):
35     """This exception is raised when a timeout occurs."""
36     pass
37
38
39 class SSH(object):
40     """Contains methods for managing and using SSH connections."""
41
42     __MAX_RECV_BUF = 10*1024*1024
43     __existing_connections = {}
44
45     def __init__(self):
46         self._ssh = None
47         self._node = None
48
49     @staticmethod
50     def _node_hash(node):
51         """Get IP address and port hash from node dictionary.
52
53         :param node: Node in topology.
54         :type node: dict
55         :returns: IP address and port for the specified node.
56         :rtype: int
57         """
58
59         return hash(frozenset([node['host'], node['port']]))
60
61     def connect(self, node, attempts=5):
62         """Connect to node prior to running exec_command or scp.
63
64         If there already is a connection to the node, this method reuses it.
65
66         :param node: Node in topology.
67         :param attempts: Number of reconnect attempts.
68         :type node: dict
69         :type attempts: int
70         :raises IOError: If cannot connect to host.
71         """
72         self._node = node
73         node_hash = self._node_hash(node)
74         if node_hash in SSH.__existing_connections:
75             self._ssh = SSH.__existing_connections[node_hash]
76             if self._ssh.get_transport().is_active():
77                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
78             else:
79                 if attempts > 0:
80                     self._reconnect(attempts-1)
81                 else:
82                     raise IOError('Cannot connect to {host}'.
83                                   format(host=node['host']))
84         else:
85             try:
86                 start = time()
87                 pkey = None
88                 if 'priv_key' in node:
89                     pkey = RSAKey.from_private_key(
90                         StringIO.StringIO(node['priv_key']))
91
92                 self._ssh = SSHClient()
93                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
94
95                 self._ssh.connect(node['host'], username=node['username'],
96                                   password=node.get('password'), pkey=pkey,
97                                   port=node['port'])
98
99                 self._ssh.get_transport().set_keepalive(10)
100
101                 SSH.__existing_connections[node_hash] = self._ssh
102                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
103                              format(
104                                  peer=self._ssh.get_transport().getpeername(),
105                                  total=(time() - start),
106                                  ssh=self._ssh))
107             except SSHException as exc:
108                 raise_from(IOError('Cannot connect to {host}'.format(
109                     host=node['host'])), exc)
110             except NoValidConnectionsError as err:
111                 raise_from(IOError(
112                     'Unable to connect to port {port} on {host}'.format(
113                         port=node['port'], host=node['host'])), err)
114
115     def disconnect(self, node=None):
116         """Close SSH connection to the node.
117
118         :param node: The node to disconnect from. None means last connected.
119         :type node: dict or None
120         """
121         if node is None:
122             node = self._node
123         if node is None:
124             return
125         node_hash = self._node_hash(node)
126         if node_hash in SSH.__existing_connections:
127             logger.debug('Disconnecting peer: {host}, {port}'.
128                          format(host=node['host'], port=node['port']))
129             ssh = SSH.__existing_connections.pop(node_hash)
130             ssh.close()
131
132     def _reconnect(self, attempts=0):
133         """Close the SSH connection and open it again.
134
135         :param attempts: Number of reconnect attempts.
136         :type attempts: int
137         """
138         node = self._node
139         self.disconnect(node)
140         self.connect(node, attempts)
141         logger.debug('Reconnecting peer done: {host}, {port}'.
142                      format(host=node['host'], port=node['port']))
143
144     def exec_command(self, cmd, timeout=10, log_stdout_err=True):
145         """Execute SSH command on a new channel on the connected Node.
146
147         :param cmd: Command to run on the Node.
148         :param timeout: Maximal time in seconds to wait until the command is
149             done. If set to None then wait forever.
150         :param log_stdout_err: If True, stdout and stderr are logged. stdout
151             and stderr are logged also if the return code is not zero
152             independently of the value of log_stdout_err.
153         :type cmd: str or OptionString
154         :type timeout: int
155         :type log_stdout_err: bool
156         :returns: return_code, stdout, stderr
157         :rtype: tuple(int, str, str)
158         :raises SSHTimeout: If command is not finished in timeout time.
159         """
160         if isinstance(cmd, (list, tuple)):
161             cmd = OptionString(cmd)
162         cmd = str(cmd)
163         stdout = StringIO.StringIO()
164         stderr = StringIO.StringIO()
165         try:
166             chan = self._ssh.get_transport().open_session(timeout=5)
167             peer = self._ssh.get_transport().getpeername()
168         except (AttributeError, SSHException):
169             self._reconnect()
170             chan = self._ssh.get_transport().open_session(timeout=5)
171             peer = self._ssh.get_transport().getpeername()
172         chan.settimeout(timeout)
173
174         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
175                      .format(peer=peer, timeout=timeout, cmd=cmd))
176
177         start = time()
178         chan.exec_command(cmd)
179         while not chan.exit_status_ready() and timeout is not None:
180             if chan.recv_ready():
181                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
182
183             if chan.recv_stderr_ready():
184                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
185
186             if time() - start > timeout:
187                 raise SSHTimeout(
188                     'Timeout exception during execution of command: {cmd}\n'
189                     'Current contents of stdout buffer: {stdout}\n'
190                     'Current contents of stderr buffer: {stderr}\n'
191                     .format(cmd=cmd, stdout=stdout.getvalue(),
192                             stderr=stderr.getvalue())
193                 )
194
195             sleep(0.1)
196         return_code = chan.recv_exit_status()
197
198         while chan.recv_ready():
199             stdout.write(chan.recv(self.__MAX_RECV_BUF))
200
201         while chan.recv_stderr_ready():
202             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
203
204         end = time()
205         logger.trace('exec_command on {peer} took {total} seconds'.
206                      format(peer=peer, total=end-start))
207
208         logger.trace('return RC {rc}'.format(rc=return_code))
209         if log_stdout_err or int(return_code):
210             logger.trace('return STDOUT {stdout}'.
211                          format(stdout=stdout.getvalue()))
212             logger.trace('return STDERR {stderr}'.
213                          format(stderr=stderr.getvalue()))
214         return return_code, stdout.getvalue(), stderr.getvalue()
215
216     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30,
217                           log_stdout_err=True):
218         """Execute SSH command with sudo on a new channel on the connected Node.
219
220         :param cmd: Command to be executed.
221         :param cmd_input: Input redirected to the command.
222         :param timeout: Timeout.
223         :param log_stdout_err: If True, stdout and stderr are logged.
224         :type cmd: str
225         :type cmd_input: str
226         :type timeout: int
227         :type log_stdout_err: bool
228         :returns: return_code, stdout, stderr
229         :rtype: tuple(int, str, str)
230
231         :Example:
232
233         >>> from ssh import SSH
234         >>> ssh = SSH()
235         >>> ssh.connect(node)
236         >>> # Execute command without input (sudo -S cmd)
237         >>> ssh.exec_command_sudo("ifconfig eth0 down")
238         >>> # Execute command with input (sudo -S cmd <<< "input")
239         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
240         """
241         if isinstance(cmd, (list, tuple)):
242             cmd = OptionString(cmd)
243         if cmd_input is None:
244             command = 'sudo -S {c}'.format(c=cmd)
245         else:
246             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
247         return self.exec_command(command, timeout,
248                                  log_stdout_err=log_stdout_err)
249
250     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
251                          timeout=30):
252         """Execute command in LXC on a new SSH channel on the connected Node.
253
254         :param lxc_cmd: Command to be executed.
255         :param lxc_name: LXC name.
256         :param lxc_params: Additional parameters for LXC attach.
257         :param sudo: Run in privileged LXC mode. Default: privileged
258         :param timeout: Timeout.
259         :type lxc_cmd: str
260         :type lxc_name: str
261         :type lxc_params: str
262         :type sudo: bool
263         :type timeout: int
264         :returns: return_code, stdout, stderr
265         """
266         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
267             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
268
269         if sudo:
270             command = 'sudo -S {c}'.format(c=command)
271         return self.exec_command(command, timeout)
272
273     def interactive_terminal_open(self, time_out=45):
274         """Open interactive terminal on a new channel on the connected Node.
275
276         :param time_out: Timeout in seconds.
277         :returns: SSH channel with opened terminal.
278
279         .. warning:: Interruptingcow is used here, and it uses
280            signal(SIGALRM) to let the operating system interrupt program
281            execution. This has the following limitations: Python signal
282            handlers only apply to the main thread, so you cannot use this
283            from other threads. You must not use this in a program that
284            uses SIGALRM itself (this includes certain profilers)
285         """
286         chan = self._ssh.get_transport().open_session()
287         chan.get_pty()
288         chan.invoke_shell()
289         chan.settimeout(int(time_out))
290         chan.set_combine_stderr(True)
291
292         buf = ''
293         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
294             try:
295                 chunk = chan.recv(self.__MAX_RECV_BUF)
296                 if not chunk:
297                     break
298                 buf += chunk
299                 if chan.exit_status_ready():
300                     logger.error('Channel exit status ready')
301                     break
302             except socket.timeout as exc:
303                 raise_from(Exception('Socket timeout: {0}'.format(buf)), exc)
304         return chan
305
306     def interactive_terminal_exec_command(self, chan, cmd, prompt):
307         """Execute command on interactive terminal.
308
309         interactive_terminal_open() method has to be called first!
310
311         :param chan: SSH channel with opened terminal.
312         :param cmd: Command to be executed.
313         :param prompt: Command prompt, sequence of characters used to
314         indicate readiness to accept commands.
315         :returns: Command output.
316
317         .. warning:: Interruptingcow is used here, and it uses
318            signal(SIGALRM) to let the operating system interrupt program
319            execution. This has the following limitations: Python signal
320            handlers only apply to the main thread, so you cannot use this
321            from other threads. You must not use this in a program that
322            uses SIGALRM itself (this includes certain profilers)
323         """
324         chan.sendall('{c}\n'.format(c=cmd))
325         buf = ''
326         while not buf.endswith(prompt):
327             try:
328                 chunk = chan.recv(self.__MAX_RECV_BUF)
329                 if not chunk:
330                     break
331                 buf += chunk
332                 if chan.exit_status_ready():
333                     logger.error('Channel exit status ready')
334                     break
335             except socket.timeout as exc:
336                 raise_from(Exception(
337                     'Socket timeout during execution of command: '
338                     '{0}\nBuffer content:\n{1}'.format(cmd, buf)), exc)
339         tmp = buf.replace(cmd.replace('\n', ''), '')
340         for item in prompt:
341             tmp.replace(item, '')
342         return tmp
343
344     @staticmethod
345     def interactive_terminal_close(chan):
346         """Close interactive terminal SSH channel.
347
348         :param chan: SSH channel to be closed.
349         """
350         chan.close()
351
352     def scp(self, local_path, remote_path, get=False, timeout=30,
353             wildcard=False):
354         """Copy files from local_path to remote_path or vice versa.
355
356         connect() method has to be called first!
357
358         :param local_path: Path to local file that should be uploaded; or
359         path where to save remote file.
360         :param remote_path: Remote path where to place uploaded file; or
361         path to remote file which should be downloaded.
362         :param get: scp operation to perform. Default is put.
363         :param timeout: Timeout value in seconds.
364         :param wildcard: If path has wildcard characters. Default is false.
365         :type local_path: str
366         :type remote_path: str
367         :type get: bool
368         :type timeout: int
369         :type wildcard: bool
370         """
371         if not get:
372             logger.trace('SCP {0} to {1}:{2}'.format(
373                 local_path, self._ssh.get_transport().getpeername(),
374                 remote_path))
375         else:
376             logger.trace('SCP {0}:{1} to {2}'.format(
377                 self._ssh.get_transport().getpeername(), remote_path,
378                 local_path))
379         # SCPCLient takes a paramiko transport as its only argument
380         if not wildcard:
381             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
382         else:
383             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
384                             socket_timeout=timeout)
385         start = time()
386         if not get:
387             scp.put(local_path, remote_path)
388         else:
389             scp.get(remote_path, local_path)
390         scp.close()
391         end = time()
392         logger.trace('SCP took {0} seconds'.format(end-start))
393
394
395 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
396     """Convenience function to ssh/exec/return rc, out & err.
397
398     Returns (rc, stdout, stderr).
399
400     :param node: The node to execute command on.
401     :param cmd: Command to execute.
402     :param timeout: Timeout value in seconds. Default: 600.
403     :param sudo: Sudo privilege execution flag. Default: False.
404     :param disconnect: Close the opened SSH connection if True.
405     :type node: dict
406     :type cmd: str or OptionString
407     :type timeout: int
408     :type sudo: bool
409     :type disconnect: bool
410     :returns: RC, Stdout, Stderr.
411     :rtype: tuple(int, str, str)
412     """
413     if node is None:
414         raise TypeError('Node parameter is None')
415     if cmd is None:
416         raise TypeError('Command parameter is None')
417     if not cmd:
418         raise ValueError('Empty command parameter')
419
420     ssh = SSH()
421
422     if node.get('host_port') is not None:
423         ssh_node = dict()
424         ssh_node['host'] = '127.0.0.1'
425         ssh_node['port'] = node['port']
426         ssh_node['username'] = node['username']
427         ssh_node['password'] = node['password']
428         import pexpect
429         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
430         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
431         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
432             format(tnl=tnl, op=options, user=node['host_username'],
433                    host=node['host'], host_port=node['host_port'])
434         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
435                      format(ssh_cmd=ssh_cmd))
436         child = pexpect.spawn(ssh_cmd)
437         child.expect('.* password: ')
438         logger.trace(child.after)
439         child.sendline(node['host_password'])
440         child.expect('Welcome .*')
441         logger.trace(child.after)
442         logger.trace('Local port forwarding finished.')
443     else:
444         ssh_node = node
445
446     try:
447         ssh.connect(ssh_node)
448     except SSHException as err:
449         logger.error("Failed to connect to node" + repr(err))
450         return None, None, None
451
452     try:
453         if not sudo:
454             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
455         else:
456             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
457                                                                timeout=timeout)
458     except SSHException as err:
459         logger.error(repr(err))
460         return None, None, None
461     finally:
462         if disconnect:
463             ssh.disconnect()
464
465     return ret_code, stdout, stderr
466
467
468 def exec_cmd_no_error(
469         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
470         retries=0, include_reason=False):
471     """Convenience function to ssh/exec/return out & err.
472
473     Verifies that return code is zero.
474     Supports retries, timeout is related to each try separately then. There is
475     sleep(1) before each retry.
476     Disconnect (if enabled) is applied after each try.
477
478     :param node: DUT node.
479     :param cmd: Command to be executed.
480     :param timeout: Timeout value in seconds. Default: 600.
481     :param sudo: Sudo privilege execution flag. Default: False.
482     :param message: Error message in case of failure. Default: None.
483     :param disconnect: Close the opened SSH connection if True.
484     :param retries: How many times to retry on failure.
485     :param include_reason: Whether default info should be appended to message.
486     :type node: dict
487     :type cmd: str or OptionString
488     :type timeout: int
489     :type sudo: bool
490     :type message: str
491     :type disconnect: bool
492     :type retries: int
493     :type include_reason: bool
494     :returns: Stdout, Stderr.
495     :rtype: tuple(str, str)
496     :raises RuntimeError: If bash return code is not 0.
497     """
498     for _ in range(retries + 1):
499         ret_code, stdout, stderr = exec_cmd(
500             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
501         if ret_code == 0:
502             break
503         sleep(1)
504     else:
505         msg = 'Command execution failed: "{cmd}"\nRC: {rc}\n{stderr}'.format(
506             cmd=cmd, rc=ret_code, stderr=stderr)
507         logger.info(msg)
508         if message:
509             if include_reason:
510                 msg = message + '\n' + msg
511             else:
512                 msg = message
513         raise RuntimeError(msg)
514
515     return stdout, stderr
516
517
518 def scp_node(
519         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
520     """Copy files from local_path to remote_path or vice versa.
521
522     :param node: SUT node.
523     :param local_path: Path to local file that should be uploaded; or
524         path where to save remote file.
525     :param remote_path: Remote path where to place uploaded file; or
526         path to remote file which should be downloaded.
527     :param get: scp operation to perform. Default is put.
528     :param timeout: Timeout value in seconds.
529     :param disconnect: Close the opened SSH connection if True.
530     :type node: dict
531     :type local_path: str
532     :type remote_path: str
533     :type get: bool
534     :type timeout: int
535     :type disconnect: bool
536     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
537     """
538     ssh = SSH()
539
540     try:
541         ssh.connect(node)
542     except SSHException as exc:
543         raise_from(RuntimeError(
544             'Failed to connect to {host}!'.format(host=node['host'])), exc)
545     try:
546         ssh.scp(local_path, remote_path, get, timeout)
547     except SCPException as exc:
548         raise_from(RuntimeError(
549             'SCP execution failed on {host}!'.format(host=node['host'])), exc)
550     finally:
551         if disconnect:
552             ssh.disconnect()