vppapigen: fix tests and run on verify 47/20947/4
authorPaul Vinciguerra <pvinci@vinciconsulting.com>
Wed, 31 Jul 2019 04:34:05 +0000 (00:34 -0400)
committerOle Troan <ot@cisco.com>
Wed, 18 Sep 2019 09:40:17 +0000 (11:40 +0200)
- changes vppapigen to only process an import once.

Type: fix

Change-Id: Ifcbcfcc69fdfb80d63195a17701762d0c239d7b4
Signed-off-by: Paul Vinciguerra <pvinci@vinciconsulting.com>
Signed-off-by: Ole Troan <ot@cisco.com>
Makefile
src/tools/vppapigen/test_vppapigen.py
src/tools/vppapigen/vppapigen.py

index adb0334..baf9845 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -623,6 +623,8 @@ endif
 
 verify: pkg-verify
 ifeq ($(OS_ID)-$(OS_VERSION_ID),ubuntu-18.04)
+       $(call banner,"Testing vppapigen")
+       @src/tools/vppapigen/test_vppapigen.py
        $(call banner,"Running tests")
        @make COMPRESS_FAILED_TEST_LOGS=yes RETRIES=3 test
 endif
index a8a0a49..5b64310 100755 (executable)
@@ -26,8 +26,8 @@ class TestTypedef(unittest.TestCase):
 
     def test_duplicatetype(self):
         test_string = '''
-        typeonly define foo1 { u8 dummy; };
-        typeonly define foo1 { u8 dummy; };
+        typedef foo1 { u8 dummy; };
+        typedef foo1 { u8 dummy; };
         '''
         self.assertRaises(KeyError, self.parser.parse_string, test_string)
 
@@ -39,23 +39,29 @@ class TestDefine(unittest.TestCase):
 
     def test_unknowntype(self):
         test_string = 'define foo { foobar foo;};'
-        self.assertRaises(ParseError, self.parser.parse_string, test_string)
+        with self.assertRaises(ParseError) as ctx:
+            self.parser.parse_string(test_string)
+        self.assertIn('Undefined type: foobar', str(ctx.exception))
+
         test_string = 'define { u8 foo;};'
-        self.assertRaises(ParseError, self.parser.parse_string, test_string)
+        with self.assertRaises(ParseError) as ctx:
+            self.parser.parse_string(test_string)
 
     def test_flags(self):
         test_string = '''
           manual_print dont_trace manual_endian define foo { u8 foo; };
+          define foo_reply {u32 context; i32 retval; };
         '''
         r = self.parser.parse_string(test_string)
         self.assertIsNotNone(r)
         s = self.parser.process(r)
         self.assertIsNotNone(s)
-        for d in s['defines']:
-            self.assertTrue(d.dont_trace)
-            self.assertTrue(d.manual_endian)
-            self.assertTrue(d.manual_print)
-            self.assertFalse(d.autoreply)
+        for d in s['Define']:
+            if d.name == 'foo':
+                self.assertTrue(d.dont_trace)
+                self.assertTrue(d.manual_endian)
+                self.assertTrue(d.manual_print)
+                self.assertFalse(d.autoreply)
 
         test_string = '''
           nonexisting_flag define foo { u8 foo; };
@@ -71,11 +77,14 @@ class TestService(unittest.TestCase):
 
     def test_service(self):
         test_string = '''
-         service foo { rpc foo (show_version) returns (show_version) };
+         autoreply define show_version { u8 foo;};
+         service { rpc show_version returns show_version_reply; };
         '''
         r = self.parser.parse_string(test_string)
-        print('R', r)
+        s = self.parser.process(r)
+        self.assertEqual(s['Service'][0].caller, 'show_version')
+        self.assertEqual(s['Service'][0].reply, 'show_version_reply')
 
 
 if __name__ == '__main__':
-    unittest.main()
+    unittest.main(verbosity=2)
index fa7e47a..bb4e2c4 100755 (executable)
@@ -22,10 +22,15 @@ sys.dont_write_bytecode = True
 # Global dictionary of new types (including enums)
 global_types = {}
 
+seen_imports = {}
+
 
 def global_type_add(name, obj):
     '''Add new type to the dictionary of types '''
     type_name = 'vl_api_' + name + '_t'
+    if type_name in global_types:
+        raise KeyError("Attempted redefinition of {!r} with {!r}.".format(
+            name, obj))
     global_types[type_name] = obj
 
 
@@ -320,20 +325,35 @@ class Enum():
 
 
 class Import():
-    def __init__(self, filename):
-        self.filename = filename
 
-        # Deal with imports
-        parser = VPPAPI(filename=filename)
-        dirlist = dirlist_get()
-        f = filename
-        for dir in dirlist:
-            f = os.path.join(dir, filename)
-            if os.path.exists(f):
-                break
+    def __new__(cls, *args, **kwargs):
+        if args[0] not in seen_imports:
+            instance = super().__new__(cls)
+            instance._initialized = False
+            seen_imports[args[0]] = instance
+
+        return seen_imports[args[0]]
 
-        with open(f, encoding='utf-8') as fd:
-            self.result = parser.parse_file(fd, None)
+    def __init__(self, filename):
+        if self._initialized:
+            return
+        else:
+            self.filename = filename
+            # Deal with imports
+            parser = VPPAPI(filename=filename)
+            dirlist = dirlist_get()
+            f = filename
+            for dir in dirlist:
+                f = os.path.join(dir, filename)
+                if os.path.exists(f):
+                    break
+            if sys.version[0] == '2':
+                with open(f) as fd:
+                    self.result = parser.parse_file(fd, None)
+            else:
+                with open(f, encoding='utf-8') as fd:
+                    self.result = parser.parse_file(fd, None)
+            self._initialized = True
 
     def __repr__(self):
         return self.filename
@@ -859,9 +879,10 @@ class VPPAPI(object):
                 continue
             if isinstance(o, Import):
                 result.append(o)
-                self.process_imports(o.result, True, result)
+                result = self.process_imports(o.result, True, result)
             else:
                 result.append(o)
+        return result
 
 
 # Add message ids to each message.
@@ -955,7 +976,7 @@ def main():
     if args.output_module == 'C':
         s = parser.process(parsed_objects)
     else:
-        parser.process_imports(parsed_objects, False, result)
+        result = parser.process_imports(parsed_objects, False, result)
         s = parser.process(result)
 
     # Add msg_id field