Adding DMM build artifacts
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2018 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25
26 __all__ = ["exec_cmd", "exec_cmd_no_error"]
27
28 # TODO: load priv key
29
30
31 class SSHTimeout(Exception):
32     """This exception is raised when a timeout occurs."""
33     pass
34
35
36 class SSH(object):
37     """Contains methods for managing and using SSH connections."""
38
39     __MAX_RECV_BUF = 10*1024*1024
40     __existing_connections = {}
41
42     def __init__(self):
43         self._ssh = None
44         self._node = None
45
46     @staticmethod
47     def _node_hash(node):
48         """Get IP address and port hash from node dictionary.
49
50         :param node: Node in topology.
51         :type node: dict
52         :returns: IP address and port for the specified node.
53         :rtype: int
54         """
55
56         return hash(frozenset([node['host'], node['port']]))
57
58     def connect(self, node, attempts=5):
59         """Connect to node prior to running exec_command or scp.
60
61         If there already is a connection to the node, this method reuses it.
62
63         :param node: Node in topology.
64         :param attempts: Number of reconnect attempts.
65         :type node: dict
66         :type attempts: int
67         :raises IOError: If cannot connect to host.
68         """
69         self._node = node
70         node_hash = self._node_hash(node)
71         if node_hash in SSH.__existing_connections:
72             self._ssh = SSH.__existing_connections[node_hash]
73             if self._ssh.get_transport().is_active():
74                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
75             else:
76                 if attempts > 0:
77                     self._reconnect(attempts-1)
78                 else:
79                     raise IOError('Cannot connect to {host}'.
80                                   format(host=node['host']))
81         else:
82             try:
83                 start = time()
84                 pkey = None
85                 if 'priv_key' in node:
86                     pkey = RSAKey.from_private_key(
87                         StringIO.StringIO(node['priv_key']))
88
89                 self._ssh = paramiko.SSHClient()
90                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
91
92                 self._ssh.connect(node['host'], username=node['username'],
93                                   password=node.get('password'), pkey=pkey,
94                                   port=node['port'])
95
96                 self._ssh.get_transport().set_keepalive(10)
97
98                 SSH.__existing_connections[node_hash] = self._ssh
99                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
100                              format(
101                                  peer=self._ssh.get_transport().getpeername(),
102                                  total=(time() - start),
103                                  ssh=self._ssh))
104             except SSHException:
105                 raise IOError('Cannot connect to {host}'.
106                               format(host=node['host']))
107
108     def disconnect(self, node):
109         """Close SSH connection to the node.
110
111         :param node: The node to disconnect from.
112         :type node: dict
113         """
114         node_hash = self._node_hash(node)
115         if node_hash in SSH.__existing_connections:
116             logger.debug('Disconnecting peer: {host}, {port}'.
117                          format(host=node['host'], port=node['port']))
118             ssh = SSH.__existing_connections.pop(node_hash)
119             ssh.close()
120
121     def _reconnect(self, attempts=0):
122         """Close the SSH connection and open it again.
123
124         :param attempts: Number of reconnect attempts.
125         :type attempts: int
126         """
127         node = self._node
128         self.disconnect(node)
129         self.connect(node, attempts)
130         logger.debug('Reconnecting peer done: {host}, {port}'.
131                      format(host=node['host'], port=node['port']))
132
133     def exec_command(self, cmd, timeout=10):
134         """Execute SSH command on a new channel on the connected Node.
135
136         :param cmd: Command to run on the Node.
137         :param timeout: Maximal time in seconds to wait until the command is
138         done. If set to None then wait forever.
139         :type cmd: str
140         :type timeout: int
141         :return return_code, stdout, stderr
142         :rtype: tuple(int, str, str)
143         :raise SSHTimeout: If command is not finished in timeout time.
144         """
145         stdout = StringIO.StringIO()
146         stderr = StringIO.StringIO()
147         try:
148             chan = self._ssh.get_transport().open_session(timeout=5)
149             peer = self._ssh.get_transport().getpeername()
150         except AttributeError:
151             self._reconnect()
152             chan = self._ssh.get_transport().open_session(timeout=5)
153             peer = self._ssh.get_transport().getpeername()
154         except SSHException:
155             self._reconnect()
156             chan = self._ssh.get_transport().open_session(timeout=5)
157             peer = self._ssh.get_transport().getpeername()
158         chan.settimeout(timeout)
159
160         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
161                      .format(peer=peer, timeout=timeout, cmd=cmd))
162
163         start = time()
164         chan.exec_command(cmd)
165         while not chan.exit_status_ready() and timeout is not None:
166             if chan.recv_ready():
167                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
168
169             if chan.recv_stderr_ready():
170                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
171
172             if time() - start > timeout:
173                 raise SSHTimeout(
174                     'Timeout exception during execution of command: {cmd}\n'
175                     'Current contents of stdout buffer: {stdout}\n'
176                     'Current contents of stderr buffer: {stderr}\n'
177                     .format(cmd=cmd, stdout=stdout.getvalue(),
178                             stderr=stderr.getvalue())
179                 )
180
181             sleep(0.1)
182         return_code = chan.recv_exit_status()
183
184         while chan.recv_ready():
185             stdout.write(chan.recv(self.__MAX_RECV_BUF))
186
187         while chan.recv_stderr_ready():
188             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
189
190         end = time()
191         logger.trace('exec_command on {peer} took {total} seconds'.
192                      format(peer=peer, total=end-start))
193
194         logger.trace('return RC {rc}'.format(rc=return_code))
195         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
196         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
197         return return_code, stdout.getvalue(), stderr.getvalue()
198
199     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
200         """Execute SSH command with sudo on a new channel on the connected Node.
201
202         :param cmd: Command to be executed.
203         :param cmd_input: Input redirected to the command.
204         :param timeout: Timeout.
205         :returns: return_code, stdout, stderr
206
207         :Example:
208
209         >>> from ssh import SSH
210         >>> ssh = SSH()
211         >>> ssh.connect(node)
212         >>> # Execute command without input (sudo -S cmd)
213         >>> ssh.exec_command_sudo("ifconfig eth0 down")
214         >>> # Execute command with input (sudo -S cmd <<< "input")
215         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
216         """
217         if cmd_input is None:
218             command = 'sudo -S {c}'.format(c=cmd)
219         else:
220             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
221         return self.exec_command(command, timeout)
222
223     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
224                          timeout=30):
225         """Execute command in LXC on a new SSH channel on the connected Node.
226
227         :param lxc_cmd: Command to be executed.
228         :param lxc_name: LXC name.
229         :param lxc_params: Additional parameters for LXC attach.
230         :param sudo: Run in privileged LXC mode. Default: privileged
231         :param timeout: Timeout.
232         :type lxc_cmd: str
233         :type lxc_name: str
234         :type lxc_params: str
235         :type sudo: bool
236         :type timeout: int
237         :returns: return_code, stdout, stderr
238         """
239         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
240             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
241
242         if sudo:
243             command = 'sudo -S {c}'.format(c=command)
244         return self.exec_command(command, timeout)
245
246     def interactive_terminal_open(self, time_out=45):
247         """Open interactive terminal on a new channel on the connected Node.
248
249         :param time_out: Timeout in seconds.
250         :returns: SSH channel with opened terminal.
251
252         .. warning:: Interruptingcow is used here, and it uses
253            signal(SIGALRM) to let the operating system interrupt program
254            execution. This has the following limitations: Python signal
255            handlers only apply to the main thread, so you cannot use this
256            from other threads. You must not use this in a program that
257            uses SIGALRM itself (this includes certain profilers)
258         """
259         chan = self._ssh.get_transport().open_session()
260         chan.get_pty()
261         chan.invoke_shell()
262         chan.settimeout(int(time_out))
263         chan.set_combine_stderr(True)
264
265         buf = ''
266         while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
267             try:
268                 chunk = chan.recv(self.__MAX_RECV_BUF)
269                 if not chunk:
270                     break
271                 buf += chunk
272                 if chan.exit_status_ready():
273                     logger.error('Channel exit status ready')
274                     break
275             except socket.timeout:
276                 logger.error('Socket timeout: {0}'.format(buf))
277                 raise Exception('Socket timeout: {0}'.format(buf))
278         return chan
279
280     def interactive_terminal_exec_command(self, chan, cmd, prompt):
281         """Execute command on interactive terminal.
282
283         interactive_terminal_open() method has to be called first!
284
285         :param chan: SSH channel with opened terminal.
286         :param cmd: Command to be executed.
287         :param prompt: Command prompt, sequence of characters used to
288         indicate readiness to accept commands.
289         :returns: Command output.
290
291         .. warning:: Interruptingcow is used here, and it uses
292            signal(SIGALRM) to let the operating system interrupt program
293            execution. This has the following limitations: Python signal
294            handlers only apply to the main thread, so you cannot use this
295            from other threads. You must not use this in a program that
296            uses SIGALRM itself (this includes certain profilers)
297         """
298         chan.sendall('{c}\n'.format(c=cmd))
299         buf = ''
300         while not buf.endswith(prompt):
301             try:
302                 chunk = chan.recv(self.__MAX_RECV_BUF)
303                 if not chunk:
304                     break
305                 buf += chunk
306                 if chan.exit_status_ready():
307                     logger.error('Channel exit status ready')
308                     break
309             except socket.timeout:
310                 logger.error('Socket timeout during execution of command: '
311                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
312                 raise Exception('Socket timeout during execution of command: '
313                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
314         tmp = buf.replace(cmd.replace('\n', ''), '')
315         for item in prompt:
316             tmp.replace(item, '')
317         return tmp
318
319     @staticmethod
320     def interactive_terminal_close(chan):
321         """Close interactive terminal SSH channel.
322
323         :param: chan: SSH channel to be closed.
324         """
325         chan.close()
326
327     def scp(self, local_path, remote_path, get=False, timeout=30,
328             wildcard=False):
329         """Copy files from local_path to remote_path or vice versa.
330
331         connect() method has to be called first!
332
333         :param local_path: Path to local file that should be uploaded; or
334         path where to save remote file.
335         :param remote_path: Remote path where to place uploaded file; or
336         path to remote file which should be downloaded.
337         :param get: scp operation to perform. Default is put.
338         :param timeout: Timeout value in seconds.
339         :param wildcard: If path has wildcard characters. Default is false.
340         :type local_path: str
341         :type remote_path: str
342         :type get: bool
343         :type timeout: int
344         :type wildcard: bool
345         """
346         if not get:
347             logger.trace('SCP {0} to {1}:{2}'.format(
348                 local_path, self._ssh.get_transport().getpeername(),
349                 remote_path))
350         else:
351             logger.trace('SCP {0}:{1} to {2}'.format(
352                 self._ssh.get_transport().getpeername(), remote_path,
353                 local_path))
354         # SCPCLient takes a paramiko transport as its only argument
355         if not wildcard:
356             scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
357         else:
358             scp = SCPClient(self._ssh.get_transport(), sanitize=lambda x: x,
359                             socket_timeout=timeout)
360         start = time()
361         if not get:
362             scp.put(local_path, remote_path)
363         else:
364             scp.get(remote_path, local_path)
365         scp.close()
366         end = time()
367         logger.trace('SCP took {0} seconds'.format(end-start))
368
369
370 def exec_cmd(node, cmd, timeout=600, sudo=False):
371     """Convenience function to ssh/exec/return rc, out & err.
372
373     Returns (rc, stdout, stderr).
374     """
375     if node is None:
376         raise TypeError('Node parameter is None')
377     if cmd is None:
378         raise TypeError('Command parameter is None')
379     if not cmd:
380         raise ValueError('Empty command parameter')
381
382     ssh = SSH()
383     try:
384         ssh.connect(node)
385     except SSHException as err:
386         logger.error("Failed to connect to node" + str(err))
387         return None, None, None
388
389     try:
390         if not sudo:
391             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
392         else:
393             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
394                                                                timeout=timeout)
395     except SSHException as err:
396         logger.error(err)
397         return None, None, None
398
399     return ret_code, stdout, stderr
400
401
402 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False, message=None):
403     """Convenience function to ssh/exec/return out & err.
404
405     Verifies that return code is zero.
406
407     :param node: DUT node.
408     :param cmd: Command to be executed.
409     :param timeout: Timeout value in seconds. Default: 600.
410     :param sudo: Sudo privilege execution flag. Default: False.
411     :param message: Error message in case of failure. Default: None.
412     :type node: dict
413     :type cmd: str
414     :type timeout: int
415     :type sudo: bool
416     :type message: str
417     :returns: Stdout, Stderr.
418     :rtype: tuple(str, str)
419     :raise RuntimeError: If bash return code is not 0.
420     """
421     ret_code, stdout, stderr = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
422     msg = ('Command execution failed: "{cmd}"\n{stderr}'.
423            format(cmd=cmd, stderr=stderr) if message is None else message)
424     if ret_code != 0:
425         raise RuntimeError(msg)
426
427     return stdout, stderr