CSIT-622: Stateful Security Groups perf tests
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :return: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node, attempts=5):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63         """
64         try:
65             self._node = node
66             node_hash = self._node_hash(node)
67             if node_hash in SSH.__existing_connections:
68                 self._ssh = SSH.__existing_connections[node_hash]
69                 logger.debug('reusing ssh: {0}'.format(self._ssh))
70             else:
71                 start = time()
72                 pkey = None
73                 if 'priv_key' in node:
74                     pkey = RSAKey.from_private_key(
75                         StringIO.StringIO(node['priv_key']))
76
77                 self._ssh = paramiko.SSHClient()
78                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
79
80                 self._ssh.connect(node['host'], username=node['username'],
81                                   password=node.get('password'), pkey=pkey,
82                                   port=node['port'])
83
84                 self._ssh.get_transport().set_keepalive(10)
85
86                 SSH.__existing_connections[node_hash] = self._ssh
87
88                 logger.trace('connect took {} seconds'.format(time() - start))
89                 logger.debug('new ssh: {0}'.format(self._ssh))
90
91             logger.debug('Connect peer: {0}'.
92                          format(self._ssh.get_transport().getpeername()))
93             logger.debug('Connections: {0}'.
94                          format(str(SSH.__existing_connections)))
95         except:
96             if attempts > 0:
97                 self._reconnect(attempts-1)
98             else:
99                 raise
100
101     def disconnect(self, node):
102         """Close SSH connection to the node.
103
104         :param node: The node to disconnect from.
105         :type node: dict
106         """
107         node_hash = self._node_hash(node)
108         if node_hash in SSH.__existing_connections:
109             logger.debug('Disconnecting peer: {}, {}'.
110                          format(node['host'], node['port']))
111             ssh = SSH.__existing_connections.pop(node_hash)
112             ssh.close()
113
114     def _reconnect(self, attempts=0):
115         """Close the SSH connection and open it again."""
116
117         node = self._node
118         self.disconnect(node)
119         self.connect(node, attempts)
120         logger.debug('Reconnecting peer done: {}'.
121                      format(self._ssh.get_transport().getpeername()))
122
123     def exec_command(self, cmd, timeout=10):
124         """Execute SSH command on a new channel on the connected Node.
125
126         :param cmd: Command to run on the Node.
127         :param timeout: Maximal time in seconds to wait until the command is
128         done. If set to None then wait forever.
129         :type cmd: str
130         :type timeout: int
131         :return return_code, stdout, stderr
132         :rtype: tuple(int, str, str)
133         :raise SSHTimeout: If command is not finished in timeout time.
134         """
135         start = time()
136         stdout = StringIO.StringIO()
137         stderr = StringIO.StringIO()
138         try:
139             chan = self._ssh.get_transport().open_session(timeout=5)
140         except AttributeError:
141             self._reconnect()
142             chan = self._ssh.get_transport().open_session(timeout=5)
143         except SSHException:
144             self._reconnect()
145             chan = self._ssh.get_transport().open_session(timeout=5)
146         chan.settimeout(timeout)
147         logger.trace('exec_command on {0}: {1}'
148                      .format(self._ssh.get_transport().getpeername(), cmd))
149
150         chan.exec_command(cmd)
151         while not chan.exit_status_ready() and timeout is not None:
152             if chan.recv_ready():
153                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
154
155             if chan.recv_stderr_ready():
156                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
157
158             if time() - start > timeout:
159                 raise SSHTimeout(
160                     'Timeout exception.\n'
161                     'Current contents of stdout buffer: {0}\n'
162                     'Current contents of stderr buffer: {1}\n'
163                     .format(stdout.getvalue(), stderr.getvalue())
164                 )
165
166             sleep(0.1)
167         return_code = chan.recv_exit_status()
168
169         while chan.recv_ready():
170             stdout.write(chan.recv(self.__MAX_RECV_BUF))
171
172         while chan.recv_stderr_ready():
173             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
174
175         end = time()
176         logger.trace('exec_command on {0} took {1} seconds'.format(
177             self._ssh.get_transport().getpeername(), end-start))
178
179         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
180
181         logger.trace('return RC {}'.format(return_code))
182         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
183         logger.trace('return STDERR {}'.format(stderr.getvalue()))
184         return return_code, stdout.getvalue(), stderr.getvalue()
185
186     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
187         """Execute SSH command with sudo on a new channel on the connected Node.
188
189         :param cmd: Command to be executed.
190         :param cmd_input: Input redirected to the command.
191         :param timeout: Timeout.
192         :return: return_code, stdout, stderr
193
194         :Example:
195
196         >>> from ssh import SSH
197         >>> ssh = SSH()
198         >>> ssh.connect(node)
199         >>> # Execute command without input (sudo -S cmd)
200         >>> ssh.exec_command_sudo("ifconfig eth0 down")
201         >>> # Execute command with input (sudo -S cmd <<< "input")
202         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
203         """
204         if cmd_input is None:
205             command = 'sudo -S {c}'.format(c=cmd)
206         else:
207             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
208         return self.exec_command(command, timeout)
209
210     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
211                          timeout=30):
212         """Execute command in LXC on a new SSH channel on the connected Node.
213
214         :param lxc_cmd: Command to be executed.
215         :param lxc_name: LXC name.
216         :param lxc_params: Additional parameters for LXC attach.
217         :param sudo: Run in privileged LXC mode. Default: privileged
218         :param timeout: Timeout.
219         :type lxc_cmd: str
220         :type lxc_name: str
221         :type lxc_params: str
222         :type sudo: bool
223         :type timeout: int
224         :return: return_code, stdout, stderr
225         """
226         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
227             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
228
229         if sudo:
230             command = 'sudo -S {c}'.format(c=command)
231         return self.exec_command(command, timeout)
232
233     def interactive_terminal_open(self, time_out=30):
234         """Open interactive terminal on a new channel on the connected Node.
235
236         :param time_out: Timeout in seconds.
237         :return: SSH channel with opened terminal.
238
239         .. warning:: Interruptingcow is used here, and it uses
240            signal(SIGALRM) to let the operating system interrupt program
241            execution. This has the following limitations: Python signal
242            handlers only apply to the main thread, so you cannot use this
243            from other threads. You must not use this in a program that
244            uses SIGALRM itself (this includes certain profilers)
245         """
246         chan = self._ssh.get_transport().open_session()
247         chan.get_pty()
248         chan.invoke_shell()
249         chan.settimeout(int(time_out))
250         chan.set_combine_stderr(True)
251
252         buf = ''
253         while not buf.endswith((":~$ ", "~]$ ")):
254             try:
255                 chunk = chan.recv(self.__MAX_RECV_BUF)
256                 if not chunk:
257                     break
258                 buf += chunk
259                 if chan.exit_status_ready():
260                     logger.error('Channel exit status ready')
261                     break
262             except socket.timeout:
263                 raise Exception('Socket timeout: {0}'.format(buf))
264         return chan
265
266     def interactive_terminal_exec_command(self, chan, cmd, prompt):
267         """Execute command on interactive terminal.
268
269         interactive_terminal_open() method has to be called first!
270
271         :param chan: SSH channel with opened terminal.
272         :param cmd: Command to be executed.
273         :param prompt: Command prompt, sequence of characters used to
274         indicate readiness to accept commands.
275         :return: Command output.
276
277         .. warning:: Interruptingcow is used here, and it uses
278            signal(SIGALRM) to let the operating system interrupt program
279            execution. This has the following limitations: Python signal
280            handlers only apply to the main thread, so you cannot use this
281            from other threads. You must not use this in a program that
282            uses SIGALRM itself (this includes certain profilers)
283         """
284         chan.sendall('{c}\n'.format(c=cmd))
285         buf = ''
286         while not buf.endswith(prompt):
287             try:
288                 chunk = chan.recv(self.__MAX_RECV_BUF)
289                 if not chunk:
290                     break
291                 buf += chunk
292                 if chan.exit_status_ready():
293                     logger.error('Channel exit status ready')
294                     break
295             except socket.timeout:
296                 raise Exception('Socket timeout: {0}'.format(buf))
297         tmp = buf.replace(cmd.replace('\n', ''), '')
298         for item in prompt:
299             tmp.replace(item, '')
300         return tmp
301
302     @staticmethod
303     def interactive_terminal_close(chan):
304         """Close interactive terminal SSH channel.
305
306         :param: chan: SSH channel to be closed.
307         """
308         chan.close()
309
310     def scp(self, local_path, remote_path, get=False):
311         """Copy files from local_path to remote_path or vice versa.
312
313         connect() method has to be called first!
314
315         :param local_path: Path to local file that should be uploaded; or
316         path where to save remote file.
317         :param remote_path: Remote path where to place uploaded file; or
318         path to remote file which should be downloaded.
319         :param get: scp operation to perform. Default is put.
320         :type local_path: str
321         :type remote_path: str
322         :type get: bool
323         """
324         if not get:
325             logger.trace('SCP {0} to {1}:{2}'.format(
326                 local_path, self._ssh.get_transport().getpeername(),
327                 remote_path))
328         else:
329             logger.trace('SCP {0}:{1} to {2}'.format(
330                 self._ssh.get_transport().getpeername(), remote_path,
331                 local_path))
332         # SCPCLient takes a paramiko transport as its only argument
333         scp = SCPClient(self._ssh.get_transport(), socket_timeout=10)
334         start = time()
335         if not get:
336             scp.put(local_path, remote_path)
337         else:
338             scp.get(remote_path, local_path)
339         scp.close()
340         end = time()
341         logger.trace('SCP took {0} seconds'.format(end-start))
342
343
344 def exec_cmd(node, cmd, timeout=600, sudo=False):
345     """Convenience function to ssh/exec/return rc, out & err.
346
347     Returns (rc, stdout, stderr).
348     """
349     if node is None:
350         raise TypeError('Node parameter is None')
351     if cmd is None:
352         raise TypeError('Command parameter is None')
353     if len(cmd) == 0:
354         raise ValueError('Empty command parameter')
355
356     ssh = SSH()
357     try:
358         ssh.connect(node)
359     except SSHException as err:
360         logger.error("Failed to connect to node" + str(err))
361         return None, None, None
362
363     try:
364         if not sudo:
365             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
366         else:
367             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
368                                                                timeout=timeout)
369     except SSHException as err:
370         logger.error(err)
371         return None, None, None
372
373     return ret_code, stdout, stderr
374
375
376 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
377     """Convenience function to ssh/exec/return out & err.
378
379     Verifies that return code is zero.
380
381     Returns (stdout, stderr).
382     """
383     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
384     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
385                  format(cmd, stderr))
386     return stdout, stderr