papi: add a per-call _timeout option
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
index 115a2c2..d6431ca 100644 (file)
@@ -29,9 +29,13 @@ class VppTransport(object):
         self.server_address = server_address
         self.header = struct.Struct('>QII')
         self.message_table = {}
+        # These queues can be accessed async.
+        # They are always up, but replaced on connect.
+        # TODO: Use multiprocessing.Pipe instead of multiprocessing.Queue
+        # if possible.
+        self.sque = multiprocessing.Queue()
+        self.q = multiprocessing.Queue()
         # The following fields are set in connect().
-        self.sque = None
-        self.q = None
         self.message_thread = None
         self.socket = None
 
@@ -92,7 +96,13 @@ class VppTransport(object):
 
         self.connected = True
 
-        # TODO: Can this block be moved even later?
+        # Queues' feeder threads from previous connect may still be sending.
+        # Close and join to avoid any errors.
+        self.sque.close()
+        self.q.close()
+        self.sque.join_thread()
+        self.q.join_thread()
+        # Finally safe to replace.
         self.sque = multiprocessing.Queue()
         self.q = multiprocessing.Queue()
         self.message_thread = threading.Thread(target=self.msg_thread_func)
@@ -115,7 +125,7 @@ class VppTransport(object):
         r, length = sockclnt_create_reply.unpack(msg)
         self.socket_index = r.index
         for m in r.message_table:
-            n = m.name.rstrip(b'\x00\x13')
+            n = m.name
             self.message_table[n] = m.index
 
         self.message_thread.daemon = True
@@ -139,14 +149,15 @@ class VppTransport(object):
             self.socket.close()
         if self.sque is not None:
             self.sque.put(True)  # Terminate listening thread
-        if self.message_thread is not None:
+        if self.message_thread is not None and self.message_thread.is_alive():
             # Allow additional connect() calls.
             self.message_thread.join()
+        # Wipe message table, VPP can be restarted with different plugins.
+        self.message_table = {}
         # Collect garbage.
-        self.sque = None
-        self.q = None
         self.message_thread = None
         self.socket = None
+        # Queues will be collected after connect replaces them.
         return rv
 
     def suspend(self):
@@ -177,37 +188,49 @@ class VppTransport(object):
 
         # Send header
         header = self.header.pack(0, len(buf), 0)
-        n = self.socket.send(header)
-        n = self.socket.send(buf)
+        try:
+            self.socket.sendall(header)
+            self.socket.sendall(buf)
+        except socket.error as err:
+            raise VppTransportSocketIOError(1, 'Sendall error: {err!r}'.format(
+                err=err))
+
+    def _read_fixed(self, size):
+        """Repeat receive until fixed size is read. Return empty on error."""
+        buf = bytearray(size)
+        view = memoryview(buf)
+        left = size
+        while 1:
+            got = self.socket.recv_into(view, left)
+            if got <= 0:
+                # Read error.
+                return ""
+            if got >= left:
+                # TODO: Raise if got > left?
+                break
+            left -= got
+            view = view[got:]
+        return buf
 
     def _read(self):
-        hdr = self.socket.recv(16)
+        """Read single complete message, return it or empty on error."""
+        hdr = self._read_fixed(16)
         if not hdr:
             return
-        (_, l, _) = self.header.unpack(hdr) # If at head of message
+        (_, hdrlen, _) = self.header.unpack(hdr)  # If at head of message
 
         # Read rest of message
-        msg = self.socket.recv(l)
-        if l > len(msg):
-            nbytes = len(msg)
-            buf = bytearray(l)
-            view = memoryview(buf)
-            view[:nbytes] = msg
-            view = view[nbytes:]
-            left = l - nbytes
-            while left:
-                nbytes = self.socket.recv_into(view, left)
-                view = view[nbytes:]
-                left -= nbytes
-            return buf
-        if l == len(msg):
+        msg = self._read_fixed(hdrlen)
+        if hdrlen == len(msg):
             return msg
         raise VppTransportSocketIOError(1, 'Unknown socket read error')
 
-    def read(self):
+    def read(self, timeout=None):
         if not self.connected:
             raise VppTransportSocketIOError(1, 'Not connected')
+        if timeout is None:
+            timeout = self.read_timeout
         try:
-            return self.q.get(True, self.read_timeout)
+            return self.q.get(True, timeout)
         except queue.Empty:
             return None