Fix PyLint errors
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16
17 import socket
18 import StringIO
19
20 from time import time, sleep
21
22 from paramiko import RSAKey, SSHClient, AutoAddPolicy
23 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
24 from robot.api import logger
25 from scp import SCPClient
26
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 class SSHTimeout(Exception):
34     """This exception is raised when a timeout occurs."""
35     pass
36
37
38 class SSH(object):
39     """Contains methods for managing and using SSH connections."""
40
41     __MAX_RECV_BUF = 10*1024*1024
42     __existing_connections = {}
43
44     def __init__(self):
45         self._ssh = None
46         self._node = None
47
48     @staticmethod
49     def _node_hash(node):
50         """Get IP address and port hash from node dictionary.
51
52         :param node: Node in topology.
53         :type node: dict
54         :returns: IP address and port for the specified node.
55         :rtype: int
56         """
57
58         return hash(frozenset([node['host'], node['port']]))
59
60     def connect(self, node, attempts=5):
61         """Connect to node prior to running exec_command or scp.
62
63         If there already is a connection to the node, this method reuses it.
64
65         :param node: Node in topology.
66         :param attempts: Number of reconnect attempts.
67         :type node: dict
68         :type attempts: int
69         :raises IOError: If cannot connect to host.
70         """
71         self._node = node
72         node_hash = self._node_hash(node)
73         if node_hash in SSH.__existing_connections:
74             self._ssh = SSH.__existing_connections[node_hash]
75             if self._ssh.get_transport().is_active():
76                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
77             else:
78                 if attempts > 0:
79                     self._reconnect(attempts-1)
80                 else:
81                     raise IOError('Cannot connect to {host}'.
82                                   format(host=node['host']))
83         else:
84             try:
85                 start = time()
86                 pkey = None
87                 if 'priv_key' in node:
88                     pkey = RSAKey.from_private_key(
89                         StringIO.StringIO(node['priv_key']))
90
91                 self._ssh = SSHClient()
92                 self._ssh.set_missing_host_key_policy(AutoAddPolicy())
93
94                 self._ssh.connect(node['host'], username=node['username'],
95                                   password=node.get('password'), pkey=pkey,
96                                   port=node['port'])
97
98                 self._ssh.get_transport().set_keepalive(10)
99
100                 SSH.__existing_connections[node_hash] = self._ssh
101                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
102                              format(
103                                  peer=self._ssh.get_transport().getpeername(),
104                                  total=(time() - start),
105                                  ssh=self._ssh))
106             except SSHException:
107                 raise IOError('Cannot connect to {host}'.
108                               format(host=node['host']))
109             except NoValidConnectionsError as err:
110                 logger.error(repr(err))
111                 raise IOError('Unable to connect to port {port} on {host}'.
112                               format(port=node['port'], host=node['host']))
113
114     def disconnect(self, node):
115         """Close SSH connection to the node.
116
117         :param node: The node to disconnect from.
118         :type node: dict
119         """
120         node_hash = self._node_hash(node)
121         if node_hash in SSH.__existing_connections:
122             logger.debug('Disconnecting peer: {host}, {port}'.
123                          format(host=node['host'], port=node['port']))
124             ssh = SSH.__existing_connections.pop(node_hash)
125             ssh.close()
126
127     def _reconnect(self, attempts=0):
128         """Close the SSH connection and open it again.
129
130         :param attempts: Number of reconnect attempts.
131         :type attempts: int
132         """
133         node = self._node
134         self.disconnect(node)
135         self.connect(node, attempts)
136         logger.debug('Reconnecting peer done: {host}, {port}'.
137                      format(host=node['host'], port=node['port']))
138
139     def exec_command(self, cmd, timeout=10):
140         """Execute SSH command on a new channel on the connected Node.
141
142         :param cmd: Command to run on the Node.
143         :param timeout: Maximal time in seconds to wait until the command is
144         done. If set to None then wait forever.
145         :type cmd: str
146         :type timeout: int
147         :return return_code, stdout, stderr
148         :rtype: tuple(int, str, str)
149         :raise SSHTimeout: If command is not finished in timeout time.
150         """
151         stdout = StringIO.StringIO()
152         stderr = StringIO.StringIO()
153         try:
154             chan = self._ssh.get_transport().open_session(timeout=5)
155             peer = self._ssh.get_transport().getpeername()
156         except AttributeError:
157             self._reconnect()
158             chan = self._ssh.get_transport().open_session(timeout=5)
159             peer = self._ssh.get_transport().getpeername()
160         except SSHException:
161             self._reconnect()
162             chan = self._ssh.get_transport().open_session(timeout=5)
163             peer = self._ssh.get_transport().getpeername()
164         chan.settimeout(timeout)
165
166         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
167                      .format(peer=peer, timeout=timeout, cmd=cmd))
168
169         start = time()
170         chan.exec_command(cmd)
171         while not chan.exit_status_ready() and timeout is not None:
172             if chan.recv_ready():
173                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
174
175             if chan.recv_stderr_ready():
176                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
177
178             if time() - start > timeout:
179                 raise SSHTimeout(
180                     'Timeout exception during execution of command: {cmd}\n'
181                     'Current contents of stdout buffer: {stdout}\n'
182                     'Current contents of stderr buffer: {stderr}\n'
183                     .format(cmd=cmd, stdout=stdout.getvalue(),
184                             stderr=stderr.getvalue())
185                 )
186
187             sleep(0.1)
188         return_code = chan.recv_exit_status()
189
190         while chan.recv_ready():
191             stdout.write(chan.recv(self.__MAX_RECV_BUF))
192
193         while chan.recv_stderr_ready():
194             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
195
196         end = time()
197         logger.trace('exec_command on {peer} took {total} seconds'.
198                      format(peer=peer, total=end-start))
199
200         logger.trace('return RC {rc}'.format(rc=return_code))
201         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
202         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
203         return return_code, stdout.getvalue(), stderr.getvalue()
204
205     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
206         """Execute SSH command with sudo on a new channel on the connected Node.
207
208         :param cmd: Command to be executed.
209         :param cmd_input: Input redirected to the command.
210         :param timeout: Timeout.
211         :returns: return_code, stdout, stderr
212
213         :Example:
214
215         >>> from ssh import SSH
216         >>> ssh = SSH()
217         >>> ssh.connect(node)
218         >>> # Execute command without input (sudo -S cmd)
219         >>> ssh.exec_command_sudo("ifconfig eth0 down")
220         >>> # Execute command with input (sudo -S cmd <<< "input")
221         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
222         """
223         if cmd_input is None:
224             command = 'sudo -S {c}'.format(c=cmd)
225         else:
226             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
227         return self.exec_command(command, timeout)
228
229     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
230                          timeout=30):
231         """Execute command in LXC on a new SSH channel on the connected Node.
232
233         :param lxc_cmd: Command to be executed.
234         :param lxc_name: LXC name.
235         :param lxc_params: Additional parameters for LXC attach.
236         :param sudo: Run in privileged LXC mode. Default: privileged
237         :param timeout: Timeout.
238         :type lxc_cmd: str
239         :type lxc_name: str
240         :type lxc_params: str
241         :type sudo: bool
242         :type timeout: int
243         :returns: return_code, stdout, stderr
244         """
245         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
246             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
247
248         if sudo:
249             command = 'sudo -S {c}'.format(c=command)
250         return self.exec_command(command, timeout)
251
252     def interactive_terminal_open(self, time_out=45):
253         """Open interactive terminal on a new channel on the connected Node.
254
255         :param time_out: Timeout in seconds.
256         :returns: SSH channel with opened terminal.
257
258         .. warning:: Interruptingcow is used here, and it uses
259            signal(SIGALRM) to let the operating system interrupt program
260            execution. This has the following limitations: Python signal
261            handlers only apply to the main thread, so you cannot use this
262            from other threads. You must not use this in a program that
263            uses SIGALRM itself (this includes certain profilers)
264         """
265         chan = self._ssh.get_transport().open_session()
266         chan.get_pty()
267         chan.invoke_shell()
268         chan.settimeout(int(time_out))
269         chan.set_combine_stderr(True)
270
271         buf = ''
272         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
273             try:
274                 chunk = chan.recv(self.__MAX_RECV_BUF)
275                 if not chunk:
276                     break
277                 buf += chunk
278                 if chan.exit_status_ready():
279                     logger.error('Channel exit status ready')
280                     break
281             except socket.timeout:
282                 logger.error('Socket timeout: {0}'.format(buf))
283                 raise Exception('Socket timeout: {0}'.format(buf))
284         return chan
285
286     def interactive_terminal_exec_command(self, chan, cmd, prompt):
287         """Execute command on interactive terminal.
288
289         interactive_terminal_open() method has to be called first!
290
291         :param chan: SSH channel with opened terminal.
292         :param cmd: Command to be executed.
293         :param prompt: Command prompt, sequence of characters used to
294         indicate readiness to accept commands.
295         :returns: Command output.
296
297         .. warning:: Interruptingcow is used here, and it uses
298            signal(SIGALRM) to let the operating system interrupt program
299            execution. This has the following limitations: Python signal
300            handlers only apply to the main thread, so you cannot use this
301            from other threads. You must not use this in a program that
302            uses SIGALRM itself (this includes certain profilers)
303         """
304         chan.sendall('{c}\n'.format(c=cmd))
305         buf = ''
306         while not buf.endswith(prompt):
307             try:
308                 chunk = chan.recv(self.__MAX_RECV_BUF)
309                 if not chunk:
310                     break
311                 buf += chunk
312                 if chan.exit_status_ready():
313                     logger.error('Channel exit status ready')
314                     break
315             except socket.timeout:
316                 logger.error('Socket timeout during execution of command: '
317                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
318                 raise Exception('Socket timeout during execution of command: '
319                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
320         tmp = buf.replace(cmd.replace('\n', ''), '')
321         for item in prompt:
322             tmp.replace(item, '')
323         return tmp
324
325     @staticmethod
326     def interactive_terminal_close(chan):
327         """Close interactive terminal SSH channel.
328
329         :param chan: SSH channel to be closed.
330         """
331         chan.close()
332
333     def scp(self, local_path, remote_path, get=False, timeout=30,
334             wildcard=False):
335         """Copy files from local_path to remote_path or vice versa.
336
337         connect() method has to be called first!
338
339         :param local_path: Path to local file that should be uploaded; or
340         path where to save remote file.
341         :param remote_path: Remote path where to place uploaded file; or
342         path to remote file which should be downloaded.
343         :param get: scp operation to perform. Default is put.
344         :param timeout: Timeout value in seconds.
345         :param wildcard: If path has wildcard characters. Default is false.
346         :type local_path: str
347         :type remote_path: str
348         :type get: bool
349         :type timeout: int
350         :type wildcard: bool
351         """
352         if not get:
353             logger.trace('SCP {0} to {1}:{2}'.format(
354                 local_path, self._ssh.get_transport().getpeername(),
355                 remote_path))
356         else:
357             logger.trace('SCP {0}:{1} to {2}'.format(
358                 self._ssh.get_transport().getpeername(), remote_path,
359                 local_path))
360         # SCPCLient takes a paramiko transport as its only argument
361         if not wildcard:
362             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
363         else:
364             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
365                             socket_timeout=timeout)
366         start = time()
367         if not get:
368             scp.put(local_path, remote_path)
369         else:
370             scp.get(remote_path, local_path)
371         scp.close()
372         end = time()
373         logger.trace('SCP took {0} seconds'.format(end-start))
374
375
376 def exec_cmd(node, cmd, timeout=600, sudo=False):
377     """Convenience function to ssh/exec/return rc, out & err.
378
379     Returns (rc, stdout, stderr).
380
381     :param node: The node to execute command on.
382     :param cmd: Command to execute.
383     :param timeout: Timeout value in seconds. Default: 600.
384     :param sudo: Sudo privilege execution flag. Default: False.
385     :type node: dict
386     :type cmd: str
387     :type timeout: int
388     :type sudo: bool
389     :returns: RC, Stdout, Stderr.
390     :rtype: tuple(int, str, str)
391     """
392     if node is None:
393         raise TypeError('Node parameter is None')
394     if cmd is None:
395         raise TypeError('Command parameter is None')
396     if not cmd:
397         raise ValueError('Empty command parameter')
398
399     ssh = SSH()
400
401     if node.get('host_port') is not None:
402         ssh_node = dict()
403         ssh_node['host'] = '127.0.0.1'
404         ssh_node['port'] = node['port']
405         ssh_node['username'] = node['username']
406         ssh_node['password'] = node['password']
407         import pexpect
408         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
409         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
410         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
411             format(tnl=tnl, op=options, user=node['host_username'],
412                    host=node['host'], host_port=node['host_port'])
413         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
414                      format(ssh_cmd=ssh_cmd))
415         child = pexpect.spawn(ssh_cmd)
416         child.expect('.* password: ')
417         logger.trace(child.after)
418         child.sendline(node['host_password'])
419         child.expect('Welcome .*')
420         logger.trace(child.after)
421         logger.trace('Local port forwarding finished.')
422     else:
423         ssh_node = node
424
425     try:
426         ssh.connect(ssh_node)
427     except SSHException as err:
428         logger.error("Failed to connect to node" + repr(err))
429         return None, None, None
430
431     try:
432         if not sudo:
433             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
434         else:
435             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
436                                                                timeout=timeout)
437     except SSHException as err:
438         logger.error(repr(err))
439         return None, None, None
440
441     return ret_code, stdout, stderr
442
443
444 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False, message=None):
445     """Convenience function to ssh/exec/return out & err.
446
447     Verifies that return code is zero.
448
449     :param node: DUT node.
450     :param cmd: Command to be executed.
451     :param timeout: Timeout value in seconds. Default: 600.
452     :param sudo: Sudo privilege execution flag. Default: False.
453     :param message: Error message in case of failure. Default: None.
454     :type node: dict
455     :type cmd: str
456     :type timeout: int
457     :type sudo: bool
458     :type message: str
459     :returns: Stdout, Stderr.
460     :rtype: tuple(str, str)
461     :raises RuntimeError: If bash return code is not 0.
462     """
463     ret_code, stdout, stderr = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
464     msg = ('Command execution failed: "{cmd}"\n{stderr}'.
465            format(cmd=cmd, stderr=stderr) if message is None else message)
466     if ret_code != 0:
467         raise RuntimeError(msg)
468
469     return stdout, stderr