api: verify message size on receipt
[vpp.git] / src / tools / vppapigen / vppapigen_c.py
index 1e18a83..f93e898 100644 (file)
@@ -210,7 +210,8 @@ class ToJSON():
             if b[1] == 0:
                     continue
             write('    if (a & {})\n'.format(b[0]))
-            write('       cJSON_AddItemToArray(array, cJSON_CreateString("{}"));\n'.format(b[0]))
+            write(
+                '       cJSON_AddItemToArray(array, cJSON_CreateString("{}"));\n'.format(b[0]))
         write('    return array;\n')
         write('}\n')
 
@@ -685,7 +686,7 @@ TOP_BOILERPLATE = '''\
     || defined(vl_printfun) ||defined(vl_endianfun) \\
     || defined(vl_api_version)||defined(vl_typedefs) \\
     || defined(vl_msg_name)||defined(vl_msg_name_crc_list) \\
-    || defined(vl_api_version_tuple)
+    || defined(vl_api_version_tuple) || defined(vl_calcsizefun)
 /* ok, something was selected */
 #else
 #warning no content included from {input_filename}
@@ -750,7 +751,7 @@ def msg_name_crc_list(s, suffix):
 
     for t in s['Define']:
         output += "\\\n_(VL_API_%s, %s, %08x) " % \
-                   (t.name.upper(), t.name, t.crc)
+            (t.name.upper(), t.name, t.crc)
     output += "\n#endif"
 
     return output
@@ -970,13 +971,13 @@ static inline void *vl_api_{name}_t_print{suffix} (vl_api_{name}_t *a, void *han
 
         write(signature.format(name=t.name, suffix='_json'))
         write('    cJSON * o = vl_api_{}_t_tojson(a);\n'.format(t.name))
-        write('    (void)s;\n');
+        write('    (void)s;\n')
         write('    char *out = cJSON_Print(o);\n')
-        write('    vl_print(handle, out);\n');
+        write('    vl_print(handle, out);\n')
         write('    cJSON_Delete(o);\n')
-        write('    cJSON_free(out);\n');
+        write('    cJSON_free(out);\n')
         write('    return handle;\n')
-        write('}\n\n');
+        write('}\n\n')
 
     write("\n#endif")
     write("\n#endif /* vl_printfun */\n")
@@ -1145,7 +1146,7 @@ static inline void vl_api_{name}_t_endian (vl_api_{name}_t *a)
 '''
 
     for t in objs:
