Fix copypaste in vpp_papi/vpp_transport_socket.py
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_transport_socket.py
index 4341cad..b00d15d 100644 (file)
@@ -29,7 +29,7 @@ class VppTransport(object):
         self.message_table = {}
         self.sque = multiprocessing.Queue()
         self.q = multiprocessing.Queue()
-        self.message_thread = threading.Thread(target=self.msg_thread_func)
+        self.message_thread = None  # Will be set on connect().
 
     def msg_thread_func(self):
         while True:
@@ -69,8 +69,13 @@ class VppTransport(object):
 
     def connect(self, name, pfx, msg_handler, rx_qlen):
 
+        if self.message_thread is not None:
+            raise RuntimeError(
+                "PAPI socket transport connect: You need to disconnect first.")
+        self.message_thread = threading.Thread(target=self.msg_thread_func)
+
         # Create a UDS socket
-        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
+        self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
         self.socket.settimeout(self.read_timeout)
 
         # Connect the socket to the port where the server is listening
@@ -107,6 +112,7 @@ class VppTransport(object):
         return 0
 
     def disconnect(self):
+        # TODO: Should we detect if user forgot to connect first?
         rv = 0
         try:  # Might fail, if VPP closes socket before packet makes it out
             rv = self.parent.api.sockclnt_delete(index=self.socket_index)
@@ -116,6 +122,8 @@ class VppTransport(object):
         self.socket.close()
         self.sque.put(True)  # Terminate listening thread
         self.message_thread.join()
+        # Allow additional connect() calls.
+        self.message_thread = None
         return rv
 
     def suspend(self):
@@ -150,36 +158,28 @@ class VppTransport(object):
         n = self.socket.send(buf)
 
     def _read(self):
-        # Header and message
-        try:
-            msg = self.socket.recv(4096)
-            if len(msg) == 0:
-                return None
-        except socket.error as message:
-            logging.error(message)
-            raise
-
-        (_, l, _) = self.header.unpack(msg[:16])
+        hdr = self.socket.recv(16)
+        if not hdr:
+            return
+        (_, l, _) = self.header.unpack(hdr) # If at head of message
 
+        # Read rest of message
+        msg = self.socket.recv(l)
         if l > len(msg):
-            buf = bytearray(l + 16)
+            nbytes = len(msg)
+            buf = bytearray(l)
             view = memoryview(buf)
-            view[:4096] = msg
-            view = view[4096:]
-            # Read rest of message
-            remaining_bytes = l - 4096 + 16
-            while remaining_bytes > 0:
-                bytes_to_read = (remaining_bytes if remaining_bytes
-                                 <= 4096 else 4096)
-                nbytes = self.socket.recv_into(view, bytes_to_read)
-                if nbytes == 0:
-                    logging.error('recv failed')
-                    break
+            view[:nbytes] = msg
+            view = view[nbytes:]
+            left = l - nbytes
+            while left:
+                nbytes = self.socket.recv_into(view, left)
                 view = view[nbytes:]
-                remaining_bytes -= nbytes
-        else:
-            buf = msg
-        return buf[16:]
+                left -= nbytes
+            return buf
+        if l == len(msg):
+            return msg
+        raise VppTransportSocketIOError(1, 'Unknown socket read error')
 
     def read(self):
         if not self.connected: