VPP_Device - add baseline tests - part IIa)
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2019 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import paramiko
17 import socket
18 import StringIO
19
20 from paramiko import RSAKey
21 from paramiko.ssh_exception import SSHException, NoValidConnectionsError
22 from robot.api import logger
23 from scp import SCPClient
24 from time import time, sleep
25
26 __all__ = ["exec_cmd", "exec_cmd_no_error"]
27
28 # TODO: load priv key
29
30
31 class SSHTimeout(Exception):
32     """This exception is raised when a timeout occurs."""
33     pass
34
35
36 class SSH(object):
37     """Contains methods for managing and using SSH connections."""
38
39     __MAX_RECV_BUF = 10*1024*1024
40     __existing_connections = {}
41
42     def __init__(self):
43         self._ssh = None
44         self._node = None
45
46     @staticmethod
47     def _node_hash(node):
48         """Get IP address and port hash from node dictionary.
49
50         :param node: Node in topology.
51         :type node: dict
52         :returns: IP address and port for the specified node.
53         :rtype: int
54         """
55
56         return hash(frozenset([node['host'], node['port']]))
57
58     def connect(self, node, attempts=5):
59         """Connect to node prior to running exec_command or scp.
60
61         If there already is a connection to the node, this method reuses it.
62
63         :param node: Node in topology.
64         :param attempts: Number of reconnect attempts.
65         :type node: dict
66         :type attempts: int
67         :raises IOError: If cannot connect to host.
68         """
69         self._node = node
70         node_hash = self._node_hash(node)
71         if node_hash in SSH.__existing_connections:
72             self._ssh = SSH.__existing_connections[node_hash]
73             if self._ssh.get_transport().is_active():
74                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
75             else:
76                 if attempts > 0:
77                     self._reconnect(attempts-1)
78                 else:
79                     raise IOError('Cannot connect to {host}'.
80                                   format(host=node['host']))
81         else:
82             try:
83                 start = time()
84                 pkey = None
85                 if 'priv_key' in node:
86                     pkey = RSAKey.from_private_key(
87                         StringIO.StringIO(node['priv_key']))
88
89                 self._ssh = paramiko.SSHClient()
90                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
91
92                 self._ssh.connect(node['host'], username=node['username'],
93                                   password=node.get('password'), pkey=pkey,
94                                   port=node['port'])
95
96                 self._ssh.get_transport().set_keepalive(10)
97
98                 SSH.__existing_connections[node_hash] = self._ssh
99                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
100                              format(
101                                  peer=self._ssh.get_transport().getpeername(),
102                                  total=(time() - start),
103                                  ssh=self._ssh))
104             except SSHException:
105                 raise IOError('Cannot connect to {host}'.
106                               format(host=node['host']))
107             except NoValidConnectionsError as err:
108                 logger.error(repr(err))
109                 raise IOError('Unable to connect to port {port} on {host}'.
110                               format(port=node['port'], host=node['host']))
111
112     def disconnect(self, node):
113         """Close SSH connection to the node.
114
115         :param node: The node to disconnect from.
116         :type node: dict
117         """
118         node_hash = self._node_hash(node)
119         if node_hash in SSH.__existing_connections:
120             logger.debug('Disconnecting peer: {host}, {port}'.
121                          format(host=node['host'], port=node['port']))
122             ssh = SSH.__existing_connections.pop(node_hash)
123             ssh.close()
124
125     def _reconnect(self, attempts=0):
126         """Close the SSH connection and open it again.
127
128         :param attempts: Number of reconnect attempts.
129         :type attempts: int
130         """
131         node = self._node
132         self.disconnect(node)
133         self.connect(node, attempts)
134         logger.debug('Reconnecting peer done: {host}, {port}'.
135                      format(host=node['host'], port=node['port']))
136
137     def exec_command(self, cmd, timeout=10):
138         """Execute SSH command on a new channel on the connected Node.
139
140         :param cmd: Command to run on the Node.
141         :param timeout: Maximal time in seconds to wait until the command is
142         done. If set to None then wait forever.
143         :type cmd: str
144         :type timeout: int
145         :return return_code, stdout, stderr
146         :rtype: tuple(int, str, str)
147         :raise SSHTimeout: If command is not finished in timeout time.
148         """
149         stdout = StringIO.StringIO()
150         stderr = StringIO.StringIO()
151         try:
152             chan = self._ssh.get_transport().open_session(timeout=5)
153             peer = self._ssh.get_transport().getpeername()
154         except AttributeError:
155             self._reconnect()
156             chan = self._ssh.get_transport().open_session(timeout=5)
157             peer = self._ssh.get_transport().getpeername()
158         except SSHException:
159             self._reconnect()
160             chan = self._ssh.get_transport().open_session(timeout=5)
161             peer = self._ssh.get_transport().getpeername()
162         chan.settimeout(timeout)
163
164         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
165                      .format(peer=peer, timeout=timeout, cmd=cmd))
166
167         start = time()
168         chan.exec_command(cmd)
169         while not chan.exit_status_ready() and timeout is not None:
170             if chan.recv_ready():
171                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
172
173             if chan.recv_stderr_ready():
174                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
175
176             if time() - start > timeout:
177                 raise SSHTimeout(
178                     'Timeout exception during execution of command: {cmd}\n'
179                     'Current contents of stdout buffer: {stdout}\n'
180                     'Current contents of stderr buffer: {stderr}\n'
181                     .format(cmd=cmd, stdout=stdout.getvalue(),
182                             stderr=stderr.getvalue())
183                 )
184
185             sleep(0.1)
186         return_code = chan.recv_exit_status()
187
188         while chan.recv_ready():
189             stdout.write(chan.recv(self.__MAX_RECV_BUF))
190
191         while chan.recv_stderr_ready():
192             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
193
194         end = time()
195         logger.trace('exec_command on {peer} took {total} seconds'.
196                      format(peer=peer, total=end-start))
197
198         logger.trace('return RC {rc}'.format(rc=return_code))
199         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
200         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
201         return return_code, stdout.getvalue(), stderr.getvalue()
202
203     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
204         """Execute SSH command with sudo on a new channel on the connected Node.
205
206         :param cmd: Command to be executed.
207         :param cmd_input: Input redirected to the command.
208         :param timeout: Timeout.
209         :returns: return_code, stdout, stderr
210
211         :Example:
212
213         >>> from ssh import SSH
214         >>> ssh = SSH()
215         >>> ssh.connect(node)
216         >>> # Execute command without input (sudo -S cmd)
217         >>> ssh.exec_command_sudo("ifconfig eth0 down")
218         >>> # Execute command with input (sudo -S cmd <<< "input")
219         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
220         """
221         if cmd_input is None:
222             command = 'sudo -S {c}'.format(c=cmd)
223         else:
224             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
225         return self.exec_command(command, timeout)
226
227     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
228                          timeout=30):
229         """Execute command in LXC on a new SSH channel on the connected Node.
230
231         :param lxc_cmd: Command to be executed.
232         :param lxc_name: LXC name.
233         :param lxc_params: Additional parameters for LXC attach.
234         :param sudo: Run in privileged LXC mode. Default: privileged
235         :param timeout: Timeout.
236         :type lxc_cmd: str
237         :type lxc_name: str
238         :type lxc_params: str
239         :type sudo: bool
240         :type timeout: int
241         :returns: return_code, stdout, stderr
242         """
243         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
244             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
245
246         if sudo:
247             command = 'sudo -S {c}'.format(c=command)
248         return self.exec_command(command, timeout)
249
250     def interactive_terminal_open(self, time_out=45):
251         """Open interactive terminal on a new channel on the connected Node.
252
253         :param time_out: Timeout in seconds.
254         :returns: SSH channel with opened terminal.
255
256         .. warning:: Interruptingcow is used here, and it uses
257            signal(SIGALRM) to let the operating system interrupt program
258            execution. This has the following limitations: Python signal
259            handlers only apply to the main thread, so you cannot use this
260            from other threads. You must not use this in a program that
261            uses SIGALRM itself (this includes certain profilers)
262         """
263         chan = self._ssh.get_transport().open_session()
264         chan.get_pty()
265         chan.invoke_shell()
266         chan.settimeout(int(time_out))
267         chan.set_combine_stderr(True)
268
269         buf = ''
270         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
271             try:
272                 chunk = chan.recv(self.__MAX_RECV_BUF)
273                 if not chunk:
274                     break
275                 buf += chunk
276                 if chan.exit_status_ready():
277                     logger.error('Channel exit status ready')
278                     break
279             except socket.timeout:
280                 logger.error('Socket timeout: {0}'.format(buf))
281                 raise Exception('Socket timeout: {0}'.format(buf))
282         return chan
283
284     def interactive_terminal_exec_command(self, chan, cmd, prompt):
285         """Execute command on interactive terminal.
286
287         interactive_terminal_open() method has to be called first!
288
289         :param chan: SSH channel with opened terminal.
290         :param cmd: Command to be executed.
291         :param prompt: Command prompt, sequence of characters used to
292         indicate readiness to accept commands.
293         :returns: Command output.
294
295         .. warning:: Interruptingcow is used here, and it uses
296            signal(SIGALRM) to let the operating system interrupt program
297            execution. This has the following limitations: Python signal
298            handlers only apply to the main thread, so you cannot use this
299            from other threads. You must not use this in a program that
300            uses SIGALRM itself (this includes certain profilers)
301         """
302         chan.sendall('{c}\n'.format(c=cmd))
303         buf = ''
304         while not buf.endswith(prompt):
305             try:
306                 chunk = chan.recv(self.__MAX_RECV_BUF)
307                 if not chunk:
308                     break
309                 buf += chunk
310                 if chan.exit_status_ready():
311                     logger.error('Channel exit status ready')
312                     break
313             except socket.timeout:
314                 logger.error('Socket timeout during execution of command: '
315                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
316                 raise Exception('Socket timeout during execution of command: '
317                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
318         tmp = buf.replace(cmd.replace('\n', ''), '')
319         for item in prompt:
320             tmp.replace(item, '')
321         return tmp
322
323     @staticmethod
324     def interactive_terminal_close(chan):
325         """Close interactive terminal SSH channel.
326
327         :param chan: SSH channel to be closed.
328         """
329         chan.close()
330
331     def scp(self, local_path, remote_path, get=False, timeout=30,
332             wildcard=False):
333         """Copy files from local_path to remote_path or vice versa.
334
335         connect() method has to be called first!
336
337         :param local_path: Path to local file that should be uploaded; or
338         path where to save remote file.
339         :param remote_path: Remote path where to place uploaded file; or
340         path to remote file which should be downloaded.
341         :param get: scp operation to perform. Default is put.
342         :param timeout: Timeout value in seconds.
343         :param wildcard: If path has wildcard characters. Default is false.
344         :type local_path: str
345         :type remote_path: str
346         :type get: bool
347         :type timeout: int
348         :type wildcard: bool
349         """
350         if not get:
351             logger.trace('SCP {0} to {1}:{2}'.format(
352                 local_path, self._ssh.get_transport().getpeername(),
353                 remote_path))
354         else:
355             logger.trace('SCP {0}:{1} to {2}'.format(
356                 self._ssh.get_transport().getpeername(), remote_path,
357                 local_path))
358         # SCPCLient takes a paramiko transport as its only argument
359         if not wildcard:
360             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
361         else:
362             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
363                             socket_timeout=timeout)
364         start = time()
365         if not get:
366             scp.put(local_path, remote_path)
367         else:
368             scp.get(remote_path, local_path)
369         scp.close()
370         end = time()
371         logger.trace('SCP took {0} seconds'.format(end-start))
372
373
374 def exec_cmd(node, cmd, timeout=600, sudo=False):
375     """Convenience function to ssh/exec/return rc, out & err.
376
377     Returns (rc, stdout, stderr).
378
379     :param node: The node to execute command on.
380     :param cmd: Command to execute.
381     :param timeout: Timeout value in seconds. Default: 600.
382     :param sudo: Sudo privilege execution flag. Default: False.
383     :type node: dict
384     :type cmd: str
385     :type timeout: int
386     :type sudo: bool
387     :returns: RC, Stdout, Stderr.
388     :rtype: tuple(int, str, str)
389     """
390     if node is None:
391         raise TypeError('Node parameter is None')
392     if cmd is None:
393         raise TypeError('Command parameter is None')
394     if not cmd:
395         raise ValueError('Empty command parameter')
396
397     ssh = SSH()
398
399     if node.get('host_port') is not None:
400         ssh_node = dict()
401         ssh_node['host'] = '127.0.0.1'
402         ssh_node['port'] = node['port']
403         ssh_node['username'] = node['username']
404         ssh_node['password'] = node['password']
405         import pexpect
406         options = '-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null'
407         tnl = '-L {port}:127.0.0.1:{port}'.format(port=node['port'])
408         ssh_cmd = 'ssh {tnl} {op} {user}@{host} -p {host_port}'.\
409             format(tnl=tnl, op=options, user=node['host_username'],
410                    host=node['host'], host_port=node['host_port'])
411         logger.trace('Initializing local port forwarding:\n{ssh_cmd}'.
412                      format(ssh_cmd=ssh_cmd))
413         child = pexpect.spawn(ssh_cmd)
414         child.expect('.* password: ')
415         logger.trace(child.after)
416         child.sendline(node['host_password'])
417         child.expect('Welcome .*')
418         logger.trace(child.after)
419         logger.trace('Local port forwarding finished.')
420     else:
421         ssh_node = node
422
423     try:
424         ssh.connect(ssh_node)
425     except SSHException as err:
426         logger.error("Failed to connect to node" + repr(err))
427         return None, None, None
428
429     try:
430         if not sudo:
431             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
432         else:
433             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
434                                                                timeout=timeout)
435     except SSHException as err:
436         logger.error(repr(err))
437         return None, None, None
438
439     return ret_code, stdout, stderr
440
441
442 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False, message=None):
443     """Convenience function to ssh/exec/return out & err.
444
445     Verifies that return code is zero.
446
447     :param node: DUT node.
448     :param cmd: Command to be executed.
449     :param timeout: Timeout value in seconds. Default: 600.
450     :param sudo: Sudo privilege execution flag. Default: False.
451     :param message: Error message in case of failure. Default: None.
452     :type node: dict
453     :type cmd: str
454     :type timeout: int
455     :type sudo: bool
456     :type message: str
457     :returns: Stdout, Stderr.
458     :rtype: tuple(str, str)
459     :raises RuntimeError: If bash return code is not 0.
460     """
461     ret_code, stdout, stderr = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
462     msg = ('Command execution failed: "{cmd}"\n{stderr}'.
463            format(cmd=cmd, stderr=stderr) if message is None else message)
464     if ret_code != 0:
465         raise RuntimeError(msg)
466
467     return stdout, stderr