a2bb9b1dc4862c6c9e7fe12ace77bdf0f1d99f10
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from paramiko import RSAKey
15 import StringIO
16 from scp import SCPClient
17 from time import time
18 from robot.api import logger
19 from interruptingcow import timeout
20 from robot.utils.asserts import assert_equal, assert_not_equal
21
22 __all__ = ["exec_cmd", "exec_cmd_no_error"]
23
24 # TODO: load priv key
25
26
27 class SSH(object):
28
29     __MAX_RECV_BUF = 10*1024*1024
30     __existing_connections = {}
31
32     def __init__(self):
33         pass
34
35     def _node_hash(self, node):
36         return hash(frozenset([node['host'], node['port']]))
37
38     def connect(self, node):
39         """Connect to node prior to running exec_command or scp.
40
41         If there already is a connection to the node, this method reuses it.
42         """
43         node_hash = self._node_hash(node)
44         if node_hash in SSH.__existing_connections:
45             self._ssh = SSH.__existing_connections[node_hash]
46             logger.debug('reusing ssh: {0}'.format(self._ssh))
47         else:
48             start = time()
49             pkey = None
50             if 'priv_key' in node:
51                 pkey = RSAKey.from_private_key(
52                         StringIO.StringIO(node['priv_key']))
53
54             self._ssh = paramiko.SSHClient()
55             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
56
57             self._ssh.connect(node['host'], username=node['username'],
58                               password=node.get('password'), pkey=pkey)
59
60             SSH.__existing_connections[node_hash] = self._ssh
61
62             logger.trace('connect took {} seconds'.format(time() - start))
63             logger.debug('new ssh: {0}'.format(self._ssh))
64
65         logger.debug('Connect peer: {0}'.
66                 format(self._ssh.get_transport().getpeername()))
67         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
68
69     def exec_command(self, cmd, timeout=10):
70         """Execute SSH command on a new channel on the connected Node.
71
72         Returns (return_code, stdout, stderr).
73         """
74         logger.trace('exec_command on {0}: {1}'
75                 .format(self._ssh.get_transport().getpeername(), cmd))
76         start = time()
77         chan = self._ssh.get_transport().open_session()
78         if timeout is not None:
79             chan.settimeout(int(timeout))
80         chan.exec_command(cmd)
81         end = time()
82         logger.trace('exec_command on {0} took {1} seconds'.format(
83             self._ssh.get_transport().getpeername(), end-start))
84
85         stdout = ""
86         while True:
87             buf = chan.recv(self.__MAX_RECV_BUF)
88             stdout += buf
89             if not buf:
90                 break
91
92         stderr = ""
93         while True:
94             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
95             stderr += buf
96             if not buf:
97                 break
98
99         return_code = chan.recv_exit_status()
100         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
101
102         return (return_code, stdout, stderr)
103
104     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
105         """Execute SSH command with sudo on a new channel on the connected Node.
106
107            :param cmd: Command to be executed.
108            :param cmd_input: Input redirected to the command.
109            :param timeout: Timeout.
110            :return: return_code, stdout, stderr
111
112            :Example:
113
114             >>> from ssh import SSH
115             >>> ssh = SSH()
116             >>> ssh.connect(node)
117             >>> #Execute command without input (sudo -S cmd)
118             >>> ssh.exec_command_sudo("ifconfig eth0 down")
119             >>> #Execute command with input (sudo -S cmd <<< "input")
120             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
121         """
122         if cmd_input is None:
123             command = 'sudo -S {c}'.format(c=cmd)
124         else:
125             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
126         return self.exec_command(command, timeout)
127
128     def interactive_terminal_open(self, time_out=10):
129         """Open interactive terminal on a new channel on the connected Node.
130
131            :param time_out: Timeout in seconds.
132            :return: SSH channel with opened terminal.
133
134            .. warning:: Interruptingcow is used here, and it uses
135                signal(SIGALRM) to let the operating system interrupt program
136                execution. This has the following limitations: Python signal
137                handlers only apply to the main thread, so you cannot use this
138                from other threads. You must not use this in a program that
139                uses SIGALRM itself (this includes certain profilers)
140         """
141         chan = self._ssh.get_transport().open_session()
142         chan.get_pty()
143         chan.invoke_shell()
144         chan.settimeout(int(time_out))
145
146         buf = ''
147         try:
148             with timeout(time_out, exception=RuntimeError):
149                 while not buf.endswith(':~$ '):
150                     if chan.recv_ready():
151                         buf = chan.recv(4096)
152         except RuntimeError:
153             raise Exception('Open interactive terminal timeout.')
154         return chan
155
156     def interactive_terminal_exec_command(self, chan, cmd, prompt,
157                                           time_out=10):
158         """Execute command on interactive terminal.
159
160            interactive_terminal_open() method has to be called first!
161
162            :param chan: SSH channel with opened terminal.
163            :param cmd: Command to be executed.
164            :param prompt: Command prompt, sequence of characters used to
165                indicate readiness to accept commands.
166            :param time_out: Timeout in seconds.
167            :return: Command output.
168
169            .. warning:: Interruptingcow is used here, and it uses
170                signal(SIGALRM) to let the operating system interrupt program
171                execution. This has the following limitations: Python signal
172                handlers only apply to the main thread, so you cannot use this
173                from other threads. You must not use this in a program that
174                uses SIGALRM itself (this includes certain profilers)
175         """
176         chan.sendall('{c}\n'.format(c=cmd))
177         buf = ''
178         try:
179             with timeout(time_out, exception=RuntimeError):
180                 while not buf.endswith(prompt):
181                     if chan.recv_ready():
182                         buf += chan.recv(4096)
183         except RuntimeError:
184             raise Exception("Exec '{c}' timeout.".format(c=cmd))
185         tmp = buf.replace(cmd.replace('\n', ''), '')
186         return tmp.replace(prompt, '')
187
188     def interactive_terminal_close(self, chan):
189         """Close interactive terminal SSH channel.
190
191            :param: chan: SSH channel to be closed.
192         """
193         chan.close()
194
195     def scp(self, local_path, remote_path):
196         """Copy files from local_path to remote_path.
197
198         connect() method has to be called first!
199         """
200         logger.trace('SCP {0} to {1}:{2}'.format(
201             local_path, self._ssh.get_transport().getpeername(), remote_path))
202         # SCPCLient takes a paramiko transport as its only argument
203         scp = SCPClient(self._ssh.get_transport())
204         start = time()
205         scp.put(local_path, remote_path)
206         scp.close()
207         end = time()
208         logger.trace('SCP took {0} seconds'.format(end-start))
209
210
211 def exec_cmd(node, cmd, timeout=None, sudo=False):
212     """Convenience function to ssh/exec/return rc, out & err.
213
214     Returns (rc, stdout, stderr).
215     """
216     if node is None:
217         raise TypeError('Node parameter is None')
218     if cmd is None:
219         raise TypeError('Command parameter is None')
220     if len(cmd) == 0:
221         raise ValueError('Empty command parameter')
222
223     ssh = SSH()
224     try:
225         ssh.connect(node)
226     except Exception, e:
227         logger.error("Failed to connect to node" + e)
228         return None
229
230     try:
231         if not sudo:
232             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
233         else:
234             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
235                                                                timeout=timeout)
236     except Exception, e:
237         logger.error(e)
238         return None
239
240     return (ret_code, stdout, stderr)
241
242 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
243     """Convenience function to ssh/exec/return out & err.
244     Verifies that return code is zero.
245
246     Returns (stdout, stderr).
247     """
248     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
249     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
250                  format(cmd, stderr))
251     return (stdout, stderr)