SSH connect use port specified in node dict
[csit.git] / resources / libraries / python / ssh.py
1 # Copyright (c) 2016 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from paramiko import RSAKey
15 import StringIO
16 from scp import SCPClient
17 from time import time
18 from robot.api import logger
19 from interruptingcow import timeout
20 from robot.utils.asserts import assert_equal, assert_not_equal
21 from socket import timeout as socket_timeout
22
23 __all__ = ["exec_cmd", "exec_cmd_no_error"]
24
25 # TODO: load priv key
26
27
28 class SSH(object):
29
30     __MAX_RECV_BUF = 10*1024*1024
31     __existing_connections = {}
32
33     def __init__(self):
34         pass
35
36     def _node_hash(self, node):
37         return hash(frozenset([node['host'], node['port']]))
38
39     def connect(self, node):
40         """Connect to node prior to running exec_command or scp.
41
42         If there already is a connection to the node, this method reuses it.
43         """
44         node_hash = self._node_hash(node)
45         if node_hash in SSH.__existing_connections:
46             self._ssh = SSH.__existing_connections[node_hash]
47             logger.debug('reusing ssh: {0}'.format(self._ssh))
48         else:
49             start = time()
50             pkey = None
51             if 'priv_key' in node:
52                 pkey = RSAKey.from_private_key(
53                         StringIO.StringIO(node['priv_key']))
54
55             self._ssh = paramiko.SSHClient()
56             self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
57
58             self._ssh.connect(node['host'], username=node['username'],
59                               password=node.get('password'), pkey=pkey,
60                               port=node['port'])
61
62             SSH.__existing_connections[node_hash] = self._ssh
63
64             logger.trace('connect took {} seconds'.format(time() - start))
65             logger.debug('new ssh: {0}'.format(self._ssh))
66
67         logger.debug('Connect peer: {0}'.
68                 format(self._ssh.get_transport().getpeername()))
69         logger.debug('Connections: {0}'.format(str(SSH.__existing_connections)))
70
71     def exec_command(self, cmd, timeout=10):
72         """Execute SSH command on a new channel on the connected Node.
73
74         Returns (return_code, stdout, stderr).
75         """
76         logger.trace('exec_command on {0}: {1}'
77                      .format(self._ssh.get_transport().getpeername(), cmd))
78         start = time()
79         chan = self._ssh.get_transport().open_session()
80         if timeout is not None:
81             chan.settimeout(int(timeout))
82         chan.exec_command(cmd)
83         end = time()
84         logger.trace('exec_command on {0} took {1} seconds'.format(
85             self._ssh.get_transport().getpeername(), end-start))
86
87         stdout = ""
88         while True:
89             try:
90                 buf = chan.recv(self.__MAX_RECV_BUF)
91                 stdout += buf
92                 if not buf:
93                     break
94             except socket_timeout:
95                 logger.trace('Channels stdout timeout occurred')
96                 break
97
98         stderr = ""
99         while True:
100             try:
101                 buf = chan.recv_stderr(self.__MAX_RECV_BUF)
102                 stderr += buf
103                 if not buf:
104                     break
105             except socket_timeout:
106                 logger.trace('Channels stderr timeout occurred')
107                 break
108
109         return_code = chan.recv_exit_status()
110         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
111
112         logger.trace('return RC {}'.format(return_code))
113         logger.trace('return STDOUT {}'.format(stdout))
114         logger.trace('return STDERR {}'.format(stderr))
115         return return_code, stdout, stderr
116
117     def exec_command_sudo(self, cmd, cmd_input=None, timeout=10):
118         """Execute SSH command with sudo on a new channel on the connected Node.
119
120            :param cmd: Command to be executed.
121            :param cmd_input: Input redirected to the command.
122            :param timeout: Timeout.
123            :return: return_code, stdout, stderr
124
125            :Example:
126
127             >>> from ssh import SSH
128             >>> ssh = SSH()
129             >>> ssh.connect(node)
130             >>> #Execute command without input (sudo -S cmd)
131             >>> ssh.exec_command_sudo("ifconfig eth0 down")
132             >>> #Execute command with input (sudo -S cmd <<< "input")
133             >>> ssh.exec_command_sudo("vpp_api_test", "dump_interface_table")
134         """
135         if cmd_input is None:
136             command = 'sudo -S {c}'.format(c=cmd)
137         else:
138             command = 'sudo -S {c} <<< "{i}"'.format(c=cmd, i=cmd_input)
139         return self.exec_command(command, timeout)
140
141     def interactive_terminal_open(self, time_out=10):
142         """Open interactive terminal on a new channel on the connected Node.
143
144            :param time_out: Timeout in seconds.
145            :return: SSH channel with opened terminal.
146
147            .. warning:: Interruptingcow is used here, and it uses
148                signal(SIGALRM) to let the operating system interrupt program
149                execution. This has the following limitations: Python signal
150                handlers only apply to the main thread, so you cannot use this
151                from other threads. You must not use this in a program that
152                uses SIGALRM itself (this includes certain profilers)
153         """
154         chan = self._ssh.get_transport().open_session()
155         chan.get_pty()
156         chan.invoke_shell()
157         chan.settimeout(int(time_out))
158
159         buf = ''
160         try:
161             with timeout(time_out, exception=RuntimeError):
162                 while not buf.endswith(':~$ '):
163                     if chan.recv_ready():
164                         buf = chan.recv(4096)
165         except RuntimeError:
166             raise Exception('Open interactive terminal timeout.')
167         return chan
168
169     def interactive_terminal_exec_command(self, chan, cmd, prompt,
170                                           time_out=10):
171         """Execute command on interactive terminal.
172
173            interactive_terminal_open() method has to be called first!
174
175            :param chan: SSH channel with opened terminal.
176            :param cmd: Command to be executed.
177            :param prompt: Command prompt, sequence of characters used to
178                indicate readiness to accept commands.
179            :param time_out: Timeout in seconds.
180            :return: Command output.
181
182            .. warning:: Interruptingcow is used here, and it uses
183                signal(SIGALRM) to let the operating system interrupt program
184                execution. This has the following limitations: Python signal
185                handlers only apply to the main thread, so you cannot use this
186                from other threads. You must not use this in a program that
187                uses SIGALRM itself (this includes certain profilers)
188         """
189         chan.sendall('{c}\n'.format(c=cmd))
190         buf = ''
191         try:
192             with timeout(time_out, exception=RuntimeError):
193                 while not buf.endswith(prompt):
194                     if chan.recv_ready():
195                         buf += chan.recv(4096)
196         except RuntimeError:
197             raise Exception("Exec '{c}' timeout.".format(c=cmd))
198         tmp = buf.replace(cmd.replace('\n', ''), '')
199         return tmp.replace(prompt, '')
200
201     def interactive_terminal_close(self, chan):
202         """Close interactive terminal SSH channel.
203
204            :param: chan: SSH channel to be closed.
205         """
206         chan.close()
207
208     def scp(self, local_path, remote_path):
209         """Copy files from local_path to remote_path.
210
211         connect() method has to be called first!
212         """
213         logger.trace('SCP {0} to {1}:{2}'.format(
214             local_path, self._ssh.get_transport().getpeername(), remote_path))
215         # SCPCLient takes a paramiko transport as its only argument
216         scp = SCPClient(self._ssh.get_transport())
217         start = time()
218         scp.put(local_path, remote_path)
219         scp.close()
220         end = time()
221         logger.trace('SCP took {0} seconds'.format(end-start))
222
223
224 def exec_cmd(node, cmd, timeout=None, sudo=False):
225     """Convenience function to ssh/exec/return rc, out & err.
226
227     Returns (rc, stdout, stderr).
228     """
229     if node is None:
230         raise TypeError('Node parameter is None')
231     if cmd is None:
232         raise TypeError('Command parameter is None')
233     if len(cmd) == 0:
234         raise ValueError('Empty command parameter')
235
236     ssh = SSH()
237     try:
238         ssh.connect(node)
239     except Exception, e:
240         logger.error("Failed to connect to node" + e)
241         return None
242
243     try:
244         if not sudo:
245             (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
246         else:
247             (ret_code, stdout, stderr) = ssh.exec_command_sudo(cmd,
248                                                                timeout=timeout)
249     except Exception, e:
250         logger.error(e)
251         return None
252
253     return (ret_code, stdout, stderr)
254
255 def exec_cmd_no_error(node, cmd, timeout=None, sudo=False):
256     """Convenience function to ssh/exec/return out & err.
257     Verifies that return code is zero.
258
259     Returns (stdout, stderr).
260     """
261     (rc, stdout, stderr) = exec_cmd(node, cmd, timeout=timeout, sudo=sudo)
262     assert_equal(rc, 0, 'Command execution failed: "{}"\n{}'.
263                  format(cmd, stderr))
264     return (stdout, stderr)