a94eec4e914944628a4962654e6951f8f036a10c
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from paramiko import RSAKey
15 import StringIO
16 from scp import SCPClient
17 from time import time
18 from robot.api import logger
19 from interruptingcow import timeout
20 from robot.utils.asserts import assert_equal, assert_not_equal
21
22 __all__ = ["exec_cmd", "exec_cmd_no_error"]
23
24 # TODO: load priv key
25
26
27 class SSH(object):
28
29     __MAX_RECV_BUF = 10*1024*1024
30     __existing_connections = {}
31
32     def __init__(self):
33         pass
34
35     def _node_hash(self, node):
36         return hash(frozenset([node['host'], node['port']]))
37
38     def connect(self, node):
39         """Connect to node prior to running exec_command or scp.
40
41         If there already is a connection to the node, this method reuses it.
42         """
43         node_hash = self._node_hash(node)
44         if node_hash in SSH.__existing_connections:
45             self._ssh = SSH.__existing_connections[node_hash]
46             logger.debug('reusing ssh: {0}'.format(self._ssh))
47         else:
48             start = time()
49             pkey = None
50             if 'priv_key' in node:
51                 pkey = RSAKey.from_private_key(
52                         StringIO.StringIO(node['priv_key']))
53
54             self._ssh = paramiko.SSHClient()
55             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
56
57             self._ssh.connect(node['host'], username=node['username'],
58                               password=node.get('password'), pkey=pkey,
59                               port=node['port'])
60
61             SSH.__existing_connections[node_hash] = self._ssh
62
63             logger.trace('connect took {} seconds'.format(time() - start))
64             logger.debug('new ssh: {0}'.format(self._ssh))
65
66         logger.debug('Connect peer: {0}'.
67                 format(self._ssh.get_transport().getpeername()))
68         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
69
70     def disconnect(self, node):
71         """Close SSH connection to the node.
72
73         :param node: The node to disconnect from.
74         :type node: dict
75         """
76         node_hash = self._node_hash(node)
77         if node_hash in SSH.__existing_connections:
78             ssh = SSH.__existing_connections.pop(node_hash)
79             ssh.close()
80
81     def exec_command(self, cmd, timeout=10):
82         """Execute SSH command on a new channel on the connected Node.
83
84         Returns (return_code, stdout, stderr).
85         """
86         logger.trace('exec_command on {0}: {1}'
87                      .format(self._ssh.get_transport().getpeername(), cmd))
88         start = time()
89         chan = self._ssh.get_transport().open_session()
90         if timeout is not None:
91             chan.settimeout(int(timeout))
92         chan.exec_command(cmd)
93         end = time()
94         logger.trace('exec_command on {0} took {1} seconds'.format(
95             self._ssh.get_transport().getpeername(), end-start))
96
97         stdout = ""
98         while True:
99             buf = chan.recv(self.__MAX_RECV_BUF)
100             stdout += buf
101             if not buf:
102                 break
103
104         stderr = ""
105         while True:
106             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
107             stderr += buf
108             if not buf:
109                 break
110
111         return_code = chan.recv_exit_status()
112         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
113
114         logger.trace('return RC {}'.format(return_code))
115         logger.trace('return STDOUT {}'.format(stdout))
116         logger.trace('return STDERR {}'.format(stderr))
117         return return_code, stdout, stderr
118
119     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
120         """Execute SSH command with sudo on a new channel on the connected Node.
121
122            :param cmd: Command to be executed.
123            :param cmd_input: Input redirected to the command.
124            :param timeout: Timeout.
125            :return: return_code, stdout, stderr
126
127            :Example:
128
129             >>> from ssh import SSH
130             >>> ssh = SSH()
131             >>> ssh.connect(node)
132             >>> #Execute command without input (sudo -S cmd)
133             >>> ssh.exec_command_sudo("ifconfig eth0 down")
134             >>> #Execute command with input (sudo -S cmd <<< "input")
135             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
136         """
137         if cmd_input is None:
138             command = 'sudo -S {c}'.format(c=cmd)
139         else:
140             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
141         return self.exec_command(command, timeout)
142
143     def interactive_terminal_open(self, time_out=10):
144         """Open interactive terminal on a new channel on the connected Node.
145
146            :param time_out: Timeout in seconds.
147            :return: SSH channel with opened terminal.
148
149            .. warning:: Interruptingcow is used here, and it uses
150                signal(SIGALRM) to let the operating system interrupt program
151                execution. This has the following limitations: Python signal
152                handlers only apply to the main thread, so you cannot use this
153                from other threads. You must not use this in a program that
154                uses SIGALRM itself (this includes certain profilers)
155         """
156         chan = self._ssh.get_transport().open_session()
157         chan.get_pty()
158         chan.invoke_shell()
159         chan.settimeout(int(time_out))
160
161         buf = ''
162         try:
163             with timeout(time_out, exception=RuntimeError):
164                 while not buf.endswith(':~$ '):
165                     if chan.recv_ready():
166                         buf = chan.recv(4096)
167         except RuntimeError:
168             raise Exception('Open interactive terminal timeout.')
169         return chan
170
171     def interactive_terminal_exec_command(self, chan, cmd, prompt,
172                                           time_out=10):
173         """Execute command on interactive terminal.
174
175            interactive_terminal_open() method has to be called first!
176
177            :param chan: SSH channel with opened terminal.
178            :param cmd: Command to be executed.
179            :param prompt: Command prompt, sequence of characters used to
180                indicate readiness to accept commands.
181            :param time_out: Timeout in seconds.
182            :return: Command output.
183
184            .. warning:: Interruptingcow is used here, and it uses
185                signal(SIGALRM) to let the operating system interrupt program
186                execution. This has the following limitations: Python signal
187                handlers only apply to the main thread, so you cannot use this
188                from other threads. You must not use this in a program that
189                uses SIGALRM itself (this includes certain profilers)
190         """
191         chan.sendall('{c}\n'.format(c=cmd))
192         buf = ''
193         try:
194             with timeout(time_out, exception=RuntimeError):
195                 while not buf.endswith(prompt):
196                     if chan.recv_ready():
197                         buf += chan.recv(4096)
198         except RuntimeError:
199             raise Exception("Exec '{c}' timeout.".format(c=cmd))
200         tmp = buf.replace(cmd.replace('\n', ''), '')
201         return tmp.replace(prompt, '')
202
203     def interactive_terminal_close(self, chan):
204         """Close interactive terminal SSH channel.
205
206            :param: chan: SSH channel to be closed.
207         """
208         chan.close()
209
210     def scp(self, local_path, remote_path):
211         """Copy files from local_path to remote_path.
212
213         connect() method has to be called first!
214         """
215         logger.trace('SCP {0} to {1}:{2}'.format(
216             local_path, self._ssh.get_transport().getpeername(), remote_path))
217         # SCPCLient takes a paramiko transport as its only argument
218         scp = SCPClient(self._ssh.get_transport())
219         start = time()
220         scp.put(local_path, remote_path)
221         scp.close()
222         end = time()
223         logger.trace('SCP took {0} seconds'.format(end-start))
224
225
226 def exec_cmd(node, cmd, timeout=None, sudo=False):
227     """Convenience function to ssh/exec/return rc, out & err.
228
229     Returns (rc, stdout, stderr).
230     """
231     if node is None:
232         raise TypeError('Node parameter is None')
233     if cmd is None:
234         raise TypeError('Command parameter is None')
235     if len(cmd) == 0:
236         raise ValueError('Empty command parameter')
237
238     ssh = SSH()
239     try:
240         ssh.connect(node)
241     except Exception, e:
242         logger.error("Failed to connect to node" + e)
243         return None
244
245     try:
246         if not sudo:
247             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
248         else:
249             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
250                                                                timeout=timeout)
251     except Exception, e:
252         logger.error(e)
253         return None
254
255     return (ret_code, stdout, stderr)
256
257 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
258     """Convenience function to ssh/exec/return out & err.
259     Verifies that return code is zero.
260
261     Returns (stdout, stderr).
262     """
263     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
264     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
265                  format(cmd, stderr))
266     return (stdout, stderr)