9c7adc44e14093a4043398d54e02244cb3d96585
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient, SCPException
26
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 class SSHTimeout(Exception):
34     """This exception is raised when a timeout occurs."""
35     pass
36
37
38 class SSH(object):
39     """Contains methods for managing and using SSH connections."""
40
41     __MAX_RECV_BUF = 10*1024*1024
42     __existing_connections = {}
43
44     def __init__(self):
45         self._ssh = None
46         self._node = None
47
48     @staticmethod
49     def _node_hash(node):
50         """Get IP address and port hash from node dictionary.
51
52         :param node: Node in topology.
53         :type node: dict
54         :returns: IP address and port for the specified node.
55         :rtype: int
56         """
57
58         return hash(frozenset([node['host'], node['port']]))
59
60     def connect(self, node, attempts=5):
61         """Connect to node prior to running exec_command or scp.
62
63         If there already is a connection to the node, this method reuses it.
64
65         :param node: Node in topology.
66         :param attempts: Number of reconnect attempts.
67         :type node: dict
68         :type attempts: int
69         :raises IOError: If cannot connect to host.
70         """
71         self._node = node
72         node_hash = self._node_hash(node)
73         if node_hash in SSH.__existing_connections:
74             self._ssh = SSH.__existing_connections[node_hash]
75             if self._ssh.get_transport().is_active():
76                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
77             else:
78                 if attempts > 0:
79                     self._reconnect(attempts-1)
80                 else:
81                     raise IOError('Cannot connect to {host}'.
82                                   format(host=node['host']))
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if 'priv_key' in node:
88                     pkey = RSAKey.from_private_key(
89                         StringIO.StringIO(node['priv_key']))
90
91                 self._ssh = SSHClient()
92                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
93
94                 self._ssh.connect(node['host'], username=node['username'],
95                                   password=node.get('password'), pkey=pkey,
96                                   port=node['port'])
97
98                 self._ssh.get_transport().set_keepalive(10)
99
100                 SSH.__existing_connections[node_hash] = self._ssh
101                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
102                              format(
103                                  peer=self._ssh.get_transport().getpeername(),
104                                  total=(time() - start),
105                                  ssh=self._ssh))
106             except SSHException:
107                 raise IOError('Cannot connect to {host}'.
108                               format(host=node['host']))
109             except NoValidConnectionsError as err:
110                 logger.error(repr(err))
111                 raise IOError('Unable to connect to port {port} on {host}'.
112                               format(port=node['port'], host=node['host']))
113
114     def disconnect(self, node=None):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from. None means last connected.
118         :type node: dict or None
119         """
120         if node is None:
121             node = self._node
122         if node is None:
123             return
124         node_hash = self._node_hash(node)
125         if node_hash in SSH.__existing_connections:
126             logger.debug('Disconnecting peer: {host}, {port}'.
127                          format(host=node['host'], port=node['port']))
128             ssh = SSH.__existing_connections.pop(node_hash)
129             ssh.close()
130
131     def _reconnect(self, attempts=0):
132         """Close the SSH connection and open it again.
133
134         :param attempts: Number of reconnect attempts.
135         :type attempts: int
136         """
137         node = self._node
138         self.disconnect(node)
139         self.connect(node, attempts)
140         logger.debug('Reconnecting peer done: {host}, {port}'.
141                      format(host=node['host'], port=node['port']))
142
143     def exec_command(self, cmd, timeout=10):
144         """Execute SSH command on a new channel on the connected Node.
145
146         :param cmd: Command to run on the Node.
147         :param timeout: Maximal time in seconds to wait until the command is
148         done. If set to None then wait forever.
149         :type cmd: str or OptionString
150         :type timeout: int
151         :return return_code, stdout, stderr
152         :rtype: tuple(int, str, str)
153         :raise SSHTimeout: If command is not finished in timeout time.
154         """
155         cmd = str(cmd)
156         stdout = StringIO.StringIO()
157         stderr = StringIO.StringIO()
158         try:
159             chan = self._ssh.get_transport().open_session(timeout=5)
160             peer = self._ssh.get_transport().getpeername()
161         except AttributeError:
162             self._reconnect()
163             chan = self._ssh.get_transport().open_session(timeout=5)
164             peer = self._ssh.get_transport().getpeername()
165         except SSHException:
166             self._reconnect()
167             chan = self._ssh.get_transport().open_session(timeout=5)
168             peer = self._ssh.get_transport().getpeername()
169         chan.settimeout(timeout)
170
171         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
172                      .format(peer=peer, timeout=timeout, cmd=cmd))
173
174         start = time()
175         chan.exec_command(cmd)
176         while not chan.exit_status_ready() and timeout is not None:
177             if chan.recv_ready():
178                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
179
180             if chan.recv_stderr_ready():
181                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
182
183             if time() - start > timeout:
184                 raise SSHTimeout(
185                     'Timeout exception during execution of command: {cmd}\n'
186                     'Current contents of stdout buffer: {stdout}\n'
187                     'Current contents of stderr buffer: {stderr}\n'
188                     .format(cmd=cmd, stdout=stdout.getvalue(),
189                             stderr=stderr.getvalue())
190                 )
191
192             sleep(0.1)
193         return_code = chan.recv_exit_status()
194
195         while chan.recv_ready():
196             stdout.write(chan.recv(self.__MAX_RECV_BUF))
197
198         while chan.recv_stderr_ready():
199             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
200
201         end = time()
202         logger.trace('exec_command on {peer} took {total} seconds'.
203                      format(peer=peer, total=end-start))
204
205         logger.trace('return RC {rc}'.format(rc=return_code))
206         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
207         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
208         return return_code, stdout.getvalue(), stderr.getvalue()
209
210     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
211         """Execute SSH command with sudo on a new channel on the connected Node.
212
213         :param cmd: Command to be executed.
214         :param cmd_input: Input redirected to the command.
215         :param timeout: Timeout.
216         :returns: return_code, stdout, stderr
217
218         :Example:
219
220         >>> from ssh import SSH
221         >>> ssh = SSH()
222         >>> ssh.connect(node)
223         >>> # Execute command without input (sudo -S cmd)
224         >>> ssh.exec_command_sudo("ifconfig eth0 down")
225         >>> # Execute command with input (sudo -S cmd <<< "input")
226         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
227         """
228         if cmd_input is None:
229             command = 'sudo -S {c}'.format(c=cmd)
230         else:
231             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
232         return self.exec_command(command, timeout)
233
234     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
235                          timeout=30):
236         """Execute command in LXC on a new SSH channel on the connected Node.
237
238         :param lxc_cmd: Command to be executed.
239         :param lxc_name: LXC name.
240         :param lxc_params: Additional parameters for LXC attach.
241         :param sudo: Run in privileged LXC mode. Default: privileged
242         :param timeout: Timeout.
243         :type lxc_cmd: str
244         :type lxc_name: str
245         :type lxc_params: str
246         :type sudo: bool
247         :type timeout: int
248         :returns: return_code, stdout, stderr
249         """
250         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
251             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
252
253         if sudo:
254             command = 'sudo -S {c}'.format(c=command)
255         return self.exec_command(command, timeout)
256
257     def interactive_terminal_open(self, time_out=45):
258         """Open interactive terminal on a new channel on the connected Node.
259
260         :param time_out: Timeout in seconds.
261         :returns: SSH channel with opened terminal.
262
263         .. warning:: Interruptingcow is used here, and it uses
264            signal(SIGALRM) to let the operating system interrupt program
265            execution. This has the following limitations: Python signal
266            handlers only apply to the main thread, so you cannot use this
267            from other threads. You must not use this in a program that
268            uses SIGALRM itself (this includes certain profilers)
269         """
270         chan = self._ssh.get_transport().open_session()
271         chan.get_pty()
272         chan.invoke_shell()
273         chan.settimeout(int(time_out))
274         chan.set_combine_stderr(True)
275
276         buf = ''
277         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
278             try:
279                 chunk = chan.recv(self.__MAX_RECV_BUF)
280                 if not chunk:
281                     break
282                 buf += chunk
283                 if chan.exit_status_ready():
284                     logger.error('Channel exit status ready')
285                     break
286             except socket.timeout:
287                 logger.error('Socket timeout: {0}'.format(buf))
288                 raise Exception('Socket timeout: {0}'.format(buf))
289         return chan
290
291     def interactive_terminal_exec_command(self, chan, cmd, prompt):
292         """Execute command on interactive terminal.
293
294         interactive_terminal_open() method has to be called first!
295
296         :param chan: SSH channel with opened terminal.
297         :param cmd: Command to be executed.
298         :param prompt: Command prompt, sequence of characters used to
299         indicate readiness to accept commands.
300         :returns: Command output.
301
302         .. warning:: Interruptingcow is used here, and it uses
303            signal(SIGALRM) to let the operating system interrupt program
304            execution. This has the following limitations: Python signal
305            handlers only apply to the main thread, so you cannot use this
306            from other threads. You must not use this in a program that
307            uses SIGALRM itself (this includes certain profilers)
308         """
309         chan.sendall('{c}\n'.format(c=cmd))
310         buf = ''
311         while not buf.endswith(prompt):
312             try:
313                 chunk = chan.recv(self.__MAX_RECV_BUF)
314                 if not chunk:
315                     break
316                 buf += chunk
317                 if chan.exit_status_ready():
318                     logger.error('Channel exit status ready')
319                     break
320             except socket.timeout:
321                 logger.error('Socket timeout during execution of command: '
322                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
323                 raise Exception('Socket timeout during execution of command: '
324                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
325         tmp = buf.replace(cmd.replace('\n', ''), '')
326         for item in prompt:
327             tmp.replace(item, '')
328         return tmp
329
330     @staticmethod
331     def interactive_terminal_close(chan):
332         """Close interactive terminal SSH channel.
333
334         :param chan: SSH channel to be closed.
335         """
336         chan.close()
337
338     def scp(self, local_path, remote_path, get=False, timeout=30,
339             wildcard=False):
340         """Copy files from local_path to remote_path or vice versa.
341
342         connect() method has to be called first!
343
344         :param local_path: Path to local file that should be uploaded; or
345         path where to save remote file.
346         :param remote_path: Remote path where to place uploaded file; or
347         path to remote file which should be downloaded.
348         :param get: scp operation to perform. Default is put.
349         :param timeout: Timeout value in seconds.
350         :param wildcard: If path has wildcard characters. Default is false.
351         :type local_path: str
352         :type remote_path: str
353         :type get: bool
354         :type timeout: int
355         :type wildcard: bool
356         """
357         if not get:
358             logger.trace('SCP {0} to {1}:{2}'.format(
359                 local_path, self._ssh.get_transport().getpeername(),
360                 remote_path))
361         else:
362             logger.trace('SCP {0}:{1} to {2}'.format(
363                 self._ssh.get_transport().getpeername(), remote_path,
364                 local_path))
365         # SCPCLient takes a paramiko transport as its only argument
366         if not wildcard:
367             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
368         else:
369             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
370                             socket_timeout=timeout)
371         start = time()
372         if not get:
373             scp.put(local_path, remote_path)
374         else:
375             scp.get(remote_path, local_path)
376         scp.close()
377         end = time()
378         logger.trace('SCP took {0} seconds'.format(end-start))
379
380
381 def exec_cmd(node, cmd, timeout=600, sudo=False, disconnect=False):
382     """Convenience function to ssh/exec/return rc, out & err.
383
384     Returns (rc, stdout, stderr).
385
386     :param node: The node to execute command on.
387     :param cmd: Command to execute.
388     :param timeout: Timeout value in seconds. Default: 600.
389     :param sudo: Sudo privilege execution flag. Default: False.
390     :param disconnect: Close the opened SSH connection if True.
391     :type node: dict
392     :type cmd: str or OptionString
393     :type timeout: int
394     :type sudo: bool
395     :type disconnect: bool
396     :returns: RC, Stdout, Stderr.
397     :rtype: tuple(int, str, str)
398     """
399     if node is None:
400         raise TypeError('Node parameter is None')
401     if cmd is None:
402         raise TypeError('Command parameter is None')
403     if not cmd:
404         raise ValueError('Empty command parameter')
405
406     ssh = SSH()
407
408     if node.get('host_port') is not None:
409         ssh_node = dict()
410         ssh_node['host'] = '127.0.0.1'
411         ssh_node['port'] = node['port']
412         ssh_node['username'] = node['username']
413         ssh_node['password'] = node['password']
414         import pexpect
415         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
416         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
417         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
418             format(tnl=tnl, op=options, user=node['host_username'],
419                    host=node['host'], host_port=node['host_port'])
420         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
421                      format(ssh_cmd=ssh_cmd))
422         child = pexpect.spawn(ssh_cmd)
423         child.expect('.* password: ')
424         logger.trace(child.after)
425         child.sendline(node['host_password'])
426         child.expect('Welcome .*')
427         logger.trace(child.after)
428         logger.trace('Local port forwarding finished.')
429     else:
430         ssh_node = node
431
432     try:
433         ssh.connect(ssh_node)
434     except SSHException as err:
435         logger.error("Failed to connect to node" + repr(err))
436         return None, None, None
437
438     try:
439         if not sudo:
440             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
441         else:
442             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
443                                                                timeout=timeout)
444     except SSHException as err:
445         logger.error(repr(err))
446         return None, None, None
447     finally:
448         if disconnect:
449             ssh.disconnect()
450
451     return ret_code, stdout, stderr
452
453
454 def exec_cmd_no_error(
455         node, cmd, timeout=600, sudo=False, message=None, disconnect=False,
456         retries=0):
457     """Convenience function to ssh/exec/return out & err.
458
459     Verifies that return code is zero.
460     Supports retries, timeout is related to each try separately then. There is
461     sleep(1) before each retry.
462     Disconnect (if enabled) is applied after each try.
463
464     :param node: DUT node.
465     :param cmd: Command to be executed.
466     :param timeout: Timeout value in seconds. Default: 600.
467     :param sudo: Sudo privilege execution flag. Default: False.
468     :param message: Error message in case of failure. Default: None.
469     :param disconnect: Close the opened SSH connection if True.
470     :param retries: How many times to retry on failure.
471     :type node: dict
472     :type cmd: str or OptionString
473     :type timeout: int
474     :type sudo: bool
475     :type message: str
476     :type disconnect: bool
477     :type retries: int
478     :returns: Stdout, Stderr.
479     :rtype: tuple(str, str)
480     :raises RuntimeError: If bash return code is not 0.
481     """
482     for _ in range(retries + 1):
483         ret_code, stdout, stderr = exec_cmd(
484             node, cmd, timeout=timeout, sudo=sudo, disconnect=disconnect)
485         if ret_code == 0:
486             break
487         sleep(1)
488     else:
489         msg = ('Command execution failed: "{cmd}"\n{stderr}'.
490                format(cmd=cmd, stderr=stderr) if message is None else message)
491         raise RuntimeError(msg)
492
493     return stdout, stderr
494
495 def scp_node(
496         node, local_path, remote_path, get=False, timeout=30, disconnect=False):
497     """Copy files from local_path to remote_path or vice versa.
498
499     :param node: SUT node.
500     :param local_path: Path to local file that should be uploaded; or
501         path where to save remote file.
502     :param remote_path: Remote path where to place uploaded file; or
503         path to remote file which should be downloaded.
504     :param get: scp operation to perform. Default is put.
505     :param timeout: Timeout value in seconds.
506     :param disconnect: Close the opened SSH connection if True.
507     :type node: dict
508     :type local_path: str
509     :type remote_path: str
510     :type get: bool
511     :type timeout: int
512     :type disconnect: bool
513     :raises RuntimeError: If SSH connection failed or SCP transfer failed.
514     """
515     ssh = SSH()
516
517     try:
518         ssh.connect(node)
519     except SSHException:
520         raise RuntimeError('Failed to connect to {host}!'
521                            .format(host=node['host']))
522     try:
523         ssh.scp(local_path, remote_path, get, timeout)
524     except SCPException:
525         raise RuntimeError('SCP execution failed on {host}!'
526                            .format(host=node['host']))
527     finally:
528         if disconnect:
529             ssh.disconnect()

©2016 FD.io a Linux Foundation Collaborative Project. All Rights Reserved.
Linux Foundation is a registered trademark of The Linux Foundation. Linux is a registered trademark of Linus Torvalds.
Please see our privacy policy and terms of use.