vapi: support services
[vpp.git] / src / vpp-api / vapi / vapi_c_gen.py
index 84d9aca..37f5ac1 100755 (executable)
@@ -477,7 +477,7 @@ class CMessage(Message):
                 {{
                   VAPI_ERR("Truncated '{self.name}' msg received, received %lu"
                     "bytes, expected %lu bytes.", buf_size,
-                    sizeof({self.get_calc_msg_size_func_name()}));
+                    {self.get_calc_msg_size_func_name()}(msg));
                   return -1;
                 }}
               return 0;
@@ -615,45 +615,66 @@ class CMessage(Message):
         return "vapi_%s" % self.name
 
     def get_op_func_decl(self):
-        if self.reply.has_payload():
-            return "vapi_error_e %s(%s)" % (
-                self.get_op_func_name(),
-                ",\n  ".join(
-                    [
-                        "struct vapi_ctx_s *ctx",
-                        "%s *msg" % self.get_c_name(),
-                        "vapi_error_e (*callback)(struct vapi_ctx_s *ctx",
-                        "                         void *callback_ctx",
-                        "                         vapi_error_e rv",
-                        "                         bool is_last",
-                        "                         %s *reply)"
-                        % self.reply.get_payload_struct_name(),
-                        "void *callback_ctx",
-                    ]
-                ),
-            )
-        else:
-            return "vapi_error_e %s(%s)" % (
-                self.get_op_func_name(),
-                ",\n  ".join(
-                    [
-                        "struct vapi_ctx_s *ctx",
-                        "%s *msg" % self.get_c_name(),
-                        "vapi_error_e (*callback)(struct vapi_ctx_s *ctx",
-                        "                         void *callback_ctx",
-                        "                         vapi_error_e rv",
-                        "                         bool is_last)",
-                        "void *callback_ctx",
-                    ]
-                ),
-            )
+        stream_param_lines = []
+        if self.has_stream_msg:
+            stream_param_lines = [
+                "vapi_error_e (*details_callback)(struct vapi_ctx_s *ctx",
+                "                                 void *callback_ctx",
+                "                                 vapi_error_e rv",
+                "                                 bool is_last",
+                "                                 %s *details)"
+                % self.stream_msg.get_payload_struct_name(),
+                "void *details_callback_ctx",
+            ]
+
+        return "vapi_error_e %s(%s)" % (
+            self.get_op_func_name(),
+            ",\n  ".join(
+                [
+                    "struct vapi_ctx_s *ctx",
+                    "%s *msg" % self.get_c_name(),
+                    "vapi_error_e (*reply_callback)(struct vapi_ctx_s *ctx",
+                    "                               void *callback_ctx",
+                    "                               vapi_error_e rv",
+                    "                               bool is_last",
+                    "                               %s *reply)"
+                    % self.reply.get_payload_struct_name(),
+                ]
+                + [
+                    "void *reply_callback_ctx",
+                ]
+                + stream_param_lines
+            ),
+        )
 
     def get_op_func_def(self):
+        param_check_lines = ["  if (!msg || !reply_callback) {"]
+        store_request_lines = [
+            "    vapi_store_request(ctx, req_context, %s, %s, "
+            % (
+                self.reply.get_msg_id_name(),
+                "VAPI_REQUEST_DUMP" if self.reply_is_stream else "VAPI_REQUEST_REG",
+            ),
+            "                       (vapi_cb_t)reply_callback, reply_callback_ctx);",
+        ]
+        if self.has_stream_msg:
+            param_check_lines = [
+                "  if (!msg || !reply_callback || !details_callback) {"
+            ]
+            store_request_lines = [
+                f"    vapi_store_request(ctx, req_context, {self.stream_msg.get_msg_id_name()}, VAPI_REQUEST_STREAM, ",
+                "                       (vapi_cb_t)details_callback, details_callback_ctx);",
+                f"    vapi_store_request(ctx, req_context, {self.reply.get_msg_id_name()}, VAPI_REQUEST_REG, ",
+                "                       (vapi_cb_t)reply_callback, reply_callback_ctx);",
+            ]
+
         return "\n".join(
             [
                 "%s" % self.get_op_func_decl(),
                 "{",
-                "  if (!msg || !callback) {",
+            ]
+            + param_check_lines
+            + [
                 "    return VAPI_EINVAL;",
                 "  }",
                 "  if (vapi_is_nonblocking(ctx) && vapi_requests_full(ctx)) {",
@@ -669,14 +690,12 @@ class CMessage(Message):
                 (
                     "  if (VAPI_OK == (rv = vapi_send_with_control_ping "
                     "(ctx, msg, req_context))) {"
-                    if self.reply_is_stream
+                    if (self.reply_is_stream and not self.has_stream_msg)
                     else "  if (VAPI_OK == (rv = vapi_send (ctx, msg))) {"
                 ),
-                (
-                    "    vapi_store_request(ctx, req_context, %s, "
-                    "(vapi_cb_t)callback, callback_ctx);"
-                    % ("true" if self.reply_is_stream else "false")
-                ),
+            ]
+            + store_request_lines
+            + [
                 "    if (VAPI_OK != vapi_producer_unlock (ctx)) {",
                 "      abort (); /* this really shouldn't happen */",
                 "    }",
@@ -792,6 +811,8 @@ def emit_definition(parser, json_file, emitted, o):
             emit_definition(parser, json_file, emitted, x)
     if hasattr(o, "reply"):
         emit_definition(parser, json_file, emitted, o.reply)
+    if hasattr(o, "stream_msg"):
+        emit_definition(parser, json_file, emitted, o.stream_msg)
     if hasattr(o, "get_c_def"):
         if (
             o not in parser.enums_by_json[json_file]
@@ -820,14 +841,14 @@ def emit_definition(parser, json_file, emitted, o):
             print("%s%s" % (function_attrs, o.get_calc_msg_size_func_def()))
             print("")
             print("%s%s" % (function_attrs, o.get_verify_msg_size_func_def()))
-            if not o.is_reply and not o.is_event:
+            if not o.is_reply and not o.is_event and not o.is_stream:
                 print("")
                 print("%s%s" % (function_attrs, o.get_alloc_func_def()))
                 print("")
                 print("%s%s" % (function_attrs, o.get_op_func_def()))
             print("")
             print("%s" % o.get_c_constructor())
-            if o.is_reply or o.is_event:
+            if (o.is_reply or o.is_event) and not o.is_stream:
                 print("")
                 print("%s%s;" % (function_attrs, o.get_event_cb_func_def()))
         elif hasattr(o, "get_swap_to_be_func_def"):