-        if t.__class__.__name__ == 'Enum' or t.__class__.__name__ == 'EnumFlag' :
+        if t.__class__.__name__ == 'Enum' or t.__class__.__name__ == 'EnumFlag':
             output += signature.format(name=t.name)
             if t.enumtype in ENDIAN_STRINGS:
                 output += ('    *a = {}(*a);\n'
@@ -1187,6 +1188,78 @@ static inline void vl_api_{name}_t_endian (vl_api_{name}_t *a)
     return output
 
 
+def calc_size_fun(objs, modulename):
+    '''Main entry point for calculate size function generation'''
+    output = '''\
+
+/****** Calculate size functions *****/\n\
+#ifdef vl_calcsizefun
+#ifndef included_{module}_calcsizefun
+#define included_{module}_calcsizefun
+
+'''
+    output = output.format(module=modulename)
+
+    signature = '''\
+/* calculate message size of message in network byte order */
+static inline uword vl_api_{name}_t_calc_size (vl_api_{name}_t *a)
+{{
+'''
+
+    for o in objs:
+        tname = o.__class__.__name__
+
+        output += signature.format(name=o.name)
+        output += f"      return sizeof(*a)"
+        if tname == 'Using':
+            if 'length' in o.alias:
+                try:
+                    tmp = int(o.alias['length'])
+                    if tmp == 0:
+                        raise (f"Unexpected length '0' for alias {o}")
+                except:
+                    # output += f" + vl_api_{o.alias.name}_t_calc_size({o.name})"
+                    print("culprit:")
+                    print(o)
+                    print(dir(o.alias))
+                    print(o.alias)
+                    raise
+        elif tname == 'Enum' or tname == 'EnumFlag':
+            pass
+        else:
+            for b in o.block:
+                if b.type == 'Option':
+                    continue
+                elif b.type == 'Field':
+                    if b.fieldtype.startswith('vl_api_'):
+                        output += f" - sizeof(a->{b.fieldname})"
+                        output += f" + {b.fieldtype}_calc_size(&a->{b.fieldname})"
+                elif b.type == 'Array':
+                    if b.lengthfield:
+                        m = list(filter(lambda x: x.fieldname == b.lengthfield, o.block))
+                        if len(m) != 1:
+                            raise Exception(f"Expected 1 match for field '{b.lengthfield}', got '{m}'")
+                        lf = m[0]
+                        if lf.fieldtype in ENDIAN_STRINGS:
+                            output += f" + {ENDIAN_STRINGS[lf.fieldtype]}(a->{b.lengthfield}) * sizeof(a->{b.fieldname}[0])"
+                        elif lf.fieldtype == "u8":
+                            output += f" + a->{b.lengthfield} * sizeof(a->{b.fieldname}[0])"
+                        else:
+                            raise Exception(f"Don't know how to endian swap {lf.fieldtype}")
+                    else:
+                        # Fixed length strings decay to nul terminated u8
+                        if b.fieldtype == 'string':
+                            if b.modern_vla:
+                                output += f" + vl_api_string_len(&a->{b.fieldname})"
+
+        output += ";\n"
+        output += '}\n\n'
+    output += "\n#endif"
+    output += "\n#endif /* vl_calcsizefun */\n\n"
+
+    return output
+
+
 def version_tuple(s, module):
     '''Generate semantic version string'''
     output = '''\
@@ -1336,6 +1409,10 @@ def generate_c_boilerplate(services, defines, counters, file_crc,
 #include "{module}.api.h"
 #undef vl_endianfun
 
+#define vl_calcsizefun
+#include "{module}.api.h"
+#undef vl_calsizefun
+
 /* instantiate all the print functions we know about */
 #define vl_print(handle, ...) vlib_cli_output (handle, __VA_ARGS__)
 #define vl_printfun
@@ -1371,6 +1448,7 @@ def generate_c_boilerplate(services, defines, counters, file_crc,
               '   .print_json = vl_api_{n}_t_print_json,\n'
               '   .tojson = vl_api_{n}_t_tojson,\n'
               '   .fromjson = vl_api_{n}_t_fromjson,\n'
+              '   .calc_size = vl_api_{n}_t_calc_size,\n'
               '   .is_autoendian = {auto}}};\n'
               .format(n=s.caller, ID=s.caller.upper(),
                       auto=d.autoendian))
@@ -1389,6 +1467,7 @@ def generate_c_boilerplate(services, defines, counters, file_crc,
                   '  .print_json = vl_api_{n}_t_print_json,\n'
                   '  .tojson = vl_api_{n}_t_tojson,\n'
                   '  .fromjson = vl_api_{n}_t_fromjson,\n'
+                  '  .calc_size = vl_api_{n}_t_calc_size,\n'
                   '  .is_autoendian = {auto}}};\n'
                   .format(n=s.reply, ID=s.reply.upper(),
                           auto=d.autoendian))
@@ -1427,6 +1506,10 @@ def generate_c_test_boilerplate(services, defines, file_crc, module, plugin,
 #include "{module}.api.h"
 #undef vl_endianfun
 
+#define vl_calcsizefun
+#include "{module}.api.h"
+#undef vl_calsizefun
+
 /* instantiate all the print functions we know about */
 #define vl_print(handle, ...) vlib_cli_output (handle, __VA_ARGS__)
 #define vl_printfun
@@ -1488,7 +1571,8 @@ def generate_c_test_boilerplate(services, defines, file_crc, module, plugin,
               '                           sizeof(vl_api_{n}_t), 1,\n'
               '                           vl_api_{n}_t_print_json,\n'
               '                           vl_api_{n}_t_tojson,\n'
-              '                           vl_api_{n}_t_fromjson);\n'
+              '                           vl_api_{n}_t_fromjson,\n'
+              '                           vl_api_{n}_t_calc_size);\n'
               .format(n=s.reply, ID=s.reply.upper()))
         write('   hash_set_mem (vam->function_by_name, "{n}", api_{n});\n'
               .format(n=s.caller))
@@ -1510,7 +1594,8 @@ def generate_c_test_boilerplate(services, defines, file_crc, module, plugin,
                   '                           sizeof(vl_api_{n}_t), 1,\n'
                   '                           vl_api_{n}_t_print_json,\n'
                   '                           vl_api_{n}_t_tojson,\n'
-                  '                           vl_api_{n}_t_fromjson);\n'
+                  '                           vl_api_{n}_t_fromjson,\n'
+                  '                           vl_api_{n}_t_calc_size);\n'
                   .format(n=e, ID=e.upper()))
 
     write('}\n')
@@ -1729,6 +1814,10 @@ def generate_c_test2_boilerplate(services, defines, module, stream):
 #include "{module}.api.h"
 #undef vl_endianfun
 
+#define vl_calcsizefun
+#include "{module}.api.h"
+#undef vl_calsizefun
+
 #define vl_print(handle, ...) vlib_cli_output (handle, __VA_ARGS__)
 #define vl_printfun
 #include "{module}.api.h"
@@ -1863,6 +1952,7 @@ def run(args, apifilename, s):
     output += stream.getvalue()
     stream.close()
     output += endianfun(s['types'] + s['Define'], modulename)
+    output += calc_size_fun(s['types'] + s['Define'], modulename)
     output += version_tuple(s, basename)
     output += BOTTOM_BOILERPLATE.format(input_filename=basename,
                                         file_crc=s['file_crc'])