Update Honeycomb tests
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13
14 import StringIO
15 from time import time
16
17 import socket
18 import paramiko
19 from paramiko import RSAKey
20 from scp import SCPClient
21 from interruptingcow import timeout
22 from robot.api import logger
23 from robot.utils.asserts import assert_equal
24
25 __all__ = ["exec_cmd", "exec_cmd_no_error"]
26
27 # TODO: load priv key
28
29
30 class SSH(object):
31
32     __MAX_RECV_BUF = 10*1024*1024
33     __existing_connections = {}
34
35     def __init__(self):
36         self._ssh = None
37
38     @staticmethod
39     def _node_hash(node):
40         return hash(frozenset([node['host'], node['port']]))
41
42     def connect(self, node):
43         """Connect to node prior to running exec_command or scp.
44
45         If there already is a connection to the node, this method reuses it.
46         """
47         node_hash = self._node_hash(node)
48         if node_hash in SSH.__existing_connections:
49             self._ssh = SSH.__existing_connections[node_hash]
50             logger.debug('reusing ssh: {0}'.format(self._ssh))
51         else:
52             start = time()
53             pkey = None
54             if 'priv_key' in node:
55                 pkey = RSAKey.from_private_key(
56                         StringIO.StringIO(node['priv_key']))
57
58             self._ssh = paramiko.SSHClient()
59             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
60
61             self._ssh.connect(node['host'], username=node['username'],
62                               password=node.get('password'), pkey=pkey,
63                               port=node['port'])
64
65             SSH.__existing_connections[node_hash] = self._ssh
66
67             logger.trace('connect took {} seconds'.format(time() - start))
68             logger.debug('new ssh: {0}'.format(self._ssh))
69
70         logger.debug('Connect peer: {0}'.
71                      format(self._ssh.get_transport().getpeername()))
72         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
73
74     def disconnect(self, node):
75         """Close SSH connection to the node.
76
77         :param node: The node to disconnect from.
78         :type node: dict
79         """
80         node_hash = self._node_hash(node)
81         if node_hash in SSH.__existing_connections:
82             ssh = SSH.__existing_connections.pop(node_hash)
83             ssh.close()
84
85     def exec_command(self, cmd, timeout=10):
86         """Execute SSH command on a new channel on the connected Node.
87
88         Returns (return_code, stdout, stderr).
89         """
90         logger.trace('exec_command on {0}: {1}'
91                      .format(self._ssh.get_transport().getpeername(), cmd))
92         start = time()
93         chan = self._ssh.get_transport().open_session()
94         if timeout is not None:
95             chan.settimeout(int(timeout))
96         chan.exec_command(cmd)
97         end = time()
98         logger.trace('exec_command on {0} took {1} seconds'.format(
99             self._ssh.get_transport().getpeername(), end-start))
100
101         stdout = ""
102         try:
103             while True:
104                 buf = chan.recv(self.__MAX_RECV_BUF)
105                 stdout += buf
106                 if not buf:
107                     break
108         except socket.timeout:
109             logger.error('Caught timeout exception, current contents '
110                          'of buffer: {0}'.format(stdout))
111             raise
112
113         stderr = ""
114         while True:
115             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
116             stderr += buf
117             if not buf:
118                 break
119
120         return_code = chan.recv_exit_status()
121         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
122
123         logger.trace('return RC {}'.format(return_code))
124         logger.trace('return STDOUT {}'.format(stdout))
125         logger.trace('return STDERR {}'.format(stderr))
126         return return_code, stdout, stderr
127
128     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
129         """Execute SSH command with sudo on a new channel on the connected Node.
130
131         :param cmd: Command to be executed.
132         :param cmd_input: Input redirected to the command.
133         :param timeout: Timeout.
134         :return: return_code, stdout, stderr
135
136         :Example:
137
138         >>> from ssh import SSH
139         >>> ssh = SSH()
140         >>> ssh.connect(node)
141         >>> # Execute command without input (sudo -S cmd)
142         >>> ssh.exec_command_sudo("ifconfig eth0 down")
143         >>> # Execute command with input (sudo -S cmd <<< "input")
144         >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
145         """
146         if cmd_input is None:
147             command = 'sudo -S {c}'.format(c=cmd)
148         else:
149             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
150         return self.exec_command(command, timeout)
151
152     def interactive_terminal_open(self, time_out=10):
153         """Open interactive terminal on a new channel on the connected Node.
154
155         :param time_out: Timeout in seconds.
156         :return: SSH channel with opened terminal.
157
158         .. warning:: Interruptingcow is used here, and it uses
159            signal(SIGALRM) to let the operating system interrupt program
160            execution. This has the following limitations: Python signal
161            handlers only apply to the main thread, so you cannot use this
162            from other threads. You must not use this in a program that
163            uses SIGALRM itself (this includes certain profilers)
164         """
165         chan = self._ssh.get_transport().open_session()
166         chan.get_pty()
167         chan.invoke_shell()
168         chan.settimeout(int(time_out))
169
170         buf = ''
171         try:
172             with timeout(time_out, exception=RuntimeError):
173                 while not buf.endswith(':~$ '):
174                     if chan.recv_ready():
175                         buf = chan.recv(4096)
176         except RuntimeError:
177             raise Exception('Open interactive terminal timeout.')
178         return chan
179
180     @staticmethod
181     def interactive_terminal_exec_command(chan, cmd, prompt,
182                                           time_out=10):
183         """Execute command on interactive terminal.
184
185         interactive_terminal_open() method has to be called first!
186
187         :param chan: SSH channel with opened terminal.
188         :param cmd: Command to be executed.
189         :param prompt: Command prompt, sequence of characters used to
190         indicate readiness to accept commands.
191         :param time_out: Timeout in seconds.
192         :return: Command output.
193
194         .. warning:: Interruptingcow is used here, and it uses
195            signal(SIGALRM) to let the operating system interrupt program
196            execution. This has the following limitations: Python signal
197            handlers only apply to the main thread, so you cannot use this
198            from other threads. You must not use this in a program that
199            uses SIGALRM itself (this includes certain profilers)
200         """
201         chan.sendall('{c}\n'.format(c=cmd))
202         buf = ''
203         try:
204             with timeout(time_out, exception=RuntimeError):
205                 while not buf.endswith(prompt):
206                     if chan.recv_ready():
207                         buf += chan.recv(4096)
208         except RuntimeError:
209             raise Exception("Exec '{c}' timeout.".format(c=cmd))
210         tmp = buf.replace(cmd.replace('\n', ''), '')
211         return tmp.replace(prompt, '')
212
213     @staticmethod
214     def interactive_terminal_close(chan):
215         """Close interactive terminal SSH channel.
216
217         :param: chan: SSH channel to be closed.
218         """
219         chan.close()
220
221     def scp(self, local_path, remote_path):
222         """Copy files from local_path to remote_path.
223
224         connect() method has to be called first!
225         """
226         logger.trace('SCP {0} to {1}:{2}'.format(
227             local_path, self._ssh.get_transport().getpeername(), remote_path))
228         # SCPCLient takes a paramiko transport as its only argument
229         scp = SCPClient(self._ssh.get_transport())
230         start = time()
231         scp.put(local_path, remote_path)
232         scp.close()
233         end = time()
234         logger.trace('SCP took {0} seconds'.format(end-start))
235
236
237 def exec_cmd(node, cmd, timeout=None, sudo=False):
238     """Convenience function to ssh/exec/return rc, out & err.
239
240     Returns (rc, stdout, stderr).
241     """
242     if node is None:
243         raise TypeError('Node parameter is None')
244     if cmd is None:
245         raise TypeError('Command parameter is None')
246     if len(cmd) == 0:
247         raise ValueError('Empty command parameter')
248
249     ssh = SSH()
250     try:
251         ssh.connect(node)
252     except Exception, e:
253         logger.error("Failed to connect to node" + e)
254         return None
255
256     try:
257         if not sudo:
258             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
259         else:
260             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
261                                                                timeout=timeout)
262     except Exception, e:
263         logger.error(e)
264         return None
265
266     return ret_code, stdout, stderr
267
268
269 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
270     """Convenience function to ssh/exec/return out & err.
271
272     Verifies that return code is zero.
273
274     Returns (stdout, stderr).
275     """
276     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
277     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
278                  format(cmd, stderr))
279     return stdout, stderr