0887a765ba51d39ecf84ee3ba26db977cdbea9d7
[vpp.git] / test / resources / libraries / python / ssh.py
1 # Copyright (c) 2015 Cisco and/or its affiliates.
2 # Licensed under the Apache License, Version 2.0 (the "License");
3 # you may not use this file except in compliance with the License.
4 # You may obtain a copy of the License at:
5 #
6 #     http://www.apache.org/licenses/LICENSE-2.0
7 #
8 # Unless required by applicable law or agreed to in writing, software
9 # distributed under the License is distributed on an "AS IS" BASIS,
10 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 # See the License for the specific language governing permissions and
12 # limitations under the License.
13 import paramiko
14 from scp import SCPClient
15 from time import time
16 from robot.api import logger
17
18 __all__ = ["exec_cmd"]
19
20 # TODO: Attempt to recycle SSH connections
21 # TODO: load priv key
22
23 class SSH(object):
24
25     __MAX_RECV_BUF = 10*1024*1024
26     __existing_connections = {}
27
28     def __init__(self):
29         self._ssh = paramiko.SSHClient()
30         self._ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
31         self._hostname = None
32
33     def _node_hash(self, node):
34         return hash(frozenset([node['host'], node['port']]))
35
36     def connect(self, node):
37         """Connect to node prior to running exec_command or scp.
38
39         If there already is a connection to the node, this method reuses it.
40         """
41         self._hostname = node['host']
42         node_hash = self._node_hash(node)
43         if node_hash in self.__existing_connections:
44             self._ssh = self.__existing_connections[node_hash]
45         else:
46             start = time()
47             self._ssh.connect(node['host'], username=node['username'],
48                     password=node['password'])
49             self.__existing_connections[node_hash] = self._ssh
50             logger.trace('connect took {} seconds'.format(time() - start))
51
52     def exec_command(self, cmd, timeout=10):
53         """Execute SSH command on a new channel on the connected Node.
54
55         Returns (return_code, stdout, stderr).
56         """
57         start = time()
58         chan = self._ssh.get_transport().open_session()
59         if timeout is not None:
60             chan.settimeout(int(timeout))
61         chan.exec_command(cmd)
62         end = time()
63         logger.trace('exec_command "{0}" on {1} took {2} seconds'.format(cmd,
64             self._hostname, end-start))
65
66
67         stdout = ""
68         while True:
69             buf = chan.recv(self.__MAX_RECV_BUF)
70             stdout += buf
71             if not buf:
72                 break
73
74         stderr = ""
75         while True:
76             buf = chan.recv_stderr(self.__MAX_RECV_BUF)
77             stderr += buf
78             if not buf:
79                 break
80
81         return_code = chan.recv_exit_status()
82         logger.trace('chan_recv/_stderr took {} seconds'.format(time()-end))
83
84         return (return_code, stdout, stderr)
85
86     def scp(self, local_path, remote_path):
87         """Copy files from local_path to remote_path.
88
89         connect() method has to be called first!
90         """
91         logger.trace('SCP {0} to {1}:{2}'.format(
92             local_path, self._hostname, remote_path))
93         # SCPCLient takes a paramiko transport as its only argument
94         scp = SCPClient(self._ssh.get_transport())
95         start = time()
96         scp.put(local_path, remote_path)
97         scp.close()
98         end = time()
99         logger.trace('SCP took {0} seconds'.format(end-start))
100
101 def exec_cmd(node, cmd, timeout=None):
102     """Convenience function to ssh/exec/return rc & out.
103
104     Returns (rc, stdout).
105     """
106     if node is None:
107         raise TypeError('Node parameter is None')
108     if cmd is None:
109         raise TypeError('Command parameter is None')
110     if len(cmd) == 0:
111         raise ValueError('Empty command parameter')
112
113     ssh = SSH()
114     try:
115         ssh.connect(node)
116     except Exception, e:
117         logger.error("Failed to connect to node" + e)
118         return None
119
120     try:
121         (ret_code, stdout, stderr) = ssh.exec_command(cmd, timeout=timeout)
122     except Exception, e:
123         logger.error(e)
124         return None
125
126     return (ret_code, stdout, stderr)
127