tests: use socket transport instead of shared memory
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
1 #
2 # VPP Unix Domain Socket Transport.
3 #
4 import socket
5 import struct
6 import threading
7 import select
8 import multiprocessing
9 import queue
10 import logging
11
12 logger = logging.getLogger('vpp_papi.transport')
13 logger.addHandler(logging.NullHandler())
14
15
16 class VppTransportSocketIOError(IOError):
17     # TODO: Document different values of error number (first numeric argument).
18     pass
19
20
21 class VppTransport:
22     VppTransportSocketIOError = VppTransportSocketIOError
23
24     def __init__(self, parent, read_timeout, server_address):
25         self.connected = False
26         self.read_timeout = read_timeout if read_timeout > 0 else 1
27         self.parent = parent
28         self.server_address = server_address
29         self.header = struct.Struct('>QII')
30         self.message_table = {}
31         # These queues can be accessed async.
32         # They are always up, but replaced on connect.
33         # TODO: Use multiprocessing.Pipe instead of multiprocessing.Queue
34         # if possible.
35         self.sque = multiprocessing.Queue()
36         self.q = multiprocessing.Queue()
37         # The following fields are set in connect().
38         self.message_thread = None
39         self.socket = None
40
41     def msg_thread_func(self):
42         while True:
43             try:
44                 rlist, _, _ = select.select([self.socket,
45                                              self.sque._reader], [], [])
46             except socket.error:
47                 # Terminate thread
48                 logging.error('select failed')
49                 self.q.put(None)
50                 return
51
52             for r in rlist:
53                 if r == self.sque._reader:
54                     # Terminate
55                     self.q.put(None)
56                     return
57
58                 elif r == self.socket:
59                     try:
60                         msg = self._read()
61                         if not msg:
62                             self.q.put(None)
63                             return
64                     except socket.error:
65                         self.q.put(None)
66                         return
67                     # Put either to local queue or if context == 0
68                     # callback queue
69                     if self.parent.has_context(msg):
70                         self.q.put(msg)
71                     else:
72                         self.parent.msg_handler_async(msg)
73                 else:
74                     raise VppTransportSocketIOError(
75                         2, 'Unknown response from select')
76
77     def connect(self, name, pfx, msg_handler, rx_qlen):
78         # TODO: Reorder the actions and add "roll-backs",
79         # to restore clean disconnect state when failure happens durng connect.
80
81         if self.message_thread is not None:
82             raise VppTransportSocketIOError(
83                 1, "PAPI socket transport connect: Need to disconnect first.")
84
85         # Create a UDS socket
86         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
87         self.socket.settimeout(self.read_timeout)
88
89         # Connect the socket to the port where the server is listening
90         try:
91             self.socket.connect(self.server_address)
92         except socket.error as msg:
93             # logging.error("{} on socket {}".format(msg, self.server_address))
94             raise msg
95
96         self.connected = True
97
98         # Queues' feeder threads from previous connect may still be sending.
99         # Close and join to avoid any errors.
100         self.sque.close()
101         self.q.close()
102         self.sque.join_thread()
103         self.q.join_thread()
104         # Finally safe to replace.
105         self.sque = multiprocessing.Queue()
106         self.q = multiprocessing.Queue()
107         self.message_thread = threading.Thread(target=self.msg_thread_func)
108
109         # Initialise sockclnt_create
110         sockclnt_create = self.parent.messages['sockclnt_create']
111         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
112
113         args = {'_vl_msg_id': 15,
114                 'name': name,
115                 'context': 124}
116         b = sockclnt_create.pack(args)
117         self.write(b)
118         msg = self._read()
119         hdr, length = self.parent.header.unpack(msg, 0)
120         if hdr.msgid != 16:
121             # TODO: Add first numeric argument.
122             raise VppTransportSocketIOError('Invalid reply message')
123
124         r, length = sockclnt_create_reply.unpack(msg)
125         self.socket_index = r.index
126         for m in r.message_table:
127             n = m.name
128             self.message_table[n] = m.index
129
130         self.message_thread.daemon = True
131         self.message_thread.start()
132
133         return 0
134
135     def disconnect(self):
136         # TODO: Support repeated disconnect calls, recommend users to call
137         # disconnect when they are not sure what the state is after failures.
138         # TODO: Any volunteer for comprehensive docstrings?
139         rv = 0
140         try:
141             # Might fail, if VPP closes socket before packet makes it out,
142             # or if there was a failure during connect().
143             # TODO: manually build message so that .disconnect releases server-side resources
144             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
145         except (IOError, self.parent.VPPApiError):
146             pass
147         self.connected = False
148         if self.socket is not None:
149             self.socket.close()
150         if self.sque is not None:
151             self.sque.put(True)  # Terminate listening thread
152         if self.message_thread is not None and self.message_thread.is_alive():
153             # Allow additional connect() calls.
154             self.message_thread.join()
155         # Wipe message table, VPP can be restarted with different plugins.
156         self.message_table = {}
157         # Collect garbage.
158         self.message_thread = None
159         self.socket = None
160         # Queues will be collected after connect replaces them.
161         return rv
162
163     def suspend(self):
164         pass
165
166     def resume(self):
167         pass
168
169     def callback(self):
170         raise NotImplementedError
171
172     def get_callback(self, do_async):
173         return self.callback
174
175     def get_msg_index(self, name):
176         try:
177             return self.message_table[name]
178         except KeyError:
179             return 0
180
181     def msg_table_max_index(self):
182         return len(self.message_table)
183
184     def write(self, buf):
185         """Send a binary-packed message to VPP."""
186         if not self.connected:
187             raise VppTransportSocketIOError(1, 'Not connected')
188
189         # Send header
190         header = self.header.pack(0, len(buf), 0)
191         try:
192             self.socket.sendall(header)
193             self.socket.sendall(buf)
194         except socket.error as err:
195             raise VppTransportSocketIOError(1, 'Sendall error: {err!r}'.format(
196                 err=err))
197
198     def _read_fixed(self, size):
199         """Repeat receive until fixed size is read. Return empty on error."""
200         buf = bytearray(size)
201         view = memoryview(buf)
202         left = size
203         while 1:
204             got = self.socket.recv_into(view, left)
205             if got <= 0:
206                 # Read error.
207                 return ""
208             if got >= left:
209                 # TODO: Raise if got > left?
210                 break
211             left -= got
212             view = view[got:]
213         return buf
214
215     def _read(self):
216         """Read single complete message, return it or empty on error."""
217         hdr = self._read_fixed(16)
218         if not hdr:
219             return
220         (_, hdrlen, _) = self.header.unpack(hdr)  # If at head of message
221
222         # Read rest of message
223         msg = self._read_fixed(hdrlen)
224         if hdrlen == len(msg):
225             return msg
226         raise VppTransportSocketIOError(1, 'Unknown socket read error')
227
228     def read(self, timeout=None):
229         if not self.connected:
230             raise VppTransportSocketIOError(1, 'Not connected')
231         if timeout is None:
232             timeout = self.read_timeout
233         try:
234             return self.q.get(True, timeout)
235         except queue.Empty:
236             return None