Fix pylint warnings in python libraries
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from interruptingcow import timeout as icTimeout
25 from robot.api import logger
26 from robot.utils.asserts import assert_equal
27
28 __all__ = ["exec_cmd", "exec_cmd_no_error"]
29
30 # TODO: load priv key
31
32
33 class SSH(object):
34     """Contains methods for managing and using SSH connections."""
35
36     __MAX_RECV_BUF = 10*1024*1024
37     __existing_connections = {}
38
39     def __init__(self):
40         self._ssh = None
41         self._node = None
42
43     @staticmethod
44     def _node_hash(node):
45         """Get IP address and port hash from node dictionary.
46
47         :param node: Node in topology.
48         :type node: dict
49         :return: IP address and port for the specified node.
50         :rtype: int
51         """
52
53         return hash(frozenset([node['host'], node['port']]))
54
55     def connect(self, node):
56         """Connect to node prior to running exec_command or scp.
57
58         If there already is a connection to the node, this method reuses it.
59         """
60         self._node = node
61         node_hash = self._node_hash(node)
62         if node_hash in SSH.__existing_connections:
63             self._ssh = SSH.__existing_connections[node_hash]
64             logger.debug('reusing ssh: {0}'.format(self._ssh))
65         else:
66             start = time()
67             pkey = None
68             if 'priv_key' in node:
69                 pkey = RSAKey.from_private_key(
70                     StringIO.StringIO(node['priv_key']))
71
72             self._ssh = paramiko.SSHClient()
73             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
74
75             self._ssh.connect(node['host'], username=node['username'],
76                               password=node.get('password'), pkey=pkey,
77                               port=node['port'])
78
79             self._ssh.get_transport().set_keepalive(10)
80
81             SSH.__existing_connections[node_hash] = self._ssh
82
83             logger.trace('connect took {} seconds'.format(time() - start))
84             logger.debug('new ssh: {0}'.format(self._ssh))
85
86         logger.debug('Connect peer: {0}'.
87                      format(self._ssh.get_transport().getpeername()))
88         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
89
90     def disconnect(self, node):
91         """Close SSH connection to the node.
92
93         :param node: The node to disconnect from.
94         :type node: dict
95         """
96         node_hash = self._node_hash(node)
97         if node_hash in SSH.__existing_connections:
98             logger.debug('Disconnecting peer: {}, {}'.
99                          format(node['host'], node['port']))
100             ssh = SSH.__existing_connections.pop(node_hash)
101             ssh.close()
102
103     def _reconnect(self):
104         """Close the SSH connection and open it again."""
105
106         node = self._node
107         self.disconnect(node)
108         self.connect(node)
109         logger.debug('Reconnecting peer done: {}'.
110                      format(self._ssh.get_transport().getpeername()))
111
112     def exec_command(self, cmd, timeout=10):
113         """Execute SSH command on a new channel on the connected Node.
114
115         :param cmd: Command to run on the Node.
116         :param timeout: Maximal time in seconds to wait until the command is
117         done. If set to None then wait forever.
118         :type cmd: str
119         :type timeout: int
120         :return return_code, stdout, stderr
121         :rtype: tuple(int, str, str)
122         :raise socket.timeout: If command is not finished in timeout time.
123         """
124         start = time()
125         stdout = StringIO.StringIO()
126         stderr = StringIO.StringIO()
127         try:
128             chan = self._ssh.get_transport().open_session(timeout=5)
129         except AttributeError:
130             self._reconnect()
131             chan = self._ssh.get_transport().open_session(timeout=5)
132         except SSHException:
133             self._reconnect()
134             chan = self._ssh.get_transport().open_session(timeout=5)
135         chan.settimeout(timeout)
136         logger.trace('exec_command on {0}: {1}'
137                      .format(self._ssh.get_transport().getpeername(), cmd))
138
139         chan.exec_command(cmd)
140         while not chan.exit_status_ready() and timeout is not None:
141             if chan.recv_ready():
142                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
143
144             if chan.recv_stderr_ready():
145                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
146
147             if time() - start > timeout:
148                 raise socket.timeout(
149                     'Timeout exception.\n'
150                     'Current contents of stdout buffer: {0}\n'
151                     'Current contents of stderr buffer: {1}\n'
152                     .format(stdout.getvalue(), stderr.getvalue())
153                 )
154
155             sleep(0.1)
156         return_code = chan.recv_exit_status()
157
158         while chan.recv_ready():
159             stdout.write(chan.recv(self.__MAX_RECV_BUF))
160
161         while chan.recv_stderr_ready():
162             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
163
164         end = time()
165         logger.trace('exec_command on {0} took {1} seconds'.format(
166             self._ssh.get_transport().getpeername(), end-start))
167
168         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
169
170         logger.trace('return RC {}'.format(return_code))
171         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
172         logger.trace('return STDERR {}'.format(stderr.getvalue()))
173         return return_code, stdout.getvalue(), stderr.getvalue()
174
175     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
176         """Execute SSH command with sudo on a new channel on the connected Node.
177
178         :param cmd: Command to be executed.
179         :param cmd_input: Input redirected to the command.
180         :param timeout: Timeout.
181         :return: return_code, stdout, stderr
182
183         :Example:
184
185         >>> from ssh import SSH
186         >>> ssh = SSH()
187         >>> ssh.connect(node)
188         >>> # Execute command without input (sudo -S cmd)
189         >>> ssh.exec_command_sudo("ifconfig eth0 down")
190         >>> # Execute command with input (sudo -S cmd <<< "input")
191         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
192         """
193         if cmd_input is None:
194             command = 'sudo -S {c}'.format(c=cmd)
195         else:
196             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
197         return self.exec_command(command, timeout)
198
199     def interactive_terminal_open(self, time_out=10):
200         """Open interactive terminal on a new channel on the connected Node.
201
202         :param time_out: Timeout in seconds.
203         :return: SSH channel with opened terminal.
204
205         .. warning:: Interruptingcow is used here, and it uses
206            signal(SIGALRM) to let the operating system interrupt program
207            execution. This has the following limitations: Python signal
208            handlers only apply to the main thread, so you cannot use this
209            from other threads. You must not use this in a program that
210            uses SIGALRM itself (this includes certain profilers)
211         """
212         chan = self._ssh.get_transport().open_session()
213         chan.get_pty()
214         chan.invoke_shell()
215         chan.settimeout(int(time_out))
216
217         buf = ''
218         try:
219             with icTimeout(time_out, exception=RuntimeError):
220                 while not buf.endswith(':~$ '):
221                     if chan.recv_ready():
222                         buf = chan.recv(4096)
223         except RuntimeError:
224             raise Exception('Open interactive terminal timeout.')
225         return chan
226
227     @staticmethod
228     def interactive_terminal_exec_command(chan, cmd, prompt,
229                                           time_out=30):
230         """Execute command on interactive terminal.
231
232         interactive_terminal_open() method has to be called first!
233
234         :param chan: SSH channel with opened terminal.
235         :param cmd: Command to be executed.
236         :param prompt: Command prompt, sequence of characters used to
237         indicate readiness to accept commands.
238         :param time_out: Timeout in seconds.
239         :return: Command output.
240
241         .. warning:: Interruptingcow is used here, and it uses
242            signal(SIGALRM) to let the operating system interrupt program
243            execution. This has the following limitations: Python signal
244            handlers only apply to the main thread, so you cannot use this
245            from other threads. You must not use this in a program that
246            uses SIGALRM itself (this includes certain profilers)
247         """
248         chan.sendall('{c}\n'.format(c=cmd))
249         buf = ''
250         try:
251             with icTimeout(time_out, exception=RuntimeError):
252                 while not buf.endswith(prompt):
253                     if chan.recv_ready():
254                         buf += chan.recv(4096)
255         except RuntimeError:
256             raise Exception("Exec '{c}' timeout.".format(c=cmd))
257         tmp = buf.replace(cmd.replace('\n', ''), '')
258         return tmp.replace(prompt, '')
259
260     @staticmethod
261     def interactive_terminal_close(chan):
262         """Close interactive terminal SSH channel.
263
264         :param: chan: SSH channel to be closed.
265         """
266         chan.close()
267
268     def scp(self, local_path, remote_path):
269         """Copy files from local_path to remote_path.
270
271         connect() method has to be called first!
272         """
273         logger.trace('SCP {0} to {1}:{2}'.format(
274             local_path, self._ssh.get_transport().getpeername(), remote_path))
275         # SCPCLient takes a paramiko transport as its only argument
276         scp = SCPClient(self._ssh.get_transport())
277         start = time()
278         scp.put(local_path, remote_path)
279         scp.close()
280         end = time()
281         logger.trace('SCP took {0} seconds'.format(end-start))
282
283
284 def exec_cmd(node, cmd, timeout=600, sudo=False):
285     """Convenience function to ssh/exec/return rc, out & err.
286
287     Returns (rc, stdout, stderr).
288     """
289     if node is None:
290         raise TypeError('Node parameter is None')
291     if cmd is None:
292         raise TypeError('Command parameter is None')
293     if len(cmd) == 0:
294         raise ValueError('Empty command parameter')
295
296     ssh = SSH()
297     try:
298         ssh.connect(node)
299     except Exception as err:
300         logger.error("Failed to connect to node" + str(err))
301         return None, None, None
302
303     try:
304         if not sudo:
305             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
306         else:
307             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
308                                                                timeout=timeout)
309     except Exception as err:
310         logger.error(err)
311         return None, None, None
312
313     return ret_code, stdout, stderr
314
315
316 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
317     """Convenience function to ssh/exec/return out & err.
318
319     Verifies that return code is zero.
320
321     Returns (stdout, stderr).
322     """
323     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
324     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
325                  format(cmd, stderr))
326     return stdout, stderr