Update VPP version downloaded from Nexus.
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from paramiko import RSAKey
15 import StringIO
16 from scp import SCPClient
17 from time import time
18 from robot.api import logger
19 from interruptingcow import timeout
20 from robot.utils.asserts import assert_equal, assert_not_equal
21 from socket import timeout as socket_timeout
22
23 __all__ = ["exec_cmd", "exec_cmd_no_error"]
24
25 # TODO: load priv key
26
27
28 class SSH(object):
29
30     __MAX_RECV_BUF = 10*1024*1024
31     __existing_connections = {}
32
33     def __init__(self):
34         pass
35
36     def _node_hash(self, node):
37         return hash(frozenset([node['host'], node['port']]))
38
39     def connect(self, node):
40         """Connect to node prior to running exec_command or scp.
41
42         If there already is a connection to the node, this method reuses it.
43         """
44         node_hash = self._node_hash(node)
45         if node_hash in SSH.__existing_connections:
46             self._ssh = SSH.__existing_connections[node_hash]
47             logger.debug('reusing ssh: {0}'.format(self._ssh))
48         else:
49             start = time()
50             pkey = None
51             if 'priv_key' in node:
52                 pkey = RSAKey.from_private_key(
53                         StringIO.StringIO(node['priv_key']))
54
55             self._ssh = paramiko.SSHClient()
56             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
57
58             self._ssh.connect(node['host'], username=node['username'],
59                               password=node.get('password'), pkey=pkey)
60
61             SSH.__existing_connections[node_hash] = self._ssh
62
63             logger.trace('connect took {} seconds'.format(time() - start))
64             logger.debug('new ssh: {0}'.format(self._ssh))
65
66         logger.debug('Connect peer: {0}'.
67                 format(self._ssh.get_transport().getpeername()))
68         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
69
70     def exec_command(self, cmd, timeout=10):
71         """Execute SSH command on a new channel on the connected Node.
72
73         Returns (return_code, stdout, stderr).
74         """
75         logger.trace('exec_command on {0}: {1}'
76                      .format(self._ssh.get_transport().getpeername(), cmd))
77         start = time()
78         chan = self._ssh.get_transport().open_session()
79         if timeout is not None:
80             chan.settimeout(int(timeout))
81         chan.exec_command(cmd)
82         end = time()
83         logger.trace('exec_command on {0} took {1} seconds'.format(
84             self._ssh.get_transport().getpeername(), end-start))
85
86         stdout = ""
87         while True:
88             try:
89                 buf = chan.recv(self.__MAX_RECV_BUF)
90                 stdout += buf
91                 if not buf:
92                     break
93             except socket_timeout:
94                 logger.trace('Channels stdout timeout occurred')
95                 break
96
97         stderr = ""
98         while True:
99             try:
100                 buf = chan.recv_stderr(self.__MAX_RECV_BUF)
101                 stderr += buf
102                 if not buf:
103                     break
104             except socket_timeout:
105                 logger.trace('Channels stderr timeout occurred')
106                 break
107
108         return_code = chan.recv_exit_status()
109         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
110
111         logger.trace('return RC {}'.format(return_code))
112         logger.trace('return STDOUT {}'.format(stdout))
113         logger.trace('return STDERR {}'.format(stderr))
114         return return_code, stdout, stderr
115
116     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
117         """Execute SSH command with sudo on a new channel on the connected Node.
118
119            :param cmd: Command to be executed.
120            :param cmd_input: Input redirected to the command.
121            :param timeout: Timeout.
122            :return: return_code, stdout, stderr
123
124            :Example:
125
126             >>> from ssh import SSH
127             >>> ssh = SSH()
128             >>> ssh.connect(node)
129             >>> #Execute command without input (sudo -S cmd)
130             >>> ssh.exec_command_sudo("ifconfig eth0 down")
131             >>> #Execute command with input (sudo -S cmd <<< "input")
132             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
133         """
134         if cmd_input is None:
135             command = 'sudo -S {c}'.format(c=cmd)
136         else:
137             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
138         return self.exec_command(command, timeout)
139
140     def interactive_terminal_open(self, time_out=10):
141         """Open interactive terminal on a new channel on the connected Node.
142
143            :param time_out: Timeout in seconds.
144            :return: SSH channel with opened terminal.
145
146            .. warning:: Interruptingcow is used here, and it uses
147                signal(SIGALRM) to let the operating system interrupt program
148                execution. This has the following limitations: Python signal
149                handlers only apply to the main thread, so you cannot use this
150                from other threads. You must not use this in a program that
151                uses SIGALRM itself (this includes certain profilers)
152         """
153         chan = self._ssh.get_transport().open_session()
154         chan.get_pty()
155         chan.invoke_shell()
156         chan.settimeout(int(time_out))
157
158         buf = ''
159         try:
160             with timeout(time_out, exception=RuntimeError):
161                 while not buf.endswith(':~$ '):
162                     if chan.recv_ready():
163                         buf = chan.recv(4096)
164         except RuntimeError:
165             raise Exception('Open interactive terminal timeout.')
166         return chan
167
168     def interactive_terminal_exec_command(self, chan, cmd, prompt,
169                                           time_out=10):
170         """Execute command on interactive terminal.
171
172            interactive_terminal_open() method has to be called first!
173
174            :param chan: SSH channel with opened terminal.
175            :param cmd: Command to be executed.
176            :param prompt: Command prompt, sequence of characters used to
177                indicate readiness to accept commands.
178            :param time_out: Timeout in seconds.
179            :return: Command output.
180
181            .. warning:: Interruptingcow is used here, and it uses
182                signal(SIGALRM) to let the operating system interrupt program
183                execution. This has the following limitations: Python signal
184                handlers only apply to the main thread, so you cannot use this
185                from other threads. You must not use this in a program that
186                uses SIGALRM itself (this includes certain profilers)
187         """
188         chan.sendall('{c}\n'.format(c=cmd))
189         buf = ''
190         try:
191             with timeout(time_out, exception=RuntimeError):
192                 while not buf.endswith(prompt):
193                     if chan.recv_ready():
194                         buf += chan.recv(4096)
195         except RuntimeError:
196             raise Exception("Exec '{c}' timeout.".format(c=cmd))
197         tmp = buf.replace(cmd.replace('\n', ''), '')
198         return tmp.replace(prompt, '')
199
200     def interactive_terminal_close(self, chan):
201         """Close interactive terminal SSH channel.
202
203            :param: chan: SSH channel to be closed.
204         """
205         chan.close()
206
207     def scp(self, local_path, remote_path):
208         """Copy files from local_path to remote_path.
209
210         connect() method has to be called first!
211         """
212         logger.trace('SCP {0} to {1}:{2}'.format(
213             local_path, self._ssh.get_transport().getpeername(), remote_path))
214         # SCPCLient takes a paramiko transport as its only argument
215         scp = SCPClient(self._ssh.get_transport())
216         start = time()
217         scp.put(local_path, remote_path)
218         scp.close()
219         end = time()
220         logger.trace('SCP took {0} seconds'.format(end-start))
221
222
223 def exec_cmd(node, cmd, timeout=None, sudo=False):
224     """Convenience function to ssh/exec/return rc, out & err.
225
226     Returns (rc, stdout, stderr).
227     """
228     if node is None:
229         raise TypeError('Node parameter is None')
230     if cmd is None:
231         raise TypeError('Command parameter is None')
232     if len(cmd) == 0:
233         raise ValueError('Empty command parameter')
234
235     ssh = SSH()
236     try:
237         ssh.connect(node)
238     except Exception, e:
239         logger.error("Failed to connect to node" + e)
240         return None
241
242     try:
243         if not sudo:
244             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
245         else:
246             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
247                                                                timeout=timeout)
248     except Exception, e:
249         logger.error(e)
250         return None
251
252     return (ret_code, stdout, stderr)
253
254 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
255     """Convenience function to ssh/exec/return out & err.
256     Verifies that return code is zero.
257
258     Returns (stdout, stderr).
259     """
260     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
261     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
262                  format(cmd, stderr))
263     return (stdout, stderr)