Allow repeated connects on PAPI socket transport
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
1 #
2 # VPP Unix Domain Socket Transport.
3 #
4 import socket
5 import struct
6 import threading
7 import select
8 import multiprocessing
9 try:
10     import queue as queue
11 except ImportError:
12     import Queue as queue
13 import logging
14
15
16 class VppTransportSocketIOError(IOError):
17     pass
18
19
20 class VppTransport(object):
21     VppTransportSocketIOError = VppTransportSocketIOError
22
23     def __init__(self, parent, read_timeout, server_address):
24         self.connected = False
25         self.read_timeout = read_timeout if read_timeout > 0 else 1
26         self.parent = parent
27         self.server_address = server_address
28         self.header = struct.Struct('>QII')
29         self.message_table = {}
30         self.sque = multiprocessing.Queue()
31         self.q = multiprocessing.Queue()
32         self.message_thread = None  # Will be set on connect().
33
34     def msg_thread_func(self):
35         while True:
36             try:
37                 rlist, _, _ = select.select([self.socket,
38                                              self.sque._reader], [], [])
39             except socket.error:
40                 # Terminate thread
41                 logging.error('select failed')
42                 self.q.put(None)
43                 return
44
45             for r in rlist:
46                 if r == self.sque._reader:
47                     # Terminate
48                     self.q.put(None)
49                     return
50
51                 elif r == self.socket:
52                     try:
53                         msg = self._read()
54                         if not msg:
55                             self.q.put(None)
56                             return
57                     except socket.error:
58                         self.q.put(None)
59                         return
60                     # Put either to local queue or if context == 0
61                     # callback queue
62                     if self.parent.has_context(msg):
63                         self.q.put(msg)
64                     else:
65                         self.parent.msg_handler_async(msg)
66                 else:
67                     raise VppTransportSocketIOError(
68                         2, 'Unknown response from select')
69
70     def connect(self, name, pfx, msg_handler, rx_qlen):
71
72         if self.message_thread.daemon is not None:
73             raise RuntimeError(
74                 "PAPI socket transport connect: You need to disconnect first.")
75         self.message_thread.daemon = threading.Thread(
76             target=self.msg_thread_func)
77
78         # Create a UDS socket
79         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
80         self.socket.settimeout(self.read_timeout)
81
82         # Connect the socket to the port where the server is listening
83         try:
84             self.socket.connect(self.server_address)
85         except socket.error as msg:
86             logging.error("{} on socket {}".format(msg, self.server_address))
87             raise
88
89         self.connected = True
90         # Initialise sockclnt_create
91         sockclnt_create = self.parent.messages['sockclnt_create']
92         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
93
94         args = {'_vl_msg_id': 15,
95                 'name': name,
96                 'context': 124}
97         b = sockclnt_create.pack(args)
98         self.write(b)
99         msg = self._read()
100         hdr, length = self.parent.header.unpack(msg, 0)
101         if hdr.msgid != 16:
102             raise VppTransportSocketIOError('Invalid reply message')
103
104         r, length = sockclnt_create_reply.unpack(msg)
105         self.socket_index = r.index
106         for m in r.message_table:
107             n = m.name.rstrip(b'\x00\x13')
108             self.message_table[n] = m.index
109
110         self.message_thread.daemon = True
111         self.message_thread.start()
112
113         return 0
114
115     def disconnect(self):
116         # TODO: Should we detect if user forgot to connect first?
117         rv = 0
118         try:  # Might fail, if VPP closes socket before packet makes it out
119             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
120         except IOError:
121             pass
122         self.connected = False
123         self.socket.close()
124         self.sque.put(True)  # Terminate listening thread
125         self.message_thread.join()
126         # Allow additional connect() calls.
127         self.message_thread = None
128         return rv
129
130     def suspend(self):
131         pass
132
133     def resume(self):
134         pass
135
136     def callback(self):
137         raise NotImplementedError
138
139     def get_callback(self, do_async):
140         return self.callback
141
142     def get_msg_index(self, name):
143         try:
144             return self.message_table[name]
145         except KeyError:
146             return 0
147
148     def msg_table_max_index(self):
149         return len(self.message_table)
150
151     def write(self, buf):
152         """Send a binary-packed message to VPP."""
153         if not self.connected:
154             raise VppTransportSocketIOError(1, 'Not connected')
155
156         # Send header
157         header = self.header.pack(0, len(buf), 0)
158         n = self.socket.send(header)
159         n = self.socket.send(buf)
160
161     def _read(self):
162         hdr = self.socket.recv(16)
163         if not hdr:
164             return
165         (_, l, _) = self.header.unpack(hdr) # If at head of message
166
167         # Read rest of message
168         msg = self.socket.recv(l)
169         if l > len(msg):
170             nbytes = len(msg)
171             buf = bytearray(l)
172             view = memoryview(buf)
173             view[:nbytes] = msg
174             view = view[nbytes:]
175             left = l - nbytes
176             while left:
177                 nbytes = self.socket.recv_into(view, left)
178                 view = view[nbytes:]
179                 left -= nbytes
180             return buf
181         if l == len(msg):
182             return msg
183         raise VPPTransportSocketIOError(1, 'Unknown socket read error')
184
185     def read(self):
186         if not self.connected:
187             raise VppTransportSocketIOError(1, 'Not connected')
188         try:
189             return self.q.get(True, self.read_timeout)
190         except queue.Empty:
191             return None