hs-test: move nginx tests into one file
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_papi.py
index 1e5d23e..30c00cd 100644 (file)
@@ -18,7 +18,6 @@ from __future__ import print_function
 from __future__ import absolute_import
 import ctypes
 import ipaddress
-import sys
 import multiprocessing as mp
 import os
 import queue
@@ -30,6 +29,7 @@ import fnmatch
 import weakref
 import atexit
 import time
+import pkg_resources
 from .vpp_format import verify_enum_hint
 from .vpp_serializer import VPPType, VPPEnumType, VPPEnumFlagType, VPPUnionType
 from .vpp_serializer import VPPMessage, vpp_get_type, VPPTypeAlias
@@ -154,7 +154,7 @@ class VPPValueError(ValueError):
 
 class VPPApiJSONFiles:
     @classmethod
-    def find_api_dir(cls, dirs):
+    def find_api_dir(cls, dirs=[]):
         """Attempt to find the best directory in which API definition
         files may reside. If the value VPP_API_DIR exists in the environment
         then it is first on the search list. If we're inside a recognized
@@ -170,6 +170,9 @@ class VPPApiJSONFiles:
         # in which case, plot a course to likely places in the src tree
         import __main__ as main
 
+        if os.getenv("VPP_API_DIR"):
+            dirs.append(os.getenv("VPP_API_DIR"))
+
         if hasattr(main, "__file__"):
             # get the path of the calling script
             localdir = os.path.dirname(os.path.realpath(main.__file__))
@@ -286,6 +289,18 @@ class VPPApiJSONFiles:
         api = json.loads(json_str)
         return self._process_json(api)
 
+    @classmethod
+    def process_json_array_str(self, json_str):
+        services = {}
+        messages = {}
+
+        apis = json.loads(json_str)
+        for a in apis:
+            m, s = self._process_json(a)
+            messages.update(m)
+            services.update(s)
+        return messages, services
+
     @staticmethod
     def _process_json(api):  # -> Tuple[Dict, Dict]
         types = {}
@@ -371,12 +386,35 @@ class VPPApiJSONFiles:
                 try:
                     messages[m[0]] = VPPMessage(m[0], m[1:])
                 except VPPNotImplementedError:
-                    ### OLE FIXME
                     logger.error("Not implemented error for {}".format(m[0]))
         except KeyError:
             pass
         return messages, services
 
+    @staticmethod
+    def load_api(apifiles=None, apidir=None):
+        messages = {}
+        services = {}
+        if not apifiles:
+            # Pick up API definitions from default directory
+            try:
+                if isinstance(apidir, list):
+                    apifiles = []
+                    for d in apidir:
+                        apifiles += VPPApiJSONFiles.find_api_files(d)
+                else:
+                    apifiles = VPPApiJSONFiles.find_api_files(apidir)
+            except (RuntimeError, VPPApiError):
+                raise VPPRuntimeError
+
+        for file in apifiles:
+            with open(file) as apidef_file:
+                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
+                messages.update(m)
+                services.update(s)
+
+        return apifiles, messages, services
+
 
 class VPPApiClient:
     """VPP interface.
@@ -391,7 +429,6 @@ class VPPApiClient:
     these messages in a background thread.
     """
 
-    apidir = None
     VPPApiError = VPPApiError
     VPPRuntimeError = VPPRuntimeError
     VPPValueError = VPPValueError
@@ -402,6 +439,7 @@ class VPPApiClient:
         self,
         *,
         apifiles=None,
+        apidir=None,
         testmode=False,
         async_thread=True,
         logger=None,
@@ -409,6 +447,7 @@ class VPPApiClient:
         read_timeout=5,
         use_socket=True,
         server_address="/run/vpp/api.sock",
+        bootstrapapi=False,
     ):
         """Create a VPP API object.
 
@@ -436,6 +475,7 @@ class VPPApiClient:
         self.id_msgdef = []
         self.header = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]])
         self.apifiles = []
+        self.apidir = apidir
         self.event_callback = None
         self.message_queue = queue.Queue()
         self.read_timeout = read_timeout
@@ -445,31 +485,37 @@ class VPPApiClient:
         self.server_address = server_address
         self._apifiles = apifiles
         self.stats = {}
+        self.bootstrapapi = bootstrapapi
 
-        if not apifiles:
-            # Pick up API definitions from default directory
+        if not bootstrapapi:
+            if self.apidir is None and hasattr(self.__class__, "apidir"):
+                # Keep supporting the old style of providing apidir.
+                self.apidir = self.__class__.apidir
             try:
