vppapigen: fix missing vla check for union class 17/22117/2
authorOle Troan <ot@cisco.com>
Wed, 18 Sep 2019 10:12:47 +0000 (12:12 +0200)
committerOle Trøan <otroan@employees.org>
Wed, 18 Sep 2019 15:10:12 +0000 (15:10 +0000)
Type: fix
Signed-off-by: Ole Troan <ot@cisco.com>
Change-Id: Ie775cf3469d761847ac39cf0d80a3ec6463b7928

src/tools/vppapigen/test_vppapigen.py
src/tools/vppapigen/vppapigen.py

index 5b64310..cff2400 100755 (executable)
@@ -1,7 +1,7 @@
 #!/usr/bin/env python3
 
 import unittest
-from vppapigen import VPPAPI, Option, ParseError
+from vppapigen import VPPAPI, Option, ParseError, Union
 
 # TODO
 # - test parsing of options, typedefs, enums, defines, CRC
@@ -18,6 +18,60 @@ class TestVersion(unittest.TestCase):
         r = self.parser.parse_string(version_string)
         self.assertTrue(isinstance(r[0], Option))
 
+class TestUnion(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.parser = VPPAPI()
+
+    def test_union(self):
+        test_string = '''
+        union foo_union {
+        u32 a;
+        u8 b;
+        };
+        '''
+        r = self.parser.parse_string(test_string)
+        self.assertTrue(isinstance(r[0], Union))
+
+    def test_union_vla(self):
+        test_string = '''
+        union foo_union_vla {
+        u32 a;
+        u8 b[a];
+        };
+        autoreply define foo {
+        vl_api_foo_union_vla_t v;
+        };
+        '''
+        r = self.parser.parse_string(test_string)
+        self.assertTrue(isinstance(r[0], Union))
+        self.assertTrue(r[0].vla)
+        s = self.parser.process(r)
+
+
+        test_string2 = '''
+        union foo_union_vla2 {
+        u32 a;
+        u8 b[a];
+        u32 c;
+        };
+        autoreply define foo2 {
+        vl_api_foo_union_vla2_t v;
+        };
+        '''
+        self.assertRaises(ValueError, self.parser.parse_string, test_string2)
+
+        test_string3 = '''
+        union foo_union_vla3 {
+        u32 a;
+        u8 b[a];
+        };
+        autoreply define foo3 {
+        vl_api_foo_union_vla3_t v;
+        u32 x;
+        };
+        '''
+        self.assertRaises(ValueError, self.parser.parse_string, test_string3)
 
 class TestTypedef(unittest.TestCase):
     @classmethod
index bb4e2c4..03362b0 100755 (executable)
@@ -145,6 +145,33 @@ def crc_block_combine(block, crc):
     return binascii.crc32(s, crc) & 0xffffffff
 
 
+def vla_is_last_check(name, block):
+    vla = False
+    for i, b in enumerate(block):
+        if isinstance(b, Array) and b.vla:
+            vla = True
+            if i + 1 < len(block):
+                raise ValueError(
+                    'VLA field "{}" must be the last field in message "{}"'
+                    .format(b.fieldname, name))
+        elif b.fieldtype.startswith('vl_api_'):
+            if global_types[b.fieldtype].vla:
+                vla = True
+                if i + 1 < len(block):
+                    raise ValueError(
+                        'VLA field "{}" must be the last '
+                        'field in message "{}"'
+                        .format(b.fieldname, name))
+        elif b.fieldtype == 'string' and b.length == 0:
+            vla = True
+            if i + 1 < len(block):
+                raise ValueError(
+                    'VLA field "{}" must be the last '
+                    'field in message "{}"'
+                    .format(b.fieldname, name))
+    return vla
+
+
 class Service():
     def __init__(self, caller, reply, events=None, stream=False):
         self.caller = caller
@@ -166,34 +193,10 @@ class Typedef():
                 self.manual_print = True
             elif f == 'manual_endian':
                 self.manual_endian = True
-        for b in block:
-            # Tag length field of a VLA
-            if isinstance(b, Array):
-                if b.lengthfield:
-                    for b2 in block:
-                        if b2.fieldname == b.lengthfield:
-                            b2.vla_len = True
 
         global_type_add(name, self)
 
-        self.vla = False
-
-        for i, b in enumerate(block):
-            if isinstance(b, Array):
-                if b.length == 0:
-                    self.vla = True
-                    if i + 1 < len(block):
-                        raise ValueError(
-                            'VLA field "{}" must be the last '
-                            'field in message "{}"'
-                            .format(b.fieldname, name))
-            elif b.fieldtype == 'string':
-                self.vla = True
-                if i + 1 < len(block):
-                    raise ValueError(
-                        'VLA field "{}" must be the last '
-                        'field in message "{}"'
-                        .format(b.fieldname, name))
+        self.vla = vla_is_last_check(name, block)
 
     def __repr__(self):
         return self.name + str(self.flags) + str(self.block)
@@ -232,8 +235,6 @@ class Union():
         self.manual_endian = False
         self.name = name
 
-        self.manual_print = False
-        self.manual_endian = False
         for f in flags:
             if f == 'manual_print':
                 self.manual_print = True
@@ -242,6 +243,8 @@ class Union():
 
         self.block = block
         self.crc = str(block).encode()
+        self.vla = vla_is_last_check(name, block)
+
         global_type_add(name, self)
 
     def __repr__(self):
@@ -269,34 +272,12 @@ class Define():
             elif f == 'autoreply':
                 self.autoreply = True
 
-        for i, b in enumerate(block):
+        for b in block:
             if isinstance(b, Option):
                 if b[1] == 'singular' and b[2] == 'true':
                     self.singular = True
                 block.remove(b)
-
-            if isinstance(b, Array) and b.vla and i + 1 < len(block):
-                raise ValueError(
-                    'VLA field "{}" must be the last field in message "{}"'
-                    .format(b.fieldname, name))
-            elif b.fieldtype.startswith('vl_api_'):
-                if (global_types[b.fieldtype].vla and i + 1 < len(block)):
-                    raise ValueError(
-                        'VLA field "{}" must be the last '
-                        'field in message "{}"'
-                        .format(b.fieldname, name))
-            elif b.fieldtype == 'string' and b.length == 0:
-                if i + 1 < len(block):
-                    raise ValueError(
-                        'VLA field "{}" must be the last '
-                        'field in message "{}"'
-                        .format(b.fieldname, name))
-            # Tag length field of a VLA
-            if isinstance(b, Array):
-                if b.lengthfield:
-                    for b2 in block:
-                        if b2.fieldname == b.lengthfield:
-                            b2.vla_len = True
+        self.vla = vla_is_last_check(name, block)
 
     def __repr__(self):
         return self.name + str(self.flags) + str(self.block)