tests: preload api files
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_papi.py
index 6755e22..5c08964 100644 (file)
@@ -154,7 +154,7 @@ class VPPValueError(ValueError):
 
 class VPPApiJSONFiles:
     @classmethod
-    def find_api_dir(cls, dirs):
+    def find_api_dir(cls, dirs=[]):
         """Attempt to find the best directory in which API definition
         files may reside. If the value VPP_API_DIR exists in the environment
         then it is first on the search list. If we're inside a recognized
@@ -170,6 +170,9 @@ class VPPApiJSONFiles:
         # in which case, plot a course to likely places in the src tree
         import __main__ as main
 
+        if os.getenv("VPP_API_DIR"):
+            dirs.append(os.getenv("VPP_API_DIR"))
+
         if hasattr(main, "__file__"):
             # get the path of the calling script
             localdir = os.path.dirname(os.path.realpath(main.__file__))
@@ -278,16 +281,15 @@ class VPPApiJSONFiles:
 
     @classmethod
     def process_json_file(self, apidef_file):
-        api = json.load(apidef_file)
-        return self._process_json(api)
+        return self._process_json(apidef_file.read())
 
     @classmethod
     def process_json_str(self, json_str):
-        api = json.loads(json_str)
-        return self._process_json(api)
+        return self._process_json(json_str)
 
     @staticmethod
-    def _process_json(api):  # -> Tuple[Dict, Dict]
+    def _process_json(json_str):  # -> Tuple[Dict, Dict]
+        api = json.loads(json_str)
         types = {}
         services = {}
         messages = {}
@@ -377,6 +379,30 @@ class VPPApiJSONFiles:
             pass
         return messages, services
 
+    @staticmethod
+    def load_api(apifiles=None, apidir=None):
+        messages = {}
+        services = {}
+        if not apifiles:
+            # Pick up API definitions from default directory
+            try:
+                if isinstance(apidir, list):
+                    apifiles = []
+                    for d in apidir:
+                        apifiles += VPPApiJSONFiles.find_api_files(d)
+                else:
+                    apifiles = VPPApiJSONFiles.find_api_files(apidir)
+            except (RuntimeError, VPPApiError):
+                raise VPPRuntimeError
+
+        for file in apifiles:
+            with open(file) as apidef_file:
+                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
+                messages.update(m)
+                services.update(s)
+
+        return apifiles, messages, services
+
 
 class VPPApiClient:
     """VPP interface.
@@ -391,7 +417,6 @@ class VPPApiClient:
     these messages in a background thread.
     """
 
-    apidir = None
     VPPApiError = VPPApiError
     VPPRuntimeError = VPPRuntimeError
     VPPValueError = VPPValueError
@@ -402,6 +427,7 @@ class VPPApiClient:
         self,
         *,
         apifiles=None,
+        apidir=None,
         testmode=False,
         async_thread=True,
         logger=None,
@@ -436,6 +462,7 @@ class VPPApiClient:
         self.id_msgdef = []
         self.header = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]])
         self.apifiles = []
+        self.apidir = apidir
         self.event_callback = None
         self.message_queue = queue.Queue()
         self.read_timeout = read_timeout
@@ -446,24 +473,15 @@ class VPPApiClient:
         self._apifiles = apifiles
         self.stats = {}
 
-        if not apifiles:
-            # Pick up API definitions from default directory
-            try:
-                apifiles = VPPApiJSONFiles.find_api_files(self.apidir)
-            except (RuntimeError, VPPApiError):
-                # In test mode we don't care that we can't find the API files
-                if testmode:
-                    apifiles = []
-                else:
-                    raise VPPRuntimeError
-
-        for file in apifiles:
-            with open(file) as apidef_file:
-                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
-                self.messages.update(m)
-                self.services.update(s)
-
-        self.apifiles = apifiles
+        try:
+            self.apifiles, self.messages, self.services = VPPApiJSONFiles.load_api(
+                apifiles, apidir
+            )
+        except VPPRuntimeError as e:
+            if testmode:
+                self.apifiles = []
+            else:
+                raise e
 
         # Basic sanity check
         if len(self.messages) == 0 and not testmode:
@@ -525,6 +543,13 @@ class VPPApiClient:
 
         return f
 
+    def make_pack_function(self, msg, i, multipart):
+        def f(**kwargs):
+            return self._call_vpp_pack(i, msg, **kwargs)
+
+        f.msg = msg
+        return f
+
     def _register_functions(self, do_async=False):
         self.id_names = [None] * (self.vpp_dictionary_maxid + 1)
         self.id_msgdef = [None] * (self.vpp_dictionary_maxid + 1)
@@ -539,7 +564,9 @@ class VPPApiClient:
                 # Create function for client side messages.
                 if name in self.services:
                     f = self.make_function(msg, i, self.services[name], do_async)
+                    f_pack = self.make_pack_function(msg, i, self.services[name])
                     setattr(self._api, name, FuncWrapper(f))
+                    setattr(self._api, name + "_pack", FuncWrapper(f_pack))
             else:
                 self.logger.debug("No such message type or failed CRC checksum: %s", n)
 
@@ -831,6 +858,13 @@ class VPPApiClient:
         self.transport.write(b)
         return context
 
+    def _call_vpp_pack(self, i, msg, **kwargs):
+        """Given a message, return the binary representation."""
+        kwargs["_vl_msg_id"] = i
+        kwargs["client_index"] = 0
+        kwargs["context"] = 0
+        return msg.pack(kwargs)
+
     def read_blocking(self, no_type_conversion=False, timeout=None):
         """Get next received message from transport within timeout, decoded.