vapi: improve vl_api_string_t handling
[vpp.git] / src / vpp-api / vapi / vapi_c_gen.py
index 37f5ac1..609f4bb 100755 (executable)
@@ -23,7 +23,7 @@ class CField(Field):
         return "vapi_type_%s" % self.name
 
     def get_c_def(self):
-        if self.type.get_c_name() == "vl_api_string_t":
+        if self.type.get_c_name() == "string":
             if self.len:
                 return "u8 %s[%d];" % (self.name, self.len)
             else:
@@ -85,12 +85,15 @@ class CField(Field):
     def needs_byte_swap(self):
         return self.type.needs_byte_swap()
 
-    def get_vla_field_length_name(self, path):
+    def get_vla_parameter_name(self, path):
         return "%s_%s_array_size" % ("_".join(path), self.name)
 
+    def get_vla_field_name(self, path):
+        return ".".join(path + [self.nelem_field.name])
+
     def get_alloc_vla_param_names(self, path):
         if self.is_vla():
-            result = [self.get_vla_field_length_name(path)]
+            result = [self.get_vla_parameter_name(path)]
         else:
             result = []
         if self.type.has_vla():
@@ -98,20 +101,22 @@ class CField(Field):
             result.extend(t)
         return result
 
-    def get_vla_calc_size_code(self, prefix, path):
+    def get_vla_calc_size_code(self, prefix, path, is_alloc):
         if self.is_vla():
             result = [
                 "sizeof(%s.%s[0]) * %s"
                 % (
                     ".".join([prefix] + path),
                     self.name,
-                    self.get_vla_field_length_name(path),
+                    self.get_vla_parameter_name(path)
+                    if is_alloc
+                    else "%s.%s" % (prefix, self.get_vla_field_name(path)),
                 )
             ]
         else:
             result = []
         if self.type.has_vla():
-            t = self.type.get_vla_calc_size_code(prefix, path + [self.name])
+            t = self.type.get_vla_calc_size_code(prefix, path + [self.name], is_alloc)
             result.extend(t)
         return result
 
@@ -123,7 +128,7 @@ class CField(Field):
                 % (
                     ".".join([prefix] + path),
                     self.nelem_field.name,
-                    self.get_vla_field_length_name(path),
+                    self.get_vla_parameter_name(path),
                 )
             )
         if self.type.has_vla():
@@ -173,12 +178,12 @@ class CStruct(Struct):
             for x in f.get_alloc_vla_param_names(path)
         ]
 
-    def get_vla_calc_size_code(self, prefix, path):
+    def get_vla_calc_size_code(self, prefix, path, is_alloc):
         return [
             x
             for f in self.fields
             if f.has_vla()
-            for x in f.get_vla_calc_size_code(prefix, path)
+            for x in f.get_vla_calc_size_code(prefix, path, is_alloc)
         ]
 
 
@@ -288,6 +293,8 @@ class CUnion(Union):
 
 class CStructType(StructType, CStruct):
     def get_c_name(self):
+        if self.name == "vl_api_string_t":
+            return "vl_api_string_t"
         return "vapi_type_%s" % self.name
 
     def get_swap_to_be_func_name(self):
@@ -398,7 +405,9 @@ class CMessage(Message):
                             " + %s" % x
                             for f in self.fields
                             if f.has_vla()
-                            for x in f.get_vla_calc_size_code("msg->payload", [])
+                            for x in f.get_vla_calc_size_code(
+                                "msg->payload", [], is_alloc=True
+                            )
                         ]
                     ),
                 ),
@@ -442,10 +451,12 @@ class CMessage(Message):
                 "  return sizeof(*msg)%s;"
                 % "".join(
                     [
-                        "+ msg->payload.%s * sizeof(msg->payload.%s[0])"
-                        % (f.nelem_field.name, f.name)
+                        " + %s" % x
                         for f in self.fields
-                        if f.nelem_field is not None
+                        if f.has_vla()
+                        for x in f.get_vla_calc_size_code(
+                            "msg->payload", [], is_alloc=False
+                        )
                     ]
                 ),
                 "}",
@@ -885,6 +896,20 @@ def gen_json_unified_header(parser, logger, j, io, name):
     print("#ifdef __cplusplus")
     print('extern "C" {')
     print("#endif")
+
+    print("#ifndef __vl_api_string_swap_fns_defined__")
+    print("#define __vl_api_string_swap_fns_defined__")
+    print("")
+    print("#include <vlibapi/api_types.h>")
+    print("")
+    function_attrs = "static inline "
+    o = parser.types["vl_api_string_t"]
+    print("%s%s" % (function_attrs, o.get_swap_to_be_func_def()))
+    print("")
+    print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
+    print("")
+    print("#endif //__vl_api_string_swap_fns_defined__")
+
     if name == "memclnt.api.vapi.h":
         print("")
         print(