PAPI: Reduce the amount of logged information
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
26
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 class SSHTimeout(Exception):
34     """This exception is raised when a timeout occurs."""
35     pass
36
37
38 class SSH(object):
39     """Contains methods for managing and using SSH connections."""
40
41     __MAX_RECV_BUF = 10*1024*1024
42     __existing_connections = {}
43
44     def __init__(self):
45         self._ssh = None
46         self._node = None
47
48     @staticmethod
49     def _node_hash(node):
50         """Get IP address and port hash from node dictionary.
51
52         :param node: Node in topology.
53         :type node: dict
54         :returns: IP address and port for the specified node.
55         :rtype: int
56         """
57
58         return hash(frozenset([node['host'], node['port']]))
59
60     def connect(self, node, attempts=5):
61         """Connect to node prior to running exec_command or scp.
62
63         If there already is a connection to the node, this method reuses it.
64
65         :param node: Node in topology.
66         :param attempts: Number of reconnect attempts.
67         :type node: dict
68         :type attempts: int
69         :raises IOError: If cannot connect to host.
70         """
71         self._node = node
72         node_hash = self._node_hash(node)
73         if node_hash in SSH.__existing_connections:
74             self._ssh = SSH.__existing_connections[node_hash]
75             if self._ssh.get_transport().is_active():
76                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
77             else:
78                 if attempts > 0:
79                     self._reconnect(attempts-1)
80                 else:
81                     raise IOError('Cannot connect to {host}'.
82                                   format(host=node['host']))
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if 'priv_key' in node:
88                     pkey = RSAKey.from_private_key(
89                         StringIO.StringIO(node['priv_key']))
90
91                 self._ssh = SSHClient()
92                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
93
94                 self._ssh.connect(node['host'], username=node['username'],
95                                   password=node.get('password'), pkey=pkey,
96                                   port=node['port'])
97
98                 self._ssh.get_transport().set_keepalive(10)
99
100                 SSH.__existing_connections[node_hash] = self._ssh
101                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
102                              format(
103                                  peer=self._ssh.get_transport().getpeername(),
104                                  total=(time() - start),
105                                  ssh=self._ssh))
106             except SSHException:
107                 raise IOError('Cannot connect to {host}'.
108                               format(host=node['host']))
109             except NoValidConnectionsError as err:
110                 logger.error(repr(err))
111                 raise IOError('Unable to connect to port {port} on {host}'.
112                               format(port=node['port'], host=node['host']))
113
114     def disconnect(self, node=None):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from. None means last connected.
118         :type node: dict or None
119         """
120         if node is None:
121             node = self._node
122         if node is None:
123             return
124         node_hash = self._node_hash(node)
125         if node_hash in SSH.__existing_connections:
126             logger.debug('Disconnecting peer: {host}, {port}'.
127                          format(host=node['host'], port=node['port']))
128             ssh = SSH.__existing_connections.pop(node_hash)
129             ssh.close()
130
131     def _reconnect(self, attempts=0):
132         """Close the SSH connection and open it again.
133
134         :param attempts: Number of reconnect attempts.
135         :type attempts: int
136         """
137         node = self._node
138         self.disconnect(node)
139         self.connect(node, attempts)
140         logger.debug('Reconnecting peer done: {host}, {port}'.
141                      format(host=node['host'], port=node['port']))
142
143     def exec_command(self, cmd, timeout=10, log_stdout_err=True):
144         """Execute SSH command on a new channel on the connected Node.
145
146         :param cmd: Command to run on the Node.
147         :param timeout: Maximal time in seconds to wait until the command is
148             done. If set to None then wait forever.
149         :param log_stdout_err: If True, stdout and stderr are logged. stdout
150             and stderr are logged also if the return code is not zero
151             independently of the value of log_stdout_err.
152         :type cmd: str or OptionString
153         :type timeout: int
154         :type log_stdout_err: bool
155         :returns: return_code, stdout, stderr
156         :rtype: tuple(int, str, str)
157         :raises SSHTimeout: If command is not finished in timeout time.
158         """
159         cmd = str(cmd)
160         stdout = StringIO.StringIO()
161         stderr = StringIO.StringIO()
162         try:
163             chan = self._ssh.get_transport().open_session(timeout=5)
164             peer = self._ssh.get_transport().getpeername()
165         except AttributeError:
166             self._reconnect()
167             chan = self._ssh.get_transport().open_session(timeout=5)
168             peer = self._ssh.get_transport().getpeername()
169         except SSHException:
170             self._reconnect()
171             chan = self._ssh.get_transport().open_session(timeout=5)
172             peer = self._ssh.get_transport().getpeername()
173         chan.settimeout(timeout)
174
175         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
176                      .format(peer=peer, timeout=timeout, cmd=cmd))
177
178         start = time()
179         chan.exec_command(cmd)
180         while not chan.exit_status_ready() and timeout is not None:
181             if chan.recv_ready():
182                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
183
184             if chan.recv_stderr_ready():
185                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
186
187             if time() - start > timeout:
188                 raise SSHTimeout(
189                     'Timeout exception during execution of command: {cmd}\n'
190                     'Current contents of stdout buffer: {stdout}\n'
191                     'Current contents of stderr buffer: {stderr}\n'
192                     .format(cmd=cmd, stdout=stdout.getvalue(),
193                             stderr=stderr.getvalue())
194                 )
195
196             sleep(0.1)
197         return_code = chan.recv_exit_status()
198
199         while chan.recv_ready():
200             stdout.write(chan.recv(self.__MAX_RECV_BUF))
201
202         while chan.recv_stderr_ready():
203             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
204
205         end = time()
206         logger.trace('exec_command on {peer} took {total} seconds'.
207                      format(peer=peer, total=end-start))
208
209         logger.trace('return RC {rc}'.format(rc=return_code))
210         if log_stdout_err or int(return_code):
211             logger.trace('return STDOUT {stdout}'.
212                          format(stdout=stdout.getvalue()))
213             logger.trace('return STDERR {stderr}'.
214                          format(stderr=stderr.getvalue()))
215         return return_code, stdout.getvalue(), stderr.getvalue()
216
217     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30,
218                           log_stdout_err=True):
219         """Execute SSH command with sudo on a new channel on the connected Node.
220
221         :param cmd: Command to be executed.
222         :param cmd_input: Input redirected to the command.
223         :param timeout: Timeout.
224         :param log_stdout_err: If True, stdout and stderr are logged.
225         :type cmd: str
226         :type cmd_input: str
227         :type timeout: int
228         :type log_stdout_err: bool
229         :returns: return_code, stdout, stderr
230         :rtype: tuple(int, str, str)
231
232         :Example:
233
234         >>> from ssh import SSH
235         >>> ssh = SSH()
236         >>> ssh.connect(node)
237         >>> # Execute command without input (sudo -S cmd)
238         >>> ssh.exec_command_sudo("ifconfig eth0 down")
239         >>> # Execute command with input (sudo -S cmd <<< "input")
240         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
241         """
242         if cmd_input is None:
243             command = 'sudo -S {c}'.format(c=cmd)
244         else:
245             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
246         return self.exec_command(command, timeout,
247                                  log_stdout_err=log_stdout_err)
248
249     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
250                          timeout=30):
251         """Execute command in LXC on a new SSH channel on the connected Node.
252
253         :param lxc_cmd: Command to be executed.
254         :param lxc_name: LXC name.
255         :param lxc_params: Additional parameters for LXC attach.
256         :param sudo: Run in privileged LXC mode. Default: privileged
257         :param timeout: Timeout.
258         :type lxc_cmd: str
259         :type lxc_name: str
260         :type lxc_params: str
261         :type sudo: bool
262         :type timeout: int
263         :returns: return_code, stdout, stderr
264         """
265         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
266             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
267
268         if sudo:
269             command = 'sudo -S {c}'.format(c=command)
270         return self.exec_command(command, timeout)
271
272     def interactive_terminal_open(self, time_out=45):
273         """Open interactive terminal on a new channel on the connected Node.
274
275         :param time_out: Timeout in seconds.
276         :returns: SSH channel with opened terminal.
277
278         .. warning:: Interruptingcow is used here, and it uses
279            signal(SIGALRM) to let the operating system interrupt program
280            execution. This has the following limitations: Python signal
281            handlers only apply to the main thread, so you cannot use this
282            from other threads. You must not use this in a program that
283            uses SIGALRM itself (this includes certain profilers)
284         """
285         chan = self._ssh.get_transport().open_session()
286         chan.get_pty()
287         chan.invoke_shell()
288         chan.settimeout(int(time_out))
289         chan.set_combine_stderr(True)
290
291         buf = ''
292         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
293             try:
294                 chunk = chan.recv(self.__MAX_RECV_BUF)
295                 if not chunk:
296                     break
297                 buf += chunk
298                 if chan.exit_status_ready():
299                     logger.error('Channel exit status ready')
300                     break
301             except socket.timeout:
302                 logger.error('Socket timeout: {0}'.format(buf))
303                 raise Exception('Socket timeout: {0}'.format(buf))
304         return chan
305
306     def interactive_terminal_exec_command(self, chan, cmd, prompt):
307         """Execute command on interactive terminal.
308
309         interactive_terminal_open() method has to be called first!
310
311         :param chan: SSH channel with opened terminal.
312         :param cmd: Command to be executed.
313         :param prompt: Command prompt, sequence of characters used to
314         indicate readiness to accept commands.
315         :returns: Command output.
316
317         .. warning:: Interruptingcow is used here, and it uses
318            signal(SIGALRM) to let the operating system interrupt program
319            execution. This has the following limitations: Python signal
320            handlers only apply to the main thread, so you cannot use this
321            from other threads. You must not use this in a program that
322            uses SIGALRM itself (this includes certain profilers)
323         """
324         chan.sendall('{c}\n'.format(c=cmd))
325         buf = ''
326         while not buf.endswith(prompt):
327             try:
328                 chunk = chan.recv(self.__MAX_RECV_BUF)
329                 if not chunk:
330                     break
331                 buf += chunk
332                 if chan.exit_status_ready():
333                     logger.error('Channel exit status ready')
334                     break
335             except socket.timeout:
336                 logger.error('Socket timeout during execution of command: '
337                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
338                 raise Exception('Socket timeout during execution of command: '
339                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
340         tmp = buf.replace(cmd.replace('\n', ''), '')
341         for item in prompt:
342             tmp.replace(item, '')
343         return tmp
344
345     @staticmethod
346     def interactive_terminal_close(chan):
347         """Close interactive terminal SSH channel.
348
349         :param chan: SSH channel to be closed.
350         """
351         chan.close()
352
353     def scp(self, local_path, remote_path, get=False, timeout=30,
354             wildcard=False):
355         """Copy files from local_path to remote_path or vice versa.
356
357         connect() method has to be called first!
358
359         :param local_path: Path to local file that should be uploaded; or
360         path where to save remote file.
361         :param remote_path: Remote path where to place uploaded file; or
362         path to remote file which should be downloaded.
363         :param get: scp operation to perform. Default is put.
364         :param timeout: Timeout value in seconds.
365         :param wildcard: If path has wildcard characters. Default is false.
366         :type local_path: str
367         :type remote_path: str
368         :type get: bool
369         :type timeout: int
370         :type wildcard: bool
371         """
372         if not get:
373             logger.trace('SCP {0} to {1}:{2}'.format(
374                 local_path, self._ssh.get_transport().getpeername(),
375                 remote_path))
376         else:
377             logger.trace('SCP {0}:{1} to {2}'.format(
378                 self._ssh.get_transport().getpeername(), remote_path,
379                 local_path))
380         # SCPCLient takes a paramiko transport as its only argument
381         if not wildcard:
382             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
383         else:
384             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
385                             socket_timeout=timeout)
386         start = time()
387         if not get:
388             scp.put(local_path, remote_path)
389         else:
390             scp.get(remote_path, local_path)
391         scp.close()
392         end = time()
393         logger.trace('SCP took {0} seconds'.format(end-start))
394
395
396 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
397     """Convenience function to ssh/exec/return rc, out & err.
398
399     Returns (rc, stdout, stderr).
400
401     :param node: The node to execute command on.
402     :param cmd: Command to execute.
403     :param timeout: Timeout value in seconds. Default: 600.
404     :param sudo: Sudo privilege execution flag. Default: False.
405     :param disconnect: Close the opened SSH connection if True.
406     :type node: dict
407     :type cmd: str or OptionString
408     :type timeout: int
409     :type sudo: bool
410     :type disconnect: bool
411     :returns: RC, Stdout, Stderr.
412     :rtype: tuple(int, str, str)
413     """
414     if node is None:
415         raise TypeError('Node parameter is None')
416     if cmd is None:
417         raise TypeError('Command parameter is None')
418     if not cmd:
419         raise ValueError('Empty command parameter')
420
421     ssh = SSH()
422
423     if node.get('host_port') is not None:
424         ssh_node = dict()
425         ssh_node['host'] = '127.0.0.1'
426         ssh_node['port'] = node['port']
427         ssh_node['username'] = node['username']
428         ssh_node['password'] = node['password']
429         import pexpect
430         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
431         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
432         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
433             format(tnl=tnl, op=options, user=node['host_username'],
434                    host=node['host'], host_port=node['host_port'])
435         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
436                      format(ssh_cmd=ssh_cmd))
437         child = pexpect.spawn(ssh_cmd)
438         child.expect('.* password: ')
439         logger.trace(child.after)
440         child.sendline(node['host_password'])
441         child.expect('Welcome .*')
442         logger.trace(child.after)
443         logger.trace('Local port forwarding finished.')
444     else:
445         ssh_node = node
446
447     try:
448         ssh.connect(ssh_node)
449     except SSHException as err:
450         logger.error("Failed to connect to node" + repr(err))
451         return None, None, None
452
453     try:
454         if not sudo:
455             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
456         else:
457             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
458                                                                timeout=timeout)
459     except SSHException as err:
460         logger.error(repr(err))
461         return None, None, None
462     finally:
463         if disconnect:
464             ssh.disconnect()
465
466     return ret_code, stdout, stderr
467
468
469 def exec_cmd_no_error(
470         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
471         retries=0):
472     """Convenience function to ssh/exec/return out & err.
473
474     Verifies that return code is zero.
475     Supports retries, timeout is related to each try separately then. There is
476     sleep(1) before each retry.
477     Disconnect (if enabled) is applied after each try.
478
479     :param node: DUT node.
480     :param cmd: Command to be executed.
481     :param timeout: Timeout value in seconds. Default: 600.
482     :param sudo: Sudo privilege execution flag. Default: False.
483     :param message: Error message in case of failure. Default: None.
484     :param disconnect: Close the opened SSH connection if True.
485     :param retries: How many times to retry on failure.
486     :type node: dict
487     :type cmd: str or OptionString
488     :type timeout: int
489     :type sudo: bool
490     :type message: str
491     :type disconnect: bool
492     :type retries: int
493     :returns: Stdout, Stderr.
494     :rtype: tuple(str, str)
495     :raises RuntimeError: If bash return code is not 0.
496     """
497     for _ in range(retries + 1):
498         ret_code, stdout, stderr = exec_cmd(
499             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
500         if ret_code == 0:
501             break
502         sleep(1)
503     else:
504         msg = ('Command execution failed: "{cmd}"\n{stderr}'.
505                format(cmd=cmd, stderr=stderr) if message is None else message)
506         raise RuntimeError(msg)
507
508     return stdout, stderr
509
510 def scp_node(
511         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
512     """Copy files from local_path to remote_path or vice versa.
513
514     :param node: SUT node.
515     :param local_path: Path to local file that should be uploaded; or
516         path where to save remote file.
517     :param remote_path: Remote path where to place uploaded file; or
518         path to remote file which should be downloaded.
519     :param get: scp operation to perform. Default is put.
520     :param timeout: Timeout value in seconds.
521     :param disconnect: Close the opened SSH connection if True.
522     :type node: dict
523     :type local_path: str
524     :type remote_path: str
525     :type get: bool
526     :type timeout: int
527     :type disconnect: bool
528     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
529     """
530     ssh = SSH()
531
532     try:
533         ssh.connect(node)
534     except SSHException:
535         raise RuntimeError('Failed to connect to {host}!'
536                            .format(host=node['host']))
537     try:
538         ssh.scp(local_path, remote_path, get, timeout)
539     except SCPException:
540         raise RuntimeError('SCP execution failed on {host}!'
541                            .format(host=node['host']))
542     finally:
543         if disconnect:
544             ssh.disconnect()