vapi: support VLAs in typedefs 04/14704/4
authorKlement Sekera <ksekera@cisco.com>
Thu, 6 Sep 2018 17:31:36 +0000 (19:31 +0200)
committerNeale Ranns <nranns@cisco.com>
Tue, 11 Sep 2018 05:58:06 +0000 (05:58 +0000)
Change-Id: I3af3916b68189c2174020e5ecc29a7bc45b25efa
Signed-off-by: Klement Sekera <ksekera@cisco.com>
src/vpp-api/vapi/vapi_c_gen.py
src/vpp-api/vapi/vapi_json_parser.py

index c4c5366..eb1006d 100755 (executable)
@@ -56,6 +56,44 @@ class CField(Field):
     def needs_byte_swap(self):
         return self.type.needs_byte_swap()
 
+    def get_vla_field_length_name(self, path):
+        return "%s_%s_array_size" % ("_".join(path), self.name)
+
+    def get_alloc_vla_param_names(self, path):
+        if self.is_vla():
+            result = [self.get_vla_field_length_name(path)]
+        else:
+            result = []
+        if self.type.has_vla():
+            t = self.type.get_alloc_vla_param_names(path + [self.name])
+            result.extend(t)
+        return result
+
+    def get_vla_calc_size_code(self, prefix, path):
+        if self.is_vla():
+            result = ["sizeof(%s.%s[0]) * %s" % (
+                ".".join([prefix] + path),
+                self.name,
+                self.get_vla_field_length_name(path))]
+        else:
+            result = []
+        if self.type.has_vla():
+            t = self.type.get_vla_calc_size_code(prefix, path + [self.name])
+            result.extend(t)
+        return result
+
+    def get_vla_assign_code(self, prefix, path):
+        result = []
+        if self.is_vla():
+            result.append("%s.%s = %s" % (
+                ".".join([prefix] + path),
+                self.nelem_field.name,
+                self.get_vla_field_length_name(path)))
+        if self.type.has_vla():
+            t = self.type.get_vla_assign_code(prefix, path + [self.name])
+            result.extend(t)
+        return result
+
 
 class CStruct(Struct):
     def get_c_def(self):
@@ -65,6 +103,19 @@ class CStruct(Struct):
                             for x in self.fields])),
             "} %s;" % self.get_c_name()])
 
+    def get_vla_assign_code(self, prefix, path):
+        return [x for f in self.fields if f.has_vla()
+                for x in f.get_vla_assign_code(prefix, path)]
+
+    def get_alloc_vla_param_names(self, path):
+        return [x for f in self.fields
+                if f.has_vla()
+                for x in f.get_alloc_vla_param_names(path)]
+
+    def get_vla_calc_size_code(self, prefix, path):
+        return [x for f in self.fields if f.has_vla()
+                for x in f.get_vla_calc_size_code(prefix, path)]
+
 
 class CSimpleType (SimpleType):
 
@@ -213,16 +264,13 @@ class CMessage (Message):
     def get_payload_struct_name(self):
         return "vapi_payload_%s" % self.name
 
-    def get_alloc_func_vla_field_length_name(self, field):
-        return "%s_array_size" % field.name
-
     def get_alloc_func_name(self):
         return "vapi_alloc_%s" % self.name
 
     def get_alloc_vla_param_names(self):
-        return [self.get_alloc_func_vla_field_length_name(f)
-                for f in self.fields
-                if f.nelem_field is not None]
+        return [x for f in self.fields
+                if f.has_vla()
+                for x in f.get_alloc_vla_param_names([])]
 
     def get_alloc_func_decl(self):
         return "%s* %s(struct vapi_ctx_s *ctx%s)" % (
@@ -244,13 +292,9 @@ class CMessage (Message):
             "  %s *msg = NULL;" % self.get_c_name(),
             "  const size_t size = sizeof(%s)%s;" % (
                 self.get_c_name(),
-                "".join([
-                    " + sizeof(msg->payload.%s[0]) * %s" % (
-                        f.name,
-                        self.get_alloc_func_vla_field_length_name(f))
-                    for f in self.fields
-                    if f.nelem_field is not None
-                ])),
+                "".join([" + %s" % x for f in self.fields if f.has_vla()
+                         for x in f.get_vla_calc_size_code("msg->payload",
+                                                           [])])),
             "  /* cast here required to play nicely with C++ world ... */",
             "  msg = (%s*)vapi_msg_alloc(ctx, size);" % self.get_c_name(),
             "  if (!msg) {",
