vppapigen: fix missing vla check for union class
[vpp.git] / src / tools / vppapigen / test_vppapigen.py
index a8a0a49..cff2400 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 
 import unittest
-from vppapigen import VPPAPI, Option, ParseError
+from vppapigen import VPPAPI, Option, ParseError, Union
 
 # TODO
 # - test parsing of options, typedefs, enums, defines, CRC
@@ -18,6 +18,60 @@ class TestVersion(unittest.TestCase):
         r = self.parser.parse_string(version_string)
         self.assertTrue(isinstance(r[0], Option))
 
+class TestUnion(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.parser = VPPAPI()
+
+    def test_union(self):
+        test_string = '''
+        union foo_union {
+        u32 a;
+        u8 b;
+        };
+        '''
+        r = self.parser.parse_string(test_string)
+        self.assertTrue(isinstance(r[0], Union))
+
+    def test_union_vla(self):
+        test_string = '''
+        union foo_union_vla {
+        u32 a;
+        u8 b[a];
+        };
+        autoreply define foo {
+        vl_api_foo_union_vla_t v;
+        };
+        '''
+        r = self.parser.parse_string(test_string)
+        self.assertTrue(isinstance(r[0], Union))
+        self.assertTrue(r[0].vla)
+        s = self.parser.process(r)
+
+
+        test_string2 = '''
+        union foo_union_vla2 {
+        u32 a;
+        u8 b[a];
+        u32 c;
+        };
+        autoreply define foo2 {
+        vl_api_foo_union_vla2_t v;
+        };
+        '''
+        self.assertRaises(ValueError, self.parser.parse_string, test_string2)
+
+        test_string3 = '''
+        union foo_union_vla3 {
+        u32 a;
+        u8 b[a];
+        };
+        autoreply define foo3 {
+        vl_api_foo_union_vla3_t v;
+        u32 x;
+        };
+        '''
+        self.assertRaises(ValueError, self.parser.parse_string, test_string3)
 
 class TestTypedef(unittest.TestCase):
     @classmethod
@@ -26,8 +80,8 @@ class TestTypedef(unittest.TestCase):
 
     def test_duplicatetype(self):
         test_string = '''
-        typeonly define foo1 { u8 dummy; };
-        typeonly define foo1 { u8 dummy; };
+        typedef foo1 { u8 dummy; };
+        typedef foo1 { u8 dummy; };
         '''
         self.assertRaises(KeyError, self.parser.parse_string, test_string)
 
@@ -39,23 +93,29 @@ class TestDefine(unittest.TestCase):
 
     def test_unknowntype(self):
         test_string = 'define foo { foobar foo;};'
-        self.assertRaises(ParseError, self.parser.parse_string, test_string)
+        with self.assertRaises(ParseError) as ctx:
+            self.parser.parse_string(test_string)
+        self.assertIn('Undefined type: foobar', str(ctx.exception))
+
         test_string = 'define { u8 foo;};'
-        self.assertRaises(ParseError, self.parser.parse_string, test_string)
+        with self.assertRaises(ParseError) as ctx:
+            self.parser.parse_string(test_string)
 
     def test_flags(self):
         test_string = '''
           manual_print dont_trace manual_endian define foo { u8 foo; };
+          define foo_reply {u32 context; i32 retval; };
         '''
         r = self.parser.parse_string(test_string)
         self.assertIsNotNone(r)
         s = self.parser.process(r)
         self.assertIsNotNone(s)
-        for d in s['defines']:
-            self.assertTrue(d.dont_trace)
-            self.assertTrue(d.manual_endian)
-            self.assertTrue(d.manual_print)
-            self.assertFalse(d.autoreply)
+        for d in s['Define']:
+            if d.name == 'foo':
+                self.assertTrue(d.dont_trace)
+                self.assertTrue(d.manual_endian)
+                self.assertTrue(d.manual_print)
+                self.assertFalse(d.autoreply)
 
         test_string = '''
           nonexisting_flag define foo { u8 foo; };
@@ -71,11 +131,14 @@ class TestService(unittest.TestCase):
 
     def test_service(self):
         test_string = '''
-         service foo { rpc foo (show_version) returns (show_version) };
+         autoreply define show_version { u8 foo;};
+         service { rpc show_version returns show_version_reply; };
         '''
         r = self.parser.parse_string(test_string)
-        print('R', r)
+        s = self.parser.process(r)
+        self.assertEqual(s['Service'][0].caller, 'show_version')
+        self.assertEqual(s['Service'][0].reply, 'show_version_reply')
 
 
 if __name__ == '__main__':
-    unittest.main()
+    unittest.main(verbosity=2)