vapi: support VLAs in typedefs
[vpp.git] / src / vpp-api / vapi / vapi_c_gen.py
index c4c5366..eb1006d 100755 (executable)
@@ -56,6 +56,44 @@ class CField(Field):
     def needs_byte_swap(self):
         return self.type.needs_byte_swap()
 
+    def get_vla_field_length_name(self, path):
+        return "%s_%s_array_size" % ("_".join(path), self.name)
+
+    def get_alloc_vla_param_names(self, path):
+        if self.is_vla():
+            result = [self.get_vla_field_length_name(path)]
+        else:
+            result = []
+        if self.type.has_vla():
+            t = self.type.get_alloc_vla_param_names(path + [self.name])
+            result.extend(t)
+        return result
+
+    def get_vla_calc_size_code(self, prefix, path):
+        if self.is_vla():
+            result = ["sizeof(%s.%s[0]) * %s" % (
+                ".".join([prefix] + path),
+                self.name,
+                self.get_vla_field_length_name(path))]
+        else:
+            result = []
+        if self.type.has_vla():
+            t = self.type.get_vla_calc_size_code(prefix, path + [self.name])
+            result.extend(t)
+        return result
+
+    def get_vla_assign_code(self, prefix, path):
+        result = []
+        if self.is_vla():
+            result.append("%s.%s = %s" % (
+                ".".join([prefix] + path),
+                self.nelem_field.name,
+                self.get_vla_field_length_name(path)))
+        if self.type.has_vla():
+            t = self.type.get_vla_assign_code(prefix, path + [self.name])
+            result.extend(t)
+        return result
+
 
 class CStruct(Struct):
     def get_c_def(self):
@@ -65,6 +103,19 @@ class CStruct(Struct):
                             for x in self.fields])),
             "} %s;" % self.get_c_name()])
 
+    def get_vla_assign_code(self, prefix, path):
+        return [x for f in self.fields if f.has_vla()
+                for x in f.get_vla_assign_code(prefix, path)]
+
+    def get_alloc_vla_param_names(self, path):
+        return [x for f in self.fields
+                if f.has_vla()
+                for x in f.get_alloc_vla_param_names(path)]
+
+    def get_vla_calc_size_code(self, prefix, path):
+        return [x for f in self.fields if f.has_vla()
+                for x in f.get_vla_calc_size_code(prefix, path)]
+
 
 class CSimpleType (SimpleType):
 
@@ -213,16 +264,13 @@ class CMessage (Message):
     def get_payload_struct_name(self):
         return "vapi_payload_%s" % self.name
 
-    def get_alloc_func_vla_field_length_name(self, field):
-        return "%s_array_size" % field.name
-
     def get_alloc_func_name(self):
         return "vapi_alloc_%s" % self.name
 
     def get_alloc_vla_param_names(self):
-        return [self.get_alloc_func_vla_field_length_name(f)
-                for f in self.fields
-                if f.nelem_field is not None]
+        return [x for f in self.fields
+                if f.has_vla()
+                for x in f.get_alloc_vla_param_names([])]
 
     def get_alloc_func_decl(self):
         return "%s* %s(struct vapi_ctx_s *ctx%s)" % (
@@ -244,13 +292,9 @@ class CMessage (Message):
             "  %s *msg = NULL;" % self.get_c_name(),
             "  const size_t size = sizeof(%s)%s;" % (
                 self.get_c_name(),
-                "".join([
-                    " + sizeof(msg->payload.%s[0]) * %s" % (
-                        f.name,
-                        self.get_alloc_func_vla_field_length_name(f))
-                    for f in self.fields
-                    if f.nelem_field is not None
-                ])),
+                "".join([" + %s" % x for f in self.fields if f.has_vla()
+                         for x in f.get_vla_calc_size_code("msg->payload",
+                                                           [])])),
             "  /* cast here required to play nicely with C++ world ... */",
             "  msg = (%s*)vapi_msg_alloc(ctx, size);" % self.get_c_name(),
             "  if (!msg) {",
@@ -259,11 +303,9 @@ class CMessage (Message):
         ] + extra + [
             "  msg->header._vl_msg_id = vapi_lookup_vl_msg_id(ctx, %s);" %
             self.get_msg_id_name(),
-            "\n".join(["  msg->payload.%s = %s;" % (
-                f.nelem_field.name,
-                self.get_alloc_func_vla_field_length_name(f))
-                for f in self.fields
-                if f.nelem_field is not None]),
+            "".join(["  %s;\n" % line
+                     for f in self.fields if f.has_vla()
+                     for line in f.get_vla_assign_code("msg->payload", [])]),
             "  return msg;",
             "}"])
 
@@ -588,23 +630,22 @@ def emit_definition(parser, json_file, emitted, o):
             print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
             print("")
             print("%s%s" % (function_attrs, o.get_calc_msg_size_func_def()))
-            print("")
             if not o.is_reply and not o.is_event:
+                print("")
                 print("%s%s" % (function_attrs, o.get_alloc_func_def()))
                 print("")
                 print("%s%s" % (function_attrs, o.get_op_func_def()))
-                print("")
-            print("%s" % o.get_c_constructor())
             print("")
+            print("%s" % o.get_c_constructor())
             if o.is_reply or o.is_event:
-                print("%s%s;" % (function_attrs, o.get_event_cb_func_def()))
                 print("")
+                print("%s%s;" % (function_attrs, o.get_event_cb_func_def()))
         elif hasattr(o, "get_swap_to_be_func_def"):
             print("%s%s" % (function_attrs, o.get_swap_to_be_func_def()))
             print("")
             print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
-            print("")
         print("#endif")
+        print("")
     emitted.append(o)