CSIT-339: Add Keywords for SMT
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :return: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63         """
64         self._node = node
65         node_hash = self._node_hash(node)
66         if node_hash in SSH.__existing_connections:
67             self._ssh = SSH.__existing_connections[node_hash]
68             logger.debug('reusing ssh: {0}'.format(self._ssh))
69         else:
70             start = time()
71             pkey = None
72             if 'priv_key' in node:
73                 pkey = RSAKey.from_private_key(
74                     StringIO.StringIO(node['priv_key']))
75
76             self._ssh = paramiko.SSHClient()
77             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
78
79             self._ssh.connect(node['host'], username=node['username'],
80                               password=node.get('password'), pkey=pkey,
81                               port=node['port'])
82
83             self._ssh.get_transport().set_keepalive(10)
84
85             SSH.__existing_connections[node_hash] = self._ssh
86
87             logger.trace('connect took {} seconds'.format(time() - start))
88             logger.debug('new ssh: {0}'.format(self._ssh))
89
90         logger.debug('Connect peer: {0}'.
91                      format(self._ssh.get_transport().getpeername()))
92         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
93
94     def disconnect(self, node):
95         """Close SSH connection to the node.
96
97         :param node: The node to disconnect from.
98         :type node: dict
99         """
100         node_hash = self._node_hash(node)
101         if node_hash in SSH.__existing_connections:
102             logger.debug('Disconnecting peer: {}, {}'.
103                          format(node['host'], node['port']))
104             ssh = SSH.__existing_connections.pop(node_hash)
105             ssh.close()
106
107     def _reconnect(self):
108         """Close the SSH connection and open it again."""
109
110         node = self._node
111         self.disconnect(node)
112         self.connect(node)
113         logger.debug('Reconnecting peer done: {}'.
114                      format(self._ssh.get_transport().getpeername()))
115
116     def exec_command(self, cmd, timeout=10):
117         """Execute SSH command on a new channel on the connected Node.
118
119         :param cmd: Command to run on the Node.
120         :param timeout: Maximal time in seconds to wait until the command is
121         done. If set to None then wait forever.
122         :type cmd: str
123         :type timeout: int
124         :return return_code, stdout, stderr
125         :rtype: tuple(int, str, str)
126         :raise SSHTimeout: If command is not finished in timeout time.
127         """
128         start = time()
129         stdout = StringIO.StringIO()
130         stderr = StringIO.StringIO()
131         try:
132             chan = self._ssh.get_transport().open_session(timeout=5)
133         except AttributeError:
134             self._reconnect()
135             chan = self._ssh.get_transport().open_session(timeout=5)
136         except SSHException:
137             self._reconnect()
138             chan = self._ssh.get_transport().open_session(timeout=5)
139         chan.settimeout(timeout)
140         logger.trace('exec_command on {0}: {1}'
141                      .format(self._ssh.get_transport().getpeername(), cmd))
142
143         chan.exec_command(cmd)
144         while not chan.exit_status_ready() and timeout is not None:
145             if chan.recv_ready():
146                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
147
148             if chan.recv_stderr_ready():
149                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
150
151             if time() - start > timeout:
152                 raise SSHTimeout(
153                     'Timeout exception.\n'
154                     'Current contents of stdout buffer: {0}\n'
155                     'Current contents of stderr buffer: {1}\n'
156                     .format(stdout.getvalue(), stderr.getvalue())
157                 )
158
159             sleep(0.1)
160         return_code = chan.recv_exit_status()
161
162         while chan.recv_ready():
163             stdout.write(chan.recv(self.__MAX_RECV_BUF))
164
165         while chan.recv_stderr_ready():
166             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
167
168         end = time()
169         logger.trace('exec_command on {0} took {1} seconds'.format(
170             self._ssh.get_transport().getpeername(), end-start))
171
172         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
173
174         logger.trace('return RC {}'.format(return_code))
175         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
176         logger.trace('return STDERR {}'.format(stderr.getvalue()))
177         return return_code, stdout.getvalue(), stderr.getvalue()
178
179     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
180         """Execute SSH command with sudo on a new channel on the connected Node.
181
182         :param cmd: Command to be executed.
183         :param cmd_input: Input redirected to the command.
184         :param timeout: Timeout.
185         :return: return_code, stdout, stderr
186
187         :Example:
188
189         >>> from ssh import SSH
190         >>> ssh = SSH()
191         >>> ssh.connect(node)
192         >>> # Execute command without input (sudo -S cmd)
193         >>> ssh.exec_command_sudo("ifconfig eth0 down")
194         >>> # Execute command with input (sudo -S cmd <<< "input")
195         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
196         """
197         if cmd_input is None:
198             command = 'sudo -S {c}'.format(c=cmd)
199         else:
200             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
201         return self.exec_command(command, timeout)
202
203     def interactive_terminal_open(self, time_out=30):
204         """Open interactive terminal on a new channel on the connected Node.
205
206         :param time_out: Timeout in seconds.
207         :return: SSH channel with opened terminal.
208
209         .. warning:: Interruptingcow is used here, and it uses
210            signal(SIGALRM) to let the operating system interrupt program
211            execution. This has the following limitations: Python signal
212            handlers only apply to the main thread, so you cannot use this
213            from other threads. You must not use this in a program that
214            uses SIGALRM itself (this includes certain profilers)
215         """
216         chan = self._ssh.get_transport().open_session()
217         chan.get_pty()
218         chan.invoke_shell()
219         chan.settimeout(int(time_out))
220         chan.set_combine_stderr(True)
221
222         buf = ''
223         while not buf.endswith((":~$ ", "~]$ ")):
224             try:
225                 chunk = chan.recv(self.__MAX_RECV_BUF)
226                 if not chunk:
227                     break
228                 buf += chunk
229                 if chan.exit_status_ready():
230                     logger.error('Channel exit status ready')
231                     break
232             except socket.timeout:
233                 raise Exception('Socket timeout: {0}'.format(buf))
234         return chan
235
236     def interactive_terminal_exec_command(self, chan, cmd, prompt):
237         """Execute command on interactive terminal.
238
239         interactive_terminal_open() method has to be called first!
240
241         :param chan: SSH channel with opened terminal.
242         :param cmd: Command to be executed.
243         :param prompt: Command prompt, sequence of characters used to
244         indicate readiness to accept commands.
245         :return: Command output.
246
247         .. warning:: Interruptingcow is used here, and it uses
248            signal(SIGALRM) to let the operating system interrupt program
249            execution. This has the following limitations: Python signal
250            handlers only apply to the main thread, so you cannot use this
251            from other threads. You must not use this in a program that
252            uses SIGALRM itself (this includes certain profilers)
253         """
254         chan.sendall('{c}\n'.format(c=cmd))
255         buf = ''
256         while not buf.endswith(prompt):
257             try:
258                 chunk = chan.recv(self.__MAX_RECV_BUF)
259                 if not chunk:
260                     break
261                 buf += chunk
262                 if chan.exit_status_ready():
263                     logger.error('Channel exit status ready')
264                     break
265             except socket.timeout:
266                 raise Exception('Socket timeout: {0}'.format(buf))
267         tmp = buf.replace(cmd.replace('\n', ''), '')
268         for p in prompt:
269             tmp.replace(p, '')
270         return tmp
271
272     @staticmethod
273     def interactive_terminal_close(chan):
274         """Close interactive terminal SSH channel.
275
276         :param: chan: SSH channel to be closed.
277         """
278         chan.close()
279
280     def scp(self, local_path, remote_path):
281         """Copy files from local_path to remote_path.
282
283         connect() method has to be called first!
284         """
285         logger.trace('SCP {0} to {1}:{2}'.format(
286             local_path, self._ssh.get_transport().getpeername(), remote_path))
287         # SCPCLient takes a paramiko transport as its only argument
288         scp = SCPClient(self._ssh.get_transport())
289         start = time()
290         scp.put(local_path, remote_path)
291         scp.close()
292         end = time()
293         logger.trace('SCP took {0} seconds'.format(end-start))
294
295
296 def exec_cmd(node, cmd, timeout=600, sudo=False):
297     """Convenience function to ssh/exec/return rc, out & err.
298
299     Returns (rc, stdout, stderr).
300     """
301     if node is None:
302         raise TypeError('Node parameter is None')
303     if cmd is None:
304         raise TypeError('Command parameter is None')
305     if len(cmd) == 0:
306         raise ValueError('Empty command parameter')
307
308     ssh = SSH()
309     try:
310         ssh.connect(node)
311     except SSHException as err:
312         logger.error("Failed to connect to node" + str(err))
313         return None, None, None
314
315     try:
316         if not sudo:
317             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
318         else:
319             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
320                                                                timeout=timeout)
321     except SSHException as err:
322         logger.error(err)
323         return None, None, None
324
325     return ret_code, stdout, stderr
326
327
328 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
329     """Convenience function to ssh/exec/return out & err.
330
331     Verifies that return code is zero.
332
333     Returns (stdout, stderr).
334     """
335     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
336     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
337                  format(cmd, stderr))
338     return stdout, stderr