vpp_transport_socket: make connect more resilient
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
1 #
2 # VPP Unix Domain Socket Transport.
3 #
4 import socket
5 import struct
6 import threading
7 import select
8 import multiprocessing
9 try:
10     import queue as queue
11 except ImportError:
12     import Queue as queue
13 import logging
14 from . import vpp_papi
15
16
17 class VppTransportSocketIOError(IOError):
18     # TODO: Document different values of error number (first numeric argument).
19     pass
20
21
22 class VppTransport(object):
23     VppTransportSocketIOError = VppTransportSocketIOError
24
25     def __init__(self, parent, read_timeout, server_address):
26         self.connected = False
27         self.read_timeout = read_timeout if read_timeout > 0 else 1
28         self.parent = parent
29         self.server_address = server_address
30         self.header = struct.Struct('>QII')
31         self.message_table = {}
32         # The following fields are set in connect().
33         self.sque = None
34         self.q = None
35         self.message_thread = None
36         self.socket = None
37
38     def msg_thread_func(self):
39         while True:
40             try:
41                 rlist, _, _ = select.select([self.socket,
42                                              self.sque._reader], [], [])
43             except socket.error:
44                 # Terminate thread
45                 logging.error('select failed')
46                 self.q.put(None)
47                 return
48
49             for r in rlist:
50                 if r == self.sque._reader:
51                     # Terminate
52                     self.q.put(None)
53                     return
54
55                 elif r == self.socket:
56                     try:
57                         msg = self._read()
58                         if not msg:
59                             self.q.put(None)
60                             return
61                     except socket.error:
62                         self.q.put(None)
63                         return
64                     # Put either to local queue or if context == 0
65                     # callback queue
66                     if self.parent.has_context(msg):
67                         self.q.put(msg)
68                     else:
69                         self.parent.msg_handler_async(msg)
70                 else:
71                     raise VppTransportSocketIOError(
72                         2, 'Unknown response from select')
73
74     def connect(self, name, pfx, msg_handler, rx_qlen):
75         # TODO: Reorder the actions and add "roll-backs",
76         # to restore clean disconnect state when failure happens durng connect.
77
78         if self.message_thread is not None:
79             raise VppTransportSocketIOError(
80                 1, "PAPI socket transport connect: Need to disconnect first.")
81
82         # Create a UDS socket
83         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
84         self.socket.settimeout(self.read_timeout)
85
86         # Connect the socket to the port where the server is listening
87         try:
88             self.socket.connect(self.server_address)
89         except socket.error as msg:
90             logging.error("{} on socket {}".format(msg, self.server_address))
91             raise
92
93         self.connected = True
94
95         # TODO: Can this block be moved even later?
96         self.sque = multiprocessing.Queue()
97         self.q = multiprocessing.Queue()
98         self.message_thread = threading.Thread(target=self.msg_thread_func)
99
100         # Initialise sockclnt_create
101         sockclnt_create = self.parent.messages['sockclnt_create']
102         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
103
104         args = {'_vl_msg_id': 15,
105                 'name': name,
106                 'context': 124}
107         b = sockclnt_create.pack(args)
108         self.write(b)
109         msg = self._read()
110         hdr, length = self.parent.header.unpack(msg, 0)
111         if hdr.msgid != 16:
112             # TODO: Add first numeric argument.
113             raise VppTransportSocketIOError('Invalid reply message')
114
115         r, length = sockclnt_create_reply.unpack(msg)
116         self.socket_index = r.index
117         for m in r.message_table:
118             n = m.name.rstrip(b'\x00\x13')
119             self.message_table[n] = m.index
120
121         self.message_thread.daemon = True
122         self.message_thread.start()
123
124         return 0
125
126     def disconnect(self):
127         # TODO: Support repeated disconnect calls, recommend users to call
128         # disconnect when they are not sure what the state is after failures.
129         # TODO: Any volunteer for comprehensive docstrings?
130         rv = 0
131         try:
132             # Might fail, if VPP closes socket before packet makes it out,
133             # or if there was a failure during connect().
134             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
135         except (IOError, vpp_papi.VPPApiError):
136             pass
137         self.connected = False
138         if self.socket is not None:
139             self.socket.close()
140         if self.sque is not None:
141             self.sque.put(True)  # Terminate listening thread
142         if self.message_thread is not None:
143             # Allow additional connect() calls.
144             self.message_thread.join()
145         # Collect garbage.
146         self.sque = None
147         self.q = None
148         self.message_thread = None
149         self.socket = None
150         return rv
151
152     def suspend(self):
153         pass
154
155     def resume(self):
156         pass
157
158     def callback(self):
159         raise NotImplementedError
160
161     def get_callback(self, do_async):
162         return self.callback
163
164     def get_msg_index(self, name):
165         try:
166             return self.message_table[name]
167         except KeyError:
168             return 0
169
170     def msg_table_max_index(self):
171         return len(self.message_table)
172
173     def write(self, buf):
174         """Send a binary-packed message to VPP."""
175         if not self.connected:
176             raise VppTransportSocketIOError(1, 'Not connected')
177
178         # Send header
179         header = self.header.pack(0, len(buf), 0)
180         n = self.socket.send(header)
181         n = self.socket.send(buf)
182
183     def _read(self):
184         hdr = self.socket.recv(16)
185         if not hdr:
186             return
187         (_, l, _) = self.header.unpack(hdr) # If at head of message
188
189         # Read rest of message
190         msg = self.socket.recv(l)
191         if l > len(msg):
192             nbytes = len(msg)
193             buf = bytearray(l)
194             view = memoryview(buf)
195             view[:nbytes] = msg
196             view = view[nbytes:]
197             left = l - nbytes
198             while left:
199                 nbytes = self.socket.recv_into(view, left)
200                 view = view[nbytes:]
201                 left -= nbytes
202             return buf
203         if l == len(msg):
204             return msg
205         raise VppTransportSocketIOError(1, 'Unknown socket read error')
206
207     def read(self):
208         if not self.connected:
209             raise VppTransportSocketIOError(1, 'Not connected')
210         try:
211             return self.q.get(True, self.read_timeout)
212         except queue.Empty:
213             return None