0009bde59bb916870fe2e640a2a480ee412e89d2
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :return: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node, attempts=5):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63         """
64         try:
65             self._node = node
66             node_hash = self._node_hash(node)
67             if node_hash in SSH.__existing_connections:
68                 self._ssh = SSH.__existing_connections[node_hash]
69                 logger.debug('reusing ssh: {0}'.format(self._ssh))
70             else:
71                 start = time()
72                 pkey = None
73                 if 'priv_key' in node:
74                     pkey = RSAKey.from_private_key(
75                         StringIO.StringIO(node['priv_key']))
76
77                 self._ssh = paramiko.SSHClient()
78                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
79
80                 self._ssh.connect(node['host'], username=node['username'],
81                                   password=node.get('password'), pkey=pkey,
82                                   port=node['port'])
83
84                 self._ssh.get_transport().set_keepalive(10)
85
86                 SSH.__existing_connections[node_hash] = self._ssh
87
88                 logger.trace('connect took {} seconds'.format(time() - start))
89                 logger.debug('new ssh: {0}'.format(self._ssh))
90
91             logger.debug('Connect peer: {0}'.
92                          format(self._ssh.get_transport().getpeername()))
93             logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
94         except:
95             if attempts > 0:
96                 self._reconnect(attempts-1)
97             else:
98                 raise
99
100     def disconnect(self, node):
101         """Close SSH connection to the node.
102
103         :param node: The node to disconnect from.
104         :type node: dict
105         """
106         node_hash = self._node_hash(node)
107         if node_hash in SSH.__existing_connections:
108             logger.debug('Disconnecting peer: {}, {}'.
109                          format(node['host'], node['port']))
110             ssh = SSH.__existing_connections.pop(node_hash)
111             ssh.close()
112
113     def _reconnect(self, attempts=0):
114         """Close the SSH connection and open it again."""
115
116         node = self._node
117         self.disconnect(node)
118         self.connect(node, attempts)
119         logger.debug('Reconnecting peer done: {}'.
120                      format(self._ssh.get_transport().getpeername()))
121
122     def exec_command(self, cmd, timeout=10):
123         """Execute SSH command on a new channel on the connected Node.
124
125         :param cmd: Command to run on the Node.
126         :param timeout: Maximal time in seconds to wait until the command is
127         done. If set to None then wait forever.
128         :type cmd: str
129         :type timeout: int
130         :return return_code, stdout, stderr
131         :rtype: tuple(int, str, str)
132         :raise SSHTimeout: If command is not finished in timeout time.
133         """
134         start = time()
135         stdout = StringIO.StringIO()
136         stderr = StringIO.StringIO()
137         try:
138             chan = self._ssh.get_transport().open_session(timeout=5)
139         except AttributeError:
140             self._reconnect()
141             chan = self._ssh.get_transport().open_session(timeout=5)
142         except SSHException:
143             self._reconnect()
144             chan = self._ssh.get_transport().open_session(timeout=5)
145         chan.settimeout(timeout)
146         logger.trace('exec_command on {0}: {1}'
147                      .format(self._ssh.get_transport().getpeername(), cmd))
148
149         chan.exec_command(cmd)
150         while not chan.exit_status_ready() and timeout is not None:
151             if chan.recv_ready():
152                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
153
154             if chan.recv_stderr_ready():
155                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
156
157             if time() - start > timeout:
158                 raise SSHTimeout(
159                     'Timeout exception.\n'
160                     'Current contents of stdout buffer: {0}\n'
161                     'Current contents of stderr buffer: {1}\n'
162                     .format(stdout.getvalue(), stderr.getvalue())
163                 )
164
165             sleep(0.1)
166         return_code = chan.recv_exit_status()
167
168         while chan.recv_ready():
169             stdout.write(chan.recv(self.__MAX_RECV_BUF))
170
171         while chan.recv_stderr_ready():
172             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
173
174         end = time()
175         logger.trace('exec_command on {0} took {1} seconds'.format(
176             self._ssh.get_transport().getpeername(), end-start))
177
178         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
179
180         logger.trace('return RC {}'.format(return_code))
181         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
182         logger.trace('return STDERR {}'.format(stderr.getvalue()))
183         return return_code, stdout.getvalue(), stderr.getvalue()
184
185     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
186         """Execute SSH command with sudo on a new channel on the connected Node.
187
188         :param cmd: Command to be executed.
189         :param cmd_input: Input redirected to the command.
190         :param timeout: Timeout.
191         :return: return_code, stdout, stderr
192
193         :Example:
194
195         >>> from ssh import SSH
196         >>> ssh = SSH()
197         >>> ssh.connect(node)
198         >>> # Execute command without input (sudo -S cmd)
199         >>> ssh.exec_command_sudo("ifconfig eth0 down")
200         >>> # Execute command with input (sudo -S cmd <<< "input")
201         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
202         """
203         if cmd_input is None:
204             command = 'sudo -S {c}'.format(c=cmd)
205         else:
206             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
207         return self.exec_command(command, timeout)
208
209     def interactive_terminal_open(self, time_out=30):
210         """Open interactive terminal on a new channel on the connected Node.
211
212         :param time_out: Timeout in seconds.
213         :return: SSH channel with opened terminal.
214
215         .. warning:: Interruptingcow is used here, and it uses
216            signal(SIGALRM) to let the operating system interrupt program
217            execution. This has the following limitations: Python signal
218            handlers only apply to the main thread, so you cannot use this
219            from other threads. You must not use this in a program that
220            uses SIGALRM itself (this includes certain profilers)
221         """
222         chan = self._ssh.get_transport().open_session()
223         chan.get_pty()
224         chan.invoke_shell()
225         chan.settimeout(int(time_out))
226         chan.set_combine_stderr(True)
227
228         buf = ''
229         while not buf.endswith((":~$ ", "~]$ ")):
230             try:
231                 chunk = chan.recv(self.__MAX_RECV_BUF)
232                 if not chunk:
233                     break
234                 buf += chunk
235                 if chan.exit_status_ready():
236                     logger.error('Channel exit status ready')
237                     break
238             except socket.timeout:
239                 raise Exception('Socket timeout: {0}'.format(buf))
240         return chan
241
242     def interactive_terminal_exec_command(self, chan, cmd, prompt):
243         """Execute command on interactive terminal.
244
245         interactive_terminal_open() method has to be called first!
246
247         :param chan: SSH channel with opened terminal.
248         :param cmd: Command to be executed.
249         :param prompt: Command prompt, sequence of characters used to
250         indicate readiness to accept commands.
251         :return: Command output.
252
253         .. warning:: Interruptingcow is used here, and it uses
254            signal(SIGALRM) to let the operating system interrupt program
255            execution. This has the following limitations: Python signal
256            handlers only apply to the main thread, so you cannot use this
257            from other threads. You must not use this in a program that
258            uses SIGALRM itself (this includes certain profilers)
259         """
260         chan.sendall('{c}\n'.format(c=cmd))
261         buf = ''
262         while not buf.endswith(prompt):
263             try:
264                 chunk = chan.recv(self.__MAX_RECV_BUF)
265                 if not chunk:
266                     break
267                 buf += chunk
268                 if chan.exit_status_ready():
269                     logger.error('Channel exit status ready')
270                     break
271             except socket.timeout:
272                 raise Exception('Socket timeout: {0}'.format(buf))
273         tmp = buf.replace(cmd.replace('\n', ''), '')
274         for p in prompt:
275             tmp.replace(p, '')
276         return tmp
277
278     @staticmethod
279     def interactive_terminal_close(chan):
280         """Close interactive terminal SSH channel.
281
282         :param: chan: SSH channel to be closed.
283         """
284         chan.close()
285
286     def scp(self, local_path, remote_path):
287         """Copy files from local_path to remote_path.
288
289         connect() method has to be called first!
290         """
291         logger.trace('SCP {0} to {1}:{2}'.format(
292             local_path, self._ssh.get_transport().getpeername(), remote_path))
293         # SCPCLient takes a paramiko transport as its only argument
294         scp = SCPClient(self._ssh.get_transport(), socket_timeout=10)
295         start = time()
296         scp.put(local_path, remote_path)
297         scp.close()
298         end = time()
299         logger.trace('SCP took {0} seconds'.format(end-start))
300
301
302 def exec_cmd(node, cmd, timeout=600, sudo=False):
303     """Convenience function to ssh/exec/return rc, out & err.
304
305     Returns (rc, stdout, stderr).
306     """
307     if node is None:
308         raise TypeError('Node parameter is None')
309     if cmd is None:
310         raise TypeError('Command parameter is None')
311     if len(cmd) == 0:
312         raise ValueError('Empty command parameter')
313
314     ssh = SSH()
315     try:
316         ssh.connect(node)
317     except SSHException as err:
318         logger.error("Failed to connect to node" + str(err))
319         return None, None, None
320
321     try:
322         if not sudo:
323             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
324         else:
325             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
326                                                                timeout=timeout)
327     except SSHException as err:
328         logger.error(err)
329         return None, None, None
330
331     return ret_code, stdout, stderr
332
333
334 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
335     """Convenience function to ssh/exec/return out & err.
336
337     Verifies that return code is zero.
338
339     Returns (stdout, stderr).
340     """
341     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
342     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
343                  format(cmd, stderr))
344     return stdout, stderr