CSIT-176 Fix interactive SSH console deadlock
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSH(object):
33     """Contains methods for managing and using SSH connections."""
34
35     __MAX_RECV_BUF = 10*1024*1024
36     __existing_connections = {}
37
38     def __init__(self):
39         self._ssh = None
40         self._node = None
41
42     @staticmethod
43     def _node_hash(node):
44         """Get IP address and port hash from node dictionary.
45
46         :param node: Node in topology.
47         :type node: dict
48         :return: IP address and port for the specified node.
49         :rtype: int
50         """
51
52         return hash(frozenset([node['host'], node['port']]))
53
54     def connect(self, node):
55         """Connect to node prior to running exec_command or scp.
56
57         If there already is a connection to the node, this method reuses it.
58         """
59         self._node = node
60         node_hash = self._node_hash(node)
61         if node_hash in SSH.__existing_connections:
62             self._ssh = SSH.__existing_connections[node_hash]
63             logger.debug('reusing ssh: {0}'.format(self._ssh))
64         else:
65             start = time()
66             pkey = None
67             if 'priv_key' in node:
68                 pkey = RSAKey.from_private_key(
69                     StringIO.StringIO(node['priv_key']))
70
71             self._ssh = paramiko.SSHClient()
72             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
73
74             self._ssh.connect(node['host'], username=node['username'],
75                               password=node.get('password'), pkey=pkey,
76                               port=node['port'])
77
78             self._ssh.get_transport().set_keepalive(10)
79
80             SSH.__existing_connections[node_hash] = self._ssh
81
82             logger.trace('connect took {} seconds'.format(time() - start))
83             logger.debug('new ssh: {0}'.format(self._ssh))
84
85         logger.debug('Connect peer: {0}'.
86                      format(self._ssh.get_transport().getpeername()))
87         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
88
89     def disconnect(self, node):
90         """Close SSH connection to the node.
91
92         :param node: The node to disconnect from.
93         :type node: dict
94         """
95         node_hash = self._node_hash(node)
96         if node_hash in SSH.__existing_connections:
97             logger.debug('Disconnecting peer: {}, {}'.
98                          format(node['host'], node['port']))
99             ssh = SSH.__existing_connections.pop(node_hash)
100             ssh.close()
101
102     def _reconnect(self):
103         """Close the SSH connection and open it again."""
104
105         node = self._node
106         self.disconnect(node)
107         self.connect(node)
108         logger.debug('Reconnecting peer done: {}'.
109                      format(self._ssh.get_transport().getpeername()))
110
111     def exec_command(self, cmd, timeout=10):
112         """Execute SSH command on a new channel on the connected Node.
113
114         :param cmd: Command to run on the Node.
115         :param timeout: Maximal time in seconds to wait until the command is
116         done. If set to None then wait forever.
117         :type cmd: str
118         :type timeout: int
119         :return return_code, stdout, stderr
120         :rtype: tuple(int, str, str)
121         :raise socket.timeout: If command is not finished in timeout time.
122         """
123         start = time()
124         stdout = StringIO.StringIO()
125         stderr = StringIO.StringIO()
126         try:
127             chan = self._ssh.get_transport().open_session(timeout=5)
128         except AttributeError:
129             self._reconnect()
130             chan = self._ssh.get_transport().open_session(timeout=5)
131         except SSHException:
132             self._reconnect()
133             chan = self._ssh.get_transport().open_session(timeout=5)
134         chan.settimeout(timeout)
135         logger.trace('exec_command on {0}: {1}'
136                      .format(self._ssh.get_transport().getpeername(), cmd))
137
138         chan.exec_command(cmd)
139         while not chan.exit_status_ready() and timeout is not None:
140             if chan.recv_ready():
141                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
142
143             if chan.recv_stderr_ready():
144                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
145
146             if time() - start > timeout:
147                 raise socket.timeout(
148                     'Timeout exception.\n'
149                     'Current contents of stdout buffer: {0}\n'
150                     'Current contents of stderr buffer: {1}\n'
151                     .format(stdout.getvalue(), stderr.getvalue())
152                 )
153
154             sleep(0.1)
155         return_code = chan.recv_exit_status()
156
157         while chan.recv_ready():
158             stdout.write(chan.recv(self.__MAX_RECV_BUF))
159
160         while chan.recv_stderr_ready():
161             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
162
163         end = time()
164         logger.trace('exec_command on {0} took {1} seconds'.format(
165             self._ssh.get_transport().getpeername(), end-start))
166
167         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
168
169         logger.trace('return RC {}'.format(return_code))
170         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
171         logger.trace('return STDERR {}'.format(stderr.getvalue()))
172         return return_code, stdout.getvalue(), stderr.getvalue()
173
174     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
175         """Execute SSH command with sudo on a new channel on the connected Node.
176
177         :param cmd: Command to be executed.
178         :param cmd_input: Input redirected to the command.
179         :param timeout: Timeout.
180         :return: return_code, stdout, stderr
181
182         :Example:
183
184         >>> from ssh import SSH
185         >>> ssh = SSH()
186         >>> ssh.connect(node)
187         >>> # Execute command without input (sudo -S cmd)
188         >>> ssh.exec_command_sudo("ifconfig eth0 down")
189         >>> # Execute command with input (sudo -S cmd <<< "input")
190         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
191         """
192         if cmd_input is None:
193             command = 'sudo -S {c}'.format(c=cmd)
194         else:
195             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
196         return self.exec_command(command, timeout)
197
198     def interactive_terminal_open(self, time_out=10):
199         """Open interactive terminal on a new channel on the connected Node.
200
201         :param time_out: Timeout in seconds.
202         :return: SSH channel with opened terminal.
203
204         .. warning:: Interruptingcow is used here, and it uses
205            signal(SIGALRM) to let the operating system interrupt program
206            execution. This has the following limitations: Python signal
207            handlers only apply to the main thread, so you cannot use this
208            from other threads. You must not use this in a program that
209            uses SIGALRM itself (this includes certain profilers)
210         """
211         chan = self._ssh.get_transport().open_session()
212         chan.get_pty()
213         chan.invoke_shell()
214         chan.settimeout(int(time_out))
215         chan.set_combine_stderr(True)
216
217         buf = ''
218         while not buf.endswith(':~$ '):
219             try:
220                 chunk = chan.recv(self.__MAX_RECV_BUF)
221                 if not chunk:
222                     break
223                 buf += chunk
224                 if chan.exit_status_ready():
225                     logger.error('Channel exit status ready')
226                     break
227             except socket.timeout:
228                 raise Exception('Socket timeout: {0}'.format(buf))
229         return chan
230
231     def interactive_terminal_exec_command(self, chan, cmd, prompt,
232                                           time_out=30):
233         """Execute command on interactive terminal.
234
235         interactive_terminal_open() method has to be called first!
236
237         :param chan: SSH channel with opened terminal.
238         :param cmd: Command to be executed.
239         :param prompt: Command prompt, sequence of characters used to
240         indicate readiness to accept commands.
241         :param time_out: Timeout in seconds.
242         :return: Command output.
243
244         .. warning:: Interruptingcow is used here, and it uses
245            signal(SIGALRM) to let the operating system interrupt program
246            execution. This has the following limitations: Python signal
247            handlers only apply to the main thread, so you cannot use this
248            from other threads. You must not use this in a program that
249            uses SIGALRM itself (this includes certain profilers)
250         """
251         chan.sendall('{c}\n'.format(c=cmd))
252         buf = ''
253         while not buf.endswith(prompt):
254             try:
255                 chunk = chan.recv(self.__MAX_RECV_BUF)
256                 if not chunk:
257                     break
258                 buf += chunk
259                 if chan.exit_status_ready():
260                     logger.error('Channel exit status ready')
261                     break
262             except socket.timeout:
263                 raise Exception('Socket timeout: {0}'.format(buf))
264         tmp = buf.replace(cmd.replace('\n', ''), '')
265         return tmp.replace(prompt, '')
266
267     @staticmethod
268     def interactive_terminal_close(chan):
269         """Close interactive terminal SSH channel.
270
271         :param: chan: SSH channel to be closed.
272         """
273         chan.close()
274
275     def scp(self, local_path, remote_path):
276         """Copy files from local_path to remote_path.
277
278         connect() method has to be called first!
279         """
280         logger.trace('SCP {0} to {1}:{2}'.format(
281             local_path, self._ssh.get_transport().getpeername(), remote_path))
282         # SCPCLient takes a paramiko transport as its only argument
283         scp = SCPClient(self._ssh.get_transport())
284         start = time()
285         scp.put(local_path, remote_path)
286         scp.close()
287         end = time()
288         logger.trace('SCP took {0} seconds'.format(end-start))
289
290
291 def exec_cmd(node, cmd, timeout=600, sudo=False):
292     """Convenience function to ssh/exec/return rc, out & err.
293
294     Returns (rc, stdout, stderr).
295     """
296     if node is None:
297         raise TypeError('Node parameter is None')
298     if cmd is None:
299         raise TypeError('Command parameter is None')
300     if len(cmd) == 0:
301         raise ValueError('Empty command parameter')
302
303     ssh = SSH()
304     try:
305         ssh.connect(node)
306     except Exception as err:
307         logger.error("Failed to connect to node" + str(err))
308         return None, None, None
309
310     try:
311         if not sudo:
312             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
313         else:
314             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
315                                                                timeout=timeout)
316     except Exception as err:
317         logger.error(err)
318         return None, None, None
319
320     return ret_code, stdout, stderr
321
322
323 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
324     """Convenience function to ssh/exec/return out & err.
325
326     Verifies that return code is zero.
327
328     Returns (stdout, stderr).
329     """
330     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
331     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
332                  format(cmd, stderr))
333     return stdout, stderr