@@ -259,11 +303,9 @@ class CMessage (Message):
         ] + extra + [
             "  msg->header._vl_msg_id = vapi_lookup_vl_msg_id(ctx, %s);" %
             self.get_msg_id_name(),
-            "\n".join(["  msg->payload.%s = %s;" % (
-                f.nelem_field.name,
-                self.get_alloc_func_vla_field_length_name(f))
-                for f in self.fields
-                if f.nelem_field is not None]),
+            "".join(["  %s;\n" % line
+                     for f in self.fields if f.has_vla()
+                     for line in f.get_vla_assign_code("msg->payload", [])]),
             "  return msg;",
             "}"])
 
@@ -588,23 +630,22 @@ def emit_definition(parser, json_file, emitted, o):
             print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
             print("")
             print("%s%s" % (function_attrs, o.get_calc_msg_size_func_def()))
-            print("")
             if not o.is_reply and not o.is_event:
+                print("")
                 print("%s%s" % (function_attrs, o.get_alloc_func_def()))
                 print("")
                 print("%s%s" % (function_attrs, o.get_op_func_def()))
-                print("")
-            print("%s" % o.get_c_constructor())
             print("")
+            print("%s" % o.get_c_constructor())
             if o.is_reply or o.is_event:
-                print("%s%s;" % (function_attrs, o.get_event_cb_func_def()))
                 print("")
+                print("%s%s;" % (function_attrs, o.get_event_cb_func_def()))
         elif hasattr(o, "get_swap_to_be_func_def"):
             print("%s%s" % (function_attrs, o.get_swap_to_be_func_def()))
             print("")
             print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
-            print("")
         print("#endif")
+        print("")
     emitted.append(o)
 
 
index 7995424..8728a1a 100644 (file)
@@ -28,13 +28,21 @@ class Field(object):
 
     def __str__(self):
         if self.len is None:
-            return "name: %s, type: %s" % (self.name, self.type)
+            return "Field(name: %s, type: %s)" % (self.name, self.type)
         elif self.len > 0:
-            return "name: %s, type: %s, length: %s" % (self.name, self.type,
-                                                       self.len)
+            return "Field(name: %s, type: %s, length: %s)" % (self.name,
+                                                              self.type,
+                                                              self.len)
         else:
-            return ("name: %s, type: %s, variable length stored in: %s" %
-                    (self.name, self.type, self.nelem_field))
+            return (
+                "Field(name: %s, type: %s, variable length stored in: %s)" %
+                (self.name, self.type, self.nelem_field))
+
+    def is_vla(self):
+        return self.nelem_field is not None
+
+    def has_vla(self):
+        return self.is_vla() or self.type.has_vla()
 
 
 class Type(object):
@@ -53,6 +61,9 @@ class SimpleType (Type):
     def __str__(self):
         return self.name
 
+    def has_vla(self):
+        return False
+
 
 def get_msg_header_defs(struct_type_class, field_class, json_parser, logger):
     return [
@@ -83,6 +94,12 @@ class Struct(object):
     def __str__(self):
         return "[%s]" % "], [".join([str(f) for f in self.fields])
 
+    def has_vla(self):
+        for f in self.fields:
+            if f.has_vla():
+                return True
+        return False
+
 
 class Enum(SimpleType):
     def __init__(self, name, value_pairs, enumtype):
@@ -110,6 +127,9 @@ class Union(Type):
             "], [" .join(["%s %s" % (i, j) for i, j in self.type_pairs])
         )
 
+    def has_vla(self):
+        return False
+
 
 class Message(object):
 
@@ -190,6 +210,13 @@ class Message(object):
                 fields.append(p)
         self.fields = fields
         self.depends = [f.type for f in self.fields]
+        logger.debug("Parsed message: %s" % self)
+
+    def __str__(self):
+        return "Message(%s, [%s], {crc: %s}" % \
+            (self.name,
+             "], [".join([str(f) for f in self.fields]),
+             self.crc)
 
 
 class StructType (Type, Struct):