FIX: VAT SSH timeout
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :return: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node, attempts=5):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63         """
64         try:
65             self._node = node
66             node_hash = self._node_hash(node)
67             if node_hash in SSH.__existing_connections:
68                 self._ssh = SSH.__existing_connections[node_hash]
69                 logger.debug('reusing ssh: {0}'.format(self._ssh))
70             else:
71                 start = time()
72                 pkey = None
73                 if 'priv_key' in node:
74                     pkey = RSAKey.from_private_key(
75                         StringIO.StringIO(node['priv_key']))
76
77                 self._ssh = paramiko.SSHClient()
78                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
79
80                 self._ssh.connect(node['host'], username=node['username'],
81                                   password=node.get('password'), pkey=pkey,
82                                   port=node['port'])
83
84                 self._ssh.get_transport().set_keepalive(10)
85
86                 SSH.__existing_connections[node_hash] = self._ssh
87
88                 logger.trace('connect took {} seconds'.format(time() - start))
89                 logger.debug('new ssh: {0}'.format(self._ssh))
90
91             logger.debug('Connect peer: {0}'.
92                          format(self._ssh.get_transport().getpeername()))
93             logger.debug('Connections: {0}'.
94                          format(str(SSH.__existing_connections)))
95         except:
96             if attempts > 0:
97                 self._reconnect(attempts-1)
98             else:
99                 raise
100
101     def disconnect(self, node):
102         """Close SSH connection to the node.
103
104         :param node: The node to disconnect from.
105         :type node: dict
106         """
107         node_hash = self._node_hash(node)
108         if node_hash in SSH.__existing_connections:
109             logger.debug('Disconnecting peer: {}, {}'.
110                          format(node['host'], node['port']))
111             ssh = SSH.__existing_connections.pop(node_hash)
112             ssh.close()
113
114     def _reconnect(self, attempts=0):
115         """Close the SSH connection and open it again."""
116
117         node = self._node
118         self.disconnect(node)
119         self.connect(node, attempts)
120         logger.debug('Reconnecting peer done: {}'.
121                      format(self._ssh.get_transport().getpeername()))
122
123     def exec_command(self, cmd, timeout=10):
124         """Execute SSH command on a new channel on the connected Node.
125
126         :param cmd: Command to run on the Node.
127         :param timeout: Maximal time in seconds to wait until the command is
128         done. If set to None then wait forever.
129         :type cmd: str
130         :type timeout: int
131         :return return_code, stdout, stderr
132         :rtype: tuple(int, str, str)
133         :raise SSHTimeout: If command is not finished in timeout time.
134         """
135         stdout = StringIO.StringIO()
136         stderr = StringIO.StringIO()
137         try:
138             chan = self._ssh.get_transport().open_session(timeout=5)
139         except AttributeError:
140             self._reconnect()
141             chan = self._ssh.get_transport().open_session(timeout=5)
142         except SSHException:
143             self._reconnect()
144             chan = self._ssh.get_transport().open_session(timeout=5)
145         chan.settimeout(timeout)
146         logger.trace('exec_command on {0}: {1}'
147                      .format(self._ssh.get_transport().getpeername(), cmd))
148
149         start = time()
150         chan.exec_command(cmd)
151         while not chan.exit_status_ready() and timeout is not None:
152             if chan.recv_ready():
153                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
154
155             if chan.recv_stderr_ready():
156                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
157
158             if time() - start > timeout:
159                 raise SSHTimeout(
160                     'Timeout exception during execution of command: {0}\n'
161                     'Current contents of stdout buffer: {1}\n'
162                     'Current contents of stderr buffer: {2}\n'
163                     .format(cmd, stdout.getvalue(), stderr.getvalue())
164                 )
165
166             sleep(0.1)
167         return_code = chan.recv_exit_status()
168
169         while chan.recv_ready():
170             stdout.write(chan.recv(self.__MAX_RECV_BUF))
171
172         while chan.recv_stderr_ready():
173             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
174
175         end = time()
176         logger.trace('exec_command on {0} took {1} seconds'.format(
177             self._ssh.get_transport().getpeername(), end-start))
178
179         logger.trace('return RC {}'.format(return_code))
180         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
181         logger.trace('return STDERR {}'.format(stderr.getvalue()))
182         return return_code, stdout.getvalue(), stderr.getvalue()
183
184     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
185         """Execute SSH command with sudo on a new channel on the connected Node.
186
187         :param cmd: Command to be executed.
188         :param cmd_input: Input redirected to the command.
189         :param timeout: Timeout.
190         :return: return_code, stdout, stderr
191
192         :Example:
193
194         >>> from ssh import SSH
195         >>> ssh = SSH()
196         >>> ssh.connect(node)
197         >>> # Execute command without input (sudo -S cmd)
198         >>> ssh.exec_command_sudo("ifconfig eth0 down")
199         >>> # Execute command with input (sudo -S cmd <<< "input")
200         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
201         """
202         if cmd_input is None:
203             command = 'sudo -S {c}'.format(c=cmd)
204         else:
205             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
206         return self.exec_command(command, timeout)
207
208     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
209                          timeout=30):
210         """Execute command in LXC on a new SSH channel on the connected Node.
211
212         :param lxc_cmd: Command to be executed.
213         :param lxc_name: LXC name.
214         :param lxc_params: Additional parameters for LXC attach.
215         :param sudo: Run in privileged LXC mode. Default: privileged
216         :param timeout: Timeout.
217         :type lxc_cmd: str
218         :type lxc_name: str
219         :type lxc_params: str
220         :type sudo: bool
221         :type timeout: int
222         :return: return_code, stdout, stderr
223         """
224         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
225             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
226
227         if sudo:
228             command = 'sudo -S {c}'.format(c=command)
229         return self.exec_command(command, timeout)
230
231     def interactive_terminal_open(self, time_out=30):
232         """Open interactive terminal on a new channel on the connected Node.
233
234         :param time_out: Timeout in seconds.
235         :return: SSH channel with opened terminal.
236
237         .. warning:: Interruptingcow is used here, and it uses
238            signal(SIGALRM) to let the operating system interrupt program
239            execution. This has the following limitations: Python signal
240            handlers only apply to the main thread, so you cannot use this
241            from other threads. You must not use this in a program that
242            uses SIGALRM itself (this includes certain profilers)
243         """
244         chan = self._ssh.get_transport().open_session()
245         chan.get_pty()
246         chan.invoke_shell()
247         chan.settimeout(int(time_out))
248         chan.set_combine_stderr(True)
249
250         buf = ''
251         while not buf.endswith((":~$ ", "~]$ ", "~]# ")):
252             try:
253                 chunk = chan.recv(self.__MAX_RECV_BUF)
254                 if not chunk:
255                     break
256                 buf += chunk
257                 if chan.exit_status_ready():
258                     logger.error('Channel exit status ready')
259                     break
260             except socket.timeout:
261                 logger.error('Socket timeout: {0}'.format(buf))
262                 raise Exception('Socket timeout: {0}'.format(buf))
263         return chan
264
265     def interactive_terminal_exec_command(self, chan, cmd, prompt):
266         """Execute command on interactive terminal.
267
268         interactive_terminal_open() method has to be called first!
269
270         :param chan: SSH channel with opened terminal.
271         :param cmd: Command to be executed.
272         :param prompt: Command prompt, sequence of characters used to
273         indicate readiness to accept commands.
274         :return: Command output.
275
276         .. warning:: Interruptingcow is used here, and it uses
277            signal(SIGALRM) to let the operating system interrupt program
278            execution. This has the following limitations: Python signal
279            handlers only apply to the main thread, so you cannot use this
280            from other threads. You must not use this in a program that
281            uses SIGALRM itself (this includes certain profilers)
282         """
283         chan.sendall('{c}\n'.format(c=cmd))
284         buf = ''
285         while not buf.endswith(prompt):
286             try:
287                 chunk = chan.recv(self.__MAX_RECV_BUF)
288                 if not chunk:
289                     break
290                 buf += chunk
291                 if chan.exit_status_ready():
292                     logger.error('Channel exit status ready')
293                     break
294             except socket.timeout:
295                 logger.error('Socket timeout during execution of command: '
296                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
297                 raise Exception('Socket timeout during execution of command: '
298                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
299         tmp = buf.replace(cmd.replace('\n', ''), '')
300         for item in prompt:
301             tmp.replace(item, '')
302         return tmp
303
304     @staticmethod
305     def interactive_terminal_close(chan):
306         """Close interactive terminal SSH channel.
307
308         :param: chan: SSH channel to be closed.
309         """
310         chan.close()
311
312     def scp(self, local_path, remote_path, get=False, timeout=10):
313         """Copy files from local_path to remote_path or vice versa.
314
315         connect() method has to be called first!
316
317         :param local_path: Path to local file that should be uploaded; or
318         path where to save remote file.
319         :param remote_path: Remote path where to place uploaded file; or
320         path to remote file which should be downloaded.
321         :param get: scp operation to perform. Default is put.
322         :param timeout: Timeout value in seconds.
323         :type local_path: str
324         :type remote_path: str
325         :type get: bool
326         :type timeout: int
327         """
328         if not get:
329             logger.trace('SCP {0} to {1}:{2}'.format(
330                 local_path, self._ssh.get_transport().getpeername(),
331                 remote_path))
332         else:
333             logger.trace('SCP {0}:{1} to {2}'.format(
334                 self._ssh.get_transport().getpeername(), remote_path,
335                 local_path))
336         # SCPCLient takes a paramiko transport as its only argument
337         scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
338         start = time()
339         if not get:
340             scp.put(local_path, remote_path)
341         else:
342             scp.get(remote_path, local_path)
343         scp.close()
344         end = time()
345         logger.trace('SCP took {0} seconds'.format(end-start))
346
347
348 def exec_cmd(node, cmd, timeout=600, sudo=False):
349     """Convenience function to ssh/exec/return rc, out & err.
350
351     Returns (rc, stdout, stderr).
352     """
353     if node is None:
354         raise TypeError('Node parameter is None')
355     if cmd is None:
356         raise TypeError('Command parameter is None')
357     if len(cmd) == 0:
358         raise ValueError('Empty command parameter')
359
360     ssh = SSH()
361     try:
362         ssh.connect(node)
363     except SSHException as err:
364         logger.error("Failed to connect to node" + str(err))
365         return None, None, None
366
367     try:
368         if not sudo:
369             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
370         else:
371             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
372                                                                timeout=timeout)
373     except SSHException as err:
374         logger.error(err)
375         return None, None, None
376
377     return ret_code, stdout, stderr
378
379
380 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
381     """Convenience function to ssh/exec/return out & err.
382
383     Verifies that return code is zero.
384
385     Returns (stdout, stderr).
386     """
387     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
388     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
389                  format(cmd, stderr))
390     return stdout, stderr