Python3: resources and libraries
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18
19 from io import StringIO
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
26
27 from resources.libraries.python.OptionString import OptionString
28
29 __all__ = [
30     u"exec_cmd", u"exec_cmd_no_error", u"SSH", u"SSHTimeout", u"scp_node"
31 ]
32
33 # TODO: load priv key
34
35
36 class SSHTimeout(Exception):
37     """This exception is raised when a timeout occurs."""
38
39
40 class SSH:
41     """Contains methods for managing and using SSH connections."""
42
43     __MAX_RECV_BUF = 10 * 1024 * 1024
44     __existing_connections = dict()
45
46     def __init__(self):
47         self._ssh = None
48         self._node = None
49
50     @staticmethod
51     def _node_hash(node):
52         """Get IP address and port hash from node dictionary.
53
54         :param node: Node in topology.
55         :type node: dict
56         :returns: IP address and port for the specified node.
57         :rtype: int
58         """
59         return hash(frozenset([node[u"host"], node[u"port"]]))
60
61     def connect(self, node, attempts=5):
62         """Connect to node prior to running exec_command or scp.
63
64         If there already is a connection to the node, this method reuses it.
65
66         :param node: Node in topology.
67         :param attempts: Number of reconnect attempts.
68         :type node: dict
69         :type attempts: int
70         :raises IOError: If cannot connect to host.
71         """
72         self._node = node
73         node_hash = self._node_hash(node)
74         if node_hash in SSH.__existing_connections:
75             self._ssh = SSH.__existing_connections[node_hash]
76             if self._ssh.get_transport().is_active():
77                 logger.debug(f"Reusing SSH: {self._ssh}")
78             else:
79                 if attempts > 0:
80                     self._reconnect(attempts-1)
81                 else:
82                     raise IOError(f"Cannot connect to {node['host']}")
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if u"priv_key" in node:
88                     pkey = RSAKey.from_private_key(StringIO(node[u"priv_key"]))
89
90                 self._ssh = SSHClient()
91                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
92
93                 self._ssh.connect(
94                     node[u"host"], username=node[u"username"],
95                     password=node.get(u"password"), pkey=pkey,
96                     port=node[u"port"]
97                 )
98
99                 self._ssh.get_transport().set_keepalive(10)
100
101                 SSH.__existing_connections[node_hash] = self._ssh
102                 logger.debug(
103                     f"New SSH to {self._ssh.get_transport().getpeername()} "
104                     f"took {time() - start} seconds: {self._ssh}"
105                 )
106             except SSHException as exc:
107                 raise IOError(f"Cannot connect to {node[u'host']}") from exc
108             except NoValidConnectionsError as err:
109                 raise IOError(
110                     f"Unable to connect to port {node[u'port']} on "
111                     f"{node[u'host']}"
112                 ) from err
113
114     def disconnect(self, node=None):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from. None means last connected.
118         :type node: dict or None
119         """
120         if node is None:
121             node = self._node
122         if node is None:
123             return
124         node_hash = self._node_hash(node)
125         if node_hash in SSH.__existing_connections:
126             logger.debug(
127                 f"Disconnecting peer: {node[u'host']}, {node[u'port']}"
128             )
129             ssh = SSH.__existing_connections.pop(node_hash)
130             ssh.close()
131
132     def _reconnect(self, attempts=0):
133         """Close the SSH connection and open it again.
134
135         :param attempts: Number of reconnect attempts.
136         :type attempts: int
137         """
138         node = self._node
139         self.disconnect(node)
140         self.connect(node, attempts)
141         logger.debug(
142             f"Reconnecting peer done: {node[u'host']}, {node[u'port']}"
143         )
144
145     def exec_command(self, cmd, timeout=10, log_stdout_err=True):
146         """Execute SSH command on a new channel on the connected Node.
147
148         :param cmd: Command to run on the Node.
149         :param timeout: Maximal time in seconds to wait until the command is
150             done. If set to None then wait forever.
151         :param log_stdout_err: If True, stdout and stderr are logged. stdout
152             and stderr are logged also if the return code is not zero
153             independently of the value of log_stdout_err.
154         :type cmd: str or OptionString
155         :type timeout: int
156         :type log_stdout_err: bool
157         :returns: return_code, stdout, stderr
158         :rtype: tuple(int, str, str)
159         :raises SSHTimeout: If command is not finished in timeout time.
160         """
161         if isinstance(cmd, (list, tuple)):
162             cmd = OptionString(cmd)
163         cmd = str(cmd)
164         stdout = u""
165         stderr = u""
166         try:
167             chan = self._ssh.get_transport().open_session(timeout=5)
168             peer = self._ssh.get_transport().getpeername()
169         except (AttributeError, SSHException):
170             self._reconnect()
171             chan = self._ssh.get_transport().open_session(timeout=5)
172             peer = self._ssh.get_transport().getpeername()
173         chan.settimeout(timeout)
174
175         logger.trace(f"exec_command on {peer} with timeout {timeout}: {cmd}")
176
177         start = time()
178         chan.exec_command(cmd)
179         while not chan.exit_status_ready() and timeout is not None:
180             if chan.recv_ready():
181                 s_out = chan.recv(self.__MAX_RECV_BUF)
182                 stdout += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
183                     if isinstance(s_out, bytes) else s_out
184
185             if chan.recv_stderr_ready():
186                 s_err = chan.recv_stderr(self.__MAX_RECV_BUF)
187                 stderr += s_err.decode(encoding=u'utf-8', errors=u'ignore') \
188                     if isinstance(s_err, bytes) else s_err
189
190             if time() - start > timeout:
191                 raise SSHTimeout(
192                     f"Timeout exception during execution of command: {cmd}\n"
193                     f"Current contents of stdout buffer: "
194                     f"{stdout}\n"
195                     f"Current contents of stderr buffer: "
196                     f"{stderr}\n"
197                 )
198
199             sleep(0.1)
200         return_code = chan.recv_exit_status()
201
202         while chan.recv_ready():
203             s_out = chan.recv(self.__MAX_RECV_BUF)
204             stdout += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
205                 if isinstance(s_out, bytes) else s_out
206
207         while chan.recv_stderr_ready():
208             s_err = chan.recv_stderr(self.__MAX_RECV_BUF)
209             stderr += s_err.decode(encoding=u'utf-8', errors=u'ignore') \
210                 if isinstance(s_err, bytes) else s_err
211
212         end = time()
213         logger.trace(f"exec_command on {peer} took {end-start} seconds")
214
215         logger.trace(f"return RC {return_code}")
216         if log_stdout_err or int(return_code):
217             logger.trace(
218                 f"return STDOUT {stdout}"
219             )
220             logger.trace(
221                 f"return STDERR {stderr}"
222             )
223         return return_code, stdout, stderr
224
225     def exec_command_sudo(
226             self, cmd, cmd_input=None, timeout=30, log_stdout_err=True):
227         """Execute SSH command with sudo on a new channel on the connected Node.
228
229         :param cmd: Command to be executed.
230         :param cmd_input: Input redirected to the command.
231         :param timeout: Timeout.
232         :param log_stdout_err: If True, stdout and stderr are logged.
233         :type cmd: str
234         :type cmd_input: str
235         :type timeout: int
236         :type log_stdout_err: bool
237         :returns: return_code, stdout, stderr
238         :rtype: tuple(int, str, str)
239
240         :Example:
241
242         >>> from ssh import SSH
243         >>> ssh = SSH()
244         >>> ssh.connect(node)
245         >>> # Execute command without input (sudo -S cmd)
246         >>> ssh.exec_command_sudo(u"ifconfig eth0 down")
247         >>> # Execute command with input (sudo -S cmd <<< 'input')
248         >>> ssh.exec_command_sudo(u"vpp_api_test", u"dump_interface_table")
249         """
250         if isinstance(cmd, (list, tuple)):
251             cmd = OptionString(cmd)
252         if cmd_input is None:
253             command = f"sudo -E -S {cmd}"
254         else:
255             command = f"sudo -E -S {cmd} <<< \"{cmd_input}\""
256         return self.exec_command(
257             command, timeout, log_stdout_err=log_stdout_err
258         )
259
260     def exec_command_lxc(
261             self, lxc_cmd, lxc_name, lxc_params=u"", sudo=True, timeout=30):
262         """Execute command in LXC on a new SSH channel on the connected Node.
263
264         :param lxc_cmd: Command to be executed.
265         :param lxc_name: LXC name.
266         :param lxc_params: Additional parameters for LXC attach.
267         :param sudo: Run in privileged LXC mode. Default: privileged
268         :param timeout: Timeout.
269         :type lxc_cmd: str
270         :type lxc_name: str
271         :type lxc_params: str
272         :type sudo: bool
273         :type timeout: int
274         :returns: return_code, stdout, stderr
275         """
276         command = f"lxc-attach {lxc_params} --name {lxc_name} -- /bin/sh " \
277             f"-c \"{lxc_cmd}\""
278
279         if sudo:
280             command = f"sudo -E -S {command}"
281         return self.exec_command(command, timeout)
282
283     def interactive_terminal_open(self, time_out=45):
284         """Open interactive terminal on a new channel on the connected Node.
285
286         :param time_out: Timeout in seconds.
287         :returns: SSH channel with opened terminal.
288
289         .. warning:: Interruptingcow is used here, and it uses
290            signal(SIGALRM) to let the operating system interrupt program
291            execution. This has the following limitations: Python signal
292            handlers only apply to the main thread, so you cannot use this
293            from other threads. You must not use this in a program that
294            uses SIGALRM itself (this includes certain profilers)
295         """
296         chan = self._ssh.get_transport().open_session()
297         chan.get_pty()
298         chan.invoke_shell()
299         chan.settimeout(int(time_out))
300         chan.set_combine_stderr(True)
301
302         buf = u""
303         while not buf.endswith((u":~# ", u":~$ ", u"~]$ ", u"~]# ")):
304             try:
305                 chunk = chan.recv(self.__MAX_RECV_BUF)
306                 if not chunk:
307                     break
308                 buf += chunk
309                 if chan.exit_status_ready():
310                     logger.error(u"Channel exit status ready")
311                     break
312             except socket.timeout as exc:
313                 raise Exception(f"Socket timeout: {buf}") from exc
314         return chan
315
316     def interactive_terminal_exec_command(self, chan, cmd, prompt):
317         """Execute command on interactive terminal.
318
319         interactive_terminal_open() method has to be called first!
320
321         :param chan: SSH channel with opened terminal.
322         :param cmd: Command to be executed.
323         :param prompt: Command prompt, sequence of characters used to
324         indicate readiness to accept commands.
325         :returns: Command output.
326
327         .. warning:: Interruptingcow is used here, and it uses
328            signal(SIGALRM) to let the operating system interrupt program
329            execution. This has the following limitations: Python signal
330            handlers only apply to the main thread, so you cannot use this
331            from other threads. You must not use this in a program that
332            uses SIGALRM itself (this includes certain profilers)
333         """
334         chan.sendall(f"{cmd}\n")
335         buf = u""
336         while not buf.endswith(prompt):
337             try:
338                 chunk = chan.recv(self.__MAX_RECV_BUF)
339                 if not chunk:
340                     break
341                 buf += chunk
342                 if chan.exit_status_ready():
343                     logger.error(u"Channel exit status ready")
344                     break
345             except socket.timeout as exc:
346                 raise Exception(
347                     f"Socket timeout during execution of command: {cmd}\n"
348                     f"Buffer content:\n{buf}"
349                 ) from exc
350         tmp = buf.replace(cmd.replace(u"\n", u""), u"")
351         for item in prompt:
352             tmp.replace(item, u"")
353         return tmp
354
355     @staticmethod
356     def interactive_terminal_close(chan):
357         """Close interactive terminal SSH channel.
358
359         :param chan: SSH channel to be closed.
360         """
361         chan.close()
362
363     def scp(
364             self, local_path, remote_path, get=False, timeout=30,
365             wildcard=False):
366         """Copy files from local_path to remote_path or vice versa.
367
368         connect() method has to be called first!
369
370         :param local_path: Path to local file that should be uploaded; or
371         path where to save remote file.
372         :param remote_path: Remote path where to place uploaded file; or
373         path to remote file which should be downloaded.
374         :param get: scp operation to perform. Default is put.
375         :param timeout: Timeout value in seconds.
376         :param wildcard: If path has wildcard characters. Default is false.
377         :type local_path: str
378         :type remote_path: str
379         :type get: bool
380         :type timeout: int
381         :type wildcard: bool
382         """
383         if not get:
384             logger.trace(
385                 f"SCP {local_path} to "
386                 f"{self._ssh.get_transport().getpeername()}:{remote_path}"
387             )
388         else:
389             logger.trace(
390                 f"SCP {self._ssh.get_transport().getpeername()}:{remote_path} "
391                 f"to {local_path}"
392             )
393         # SCPCLient takes a paramiko transport as its only argument
394         if not wildcard:
395             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
396         else:
397             scp = SCPClient(
398                 self._ssh.get_transport(), sanitize=lambda x: x,
399                 socket_timeout=timeout
400             )
401         start = time()
402         if not get:
403             scp.put(local_path, remote_path)
404         else:
405             scp.get(remote_path, local_path)
406         scp.close()
407         end = time()
408         logger.trace(f"SCP took {end-start} seconds")
409
410
411 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
412     """Convenience function to ssh/exec/return rc, out & err.
413
414     Returns (rc, stdout, stderr).
415
416     :param node: The node to execute command on.
417     :param cmd: Command to execute.
418     :param timeout: Timeout value in seconds. Default: 600.
419     :param sudo: Sudo privilege execution flag. Default: False.
420     :param disconnect: Close the opened SSH connection if True.
421     :type node: dict
422     :type cmd: str or OptionString
423     :type timeout: int
424     :type sudo: bool
425     :type disconnect: bool
426     :returns: RC, Stdout, Stderr.
427     :rtype: tuple(int, str, str)
428     """
429     if node is None:
430         raise TypeError(u"Node parameter is None")
431     if cmd is None:
432         raise TypeError(u"Command parameter is None")
433     if not cmd:
434         raise ValueError(u"Empty command parameter")
435
436     ssh = SSH()
437
438     try:
439         ssh.connect(node)
440     except SSHException as err:
441         logger.error(f"Failed to connect to node {node[u'host']}\n{err!r}")
442         return None, None, None
443
444     try:
445         if not sudo:
446             ret_code, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
447         else:
448             ret_code, stdout, stderr = ssh.exec_command_sudo(
449                 cmd, timeout=timeout
450             )
451     except SSHException as err:
452         logger.error(repr(err))
453         return None, None, None
454     finally:
455         if disconnect:
456             ssh.disconnect()
457
458     return ret_code, stdout, stderr
459
460
461 def exec_cmd_no_error(
462         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
463         retries=0, include_reason=False):
464     """Convenience function to ssh/exec/return out & err.
465
466     Verifies that return code is zero.
467     Supports retries, timeout is related to each try separately then. There is
468     sleep(1) before each retry.
469     Disconnect (if enabled) is applied after each try.
470
471     :param node: DUT node.
472     :param cmd: Command to be executed.
473     :param timeout: Timeout value in seconds. Default: 600.
474     :param sudo: Sudo privilege execution flag. Default: False.
475     :param message: Error message in case of failure. Default: None.
476     :param disconnect: Close the opened SSH connection if True.
477     :param retries: How many times to retry on failure.
478     :param include_reason: Whether default info should be appended to message.
479     :type node: dict
480     :type cmd: str or OptionString
481     :type timeout: int
482     :type sudo: bool
483     :type message: str
484     :type disconnect: bool
485     :type retries: int
486     :type include_reason: bool
487     :returns: Stdout, Stderr.
488     :rtype: tuple(str, str)
489     :raises RuntimeError: If bash return code is not 0.
490     """
491     for _ in range(retries + 1):
492         ret_code, stdout, stderr = exec_cmd(
493             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect
494         )
495         if ret_code == 0:
496             break
497         sleep(1)
498     else:
499         msg = f"Command execution failed: '{cmd}'\nRC: {ret_code}\n{stderr}"
500         logger.info(msg)
501         if message:
502             msg = f"{message}\n{msg}" if include_reason else message
503         raise RuntimeError(msg)
504
505     return stdout, stderr
506
507
508 def scp_node(
509         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
510     """Copy files from local_path to remote_path or vice versa.
511
512     :param node: SUT node.
513     :param local_path: Path to local file that should be uploaded; or
514         path where to save remote file.
515     :param remote_path: Remote path where to place uploaded file; or
516         path to remote file which should be downloaded.
517     :param get: scp operation to perform. Default is put.
518     :param timeout: Timeout value in seconds.
519     :param disconnect: Close the opened SSH connection if True.
520     :type node: dict
521     :type local_path: str
522     :type remote_path: str
523     :type get: bool
524     :type timeout: int
525     :type disconnect: bool
526     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
527     """
528     ssh = SSH()
529
530     try:
531         ssh.connect(node)
532     except SSHException as exc:
533         raise RuntimeError(f"Failed to connect to {node[u'host']}!") from exc
534     try:
535         ssh.scp(local_path, remote_path, get, timeout)
536     except SCPException as exc:
537         raise RuntimeError(f"SCP execution failed on {node[u'host']}!") from exc
538     finally:
539         if disconnect:
540             ssh.disconnect()