Interactive terminal fixes
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 import StringIO
15 from time import time, sleep
16
17 import socket
18 import paramiko
19 from paramiko import RSAKey
20 from paramiko.ssh_exception import SSHException
21 from scp import SCPClient
22 from interruptingcow import timeout
23 from robot.api import logger
24 from robot.utils.asserts import assert_equal
25
26 __all__ = ["exec_cmd", "exec_cmd_no_error"]
27
28 # TODO: load priv key
29
30
31 class SSH(object):
32
33     __MAX_RECV_BUF = 10*1024*1024
34     __existing_connections = {}
35
36     def __init__(self):
37         self._ssh = None
38         self._node = None
39
40     @staticmethod
41     def _node_hash(node):
42         return hash(frozenset([node['host'], node['port']]))
43
44     def connect(self, node):
45         """Connect to node prior to running exec_command or scp.
46
47         If there already is a connection to the node, this method reuses it.
48         """
49         self._node = node
50         node_hash = self._node_hash(node)
51         if node_hash in SSH.__existing_connections:
52             self._ssh = SSH.__existing_connections[node_hash]
53             logger.debug('reusing ssh: {0}'.format(self._ssh))
54         else:
55             start = time()
56             pkey = None
57             if 'priv_key' in node:
58                 pkey = RSAKey.from_private_key(
59                         StringIO.StringIO(node['priv_key']))
60
61             self._ssh = paramiko.SSHClient()
62             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
63
64             self._ssh.connect(node['host'], username=node['username'],
65                               password=node.get('password'), pkey=pkey,
66                               port=node['port'])
67
68             self._ssh.get_transport().set_keepalive(10)
69
70             SSH.__existing_connections[node_hash] = self._ssh
71
72             logger.trace('connect took {} seconds'.format(time() - start))
73             logger.debug('new ssh: {0}'.format(self._ssh))
74
75         logger.debug('Connect peer: {0}'.
76                      format(self._ssh.get_transport().getpeername()))
77         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
78
79     def disconnect(self, node):
80         """Close SSH connection to the node.
81
82         :param node: The node to disconnect from.
83         :type node: dict
84         """
85         node_hash = self._node_hash(node)
86         if node_hash in SSH.__existing_connections:
87             logger.debug('Disconnecting peer: {}, {}'.
88                          format(node['host'], node['port']))
89             ssh = SSH.__existing_connections.pop(node_hash)
90             ssh.close()
91
92     def _reconnect(self):
93         node = self._node
94         self.disconnect(node)
95         self.connect(node)
96         logger.debug('Reconnecting peer done: {}'.
97                      format(self._ssh.get_transport().getpeername()))
98
99     def exec_command(self, cmd, timeout=10):
100         """Execute SSH command on a new channel on the connected Node.
101
102         :param cmd: Command to run on the Node.
103         :param timeout: Maximal time in seconds to wait while the command is
104         done. If is None then wait forever.
105         :type cmd: str
106         :type timeout: int
107         :return return_code, stdout, stderr
108         :rtype: tuple(int, str, str)
109         :raise socket.timeout: If command is not finished in timeout time.
110         """
111         start = time()
112         stdout = StringIO.StringIO()
113         stderr = StringIO.StringIO()
114         try:
115             chan = self._ssh.get_transport().open_session(timeout=5)
116         except AttributeError:
117             self._reconnect()
118             chan = self._ssh.get_transport().open_session(timeout=5)
119         except SSHException:
120             self._reconnect()
121             chan = self._ssh.get_transport().open_session(timeout=5)
122         chan.settimeout(timeout)
123         logger.trace('exec_command on {0}: {1}'
124                      .format(self._ssh.get_transport().getpeername(), cmd))
125
126         chan.exec_command(cmd)
127         while not chan.exit_status_ready() and timeout is not None:
128             if chan.recv_ready():
129                 stdout.write(chan.recv(self.__MAX_RECV_BUF))
130
131             if chan.recv_stderr_ready():
132                 stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
133
134             if time() - start > timeout:
135                 raise socket.timeout(
136                     'Timeout exception.\n'
137                     'Current contents of stdout buffer: {0}\n'
138                     'Current contents of stderr buffer: {1}\n'
139                     .format(stdout.getvalue(), stderr.getvalue())
140                 )
141
142             sleep(0.1)
143         return_code = chan.recv_exit_status()
144
145         while chan.recv_ready():
146             stdout.write(chan.recv(self.__MAX_RECV_BUF))
147
148         while chan.recv_stderr_ready():
149             stderr.write(chan.recv_stderr(self.__MAX_RECV_BUF))
150
151         end = time()
152         logger.trace('exec_command on {0} took {1} seconds'.format(
153             self._ssh.get_transport().getpeername(), end-start))
154
155         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
156
157         logger.trace('return RC {}'.format(return_code))
158         logger.trace('return STDOUT {}'.format(stdout.getvalue()))
159         logger.trace('return STDERR {}'.format(stderr.getvalue()))
160         return return_code, stdout.getvalue(), stderr.getvalue()
161
162     def exec_command_sudo(self, cmd, cmd_input=None, timeout=30):
163         """Execute SSH command with sudo on a new channel on the connected Node.
164
165         :param cmd: Command to be executed.
166         :param cmd_input: Input redirected to the command.
167         :param timeout: Timeout.
168         :return: return_code, stdout, stderr
169
170         :Example:
171
172         >>> from ssh import SSH
173         >>> ssh = SSH()
174         >>> ssh.connect(node)
175         >>> # Execute command without input (sudo -S cmd)
176         >>> ssh.exec_command_sudo("ifconfig eth0 down")
177         >>> # Execute command with input (sudo -S cmd <<< "input")
178         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
179         """
180         if cmd_input is None:
181             command = 'sudo -S {c}'.format(c=cmd)
182         else:
183             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
184         return self.exec_command(command, timeout)
185
186     def interactive_terminal_open(self, time_out=10):
187         """Open interactive terminal on a new channel on the connected Node.
188
189         :param time_out: Timeout in seconds.
190         :return: SSH channel with opened terminal.
191
192         .. warning:: Interruptingcow is used here, and it uses
193            signal(SIGALRM) to let the operating system interrupt program
194            execution. This has the following limitations: Python signal
195            handlers only apply to the main thread, so you cannot use this
196            from other threads. You must not use this in a program that
197            uses SIGALRM itself (this includes certain profilers)
198         """
199         chan = self._ssh.get_transport().open_session()
200         chan.get_pty()
201         chan.invoke_shell()
202         chan.settimeout(int(time_out))
203
204         buf = ''
205         try:
206             with timeout(time_out, exception=RuntimeError):
207                 while not buf.endswith(':~$ '):
208                     if chan.recv_ready():
209                         buf = chan.recv(4096)
210         except RuntimeError:
211             raise Exception('Open interactive terminal timeout.')
212         return chan
213
214     @staticmethod
215     def interactive_terminal_exec_command(chan, cmd, prompt,
216                                           time_out=30):
217         """Execute command on interactive terminal.
218
219         interactive_terminal_open() method has to be called first!
220
221         :param chan: SSH channel with opened terminal.
222         :param cmd: Command to be executed.
223         :param prompt: Command prompt, sequence of characters used to
224         indicate readiness to accept commands.
225         :param time_out: Timeout in seconds.
226         :return: Command output.
227
228         .. warning:: Interruptingcow is used here, and it uses
229            signal(SIGALRM) to let the operating system interrupt program
230            execution. This has the following limitations: Python signal
231            handlers only apply to the main thread, so you cannot use this
232            from other threads. You must not use this in a program that
233            uses SIGALRM itself (this includes certain profilers)
234         """
235         chan.sendall('{c}\n'.format(c=cmd))
236         buf = ''
237         try:
238             with timeout(time_out, exception=RuntimeError):
239                 while not buf.endswith(prompt):
240                     if chan.recv_ready():
241                         buf += chan.recv(4096)
242         except RuntimeError:
243             raise Exception("Exec '{c}' timeout.".format(c=cmd))
244         tmp = buf.replace(cmd.replace('\n', ''), '')
245         return tmp.replace(prompt, '')
246
247     @staticmethod
248     def interactive_terminal_close(chan):
249         """Close interactive terminal SSH channel.
250
251         :param: chan: SSH channel to be closed.
252         """
253         chan.close()
254
255     def scp(self, local_path, remote_path):
256         """Copy files from local_path to remote_path.
257
258         connect() method has to be called first!
259         """
260         logger.trace('SCP {0} to {1}:{2}'.format(
261             local_path, self._ssh.get_transport().getpeername(), remote_path))
262         # SCPCLient takes a paramiko transport as its only argument
263         scp = SCPClient(self._ssh.get_transport())
264         start = time()
265         scp.put(local_path, remote_path)
266         scp.close()
267         end = time()
268         logger.trace('SCP took {0} seconds'.format(end-start))
269
270
271 def exec_cmd(node, cmd, timeout=600, sudo=False):
272     """Convenience function to ssh/exec/return rc, out & err.
273
274     Returns (rc, stdout, stderr).
275     """
276     if node is None:
277         raise TypeError('Node parameter is None')
278     if cmd is None:
279         raise TypeError('Command parameter is None')
280     if len(cmd) == 0:
281         raise ValueError('Empty command parameter')
282
283     ssh = SSH()
284     try:
285         ssh.connect(node)
286     except Exception, e:
287         logger.error("Failed to connect to node" + str(e))
288         return None, None, None
289
290     try:
291         if not sudo:
292             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
293         else:
294             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
295                                                                timeout=timeout)
296     except Exception, e:
297         logger.error(e)
298         return None, None, None
299
300     return ret_code, stdout, stderr
301
302
303 def exec_cmd_no_error(node, cmd, timeout=600, sudo=False):
304     """Convenience function to ssh/exec/return out & err.
305
306     Verifies that return code is zero.
307
308     Returns (stdout, stderr).
309     """
310     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
311     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
312                  format(cmd, stderr))
313     return stdout, stderr