fix(ipsec): Prepare IPsecUtil for upcoming changes
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18
19 from io import StringIO
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
26
27 from resources.libraries.python.OptionString import OptionString
28
29 __all__ = [
30     u"exec_cmd", u"exec_cmd_no_error", u"SSH", u"SSHTimeout", u"scp_node"
31 ]
32
33 # TODO: load priv key
34
35
36 class SSHTimeout(Exception):
37     """This exception is raised when a timeout occurs."""
38
39
40 class SSH:
41     """Contains methods for managing and using SSH connections."""
42
43     __MAX_RECV_BUF = 10 * 1024 * 1024
44     __existing_connections = dict()
45
46     def __init__(self):
47         self._ssh = None
48         self._node = None
49
50     @staticmethod
51     def _node_hash(node):
52         """Get IP address and port hash from node dictionary.
53
54         :param node: Node in topology.
55         :type node: dict
56         :returns: IP address and port for the specified node.
57         :rtype: int
58         """
59         return hash(frozenset([node[u"host"], node[u"port"]]))
60
61     def connect(self, node, attempts=5):
62         """Connect to node prior to running exec_command or scp.
63
64         If there already is a connection to the node, this method reuses it.
65
66         :param node: Node in topology.
67         :param attempts: Number of reconnect attempts.
68         :type node: dict
69         :type attempts: int
70         :raises IOError: If cannot connect to host.
71         """
72         self._node = node
73         node_hash = self._node_hash(node)
74         if node_hash in SSH.__existing_connections:
75             self._ssh = SSH.__existing_connections[node_hash]
76             if self._ssh.get_transport().is_active():
77                 logger.debug(f"Reusing SSH: {self._ssh}")
78             else:
79                 if attempts > 0:
80                     self._reconnect(attempts-1)
81                 else:
82                     raise IOError(f"Cannot connect to {node['host']}")
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if u"priv_key" in node:
88                     pkey = RSAKey.from_private_key(StringIO(node[u"priv_key"]))
89
90                 self._ssh = SSHClient()
91                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
92
93                 self._ssh.connect(
94                     node[u"host"], username=node[u"username"],
95                     password=node.get(u"password"), pkey=pkey,
96                     port=node[u"port"]
97                 )
98
99                 self._ssh.get_transport().set_keepalive(10)
100
101                 SSH.__existing_connections[node_hash] = self._ssh
102                 logger.debug(
103                     f"New SSH to {self._ssh.get_transport().getpeername()} "
104                     f"took {time() - start} seconds: {self._ssh}"
105                 )
106             except SSHException as exc:
107                 raise IOError(f"Cannot connect to {node[u'host']}") from exc
108             except NoValidConnectionsError as err:
109                 raise IOError(
110                     f"Unable to connect to port {node[u'port']} on "
111                     f"{node[u'host']}"
112                 ) from err
113
114     def disconnect(self, node=None):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from. None means last connected.
118         :type node: dict or None
119         """
120         if node is None:
121             node = self._node
122         if node is None:
123             return
124         node_hash = self._node_hash(node)
125         if node_hash in SSH.__existing_connections:
126             logger.debug(
127                 f"Disconnecting peer: {node[u'host']}, {node[u'port']}"
128             )
129             ssh = SSH.__existing_connections.pop(node_hash)
130             ssh.close()
131
132     def _reconnect(self, attempts=0):
133         """Close the SSH connection and open it again.
134
135         :param attempts: Number of reconnect attempts.
136         :type attempts: int
137         """
138         node = self._node
139         self.disconnect(node)
140         self.connect(node, attempts)
141         logger.debug(
142             f"Reconnecting peer done: {node[u'host']}, {node[u'port']}"
143         )
144
145     def exec_command(self, cmd, timeout=10, log_stdout_err=True):
146         """Execute SSH command on a new channel on the connected Node.
147
148         :param cmd: Command to run on the Node.
149         :param timeout: Maximal time in seconds to wait until the command is
150             done. If set to None then wait forever.
151         :param log_stdout_err: If True, stdout and stderr are logged. stdout
152             and stderr are logged also if the return code is not zero
153             independently of the value of log_stdout_err.
154         :type cmd: str or OptionString
155         :type timeout: int
156         :type log_stdout_err: bool
157         :returns: return_code, stdout, stderr
158         :rtype: tuple(int, str, str)
159         :raises SSHTimeout: If command is not finished in timeout time.
160         """
161         if isinstance(cmd, (list, tuple)):
162             cmd = OptionString(cmd)
163         cmd = str(cmd)
164         stdout = u""
165         stderr = u""
166         try:
167             chan = self._ssh.get_transport().open_session(timeout=5)
168             peer = self._ssh.get_transport().getpeername()
169         except (AttributeError, SSHException):
170             self._reconnect()
171             chan = self._ssh.get_transport().open_session(timeout=5)
172             peer = self._ssh.get_transport().getpeername()
173         chan.settimeout(timeout)
174
175         logger.trace(f"exec_command on {peer} with timeout {timeout}: {cmd}")
176
177         start = time()
178         chan.exec_command(cmd)
179         while not chan.exit_status_ready() and timeout is not None:
180             if chan.recv_ready():
181                 s_out = chan.recv(self.__MAX_RECV_BUF)
182                 stdout += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
183                     if isinstance(s_out, bytes) else s_out
184
185             if chan.recv_stderr_ready():
186                 s_err = chan.recv_stderr(self.__MAX_RECV_BUF)
187                 stderr += s_err.decode(encoding=u'utf-8', errors=u'ignore') \
188                     if isinstance(s_err, bytes) else s_err
189
190             if time() - start > timeout:
191                 raise SSHTimeout(
192                     f"Timeout exception during execution of command: {cmd}\n"
193                     f"Current contents of stdout buffer: "
194                     f"{stdout}\n"
195                     f"Current contents of stderr buffer: "
196                     f"{stderr}\n"
197                 )
198
199             sleep(0.1)
200         return_code = chan.recv_exit_status()
201
202         while chan.recv_ready():
203             s_out = chan.recv(self.__MAX_RECV_BUF)
204             stdout += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
205                 if isinstance(s_out, bytes) else s_out
206
207         while chan.recv_stderr_ready():
208             s_err = chan.recv_stderr(self.__MAX_RECV_BUF)
209             stderr += s_err.decode(encoding=u'utf-8', errors=u'ignore') \
210                 if isinstance(s_err, bytes) else s_err
211
212         end = time()
213         logger.trace(f"exec_command on {peer} took {end-start} seconds")
214
215         logger.trace(f"return RC {return_code}")
216         if log_stdout_err or int(return_code):
217             logger.trace(
218                 f"return STDOUT {stdout}"
219             )
220             logger.trace(
221                 f"return STDERR {stderr}"
222             )
223         return return_code, stdout, stderr
224
225     def exec_command_sudo(
226             self, cmd, cmd_input=None, timeout=30, log_stdout_err=True):
227         """Execute SSH command with sudo on a new channel on the connected Node.
228
229         :param cmd: Command to be executed.
230         :param cmd_input: Input redirected to the command.
231         :param timeout: Timeout.
232         :param log_stdout_err: If True, stdout and stderr are logged.
233         :type cmd: str
234         :type cmd_input: str
235         :type timeout: int
236         :type log_stdout_err: bool
237         :returns: return_code, stdout, stderr
238         :rtype: tuple(int, str, str)
239
240         :Example:
241
242         >>> from ssh import SSH
243         >>> ssh = SSH()
244         >>> ssh.connect(node)
245         >>> # Execute command without input (sudo -S cmd)
246         >>> ssh.exec_command_sudo(u"ifconfig eth0 down")
247         >>> # Execute command with input (sudo -S cmd <<< 'input')
248         >>> ssh.exec_command_sudo(u"vpp_api_test", u"dump_interface_table")
249         """
250         if isinstance(cmd, (list, tuple)):
251             cmd = OptionString(cmd)
252         if cmd_input is None:
253             command = f"sudo -E -S {cmd}"
254         else:
255             command = f"sudo -E -S {cmd} <<< \"{cmd_input}\""
256         return self.exec_command(
257             command, timeout, log_stdout_err=log_stdout_err
258         )
259
260     def exec_command_lxc(
261             self, lxc_cmd, lxc_name, lxc_params=u"", sudo=True, timeout=30):
262         """Execute command in LXC on a new SSH channel on the connected Node.
263
264         :param lxc_cmd: Command to be executed.
265         :param lxc_name: LXC name.
266         :param lxc_params: Additional parameters for LXC attach.
267         :param sudo: Run in privileged LXC mode. Default: privileged
268         :param timeout: Timeout.
269         :type lxc_cmd: str
270         :type lxc_name: str
271         :type lxc_params: str
272         :type sudo: bool
273         :type timeout: int
274         :returns: return_code, stdout, stderr
275         """
276         command = f"lxc-attach {lxc_params} --name {lxc_name} -- /bin/sh " \
277             f"-c \"{lxc_cmd}\""
278
279         if sudo:
280             command = f"sudo -E -S {command}"
281         return self.exec_command(command, timeout)
282
283     def interactive_terminal_open(self, time_out=45):
284         """Open interactive terminal on a new channel on the connected Node.
285
286         :param time_out: Timeout in seconds.
287         :returns: SSH channel with opened terminal.
288
289         .. warning:: Interruptingcow is used here, and it uses
290            signal(SIGALRM) to let the operating system interrupt program
291            execution. This has the following limitations: Python signal
292            handlers only apply to the main thread, so you cannot use this
293            from other threads. You must not use this in a program that
294            uses SIGALRM itself (this includes certain profilers)
295         """
296         chan = self._ssh.get_transport().open_session()
297         chan.get_pty()
298         chan.invoke_shell()
299         chan.settimeout(int(time_out))
300         chan.set_combine_stderr(True)
301
302         buf = u""
303         while not buf.endswith((u":~# ", u":~$ ", u"~]$ ", u"~]# ")):
304             try:
305                 s_out = chan.recv(self.__MAX_RECV_BUF)
306                 if not s_out:
307                     break
308                 buf += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
309                     if isinstance(s_out, bytes) else s_out
310                 if chan.exit_status_ready():
311                     logger.error(u"Channel exit status ready")
312                     break
313             except socket.timeout as exc:
314                 raise Exception(f"Socket timeout: {buf}") from exc
315         return chan
316
317     def interactive_terminal_exec_command(self, chan, cmd, prompt):
318         """Execute command on interactive terminal.
319
320         interactive_terminal_open() method has to be called first!
321
322         :param chan: SSH channel with opened terminal.
323         :param cmd: Command to be executed.
324         :param prompt: Command prompt, sequence of characters used to
325         indicate readiness to accept commands.
326         :returns: Command output.
327
328         .. warning:: Interruptingcow is used here, and it uses
329            signal(SIGALRM) to let the operating system interrupt program
330            execution. This has the following limitations: Python signal
331            handlers only apply to the main thread, so you cannot use this
332            from other threads. You must not use this in a program that
333            uses SIGALRM itself (this includes certain profilers)
334         """
335         chan.sendall(f"{cmd}\n")
336         buf = u""
337         while not buf.endswith(prompt):
338             try:
339                 s_out = chan.recv(self.__MAX_RECV_BUF)
340                 if not s_out:
341                     break
342                 buf += s_out.decode(encoding=u'utf-8', errors=u'ignore') \
343                     if isinstance(s_out, bytes) else s_out
344                 if chan.exit_status_ready():
345                     logger.error(u"Channel exit status ready")
346                     break
347             except socket.timeout as exc:
348                 raise Exception(
349                     f"Socket timeout during execution of command: {cmd}\n"
350                     f"Buffer content:\n{buf}"
351                 ) from exc
352         tmp = buf.replace(cmd.replace(u"\n", u""), u"")
353         for item in prompt:
354             tmp.replace(item, u"")
355         return tmp
356
357     @staticmethod
358     def interactive_terminal_close(chan):
359         """Close interactive terminal SSH channel.
360
361         :param chan: SSH channel to be closed.
362         """
363         chan.close()
364
365     def scp(
366             self, local_path, remote_path, get=False, timeout=30,
367             wildcard=False):
368         """Copy files from local_path to remote_path or vice versa.
369
370         connect() method has to be called first!
371
372         :param local_path: Path to local file that should be uploaded; or
373         path where to save remote file.
374         :param remote_path: Remote path where to place uploaded file; or
375         path to remote file which should be downloaded.
376         :param get: scp operation to perform. Default is put.
377         :param timeout: Timeout value in seconds.
378         :param wildcard: If path has wildcard characters. Default is false.
379         :type local_path: str
380         :type remote_path: str
381         :type get: bool
382         :type timeout: int
383         :type wildcard: bool
384         """
385         if not get:
386             logger.trace(
387                 f"SCP {local_path} to "
388                 f"{self._ssh.get_transport().getpeername()}:{remote_path}"
389             )
390         else:
391             logger.trace(
392                 f"SCP {self._ssh.get_transport().getpeername()}:{remote_path} "
393                 f"to {local_path}"
394             )
395         # SCPCLient takes a paramiko transport as its only argument
396         if not wildcard:
397             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
398         else:
399             scp = SCPClient(
400                 self._ssh.get_transport(), sanitize=lambda x: x,
401                 socket_timeout=timeout
402             )
403         start = time()
404         if not get:
405             scp.put(local_path, remote_path)
406         else:
407             scp.get(remote_path, local_path)
408         scp.close()
409         end = time()
410         logger.trace(f"SCP took {end-start} seconds")
411
412
413 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
414     """Convenience function to ssh/exec/return rc, out & err.
415
416     Returns (rc, stdout, stderr).
417
418     :param node: The node to execute command on.
419     :param cmd: Command to execute.
420     :param timeout: Timeout value in seconds. Default: 600.
421     :param sudo: Sudo privilege execution flag. Default: False.
422     :param disconnect: Close the opened SSH connection if True.
423     :type node: dict
424     :type cmd: str or OptionString
425     :type timeout: int
426     :type sudo: bool
427     :type disconnect: bool
428     :returns: RC, Stdout, Stderr.
429     :rtype: tuple(int, str, str)
430     """
431     if node is None:
432         raise TypeError(u"Node parameter is None")
433     if cmd is None:
434         raise TypeError(u"Command parameter is None")
435     if not cmd:
436         raise ValueError(u"Empty command parameter")
437
438     ssh = SSH()
439
440     try:
441         ssh.connect(node)
442     except SSHException as err:
443         logger.error(f"Failed to connect to node {node[u'host']}\n{err!r}")
444         return None, None, None
445
446     try:
447         if not sudo:
448             ret_code, stdout, stderr = ssh.exec_command(cmd, timeout=timeout)
449         else:
450             ret_code, stdout, stderr = ssh.exec_command_sudo(
451                 cmd, timeout=timeout
452             )
453     except SSHException as err:
454         logger.error(repr(err))
455         return None, None, None
456     finally:
457         if disconnect:
458             ssh.disconnect()
459
460     return ret_code, stdout, stderr
461
462
463 def exec_cmd_no_error(
464         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
465         retries=0, include_reason=False):
466     """Convenience function to ssh/exec/return out & err.
467
468     Verifies that return code is zero.
469     Supports retries, timeout is related to each try separately then. There is
470     sleep(1) before each retry.
471     Disconnect (if enabled) is applied after each try.
472
473     :param node: DUT node.
474     :param cmd: Command to be executed.
475     :param timeout: Timeout value in seconds. Default: 600.
476     :param sudo: Sudo privilege execution flag. Default: False.
477     :param message: Error message in case of failure. Default: None.
478     :param disconnect: Close the opened SSH connection if True.
479     :param retries: How many times to retry on failure.
480     :param include_reason: Whether default info should be appended to message.
481     :type node: dict
482     :type cmd: str or OptionString
483     :type timeout: int
484     :type sudo: bool
485     :type message: str
486     :type disconnect: bool
487     :type retries: int
488     :type include_reason: bool
489     :returns: Stdout, Stderr.
490     :rtype: tuple(str, str)
491     :raises RuntimeError: If bash return code is not 0.
492     """
493     for _ in range(retries + 1):
494         ret_code, stdout, stderr = exec_cmd(
495             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect
496         )
497         if ret_code == 0:
498             break
499         sleep(1)
500     else:
501         msg = f"Command execution failed: '{cmd}'\nRC: {ret_code}\n{stderr}"
502         logger.info(msg)
503         if message:
504             msg = f"{message}\n{msg}" if include_reason else message
505         raise RuntimeError(msg)
506
507     return stdout, stderr
508
509
510 def scp_node(
511         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
512     """Copy files from local_path to remote_path or vice versa.
513
514     :param node: SUT node.
515     :param local_path: Path to local file that should be uploaded; or
516         path where to save remote file.
517     :param remote_path: Remote path where to place uploaded file; or
518         path to remote file which should be downloaded.
519     :param get: scp operation to perform. Default is put.
520     :param timeout: Timeout value in seconds.
521     :param disconnect: Close the opened SSH connection if True.
522     :type node: dict
523     :type local_path: str
524     :type remote_path: str
525     :type get: bool
526     :type timeout: int
527     :type disconnect: bool
528     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
529     """
530     ssh = SSH()
531
532     try:
533         ssh.connect(node)
534     except SSHException as exc:
535         raise RuntimeError(f"Failed to connect to {node[u'host']}!") from exc
536     try:
537         ssh.scp(local_path, remote_path, get, timeout)
538     except SCPException as exc:
539         raise RuntimeError(f"SCP execution failed on {node[u'host']}!") from exc
540     finally:
541         if disconnect:
542             ssh.disconnect()