Fixed SSH exec_command timeout
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from paramiko import RSAKey
15 import StringIO
16 from scp import SCPClient
17 from time import time
18 from robot.api import logger
19 from interruptingcow import timeout
20 from robot.utils.asserts import assert_equal, assert_not_equal
21
22 __all__ = ["exec_cmd", "exec_cmd_no_error"]
23
24 # TODO: load priv key
25
26
27 class SSH(object):
28
29     __MAX_RECV_BUF = 10*1024*1024
30     __existing_connections = {}
31
32     def __init__(self):
33         pass
34
35     def _node_hash(self, node):
36         return hash(frozenset([node['host'], node['port']]))
37
38     def connect(self, node):
39         """Connect to node prior to running exec_command or scp.
40
41         If there already is a connection to the node, this method reuses it.
42         """
43         node_hash = self._node_hash(node)
44         if node_hash in SSH.__existing_connections:
45             self._ssh = SSH.__existing_connections[node_hash]
46             logger.debug('reusing ssh: {0}'.format(self._ssh))
47         else:
48             start = time()
49             pkey = None
50             if 'priv_key' in node:
51                 pkey = RSAKey.from_private_key(
52                         StringIO.StringIO(node['priv_key']))
53
54             self._ssh = paramiko.SSHClient()
55             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
56
57             self._ssh.connect(node['host'], username=node['username'],
58                               password=node.get('password'), pkey=pkey,
59                               port=node['port'])
60
61             SSH.__existing_connections[node_hash] = self._ssh
62
63             logger.trace('connect took {} seconds'.format(time() - start))
64             logger.debug('new ssh: {0}'.format(self._ssh))
65
66         logger.debug('Connect peer: {0}'.
67                 format(self._ssh.get_transport().getpeername()))
68         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
69
70     def exec_command(self, cmd, timeout=10):
71         """Execute SSH command on a new channel on the connected Node.
72
73         Returns (return_code, stdout, stderr).
74         """
75         logger.trace('exec_command on {0}: {1}'
76                      .format(self._ssh.get_transport().getpeername(), cmd))
77         start = time()
78         chan = self._ssh.get_transport().open_session()
79         if timeout is not None:
80             chan.settimeout(int(timeout))
81         chan.exec_command(cmd)
82         end = time()
83         logger.trace('exec_command on {0} took {1} seconds'.format(
84             self._ssh.get_transport().getpeername(), end-start))
85
86         stdout = ""
87         while True:
88             buf = chan.recv(self.__MAX_RECV_BUF)
89             stdout += buf
90             if not buf:
91                 break
92
93         stderr = ""
94         while True:
95             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
96             stderr += buf
97             if not buf:
98                 break
99
100         return_code = chan.recv_exit_status()
101         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
102
103         logger.trace('return RC {}'.format(return_code))
104         logger.trace('return STDOUT {}'.format(stdout))
105         logger.trace('return STDERR {}'.format(stderr))
106         return return_code, stdout, stderr
107
108     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
109         """Execute SSH command with sudo on a new channel on the connected Node.
110
111            :param cmd: Command to be executed.
112            :param cmd_input: Input redirected to the command.
113            :param timeout: Timeout.
114            :return: return_code, stdout, stderr
115
116            :Example:
117
118             >>> from ssh import SSH
119             >>> ssh = SSH()
120             >>> ssh.connect(node)
121             >>> #Execute command without input (sudo -S cmd)
122             >>> ssh.exec_command_sudo("ifconfig eth0 down")
123             >>> #Execute command with input (sudo -S cmd <<< "input")
124             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
125         """
126         if cmd_input is None:
127             command = 'sudo -S {c}'.format(c=cmd)
128         else:
129             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
130         return self.exec_command(command, timeout)
131
132     def interactive_terminal_open(self, time_out=10):
133         """Open interactive terminal on a new channel on the connected Node.
134
135            :param time_out: Timeout in seconds.
136            :return: SSH channel with opened terminal.
137
138            .. warning:: Interruptingcow is used here, and it uses
139                signal(SIGALRM) to let the operating system interrupt program
140                execution. This has the following limitations: Python signal
141                handlers only apply to the main thread, so you cannot use this
142                from other threads. You must not use this in a program that
143                uses SIGALRM itself (this includes certain profilers)
144         """
145         chan = self._ssh.get_transport().open_session()
146         chan.get_pty()
147         chan.invoke_shell()
148         chan.settimeout(int(time_out))
149
150         buf = ''
151         try:
152             with timeout(time_out, exception=RuntimeError):
153                 while not buf.endswith(':~$ '):
154                     if chan.recv_ready():
155                         buf = chan.recv(4096)
156         except RuntimeError:
157             raise Exception('Open interactive terminal timeout.')
158         return chan
159
160     def interactive_terminal_exec_command(self, chan, cmd, prompt,
161                                           time_out=10):
162         """Execute command on interactive terminal.
163
164            interactive_terminal_open() method has to be called first!
165
166            :param chan: SSH channel with opened terminal.
167            :param cmd: Command to be executed.
168            :param prompt: Command prompt, sequence of characters used to
169                indicate readiness to accept commands.
170            :param time_out: Timeout in seconds.
171            :return: Command output.
172
173            .. warning:: Interruptingcow is used here, and it uses
174                signal(SIGALRM) to let the operating system interrupt program
175                execution. This has the following limitations: Python signal
176                handlers only apply to the main thread, so you cannot use this
177                from other threads. You must not use this in a program that
178                uses SIGALRM itself (this includes certain profilers)
179         """
180         chan.sendall('{c}\n'.format(c=cmd))
181         buf = ''
182         try:
183             with timeout(time_out, exception=RuntimeError):
184                 while not buf.endswith(prompt):
185                     if chan.recv_ready():
186                         buf += chan.recv(4096)
187         except RuntimeError:
188             raise Exception("Exec '{c}' timeout.".format(c=cmd))
189         tmp = buf.replace(cmd.replace('\n', ''), '')
190         return tmp.replace(prompt, '')
191
192     def interactive_terminal_close(self, chan):
193         """Close interactive terminal SSH channel.
194
195            :param: chan: SSH channel to be closed.
196         """
197         chan.close()
198
199     def scp(self, local_path, remote_path):
200         """Copy files from local_path to remote_path.
201
202         connect() method has to be called first!
203         """
204         logger.trace('SCP {0} to {1}:{2}'.format(
205             local_path, self._ssh.get_transport().getpeername(), remote_path))
206         # SCPCLient takes a paramiko transport as its only argument
207         scp = SCPClient(self._ssh.get_transport())
208         start = time()
209         scp.put(local_path, remote_path)
210         scp.close()
211         end = time()
212         logger.trace('SCP took {0} seconds'.format(end-start))
213
214
215 def exec_cmd(node, cmd, timeout=None, sudo=False):
216     """Convenience function to ssh/exec/return rc, out & err.
217
218     Returns (rc, stdout, stderr).
219     """
220     if node is None:
221         raise TypeError('Node parameter is None')
222     if cmd is None:
223         raise TypeError('Command parameter is None')
224     if len(cmd) == 0:
225         raise ValueError('Empty command parameter')
226
227     ssh = SSH()
228     try:
229         ssh.connect(node)
230     except Exception, e:
231         logger.error("Failed to connect to node" + e)
232         return None
233
234     try:
235         if not sudo:
236             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
237         else:
238             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
239                                                                timeout=timeout)
240     except Exception, e:
241         logger.error(e)
242         return None
243
244     return (ret_code, stdout, stderr)
245
246 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
247     """Convenience function to ssh/exec/return out & err.
248     Verifies that return code is zero.
249
250     Returns (stdout, stderr).
251     """
252     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
253     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
254                  format(cmd, stderr))
255     return (stdout, stderr)