966d1b0448ac729c9a5112202d866ad19ce433b7
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
26
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 class SSHTimeout(Exception):
34     """This exception is raised when a timeout occurs."""
35     pass
36
37
38 class SSH(object):
39     """Contains methods for managing and using SSH connections."""
40
41     __MAX_RECV_BUF = 10*1024*1024
42     __existing_connections = {}
43
44     def __init__(self):
45         self._ssh = None
46         self._node = None
47
48     @staticmethod
49     def _node_hash(node):
50         """Get IP address and port hash from node dictionary.
51
52         :param node: Node in topology.
53         :type node: dict
54         :returns: IP address and port for the specified node.
55         :rtype: int
56         """
57
58         return hash(frozenset([node['host'], node['port']]))
59
60     def connect(self, node, attempts=5):
61         """Connect to node prior to running exec_command or scp.
62
63         If there already is a connection to the node, this method reuses it.
64
65         :param node: Node in topology.
66         :param attempts: Number of reconnect attempts.
67         :type node: dict
68         :type attempts: int
69         :raises IOError: If cannot connect to host.
70         """
71         self._node = node
72         node_hash = self._node_hash(node)
73         if node_hash in SSH.__existing_connections:
74             self._ssh = SSH.__existing_connections[node_hash]
75             if self._ssh.get_transport().is_active():
76                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
77             else:
78                 if attempts > 0:
79                     self._reconnect(attempts-1)
80                 else:
81                     raise IOError('Cannot connect to {host}'.
82                                   format(host=node['host']))
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if 'priv_key' in node:
88                     pkey = RSAKey.from_private_key(
89                         StringIO.StringIO(node['priv_key']))
90
91                 self._ssh = SSHClient()
92                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
93
94                 self._ssh.connect(node['host'], username=node['username'],
95                                   password=node.get('password'), pkey=pkey,
96                                   port=node['port'])
97
98                 self._ssh.get_transport().set_keepalive(10)
99
100                 SSH.__existing_connections[node_hash] = self._ssh
101                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
102                              format(
103                                  peer=self._ssh.get_transport().getpeername(),
104                                  total=(time() - start),
105                                  ssh=self._ssh))
106             except SSHException:
107                 raise IOError('Cannot connect to {host}'.
108                               format(host=node['host']))
109             except NoValidConnectionsError as err:
110                 logger.error(repr(err))
111                 raise IOError('Unable to connect to port {port} on {host}'.
112                               format(port=node['port'], host=node['host']))
113
114     def disconnect(self, node=None):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from. None means last connected.
118         :type node: dict or None
119         """
120         if node is None:
121             node = self._node
122         if node is None:
123             return
124         node_hash = self._node_hash(node)
125         if node_hash in SSH.__existing_connections:
126             logger.debug('Disconnecting peer: {host}, {port}'.
127                          format(host=node['host'], port=node['port']))
128             ssh = SSH.__existing_connections.pop(node_hash)
129             ssh.close()
130
131     def _reconnect(self, attempts=0):
132         """Close the SSH connection and open it again.
133
134         :param attempts: Number of reconnect attempts.
135         :type attempts: int
136         """
137         node = self._node
138         self.disconnect(node)
139         self.connect(node, attempts)
140         logger.debug('Reconnecting peer done: {host}, {port}'.
141                      format(host=node['host'], port=node['port']))
142
143     def exec_command(self, cmd, timeout=10):
144         """Execute SSH command on a new channel on the connected Node.
145
146         :param cmd: Command to run on the Node.
147         :param timeout: Maximal time in seconds to wait until the command is
148         done. If set to None then wait forever.
149         :type cmd: str or OptionString
150         :type timeout: int
151         :return return_code, stdout, stderr
152         :rtype: tuple(int, str, str)
153         :raise SSHTimeout: If command is not finished in timeout time.
154         """
155         cmd = str(cmd)
156         stdout = StringIO.StringIO()
157         stderr = StringIO.StringIO()
158         try:
159             chan = self._ssh.get_transport().open_session(timeout=5)
160             peer = self._ssh.get_transport().getpeername()
161         except AttributeError:
162             self._reconnect()
163             chan = self._ssh.get_transport().open_session(timeout=5)
164             peer = self._ssh.get_transport().getpeername()
165         except SSHException:
166             self._reconnect()
167             chan = self._ssh.get_transport().open_session(timeout=5)
168             peer = self._ssh.get_transport().getpeername()
169         chan.settimeout(timeout)
170
171         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
172                      .format(peer=peer, timeout=timeout, cmd=cmd))
173
174         start = time()
175         chan.exec_command(cmd)
176         while not chan.exit_status_ready() and timeout is not None:
177             if chan.recv_ready():
178                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
179
180             if chan.recv_stderr_ready():
181                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
182
183             if time() - start > timeout:
184                 raise SSHTimeout(
185                     'Timeout exception during execution of command: {cmd}\n'
186                     'Current contents of stdout buffer: {stdout}\n'
187                     'Current contents of stderr buffer: {stderr}\n'
188                     .format(cmd=cmd, stdout=stdout.getvalue(),
189                             stderr=stderr.getvalue())
190                 )
191
192             sleep(0.1)
193         return_code = chan.recv_exit_status()
194
195         while chan.recv_ready():
196             stdout.write(chan.recv(self.__MAX_RECV_BUF))
197
198         while chan.recv_stderr_ready():
199             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
200
201         end = time()
202         logger.trace('exec_command on {peer} took {total} seconds'.
203                      format(peer=peer, total=end-start))
204
205         logger.trace('return RC {rc}'.format(rc=return_code))
206         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
207         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
208         return return_code, stdout.getvalue(), stderr.getvalue()
209
210     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
211         """Execute SSH command with sudo on a new channel on the connected Node.
212
213         :param cmd: Command to be executed.
214         :param cmd_input: Input redirected to the command.
215         :param timeout: Timeout.
216         :returns: return_code, stdout, stderr
217
218         :Example:
219
220         >>> from ssh import SSH
221         >>> ssh = SSH()
222         >>> ssh.connect(node)
223         >>> # Execute command without input (sudo -S cmd)
224         >>> ssh.exec_command_sudo("ifconfig eth0 down")
225         >>> # Execute command with input (sudo -S cmd <<< "input")
226         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
227         """
228         if cmd_input is None:
229             command = 'sudo -S {c}'.format(c=cmd)
230         else:
231             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
232         return self.exec_command(command, timeout)
233
234     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
235                          timeout=30):
236         """Execute command in LXC on a new SSH channel on the connected Node.
237
238         :param lxc_cmd: Command to be executed.
239         :param lxc_name: LXC name.
240         :param lxc_params: Additional parameters for LXC attach.
241         :param sudo: Run in privileged LXC mode. Default: privileged
242         :param timeout: Timeout.
243         :type lxc_cmd: str
244         :type lxc_name: str
245         :type lxc_params: str
246         :type sudo: bool
247         :type timeout: int
248         :returns: return_code, stdout, stderr
249         """
250         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
251             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
252
253         if sudo:
254             command = 'sudo -S {c}'.format(c=command)
255         return self.exec_command(command, timeout)
256
257     def interactive_terminal_open(self, time_out=45):
258         """Open interactive terminal on a new channel on the connected Node.
259
260         :param time_out: Timeout in seconds.
261         :returns: SSH channel with opened terminal.
262
263         .. warning:: Interruptingcow is used here, and it uses
264            signal(SIGALRM) to let the operating system interrupt program
265            execution. This has the following limitations: Python signal
266            handlers only apply to the main thread, so you cannot use this
267            from other threads. You must not use this in a program that
268            uses SIGALRM itself (this includes certain profilers)
269         """
270         chan = self._ssh.get_transport().open_session()
271         chan.get_pty()
272         chan.invoke_shell()
273         chan.settimeout(int(time_out))
274         chan.set_combine_stderr(True)
275
276         buf = ''
277         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
278             try:
279                 chunk = chan.recv(self.__MAX_RECV_BUF)
280                 if not chunk:
281                     break
282                 buf += chunk
283                 if chan.exit_status_ready():
284                     logger.error('Channel exit status ready')
285                     break
286             except socket.timeout:
287                 logger.error('Socket timeout: {0}'.format(buf))
288                 raise Exception('Socket timeout: {0}'.format(buf))
289         return chan
290
291     def interactive_terminal_exec_command(self, chan, cmd, prompt):
292         """Execute command on interactive terminal.
293
294         interactive_terminal_open() method has to be called first!
295
296         :param chan: SSH channel with opened terminal.
297         :param cmd: Command to be executed.
298         :param prompt: Command prompt, sequence of characters used to
299         indicate readiness to accept commands.
300         :returns: Command output.
301
302         .. warning:: Interruptingcow is used here, and it uses
303            signal(SIGALRM) to let the operating system interrupt program
304            execution. This has the following limitations: Python signal
305            handlers only apply to the main thread, so you cannot use this
306            from other threads. You must not use this in a program that
307            uses SIGALRM itself (this includes certain profilers)
308         """
309         chan.sendall('{c}\n'.format(c=cmd))
310         buf = ''
311         while not buf.endswith(prompt):
312             try:
313                 chunk = chan.recv(self.__MAX_RECV_BUF)
314                 if not chunk:
315                     break
316                 buf += chunk
317                 if chan.exit_status_ready():
318                     logger.error('Channel exit status ready')
319                     break
320             except socket.timeout:
321                 logger.error('Socket timeout during execution of command: '
322                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
323                 raise Exception('Socket timeout during execution of command: '
324                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
325         tmp = buf.replace(cmd.replace('\n', ''), '')
326         for item in prompt:
327             tmp.replace(item, '')
328         return tmp
329
330     @staticmethod
331     def interactive_terminal_close(chan):
332         """Close interactive terminal SSH channel.
333
334         :param chan: SSH channel to be closed.
335         """
336         chan.close()
337
338     def scp(self, local_path, remote_path, get=False, timeout=30,
339             wildcard=False):
340         """Copy files from local_path to remote_path or vice versa.
341
342         connect() method has to be called first!
343
344         :param local_path: Path to local file that should be uploaded; or
345         path where to save remote file.
346         :param remote_path: Remote path where to place uploaded file; or
347         path to remote file which should be downloaded.
348         :param get: scp operation to perform. Default is put.
349         :param timeout: Timeout value in seconds.
350         :param wildcard: If path has wildcard characters. Default is false.
351         :type local_path: str
352         :type remote_path: str
353         :type get: bool
354         :type timeout: int
355         :type wildcard: bool
356         """
357         if not get:
358             logger.trace('SCP {0} to {1}:{2}'.format(
359                 local_path, self._ssh.get_transport().getpeername(),
360                 remote_path))
361         else:
362             logger.trace('SCP {0}:{1} to {2}'.format(
363                 self._ssh.get_transport().getpeername(), remote_path,
364                 local_path))
365         # SCPCLient takes a paramiko transport as its only argument
366         if not wildcard:
367             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
368         else:
369             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
370                             socket_timeout=timeout)
371         start = time()
372         if not get:
373             scp.put(local_path, remote_path)
374         else:
375             scp.get(remote_path, local_path)
376         scp.close()
377         end = time()
378         logger.trace('SCP took {0} seconds'.format(end-start))
379
380
381 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
382     """Convenience function to ssh/exec/return rc, out & err.
383
384     Returns (rc, stdout, stderr).
385
386     :param node: The node to execute command on.
387     :param cmd: Command to execute.
388     :param timeout: Timeout value in seconds. Default: 600.
389     :param sudo: Sudo privilege execution flag. Default: False.
390     :param disconnect: Close the opened SSH connection if True.
391     :type node: dict
392     :type cmd: str or OptionString
393     :type timeout: int
394     :type sudo: bool
395     :type disconnect: bool
396     :returns: RC, Stdout, Stderr.
397     :rtype: tuple(int, str, str)
398     """
399     if node is None:
400         raise TypeError('Node parameter is None')
401     if cmd is None:
402         raise TypeError('Command parameter is None')
403     if not cmd:
404         raise ValueError('Empty command parameter')
405
406     ssh = SSH()
407
408     if node.get('host_port') is not None:
409         ssh_node = dict()
410         ssh_node['host'] = '127.0.0.1'
411         ssh_node['port'] = node['port']
412         ssh_node['username'] = node['username']
413         ssh_node['password'] = node['password']
414         import pexpect
415         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
416         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
417         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
418             format(tnl=tnl, op=options, user=node['host_username'],
419                    host=node['host'], host_port=node['host_port'])
420         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
421                      format(ssh_cmd=ssh_cmd))
422         child = pexpect.spawn(ssh_cmd)
423         child.expect('.* password: ')
424         logger.trace(child.after)
425         child.sendline(node['host_password'])
426         child.expect('Welcome .*')
427         logger.trace(child.after)
428         logger.trace('Local port forwarding finished.')
429     else:
430         ssh_node = node
431
432     try:
433         ssh.connect(ssh_node)
434     except SSHException as err:
435         logger.error("Failed to connect to node" + repr(err))
436         return None, None, None
437
438     try:
439         if not sudo:
440             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
441         else:
442             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
443                                                                timeout=timeout)
444     except SSHException as err:
445         logger.error(repr(err))
446         return None, None, None
447     finally:
448         if disconnect:
449             ssh.disconnect()
450
451     return ret_code, stdout, stderr
452
453
454 def exec_cmd_no_error(
455         node, cmd, timeout=600, sudo=False, message=None, disconnect=False):
456     """Convenience function to ssh/exec/return out & err.
457
458     Verifies that return code is zero.
459
460     :param node: DUT node.
461     :param cmd: Command to be executed.
462     :param timeout: Timeout value in seconds. Default: 600.
463     :param sudo: Sudo privilege execution flag. Default: False.
464     :param message: Error message in case of failure. Default: None.
465     :param disconnect: Close the opened SSH connection if True.
466     :type node: dict
467     :type cmd: str or OptionString
468     :type timeout: int
469     :type sudo: bool
470     :type message: str
471     :type disconnect: bool
472     :returns: Stdout, Stderr.
473     :rtype: tuple(str, str)
474     :raises RuntimeError: If bash return code is not 0.
475     """
476     ret_code, stdout, stderr = exec_cmd(
477         node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
478     msg = ('Command execution failed: "{cmd}"\n{stderr}'.
479            format(cmd=cmd, stderr=stderr) if message is None else message)
480     if ret_code != 0:
481         raise RuntimeError(msg)
482
483     return stdout, stderr
484
485 def scp_node(
486         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
487     """Copy files from local_path to remote_path or vice versa.
488
489     :param node: SUT node.
490     :param local_path: Path to local file that should be uploaded; or
491         path where to save remote file.
492     :param remote_path: Remote path where to place uploaded file; or
493         path to remote file which should be downloaded.
494     :param get: scp operation to perform. Default is put.
495     :param timeout: Timeout value in seconds.
496     :param disconnect: Close the opened SSH connection if True.
497     :type node: dict
498     :type local_path: str
499     :type remote_path: str
500     :type get: bool
501     :type timeout: int
502     :type disconnect: bool
503     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
504     """
505     ssh = SSH()
506
507     try:
508         ssh.connect(node)
509     except SSHException:
510         raise RuntimeError('Failed to connect to {host}!'
511                            .format(host=node['host']))
512     try:
513         ssh.scp(local_path, remote_path, get, timeout)
514     except SCPException:
515         raise RuntimeError('SCP execution failed on {host}!'
516                            .format(host=node['host']))
517     finally:
518         if disconnect:
519             ssh.disconnect()