4bed173bbb359005c71294c5d427a8e1328d3dcd
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2018 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 """Library for SSH connection management."""
15
16 import StringIO
17 from time import time, sleep
18
19 import socket
20 import paramiko
21 from paramiko import RSAKey
22 from paramiko.ssh_exception import SSHException
23 from scp import SCPClient
24 from robot.api import logger
25 from robot.utils.asserts import assert_equal
26
27 __all__ = ["exec_cmd", "exec_cmd_no_error"]
28
29 # TODO: load priv key
30
31
32 class SSHTimeout(Exception):
33     """This exception is raised when a timeout occurs."""
34     pass
35
36
37 class SSH(object):
38     """Contains methods for managing and using SSH connections."""
39
40     __MAX_RECV_BUF = 10*1024*1024
41     __existing_connections = {}
42
43     def __init__(self):
44         self._ssh = None
45         self._node = None
46
47     @staticmethod
48     def _node_hash(node):
49         """Get IP address and port hash from node dictionary.
50
51         :param node: Node in topology.
52         :type node: dict
53         :returns: IP address and port for the specified node.
54         :rtype: int
55         """
56
57         return hash(frozenset([node['host'], node['port']]))
58
59     def connect(self, node, attempts=5):
60         """Connect to node prior to running exec_command or scp.
61
62         If there already is a connection to the node, this method reuses it.
63
64         :param node: Node in topology.
65         :param attempts: Number of reconnect attempts.
66         :type node: dict
67         :type attempts: int
68         :raises IOError: If cannot connect to host.
69         """
70         self._node = node
71         node_hash = self._node_hash(node)
72         if node_hash in SSH.__existing_connections:
73             self._ssh = SSH.__existing_connections[node_hash]
74             if self._ssh.get_transport().is_active():
75                 logger.debug('Reusing SSH: {ssh}'.format(ssh=self._ssh))
76             else:
77                 if attempts > 0:
78                     self._reconnect(attempts-1)
79                 else:
80                     raise IOError('Cannot connect to {host}'.
81                                   format(host=node['host']))
82         else:
83             try:
84                 start = time()
85                 pkey = None
86                 if 'priv_key' in node:
87                     pkey = RSAKey.from_private_key(
88                         StringIO.StringIO(node['priv_key']))
89
90                 self._ssh = paramiko.SSHClient()
91                 self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
92
93                 self._ssh.connect(node['host'], username=node['username'],
94                                   password=node.get('password'), pkey=pkey,
95                                   port=node['port'])
96
97                 self._ssh.get_transport().set_keepalive(10)
98
99                 SSH.__existing_connections[node_hash] = self._ssh
100                 logger.debug('New SSH to {peer} took {total} seconds: {ssh}'.
101                              format(
102                                  peer=self._ssh.get_transport().getpeername(),
103                                  total=(time() - start),
104                                  ssh=self._ssh))
105             except SSHException:
106                 raise IOError('Cannot connect to {host}'.
107                               format(host=node['host']))
108
109     def disconnect(self, node):
110         """Close SSH connection to the node.
111
112         :param node: The node to disconnect from.
113         :type node: dict
114         """
115         node_hash = self._node_hash(node)
116         if node_hash in SSH.__existing_connections:
117             logger.debug('Disconnecting peer: {host}, {port}'.
118                          format(host=node['host'], port=node['port']))
119             ssh = SSH.__existing_connections.pop(node_hash)
120             ssh.close()
121
122     def _reconnect(self, attempts=0):
123         """Close the SSH connection and open it again.
124
125         :param attempts: Number of reconnect attempts.
126         :type attempts: int
127         """
128         node = self._node
129         self.disconnect(node)
130         self.connect(node, attempts)
131         logger.debug('Reconnecting peer done: {host}, {port}'.
132                      format(host=node['host'], port=node['port']))
133
134     def exec_command(self, cmd, timeout=10):
135         """Execute SSH command on a new channel on the connected Node.
136
137         :param cmd: Command to run on the Node.
138         :param timeout: Maximal time in seconds to wait until the command is
139         done. If set to None then wait forever.
140         :type cmd: str
141         :type timeout: int
142         :return return_code, stdout, stderr
143         :rtype: tuple(int, str, str)
144         :raise SSHTimeout: If command is not finished in timeout time.
145         """
146         stdout = StringIO.StringIO()
147         stderr = StringIO.StringIO()
148         try:
149             chan = self._ssh.get_transport().open_session(timeout=5)
150             peer = self._ssh.get_transport().getpeername()
151         except AttributeError:
152             self._reconnect()
153             chan = self._ssh.get_transport().open_session(timeout=5)
154             peer = self._ssh.get_transport().getpeername()
155         except SSHException:
156             self._reconnect()
157             chan = self._ssh.get_transport().open_session(timeout=5)
158             peer = self._ssh.get_transport().getpeername()
159         chan.settimeout(timeout)
160
161         logger.trace('exec_command on {peer} with timeout {timeout}: {cmd}'
162                      .format(peer=peer, timeout=timeout, cmd=cmd))
163
164         start = time()
165         chan.exec_command(cmd)
166         while not chan.exit_status_ready() and timeout is not None:
167             if chan.recv_ready():
168                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
169
170             if chan.recv_stderr_ready():
171                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
172
173             if time() - start > timeout:
174                 raise SSHTimeout(
175                     'Timeout exception during execution of command: {cmd}\n'
176                     'Current contents of stdout buffer: {stdout}\n'
177                     'Current contents of stderr buffer: {stderr}\n'
178                     .format(cmd=cmd, stdout=stdout.getvalue(),
179                             stderr=stderr.getvalue())
180                 )
181
182             sleep(0.1)
183         return_code = chan.recv_exit_status()
184
185         while chan.recv_ready():
186             stdout.write(chan.recv(self.__MAX_RECV_BUF))
187
188         while chan.recv_stderr_ready():
189             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
190
191         end = time()
192         logger.trace('exec_command on {peer} took {total} seconds'.
193                      format(peer=peer, total=end-start))
194
195         logger.trace('return RC {rc}'.format(rc=return_code))
196         logger.trace('return STDOUT {stdout}'.format(stdout=stdout.getvalue()))
197         logger.trace('return STDERR {stderr}'.format(stderr=stderr.getvalue()))
198         return return_code, stdout.getvalue(), stderr.getvalue()
199
200     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
201         """Execute SSH command with sudo on a new channel on the connected Node.
202
203         :param cmd: Command to be executed.
204         :param cmd_input: Input redirected to the command.
205         :param timeout: Timeout.
206         :returns: return_code, stdout, stderr
207
208         :Example:
209
210         >>> from ssh import SSH
211         >>> ssh = SSH()
212         >>> ssh.connect(node)
213         >>> # Execute command without input (sudo -S cmd)
214         >>> ssh.exec_command_sudo("ifconfig eth0 down")
215         >>> # Execute command with input (sudo -S cmd <<< "input")
216         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
217         """
218         if cmd_input is None:
219             command = 'sudo -S {c}'.format(c=cmd)
220         else:
221             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
222         return self.exec_command(command, timeout)
223
224     def exec_command_lxc(self, lxc_cmd, lxc_name, lxc_params='', sudo=True,
225                          timeout=30):
226         """Execute command in LXC on a new SSH channel on the connected Node.
227
228         :param lxc_cmd: Command to be executed.
229         :param lxc_name: LXC name.
230         :param lxc_params: Additional parameters for LXC attach.
231         :param sudo: Run in privileged LXC mode. Default: privileged
232         :param timeout: Timeout.
233         :type lxc_cmd: str
234         :type lxc_name: str
235         :type lxc_params: str
236         :type sudo: bool
237         :type timeout: int
238         :returns: return_code, stdout, stderr
239         """
240         command = "lxc-attach {p} --name {n} -- /bin/sh -c '{c}'"\
241             .format(p=lxc_params, n=lxc_name, c=lxc_cmd)
242
243         if sudo:
244             command = 'sudo -S {c}'.format(c=command)
245         return self.exec_command(command, timeout)
246
247     def interactive_terminal_open(self, time_out=30):
248         """Open interactive terminal on a new channel on the connected Node.
249
250         :param time_out: Timeout in seconds.
251         :returns: SSH channel with opened terminal.
252
253         .. warning:: Interruptingcow is used here, and it uses
254            signal(SIGALRM) to let the operating system interrupt program
255            execution. This has the following limitations: Python signal
256            handlers only apply to the main thread, so you cannot use this
257            from other threads. You must not use this in a program that
258            uses SIGALRM itself (this includes certain profilers)
259         """
260         chan = self._ssh.get_transport().open_session()
261         chan.get_pty()
262         chan.invoke_shell()
263         chan.settimeout(int(time_out))
264         chan.set_combine_stderr(True)
265
266         buf = ''
267         while not buf.endswith((":~$ ", "~]$ ", "~]# ")):
268             try:
269                 chunk = chan.recv(self.__MAX_RECV_BUF)
270                 if not chunk:
271                     break
272                 buf += chunk
273                 if chan.exit_status_ready():
274                     logger.error('Channel exit status ready')
275                     break
276             except socket.timeout:
277                 logger.error('Socket timeout: {0}'.format(buf))
278                 raise Exception('Socket timeout: {0}'.format(buf))
279         return chan
280
281     def interactive_terminal_exec_command(self, chan, cmd, prompt):
282         """Execute command on interactive terminal.
283
284         interactive_terminal_open() method has to be called first!
285
286         :param chan: SSH channel with opened terminal.
287         :param cmd: Command to be executed.
288         :param prompt: Command prompt, sequence of characters used to
289         indicate readiness to accept commands.
290         :returns: Command output.
291
292         .. warning:: Interruptingcow is used here, and it uses
293            signal(SIGALRM) to let the operating system interrupt program
294            execution. This has the following limitations: Python signal
295            handlers only apply to the main thread, so you cannot use this
296            from other threads. You must not use this in a program that
297            uses SIGALRM itself (this includes certain profilers)
298         """
299         chan.sendall('{c}\n'.format(c=cmd))
300         buf = ''
301         while not buf.endswith(prompt):
302             try:
303                 chunk = chan.recv(self.__MAX_RECV_BUF)
304                 if not chunk:
305                     break
306                 buf += chunk
307                 if chan.exit_status_ready():
308                     logger.error('Channel exit status ready')
309                     break
310             except socket.timeout:
311                 logger.error('Socket timeout during execution of command: '
312                              '{0}\nBuffer content:\n{1}'.format(cmd, buf))
313                 raise Exception('Socket timeout during execution of command: '
314                                 '{0}\nBuffer content:\n{1}'.format(cmd, buf))
315         tmp = buf.replace(cmd.replace('\n', ''), '')
316         for item in prompt:
317             tmp.replace(item, '')
318         return tmp
319
320     @staticmethod
321     def interactive_terminal_close(chan):
322         """Close interactive terminal SSH channel.
323
324         :param: chan: SSH channel to be closed.
325         """
326         chan.close()
327
328     def scp(self, local_path, remote_path, get=False, timeout=30):
329         """Copy files from local_path to remote_path or vice versa.
330
331         connect() method has to be called first!
332
333         :param local_path: Path to local file that should be uploaded; or
334         path where to save remote file.
335         :param remote_path: Remote path where to place uploaded file; or
336         path to remote file which should be downloaded.
337         :param get: scp operation to perform. Default is put.
338         :param timeout: Timeout value in seconds.
339         :type local_path: str
340         :type remote_path: str
341         :type get: bool
342         :type timeout: int
343         """
344         if not get:
345             logger.trace('SCP {0} to {1}:{2}'.format(
346                 local_path, self._ssh.get_transport().getpeername(),
347                 remote_path))
348         else:
349             logger.trace('SCP {0}:{1} to {2}'.format(
350                 self._ssh.get_transport().getpeername(), remote_path,
351                 local_path))
352         # SCPCLient takes a paramiko transport as its only argument
353         scp = SCPClient(self._ssh.get_transport(), socket_timeout=timeout)
354         start = time()
355         if not get:
356             scp.put(local_path, remote_path)
357         else:
358             scp.get(remote_path, local_path)
359         scp.close()
360         end = time()
361         logger.trace('SCP took {0} seconds'.format(end-start))
362
363
364 def exec_cmd(node, cmd, timeout=600, sudo=False):
365     """Convenience function to ssh/exec/return rc, out & err.
366
367     Returns (rc, stdout, stderr).
368     """
369     if node is None:
370         raise TypeError('Node parameter is None')
371     if cmd is None:
372         raise TypeError('Command parameter is None')
373     if len(cmd) == 0:
374         raise ValueError('Empty command parameter')
375
376     ssh = SSH()
377     try:
378         ssh.connect(node)
379     except SSHException as err:
380         logger.error("Failed to connect to node" + str(err))
381         return None, None, None
382
383     try:
384         if not sudo:
385             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
386         else:
387             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
388                                                                timeout=timeout)
389     except SSHException as err:
390         logger.error(err)
391         return None, None, None
392
393     return ret_code, stdout, stderr
394
395
396 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
397     """Convenience function to ssh/exec/return out & err.
398
399     Verifies that return code is zero.
400
401     Returns (stdout, stderr).
402     """
403     (ret_code, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
404     assert_equal(ret_code, 0, 'Command execution failed: "{}"\n{}'.
405                  format(cmd, stderr))
406     return stdout, stderr