Improve exceptions in vpp_transport_socket.py
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
1 #
2 # VPP Unix Domain Socket Transport.
3 #
4 import socket
5 import struct
6 import threading
7 import select
8 import multiprocessing
9 try:
10     import queue as queue
11 except ImportError:
12     import Queue as queue
13 import logging
14
15
16 class VppTransportSocketIOError(IOError):
17     # TODO: Document different values of error number (first numeric argument).
18     pass
19
20
21 class VppTransport(object):
22     VppTransportSocketIOError = VppTransportSocketIOError
23
24     def __init__(self, parent, read_timeout, server_address):
25         self.connected = False
26         self.read_timeout = read_timeout if read_timeout > 0 else 1
27         self.parent = parent
28         self.server_address = server_address
29         self.header = struct.Struct('>QII')
30         self.message_table = {}
31         self.sque = multiprocessing.Queue()
32         self.q = multiprocessing.Queue()
33         self.message_thread = None  # Will be set on connect().
34
35     def msg_thread_func(self):
36         while True:
37             try:
38                 rlist, _, _ = select.select([self.socket,
39                                              self.sque._reader], [], [])
40             except socket.error:
41                 # Terminate thread
42                 logging.error('select failed')
43                 self.q.put(None)
44                 return
45
46             for r in rlist:
47                 if r == self.sque._reader:
48                     # Terminate
49                     self.q.put(None)
50                     return
51
52                 elif r == self.socket:
53                     try:
54                         msg = self._read()
55                         if not msg:
56                             self.q.put(None)
57                             return
58                     except socket.error:
59                         self.q.put(None)
60                         return
61                     # Put either to local queue or if context == 0
62                     # callback queue
63                     if self.parent.has_context(msg):
64                         self.q.put(msg)
65                     else:
66                         self.parent.msg_handler_async(msg)
67                 else:
68                     raise VppTransportSocketIOError(
69                         2, 'Unknown response from select')
70
71     def connect(self, name, pfx, msg_handler, rx_qlen):
72
73         if self.message_thread is not None:
74             raise VppTransportSocketIOError(
75                 1, "PAPI socket transport connect: Need to disconnect first.")
76         self.message_thread = threading.Thread(target=self.msg_thread_func)
77
78         # Create a UDS socket
79         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
80         self.socket.settimeout(self.read_timeout)
81
82         # Connect the socket to the port where the server is listening
83         try:
84             self.socket.connect(self.server_address)
85         except socket.error as msg:
86             logging.error("{} on socket {}".format(msg, self.server_address))
87             raise
88
89         self.connected = True
90         # Initialise sockclnt_create
91         sockclnt_create = self.parent.messages['sockclnt_create']
92         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
93
94         args = {'_vl_msg_id': 15,
95                 'name': name,
96                 'context': 124}
97         b = sockclnt_create.pack(args)
98         self.write(b)
99         msg = self._read()
100         hdr, length = self.parent.header.unpack(msg, 0)
101         if hdr.msgid != 16:
102             # TODO: Add first numeric argument.
103             raise VppTransportSocketIOError('Invalid reply message')
104
105         r, length = sockclnt_create_reply.unpack(msg)
106         self.socket_index = r.index
107         for m in r.message_table:
108             n = m.name.rstrip(b'\x00\x13')
109             self.message_table[n] = m.index
110
111         self.message_thread.daemon = True
112         self.message_thread.start()
113
114         return 0
115
116     def disconnect(self):
117         # TODO: Should we detect if user forgot to connect first?
118         rv = 0
119         try:  # Might fail, if VPP closes socket before packet makes it out
120             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
121         except IOError:
122             pass
123         self.connected = False
124         self.socket.close()
125         self.sque.put(True)  # Terminate listening thread
126         self.message_thread.join()
127         # Allow additional connect() calls.
128         self.message_thread = None
129         return rv
130
131     def suspend(self):
132         pass
133
134     def resume(self):
135         pass
136
137     def callback(self):
138         raise NotImplementedError
139
140     def get_callback(self, do_async):
141         return self.callback
142
143     def get_msg_index(self, name):
144         try:
145             return self.message_table[name]
146         except KeyError:
147             return 0
148
149     def msg_table_max_index(self):
150         return len(self.message_table)
151
152     def write(self, buf):
153         """Send a binary-packed message to VPP."""
154         if not self.connected:
155             raise VppTransportSocketIOError(1, 'Not connected')
156
157         # Send header
158         header = self.header.pack(0, len(buf), 0)
159         n = self.socket.send(header)
160         n = self.socket.send(buf)
161
162     def _read(self):
163         hdr = self.socket.recv(16)
164         if not hdr:
165             return
166         (_, l, _) = self.header.unpack(hdr) # If at head of message
167
168         # Read rest of message
169         msg = self.socket.recv(l)
170         if l > len(msg):
171             nbytes = len(msg)
172             buf = bytearray(l)
173             view = memoryview(buf)
174             view[:nbytes] = msg
175             view = view[nbytes:]
176             left = l - nbytes
177             while left:
178                 nbytes = self.socket.recv_into(view, left)
179                 view = view[nbytes:]
180                 left -= nbytes
181             return buf
182         if l == len(msg):
183             return msg
184         raise VppTransportSocketIOError(1, 'Unknown socket read error')
185
186     def read(self):
187         if not self.connected:
188             raise VppTransportSocketIOError(1, 'Not connected')
189         try:
190             return self.q.get(True, self.read_timeout)
191         except queue.Empty:
192             return None