Fix various pylint 1.5.4 warnings
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2018 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :returns: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node, attempts=5):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63         """
64         try:
65             self._node = node
66             node_hash = self._node_hash(node)
67             if node_hash in SSH.__existing_connections:
68                 self._ssh = SSH.__existing_connections[node_hash]
69                 logger.debug('reusing ssh: {0}'.format(self._ssh))
70             else:
71                 start = time()
72                 pkey = None
73                 if 'priv_key' in node:
74                     pkey = RSAKey.from_private_key(
75                         StringIO.StringIO(node['priv_key']))
76
77                 self._ssh = paramiko.SSHClient()
78                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
79
80                 self._ssh.connect(node['host'], username=node['username'],
81                                   password=node.get('password'), pkey=pkey,
82                                   port=node['port'])
83
84                 self._ssh.get_transport().set_keepalive(10)
85
86                 SSH.__existing_connections[node_hash] = self._ssh
87
88                 logger.trace('connect took {} seconds'.format(time() - start))
89                 logger.debug('new ssh: {0}'.format(self._ssh))
90
91             logger.debug('Connect peer: {0}'.
92                          format(self._ssh.get_transport().getpeername()))
93             logger.debug('Connections: {0}'.
94                          format(str(SSH.__existing_connections)))
95         except RuntimeError as exc:
96             if attempts > 0:
97                 self._reconnect(attempts-1)
98             else:
99                 raise exc
100
101     def disconnect(self, node):
102         """Close SSH connection to the node.
103
104         :param node: The node to disconnect from.
105         :type node: dict
106         """
107         node_hash = self._node_hash(node)
108         if node_hash in SSH.__existing_connections:
109             logger.debug('Disconnecting peer: {}, {}'.
110                          format(node['host'], node['port']))
111             ssh = SSH.__existing_connections.pop(node_hash)
112             ssh.close()
113
114     def _reconnect(self, attempts=0):
115         """Close the SSH connection and open it again."""
116
117         node = self._node
118         self.disconnect(node)
119         self.connect(node, attempts)
120         logger.debug('Reconnecting peer done: {}'.
121                      format(self._ssh.get_transport().getpeername()))
122
123     def exec_command(self, cmd, timeout=10):
124         """Execute SSH command on a new channel on the connected Node.
125
126         :param cmd: Command to run on the Node.
127         :param timeout: Maximal time in seconds to wait until the command is
128         done. If set to None then wait forever.
129         :type cmd: str
130         :type timeout: int
131         :return return_code, stdout, stderr
132         :rtype: tuple(int, str, str)
133         :raise SSHTimeout: If command is not finished in timeout time.
134         """
135         stdout = StringIO.StringIO()
136         stderr = StringIO.StringIO()
137         try:
138             chan = self._ssh.get_transport().open_session(timeout=5)
139             peer = self._ssh.get_transport().getpeername()
140         except AttributeError:
141             self._reconnect()
142             chan = self._ssh.get_transport().open_session(timeout=5)
143             peer = self._ssh.get_transport().getpeername()
144         except SSHException:
145             self._reconnect()
146             chan = self._ssh.get_transport().open_session(timeout=5)
147             peer = self._ssh.get_transport().getpeername()
148         chan.settimeout(timeout)
149
150         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
151                      .format(peer=peer, timeout=timeout, cmd=cmd))
152
153         start = time()
154         chan.exec_command(cmd)
155         while not chan.exit_status_ready() and timeout is not None:
156             if chan.recv_ready():
157                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
158
159             if chan.recv_stderr_ready():
160                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
161
162             if time() - start > timeout:
163                 raise SSHTimeout(
164                     'Timeout exception during execution of command: {cmd}\n'
165                     'Current contents of stdout buffer: {stdout}\n'
166                     'Current contents of stderr buffer: {stderr}\n'
167                     .format(cmd=cmd, stdout=stdout.getvalue(),
168                             stderr=stderr.getvalue())
169                 )
170
171             sleep(0.1)
172         return_code = chan.recv_exit_status()
173
174         while chan.recv_ready():
175             stdout.write(chan.recv(self.__MAX_RECV_BUF))
176
177         while chan.recv_stderr_ready():
178             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
179
180         end = time()
181         logger.trace('exec_command on {peer} took {total} seconds'.
182                      format(peer=peer, total=end-start))
183
184         logger.trace('return RC {rc}'.format(rc=return_code))
185         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
186         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
187         return return_code, stdout.getvalue(), stderr.getvalue()
188
189     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
190         """Execute SSH command with sudo on a new channel on the connected Node.
191
192         :param cmd: Command to be executed.
193         :param cmd_input: Input redirected to the command.
194         :param timeout: Timeout.
195         :returns: return_code, stdout, stderr
196
197         :Example:
198
199         >>> from ssh import SSH
200         >>> ssh = SSH()
201         >>> ssh.connect(node)
202         >>> # Execute command without input (sudo -S cmd)
203         >>> ssh.exec_command_sudo("ifconfig eth0 down")
204         >>> # Execute command with input (sudo -S cmd <<< "input")
205         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
206         """
207         if cmd_input is None:
208             command = 'sudo -S {c}'.format(c=cmd)
209         else:
210             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
211         return self.exec_command(command, timeout)
212
213     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
214                          timeout=30):
215         """Execute command in LXC on a new SSH channel on the connected Node.
216
217         :param lxc_cmd: Command to be executed.
218         :param lxc_name: LXC name.
219         :param lxc_params: Additional parameters for LXC attach.
220         :param sudo: Run in privileged LXC mode. Default: privileged
221         :param timeout: Timeout.
222         :type lxc_cmd: str
223         :type lxc_name: str
224         :type lxc_params: str
225         :type sudo: bool
226         :type timeout: int
227         :returns: return_code, stdout, stderr
228         """
229         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
230             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
231
232         if sudo:
233             command = 'sudo -S {c}'.format(c=command)
234         return self.exec_command(command, timeout)
235
236     def interactive_terminal_open(self, time_out=30):
237         """Open interactive terminal on a new channel on the connected Node.
238
239         FIXME: Convert or document other possible exceptions, such as
240         socket.error or SSHException.
241
242         .. warning:: Interruptingcow is used here, and it uses
243            signal(SIGALRM) to let the operating system interrupt program
244            execution. This has the following limitations: Python signal
245            handlers only apply to the main thread, so you cannot use this
246            from other threads. You must not use this in a program that
247            uses SIGALRM itself (this includes certain profilers)
248
249         :param time_out: Timeout in seconds.
250         :returns: SSH channel with opened terminal.
251         :raise IOError: If receive attempt results in socket.timeout.
252         """
253         chan = self._ssh.get_transport().open_session()
254         chan.get_pty()
255         chan.invoke_shell()
256         chan.settimeout(int(time_out))
257         chan.set_combine_stderr(True)
258
259         buf = ''
260         while not buf.endswith((":~$ ", "~]$ ", "~]# ")):
261             try:
262                 chunk = chan.recv(self.__MAX_RECV_BUF)
263                 if not chunk:
264                     break
265                 buf += chunk
266                 if chan.exit_status_ready():
267                     logger.error('Channel exit status ready')
268                     break
269             except socket.timeout:
270                 logger.error('Socket timeout: {0}'.format(buf))
271                 # TODO: Find out which exception would callers appreciate here.
272                 raise IOError('Socket timeout: {0}'.format(buf))
273         return chan
274
275     def interactive_terminal_exec_command(self, chan, cmd, prompt):
276         """Execute command on interactive terminal.
277
278         interactive_terminal_open() method has to be called first!
279
280         .. warning:: Interruptingcow is used here, and it uses
281            signal(SIGALRM) to let the operating system interrupt program
282            execution. This has the following limitations: Python signal
283            handlers only apply to the main thread, so you cannot use this
284            from other threads. You must not use this in a program that
285            uses SIGALRM itself (this includes certain profilers)
286
287         :param chan: SSH channel with opened terminal.
288         :param cmd: Command to be executed.
289         :param prompt: Command prompt, sequence of characters used to
290             indicate readiness to accept commands.
291         :returns: Command output.
292         :raise IOError: If receive attempt results in socket.timeout.
293         """
294         chan.sendall('{c}\n'.format(c=cmd))
295         buf = ''
296         while not buf.endswith(prompt):
297             try:
298                 chunk = chan.recv(self.__MAX_RECV_BUF)
299                 if not chunk:
300                     break
301                 buf += chunk
302                 if chan.exit_status_ready():
303                     logger.error('Channel exit status ready')
304                     break
305             except socket.timeout:
306                 logger.error('Socket timeout during execution of command: '
307                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
308                 # TODO: Find out which exception would callers appreciate here.
309                 raise IOError('Socket timeout during execution of command: '
310                               '{0}\nBuffer content:\n{1}'.format(cmd, buf))
311         tmp = buf.replace(cmd.replace('\n', ''), '')
312         for item in prompt:
313             tmp.replace(item, '')
314         return tmp
315
316     @staticmethod
317     def interactive_terminal_close(chan):
318         """Close interactive terminal SSH channel.
319
320         :param: chan: SSH channel to be closed.
321         """
322         chan.close()
323
324     def scp(self, local_path, remote_path, get=False, timeout=10):
325         """Copy files from local_path to remote_path or vice versa.
326
327         connect() method has to be called first!
328
329         :param local_path: Path to local file that should be uploaded; or
330         path where to save remote file.
331         :param remote_path: Remote path where to place uploaded file; or
332         path to remote file which should be downloaded.
333         :param get: scp operation to perform. Default is put.
334         :param timeout: Timeout value in seconds.
335         :type local_path: str
336         :type remote_path: str
337         :type get: bool
338         :type timeout: int
339         """
340         if not get:
341             logger.trace('SCP {0} to {1}:{2}'.format(
342                 local_path, self._ssh.get_transport().getpeername(),
343                 remote_path))
344         else:
345             logger.trace('SCP {0}:{1} to {2}'.format(
346                 self._ssh.get_transport().getpeername(), remote_path,
347                 local_path))
348         # SCPCLient takes a paramiko transport as its only argument
349         scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
350         start = time()
351         if not get:
352             scp.put(local_path, remote_path)
353         else:
354             scp.get(remote_path, local_path)
355         scp.close()
356         end = time()
357         logger.trace('SCP took {0} seconds'.format(end-start))
358
359
360 def exec_cmd(node, cmd, timeout=600, sudo=False):
361     """Convenience function to ssh/exec/return rc, out & err.
362
363     FIXME: Document :param, :type, :raise and similar.
364     Returns (rc, stdout, stderr).
365     """
366     if node is None:
367         raise TypeError('Node parameter is None')
368     if cmd is None:
369         raise TypeError('Command parameter is None')
370     if len(cmd) == 0:
371         raise ValueError('Empty command parameter')
372
373     ssh = SSH()
374     try:
375         ssh.connect(node)
376     except SSHException as err:
377         logger.error("Failed to connect to node" + str(err))
378         return None, None, None
379
380     try:
381         if not sudo:
382             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
383         else:
384             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
385                                                                timeout=timeout)
386     except SSHException as err:
387         logger.error(err)
388         return None, None, None
389
390     return ret_code, stdout, stderr
391
392
393 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
394     """Convenience function to ssh/exec/return out & err.
395
396     Verifies that return code is zero.
397
398     Returns (stdout, stderr).
399     """
400     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
401     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
402                  format(cmd, stderr))
403     return stdout, stderr