Add Vagrantfile for local testing.
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import socket
14 import paramiko
15 from paramiko import RSAKey
16 import StringIO
17 from scp import SCPClient
18 from time import time
19 from robot.api import logger
20 from interruptingcow import timeout
21 from robot.utils.asserts import assert_equal, assert_not_equal
22
23 __all__ = ["exec_cmd", "exec_cmd_no_error"]
24
25 # TODO: load priv key
26
27
28 class SSH(object):
29
30     __MAX_RECV_BUF = 10*1024*1024
31     __existing_connections = {}
32
33     def __init__(self):
34         pass
35
36     def _node_hash(self, node):
37         return hash(frozenset([node['host'], node['port']]))
38
39     def connect(self, node):
40         """Connect to node prior to running exec_command or scp.
41
42         If there already is a connection to the node, this method reuses it.
43         """
44         node_hash = self._node_hash(node)
45         if node_hash in SSH.__existing_connections:
46             self._ssh = SSH.__existing_connections[node_hash]
47             logger.debug('reusing ssh: {0}'.format(self._ssh))
48         else:
49             start = time()
50             pkey = None
51             if 'priv_key' in node:
52                 pkey = RSAKey.from_private_key(
53                         StringIO.StringIO(node['priv_key']))
54
55             self._ssh = paramiko.SSHClient()
56             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
57
58             self._ssh.connect(node['host'], username=node['username'],
59                               password=node.get('password'), pkey=pkey,
60                               port=node['port'])
61
62             SSH.__existing_connections[node_hash] = self._ssh
63
64             logger.trace('connect took {} seconds'.format(time() - start))
65             logger.debug('new ssh: {0}'.format(self._ssh))
66
67         logger.debug('Connect peer: {0}'.
68                 format(self._ssh.get_transport().getpeername()))
69         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
70
71     def disconnect(self, node):
72         """Close SSH connection to the node.
73
74         :param node: The node to disconnect from.
75         :type node: dict
76         """
77         node_hash = self._node_hash(node)
78         if node_hash in SSH.__existing_connections:
79             ssh = SSH.__existing_connections.pop(node_hash)
80             ssh.close()
81
82     def exec_command(self, cmd, timeout=10):
83         """Execute SSH command on a new channel on the connected Node.
84
85         Returns (return_code, stdout, stderr).
86         """
87         logger.trace('exec_command on {0}: {1}'
88                      .format(self._ssh.get_transport().getpeername(), cmd))
89         start = time()
90         chan = self._ssh.get_transport().open_session()
91         if timeout is not None:
92             chan.settimeout(int(timeout))
93         chan.exec_command(cmd)
94         end = time()
95         logger.trace('exec_command on {0} took {1} seconds'.format(
96             self._ssh.get_transport().getpeername(), end-start))
97
98         stdout = ""
99         try:
100             while True:
101                 buf = chan.recv(self.__MAX_RECV_BUF)
102                 stdout += buf
103                 if not buf:
104                     break
105         except socket.timeout:
106             logger.error('Caught timeout exception, current contents '
107                          'of buffer: {0}'.format(stdout))
108             raise
109
110
111         stderr = ""
112         while True:
113             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
114             stderr += buf
115             if not buf:
116                 break
117
118         return_code = chan.recv_exit_status()
119         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
120
121         logger.trace('return RC {}'.format(return_code))
122         logger.trace('return STDOUT {}'.format(stdout))
123         logger.trace('return STDERR {}'.format(stderr))
124         return return_code, stdout, stderr
125
126     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
127         """Execute SSH command with sudo on a new channel on the connected Node.
128
129            :param cmd: Command to be executed.
130            :param cmd_input: Input redirected to the command.
131            :param timeout: Timeout.
132            :return: return_code, stdout, stderr
133
134            :Example:
135
136             >>> from ssh import SSH
137             >>> ssh = SSH()
138             >>> ssh.connect(node)
139             >>> #Execute command without input (sudo -S cmd)
140             >>> ssh.exec_command_sudo("ifconfig eth0 down")
141             >>> #Execute command with input (sudo -S cmd <<< "input")
142             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
143         """
144         if cmd_input is None:
145             command = 'sudo -S {c}'.format(c=cmd)
146         else:
147             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
148         return self.exec_command(command, timeout)
149
150     def interactive_terminal_open(self, time_out=10):
151         """Open interactive terminal on a new channel on the connected Node.
152
153            :param time_out: Timeout in seconds.
154            :return: SSH channel with opened terminal.
155
156            .. warning:: Interruptingcow is used here, and it uses
157                signal(SIGALRM) to let the operating system interrupt program
158                execution. This has the following limitations: Python signal
159                handlers only apply to the main thread, so you cannot use this
160                from other threads. You must not use this in a program that
161                uses SIGALRM itself (this includes certain profilers)
162         """
163         chan = self._ssh.get_transport().open_session()
164         chan.get_pty()
165         chan.invoke_shell()
166         chan.settimeout(int(time_out))
167
168         buf = ''
169         try:
170             with timeout(time_out, exception=RuntimeError):
171                 while not buf.endswith(':~$ '):
172                     if chan.recv_ready():
173                         buf = chan.recv(4096)
174         except RuntimeError:
175             raise Exception('Open interactive terminal timeout.')
176         return chan
177
178     def interactive_terminal_exec_command(self, chan, cmd, prompt,
179                                           time_out=10):
180         """Execute command on interactive terminal.
181
182            interactive_terminal_open() method has to be called first!
183
184            :param chan: SSH channel with opened terminal.
185            :param cmd: Command to be executed.
186            :param prompt: Command prompt, sequence of characters used to
187                indicate readiness to accept commands.
188            :param time_out: Timeout in seconds.
189            :return: Command output.
190
191            .. warning:: Interruptingcow is used here, and it uses
192                signal(SIGALRM) to let the operating system interrupt program
193                execution. This has the following limitations: Python signal
194                handlers only apply to the main thread, so you cannot use this
195                from other threads. You must not use this in a program that
196                uses SIGALRM itself (this includes certain profilers)
197         """
198         chan.sendall('{c}\n'.format(c=cmd))
199         buf = ''
200         try:
201             with timeout(time_out, exception=RuntimeError):
202                 while not buf.endswith(prompt):
203                     if chan.recv_ready():
204                         buf += chan.recv(4096)
205         except RuntimeError:
206             raise Exception("Exec '{c}' timeout.".format(c=cmd))
207         tmp = buf.replace(cmd.replace('\n', ''), '')
208         return tmp.replace(prompt, '')
209
210     def interactive_terminal_close(self, chan):
211         """Close interactive terminal SSH channel.
212
213            :param: chan: SSH channel to be closed.
214         """
215         chan.close()
216
217     def scp(self, local_path, remote_path):
218         """Copy files from local_path to remote_path.
219
220         connect() method has to be called first!
221         """
222         logger.trace('SCP {0} to {1}:{2}'.format(
223             local_path, self._ssh.get_transport().getpeername(), remote_path))
224         # SCPCLient takes a paramiko transport as its only argument
225         scp = SCPClient(self._ssh.get_transport())
226         start = time()
227         scp.put(local_path, remote_path)
228         scp.close()
229         end = time()
230         logger.trace('SCP took {0} seconds'.format(end-start))
231
232
233 def exec_cmd(node, cmd, timeout=None, sudo=False):
234     """Convenience function to ssh/exec/return rc, out & err.
235
236     Returns (rc, stdout, stderr).
237     """
238     if node is None:
239         raise TypeError('Node parameter is None')
240     if cmd is None:
241         raise TypeError('Command parameter is None')
242     if len(cmd) == 0:
243         raise ValueError('Empty command parameter')
244
245     ssh = SSH()
246     try:
247         ssh.connect(node)
248     except Exception, e:
249         logger.error("Failed to connect to node" + e)
250         return None
251
252     try:
253         if not sudo:
254             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
255         else:
256             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
257                                                                timeout=timeout)
258     except Exception, e:
259         logger.error(e)
260         return None
261
262     return (ret_code, stdout, stderr)
263
264 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
265     """Convenience function to ssh/exec/return out & err.
266     Verifies that return code is zero.
267
268     Returns (stdout, stderr).
269     """
270     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
271     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
272                  format(cmd, stderr))
273     return (stdout, stderr)