#!/usr/bin/env python3
import argparse
+import inspect
import os
import sys
import logging
return "vl_api_string_t %s;" % (self.name)
else:
if self.len is not None and type(self.len) != dict:
- return "%s %s[%d];" % (self.type.get_c_name(), self.name, self.len)
+ return "%s %s[%d];" % (self.type.get_c_name(), self.name,
+ self.len)
else:
return "%s %s;" % (self.type.get_c_name(), self.name)
"}",
])
+ def get_verify_msg_size_func_name(self):
+ return f"vapi_verify_{self.name}_msg_size"
+
+ def get_verify_msg_size_func_decl(self):
+ return "int %s(%s *msg, uword buf_size)" % (
+ self.get_verify_msg_size_func_name(),
+ self.get_c_name())
+
+ def get_verify_msg_size_func_def(self):
+ return inspect.cleandoc(
+ f"""
+ {self.get_verify_msg_size_func_decl()}
+ {{
+ if (sizeof({self.get_c_name()}) > buf_size)
+ {{
+ VAPI_ERR("Truncated '{self.name}' msg received, received %lu"
+ "bytes, expected %lu bytes.", buf_size,
+ sizeof({self.get_c_name()}));
+ return -1;
+ }}
+ if ({self.get_calc_msg_size_func_name()}(msg) > buf_size)
+ {{
+ VAPI_ERR("Truncated '{self.name}' msg received, received %lu"
+ "bytes, expected %lu bytes.", buf_size,
+ sizeof({self.get_calc_msg_size_func_name()}));
+ return -1;
+ }}
+ return 0;
+ }}
+ """)
+
def get_c_def(self):
if self.has_payload():
return "\n".join([
if has_context else ' 0,',
(' offsetof(%s, payload),' % self.get_c_name())
if self.has_payload() else ' VAPI_INVALID_MSG_ID,',
- ' sizeof(%s),' % self.get_c_name(),
+ ' (verify_msg_size_fn_t)%s,' %
+ self.get_verify_msg_size_func_name(),
' (generic_swap_fn_t)%s,' % self.get_swap_to_be_func_name(),
' (generic_swap_fn_t)%s,' % self.get_swap_to_host_func_name(),
' VAPI_INVALID_MSG_ID,',
print("%s%s" % (function_attrs, o.get_swap_to_host_func_def()))
print("")
print("%s%s" % (function_attrs, o.get_calc_msg_size_func_def()))
+ print("")
+ print("%s%s" % (function_attrs, o.get_verify_msg_size_func_def()))
if not o.is_reply and not o.is_event:
print("")
print("%s%s" % (function_attrs, o.get_alloc_func_def()))
orig_stdout = sys.stdout
sys.stdout = io
include_guard = "__included_%s" % (
- j.replace(".", "_").replace("/", "_").replace("-", "_").replace("+", "_"))
+ j.replace(".", "_").replace("/", "_").replace("-", "_").replace(
+ "+", "_"))
print("#ifndef %s" % include_guard)
print("#define %s" % include_guard)
print("")