-                apifiles = VPPApiJSONFiles.find_api_files(self.apidir)
-            except (RuntimeError, VPPApiError):
-                # In test mode we don't care that we can't find the API files
+                self.apifiles, self.messages, self.services = VPPApiJSONFiles.load_api(
+                    apifiles, self.apidir
+                )
+            except VPPRuntimeError as e:
                 if testmode:
-                    apifiles = []
+                    self.apifiles = []
                 else:
-                    raise VPPRuntimeError
-
-        for file in apifiles:
-            with open(file) as apidef_file:
-                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
-                self.messages.update(m)
-                self.services.update(s)
-
-        self.apifiles = apifiles
+                    raise e
+        else:
+            # Bootstrap the API (memclnt.api bundled with VPP PAPI)
+            resource_path = "/".join(("data", "memclnt.api.json"))
+            file_content = pkg_resources.resource_string(__name__, resource_path)
+            self.messages, self.services = VPPApiJSONFiles.process_json_str(
+                file_content
+            )
 
         # Basic sanity check
         if len(self.messages) == 0 and not testmode:
             raise VPPValueError(1, "Missing JSON message definitions")
-        if not (verify_enum_hint(VppEnum.vl_api_address_family_t)):
-            raise VPPRuntimeError("Invalid address family hints. " "Cannot continue.")
+        if not bootstrapapi:
+            if not (verify_enum_hint(VppEnum.vl_api_address_family_t)):
+                raise VPPRuntimeError(
+                    "Invalid address family hints. " "Cannot continue."
+                )
 
         self.transport = VppTransport(
             self, read_timeout=read_timeout, server_address=server_address
@@ -525,6 +571,13 @@ class VPPApiClient:
 
         return f
 
+    def make_pack_function(self, msg, i, multipart):
+        def f(**kwargs):
+            return self._call_vpp_pack(i, msg, **kwargs)
+
+        f.msg = msg
+        return f
+
     def _register_functions(self, do_async=False):
         self.id_names = [None] * (self.vpp_dictionary_maxid + 1)
         self.id_msgdef = [None] * (self.vpp_dictionary_maxid + 1)
@@ -539,17 +592,38 @@ class VPPApiClient:
                 # Create function for client side messages.
                 if name in self.services:
                     f = self.make_function(msg, i, self.services[name], do_async)
+                    f_pack = self.make_pack_function(msg, i, self.services[name])
                     setattr(self._api, name, FuncWrapper(f))
+                    setattr(self._api, name + "_pack", FuncWrapper(f_pack))
             else:
                 self.logger.debug("No such message type or failed CRC checksum: %s", n)
 
+    def get_api_definitions(self):
+        """get_api_definition. Bootstrap from the embedded memclnt.api.json file."""
+
+        # Bootstrap so we can call the get_api_json function
+        self._register_functions(do_async=False)
+
+        r = self.api.get_api_json()
+        if r.retval != 0:
+            raise VPPApiError("Failed to load API definitions from VPP")
+
+        # Process JSON
+        m, s = VPPApiJSONFiles.process_json_array_str(r.json)
+        self.messages.update(m)
+        self.services.update(s)
+
     def connect_internal(self, name, msg_handler, chroot_prefix, rx_qlen, do_async):
         pfx = chroot_prefix.encode("utf-8") if chroot_prefix else None
 
-        rv = self.transport.connect(name, pfx, msg_handler, rx_qlen)
+        rv = self.transport.connect(name, pfx, msg_handler, rx_qlen, do_async)
         if rv != 0:
             raise VPPIOError(2, "Connect failed")
         self.vpp_dictionary_maxid = self.transport.msg_table_max_index()
+
+        # Register functions
+        if self.bootstrapapi:
+            self.get_api_definitions()
         self._register_functions(do_async=do_async)
 
         # Initialise control ping
@@ -558,6 +632,7 @@ class VPPApiClient:
             ("control_ping" + "_" + crc[2:])
         )
         self.control_ping_msgdef = self.messages["control_ping"]
+
         if self.async_thread:
             self.event_thread = threading.Thread(target=self.thread_msg_handler)
             self.event_thread.daemon = True
@@ -629,6 +704,7 @@ class VPPApiClient:
         )
 
         (i, ci, context), size = header.unpack(msg, 0)
+
         if self.id_names[i] == "rx_thread_exit":
             return
 
@@ -831,6 +907,13 @@ class VPPApiClient:
         self.transport.write(b)
         return context
 
+    def _call_vpp_pack(self, i, msg, **kwargs):
+        """Given a message, return the binary representation."""
+        kwargs["_vl_msg_id"] = i
+        kwargs["client_index"] = 0
+        kwargs["context"] = 0
+        return msg.pack(kwargs)
+
     def read_blocking(self, no_type_conversion=False, timeout=None):
         """Get next received message from transport within timeout, decoded.