Fix copypaste in vpp_papi/vpp_transport_socket.py
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
1 #
2 # VPP Unix Domain Socket Transport.
3 #
4 import socket
5 import struct
6 import threading
7 import select
8 import multiprocessing
9 try:
10     import queue as queue
11 except ImportError:
12     import Queue as queue
13 import logging
14
15
16 class VppTransportSocketIOError(IOError):
17     pass
18
19
20 class VppTransport(object):
21     VppTransportSocketIOError = VppTransportSocketIOError
22
23     def __init__(self, parent, read_timeout, server_address):
24         self.connected = False
25         self.read_timeout = read_timeout if read_timeout > 0 else 1
26         self.parent = parent
27         self.server_address = server_address
28         self.header = struct.Struct('>QII')
29         self.message_table = {}
30         self.sque = multiprocessing.Queue()
31         self.q = multiprocessing.Queue()
32         self.message_thread = None  # Will be set on connect().
33
34     def msg_thread_func(self):
35         while True:
36             try:
37                 rlist, _, _ = select.select([self.socket,
38                                              self.sque._reader], [], [])
39             except socket.error:
40                 # Terminate thread
41                 logging.error('select failed')
42                 self.q.put(None)
43                 return
44
45             for r in rlist:
46                 if r == self.sque._reader:
47                     # Terminate
48                     self.q.put(None)
49                     return
50
51                 elif r == self.socket:
52                     try:
53                         msg = self._read()
54                         if not msg:
55                             self.q.put(None)
56                             return
57                     except socket.error:
58                         self.q.put(None)
59                         return
60                     # Put either to local queue or if context == 0
61                     # callback queue
62                     if self.parent.has_context(msg):
63                         self.q.put(msg)
64                     else:
65                         self.parent.msg_handler_async(msg)
66                 else:
67                     raise VppTransportSocketIOError(
68                         2, 'Unknown response from select')
69
70     def connect(self, name, pfx, msg_handler, rx_qlen):
71
72         if self.message_thread is not None:
73             raise RuntimeError(
74                 "PAPI socket transport connect: You need to disconnect first.")
75         self.message_thread = threading.Thread(target=self.msg_thread_func)
76
77         # Create a UDS socket
78         self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
79         self.socket.settimeout(self.read_timeout)
80
81         # Connect the socket to the port where the server is listening
82         try:
83             self.socket.connect(self.server_address)
84         except socket.error as msg:
85             logging.error("{} on socket {}".format(msg, self.server_address))
86             raise
87
88         self.connected = True
89         # Initialise sockclnt_create
90         sockclnt_create = self.parent.messages['sockclnt_create']
91         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
92
93         args = {'_vl_msg_id': 15,
94                 'name': name,
95                 'context': 124}
96         b = sockclnt_create.pack(args)
97         self.write(b)
98         msg = self._read()
99         hdr, length = self.parent.header.unpack(msg, 0)
100         if hdr.msgid != 16:
101             raise VppTransportSocketIOError('Invalid reply message')
102
103         r, length = sockclnt_create_reply.unpack(msg)
104         self.socket_index = r.index
105         for m in r.message_table:
106             n = m.name.rstrip(b'\x00\x13')
107             self.message_table[n] = m.index
108
109         self.message_thread.daemon = True
110         self.message_thread.start()
111
112         return 0
113
114     def disconnect(self):
115         # TODO: Should we detect if user forgot to connect first?
116         rv = 0
117         try:  # Might fail, if VPP closes socket before packet makes it out
118             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
119         except IOError:
120             pass
121         self.connected = False
122         self.socket.close()
123         self.sque.put(True)  # Terminate listening thread
124         self.message_thread.join()
125         # Allow additional connect() calls.
126         self.message_thread = None
127         return rv
128
129     def suspend(self):
130         pass
131
132     def resume(self):
133         pass
134
135     def callback(self):
136         raise NotImplementedError
137
138     def get_callback(self, do_async):
139         return self.callback
140
141     def get_msg_index(self, name):
142         try:
143             return self.message_table[name]
144         except KeyError:
145             return 0
146
147     def msg_table_max_index(self):
148         return len(self.message_table)
149
150     def write(self, buf):
151         """Send a binary-packed message to VPP."""
152         if not self.connected:
153             raise VppTransportSocketIOError(1, 'Not connected')
154
155         # Send header
156         header = self.header.pack(0, len(buf), 0)
157         n = self.socket.send(header)
158         n = self.socket.send(buf)
159
160     def _read(self):
161         hdr = self.socket.recv(16)
162         if not hdr:
163             return
164         (_, l, _) = self.header.unpack(hdr) # If at head of message
165
166         # Read rest of message
167         msg = self.socket.recv(l)
168         if l > len(msg):
169             nbytes = len(msg)
170             buf = bytearray(l)
171             view = memoryview(buf)
172             view[:nbytes] = msg
173             view = view[nbytes:]
174             left = l - nbytes
175             while left:
176                 nbytes = self.socket.recv_into(view, left)
177                 view = view[nbytes:]
178                 left -= nbytes
179             return buf
180         if l == len(msg):
181             return msg
182         raise VppTransportSocketIOError(1, 'Unknown socket read error')
183
184     def read(self):
185         if not self.connected:
186             raise VppTransportSocketIOError(1, 'Not connected')
187         try:
188             return self.q.get(True, self.read_timeout)
189         except queue.Empty:
190             return None