papi: avoid IOError on disconnect
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
index 29b4c65..6989e9a 100644 (file)
@@ -6,11 +6,22 @@ import struct
 import threading
 import select
 import multiprocessing
-import queue
+try:
+    import queue as queue
+except ImportError:
+    import Queue as queue
 import logging
+from . import vpp_papi
+
+
+class VppTransportSocketIOError(IOError):
+    # TODO: Document different values of error number (first numeric argument).
+    pass
 
 
 class VppTransport(object):
+    VppTransportSocketIOError = VppTransportSocketIOError
+
     def __init__(self, parent, read_timeout, server_address):
         self.connected = False
         self.read_timeout = read_timeout if read_timeout > 0 else 1
@@ -18,9 +29,15 @@ class VppTransport(object):
         self.server_address = server_address
         self.header = struct.Struct('>QII')
         self.message_table = {}
+        # These queues can be accessed async.
+        # They are always up, but replaced on connect.
+        # TODO: Use multiprocessing.Pipe instead of multiprocessing.Queue
+        # if possible.
         self.sque = multiprocessing.Queue()
         self.q = multiprocessing.Queue()
-        self.message_thread = threading.Thread(target=self.msg_thread_func)
+        # The following fields are set in connect().
+        self.message_thread = None
+        self.socket = None
 
     def msg_thread_func(self):
         while True:
@@ -50,18 +67,24 @@ class VppTransport(object):
                         return
                     # Put either to local queue or if context == 0
                     # callback queue
-                    r = self.parent.decode_incoming_msg(msg)
-                    if hasattr(r, 'context') and r.context > 0:
+                    if self.parent.has_context(msg):
                         self.q.put(msg)
                     else:
                         self.parent.msg_handler_async(msg)
                 else:
-                    raise IOError(2, 'Unknown response from select')
+                    raise VppTransportSocketIOError(
+                        2, 'Unknown response from select')
 
     def connect(self, name, pfx, msg_handler, rx_qlen):
+        # TODO: Reorder the actions and add "roll-backs",
+        # to restore clean disconnect state when failure happens durng connect.
+
+        if self.message_thread is not None:
+            raise VppTransportSocketIOError(
+                1, "PAPI socket transport connect: Need to disconnect first.")
 
         # Create a UDS socket
-        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
+        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         self.socket.settimeout(self.read_timeout)
 
         # Connect the socket to the port where the server is listening
@@ -72,6 +95,18 @@ class VppTransport(object):
             raise
 
         self.connected = True
+
+        # Queues' feeder threads from previous connect may still be sending.
+        # Close and join to avoid any errors.
+        self.sque.close()
+        self.q.close()
+        self.sque.join_thread()
+        self.q.join_thread()
+        # Finally safe to replace.
+        self.sque = multiprocessing.Queue()
+        self.q = multiprocessing.Queue()
+        self.message_thread = threading.Thread(target=self.msg_thread_func)
+
         # Initialise sockclnt_create
         sockclnt_create = self.parent.messages['sockclnt_create']
         sockclnt_create_reply = self.parent.messages['sockclnt_create_reply']
@@ -84,7 +119,8 @@ class VppTransport(object):
         msg = self._read()
         hdr, length = self.parent.header.unpack(msg, 0)
         if hdr.msgid != 16:
-            raise IOError('Invalid reply message')
+            # TODO: Add first numeric argument.
+            raise VppTransportSocketIOError('Invalid reply message')
 
         r, length = sockclnt_create_reply.unpack(msg)
         self.socket_index = r.index
@@ -98,15 +134,28 @@ class VppTransport(object):
         return 0
 
     def disconnect(self):
+        # TODO: Support repeated disconnect calls, recommend users to call
+        # disconnect when they are not sure what the state is after failures.
+        # TODO: Any volunteer for comprehensive docstrings?
         rv = 0
-        try:  # Might fail, if VPP closes socket before packet makes it out
+        try:
+            # Might fail, if VPP closes socket before packet makes it out,
+            # or if there was a failure during connect().
             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
-        except IOError:
+        except (IOError, vpp_papi.VPPApiError):
             pass
         self.connected = False
-        self.socket.close()
-        self.sque.put(True)  # Terminate listening thread
-        self.message_thread.join()
+        if self.socket is not None:
+            self.socket.close()
+        if self.sque is not None:
+            self.sque.put(True)  # Terminate listening thread
+        if self.message_thread is not None:
+            # Allow additional connect() calls.
+            self.message_thread.join()
+        # Collect garbage.
+        self.message_thread = None
+        self.socket = None
+        # Queues will be collected after connect replaces them.
         return rv
 
     def suspend(self):
@@ -116,7 +165,7 @@ class VppTransport(object):
         pass
 
     def callback(self):
-        raise NotImplemented
+        raise NotImplementedError
 
     def get_callback(self, do_async):
         return self.callback
@@ -133,7 +182,7 @@ class VppTransport(object):
     def write(self, buf):
         """Send a binary-packed message to VPP."""
         if not self.connected:
-            raise IOError(1, 'Not connected')
+            raise VppTransportSocketIOError(1, 'Not connected')
 
         # Send header
         header = self.header.pack(0, len(buf), 0)
@@ -141,40 +190,32 @@ class VppTransport(object):
         n = self.socket.send(buf)
 
     def _read(self):
-        # Header and message
-        try:
-            msg = self.socket.recv(4096)
-            if len(msg) == 0:
-                return None
-        except socket.error as message:
-            logging.error(message)
-            raise
-
-        (_, l, _) = self.header.unpack(msg[:16])
+        hdr = self.socket.recv(16)
+        if not hdr:
+            return
+        (_, l, _) = self.header.unpack(hdr) # If at head of message
 
+        # Read rest of message
+        msg = self.socket.recv(l)
         if l > len(msg):
-            buf = bytearray(l + 16)
+            nbytes = len(msg)
+            buf = bytearray(l)
             view = memoryview(buf)
-            view[:4096] = msg
-            view = view[4096:]
-            # Read rest of message
-            remaining_bytes = l - 4096 + 16
-            while remaining_bytes > 0:
-                bytes_to_read = (remaining_bytes if remaining_bytes
-                                 <= 4096 else 4096)
-                nbytes = self.socket.recv_into(view, bytes_to_read)
-                if nbytes == 0:
-                    logging.error('recv failed')
-                    break
+            view[:nbytes] = msg
+            view = view[nbytes:]
+            left = l - nbytes
+            while left:
+                nbytes = self.socket.recv_into(view, left)
                 view = view[nbytes:]
-                remaining_bytes -= nbytes
-        else:
-            buf = msg
-        return buf[16:]
+                left -= nbytes
+            return buf
+        if l == len(msg):
+            return msg
+        raise VppTransportSocketIOError(1, 'Unknown socket read error')
 
     def read(self):
         if not self.connected:
-            raise IOError(1, 'Not connected')
+            raise VppTransportSocketIOError(1, 'Not connected')
         try:
             return self.q.get(True, self.read_timeout)
         except queue.Empty: