tests: preload api files
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_papi.py
index a9edfed..5c08964 100644 (file)
@@ -281,16 +281,15 @@ class VPPApiJSONFiles:
 
     @classmethod
     def process_json_file(self, apidef_file):
-        api = json.load(apidef_file)
-        return self._process_json(api)
+        return self._process_json(apidef_file.read())
 
     @classmethod
     def process_json_str(self, json_str):
-        api = json.loads(json_str)
-        return self._process_json(api)
+        return self._process_json(json_str)
 
     @staticmethod
-    def _process_json(api):  # -> Tuple[Dict, Dict]
+    def _process_json(json_str):  # -> Tuple[Dict, Dict]
+        api = json.loads(json_str)
         types = {}
         services = {}
         messages = {}
@@ -380,6 +379,30 @@ class VPPApiJSONFiles:
             pass
         return messages, services
 
+    @staticmethod
+    def load_api(apifiles=None, apidir=None):
+        messages = {}
+        services = {}
+        if not apifiles:
+            # Pick up API definitions from default directory
+            try:
+                if isinstance(apidir, list):
+                    apifiles = []
+                    for d in apidir:
+                        apifiles += VPPApiJSONFiles.find_api_files(d)
+                else:
+                    apifiles = VPPApiJSONFiles.find_api_files(apidir)
+            except (RuntimeError, VPPApiError):
+                raise VPPRuntimeError
+
+        for file in apifiles:
+            with open(file) as apidef_file:
+                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
+                messages.update(m)
+                services.update(s)
+
+        return apifiles, messages, services
+
 
 class VPPApiClient:
     """VPP interface.
@@ -394,7 +417,6 @@ class VPPApiClient:
     these messages in a background thread.
     """
 
-    apidir = None
     VPPApiError = VPPApiError
     VPPRuntimeError = VPPRuntimeError
     VPPValueError = VPPValueError
@@ -405,6 +427,7 @@ class VPPApiClient:
         self,
         *,
         apifiles=None,
+        apidir=None,
         testmode=False,
         async_thread=True,
         logger=None,
@@ -439,6 +462,7 @@ class VPPApiClient:
         self.id_msgdef = []
         self.header = VPPType("header", [["u16", "msgid"], ["u32", "client_index"]])
         self.apifiles = []
+        self.apidir = apidir
         self.event_callback = None
         self.message_queue = queue.Queue()
         self.read_timeout = read_timeout
@@ -449,29 +473,15 @@ class VPPApiClient:
         self._apifiles = apifiles
         self.stats = {}
 
-        if not apifiles:
-            # Pick up API definitions from default directory
-            try:
-                if isinstance(self.apidir, list):
-                    apifiles = []
-                    for d in self.apidir:
-                        apifiles += VPPApiJSONFiles.find_api_files(d)
-                else:
-                    apifiles = VPPApiJSONFiles.find_api_files(self.apidir)
-            except (RuntimeError, VPPApiError):
-                # In test mode we don't care that we can't find the API files
-                if testmode:
-                    apifiles = []
-                else:
-                    raise VPPRuntimeError
-
-        for file in apifiles:
-            with open(file) as apidef_file:
-                m, s = VPPApiJSONFiles.process_json_file(apidef_file)
-                self.messages.update(m)
-                self.services.update(s)
-
-        self.apifiles = apifiles
+        try:
+            self.apifiles, self.messages, self.services = VPPApiJSONFiles.load_api(
+                apifiles, apidir
+            )
+        except VPPRuntimeError as e:
+            if testmode:
+                self.apifiles = []
+            else:
+                raise e
 
         # Basic sanity check
         if len(self.messages) == 0 and not testmode: