api: verify message size on receipt
[vpp.git] / src / vlibapi / api_shared.c
index dd51ee5..f11344e 100644 (file)
@@ -500,8 +500,8 @@ vl_msg_api_barrier_release (void)
 }
 
 always_inline void
-msg_handler_internal (api_main_t * am,
-                     void *the_msg, int trace_it, int do_it, int free_it)
+msg_handler_internal (api_main_t *am, void *the_msg, uword msg_len,
+                     int trace_it, int do_it, int free_it)
 {
   u16 id = clib_net_to_host_u16 (*((u16 *) the_msg));
   u8 *(*print_fp) (void *, void *);
@@ -545,8 +545,35 @@ msg_handler_internal (api_main_t * am,
            }
        }
 
-      if (do_it)
+      uword calc_size = 0;
+      uword (*calc_size_fp) (void *);
+      calc_size_fp = am->msg_calc_size_funcs[id];
+      ASSERT (NULL != calc_size_fp);
+      if (calc_size_fp)
        {
+         calc_size = (*calc_size_fp) (the_msg);
+         ASSERT (calc_size <= msg_len);
+         if (calc_size > msg_len)
+           {
+             clib_warning (
+               "Truncated message '%s' (id %u) received, calculated size "
+               "%lu is bigger than actual size %llu, message dropped.",
+               am->msg_names[id], id, calc_size, msg_len);
+           }
+       }
+      else
+       {
+         clib_warning ("Message '%s' (id %u) has NULL calc_size_func, cannot "
+                       "verify message size is correct",
+                       am->msg_names[id], id);
+       }
+
+      /* don't process message if it's truncated, otherwise byte swaps
+       * and stuff could corrupt memory even beyond message if it's malicious
+       * e.g. VLA length field set to 1M elements, but VLA empty */
+      if (do_it && calc_size <= msg_len)
+       {
+
          if (!am->is_mp_safe[id])
            {
              vl_msg_api_barrier_trace_context (am->msg_names[id]);
@@ -569,6 +596,7 @@ msg_handler_internal (api_main_t * am,
          if (PREDICT_FALSE (vec_len (am->perf_counter_cbs) != 0))
            clib_call_callbacks (am->perf_counter_cbs, am, id,
                                 1 /* after */ );
+
          if (!am->is_mp_safe[id])
            vl_msg_api_barrier_release ();
        }
@@ -767,32 +795,30 @@ vl_msg_api_handler_with_vm_node (api_main_t * am, svm_region_t * vlib_rp,
 }
 
 void
-vl_msg_api_handler (void *the_msg)
+vl_msg_api_handler (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-                       (am->rx_trace
-                        && am->rx_trace->enabled) /* trace_it */ ,
-                       1 /* do_it */ , 1 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+                       (am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+                       1 /* do_it */, 1 /* free_it */);
 }
 
 void
-vl_msg_api_handler_no_free (void *the_msg)
+vl_msg_api_handler_no_free (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
-  msg_handler_internal (am, the_msg,
-                       (am->rx_trace
-                        && am->rx_trace->enabled) /* trace_it */ ,
-                       1 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+                       (am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+                       1 /* do_it */, 0 /* free_it */);
 }
 
 void
-vl_msg_api_handler_no_trace_no_free (void *the_msg)
+vl_msg_api_handler_no_trace_no_free (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
-  msg_handler_internal (am, the_msg, 0 /* trace_it */ , 1 /* do_it */ ,
-                       0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len, 0 /* trace_it */, 1 /* do_it */,
+                       0 /* free_it */);
 }
 
 /*
@@ -805,14 +831,13 @@ vl_msg_api_handler_no_trace_no_free (void *the_msg)
  *
  */
 void
-vl_msg_api_trace_only (void *the_msg)
+vl_msg_api_trace_only (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-                       (am->rx_trace
-                        && am->rx_trace->enabled) /* trace_it */ ,
-                       0 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+                       (am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+                       0 /* do_it */, 0 /* free_it */);
 }
 
 void
@@ -863,14 +888,13 @@ vl_msg_api_get_msg_length (void *msg_arg)
  * vl_msg_api_socket_handler
  */
 void
-vl_msg_api_socket_handler (void *the_msg)
+vl_msg_api_socket_handler (void *the_msg, uword msg_len)
 {
   api_main_t *am = vlibapi_get_main ();
 
-  msg_handler_internal (am, the_msg,
-                       (am->rx_trace
-                        && am->rx_trace->enabled) /* trace_it */ ,
-                       1 /* do_it */ , 0 /* free_it */ );
+  msg_handler_internal (am, the_msg, msg_len,
+                       (am->rx_trace && am->rx_trace->enabled) /* trace_it */,
+                       1 /* do_it */, 0 /* free_it */);
 }
 
 #define foreach_msg_api_vector                                                \
@@ -882,6 +906,7 @@ vl_msg_api_socket_handler (void *the_msg)
   _ (msg_print_json_handlers)                                                 \
   _ (msg_tojson_handlers)                                                     \
   _ (msg_fromjson_handlers)                                                   \
+  _ (msg_calc_size_funcs)                                                     \
   _ (api_trace_cfg)                                                           \
   _ (message_bounce)                                                          \
   _ (is_mp_safe)                                                              \
@@ -927,6 +952,7 @@ vl_msg_api_config (vl_msg_api_msg_config_t * c)
   am->msg_print_json_handlers[c->id] = c->print_json;
   am->msg_tojson_handlers[c->id] = c->tojson;
   am->msg_fromjson_handlers[c->id] = c->fromjson;
+  am->msg_calc_size_funcs[c->id] = c->calc_size;
   am->message_bounce[c->id] = c->message_bounce;
   am->is_mp_safe[c->id] = c->is_mp_safe;
   am->is_autoendian[c->id] = c->is_autoendian;
@@ -948,7 +974,8 @@ vl_msg_api_config (vl_msg_api_msg_config_t * c)
 void
 vl_msg_api_set_handlers (int id, char *name, void *handler, void *cleanup,
                         void *endian, void *print, int size, int traced,
-                        void *print_json, void *tojson, void *fromjson)
+                        void *print_json, void *tojson, void *fromjson,
+                        void *calc_size)
 {
   vl_msg_api_msg_config_t cfg;
   vl_msg_api_msg_config_t *c = &cfg;
@@ -969,6 +996,7 @@ vl_msg_api_set_handlers (int id, char *name, void *handler, void *cleanup,
   c->tojson = tojson;
   c->fromjson = fromjson;
   c->print_json = print_json;
+  c->calc_size = calc_size;
   vl_msg_api_config (c);
 }
 
@@ -999,8 +1027,11 @@ vl_msg_api_queue_handler (svm_queue_t * q)
 {
   uword msg;
 
-  while (!svm_queue_sub (q, (u8 *) & msg, SVM_Q_WAIT, 0))
-    vl_msg_api_handler ((void *) msg);
+  while (!svm_queue_sub (q, (u8 *) &msg, SVM_Q_WAIT, 0))
+    {
+      msgbuf_t *msgbuf = (msgbuf_t *) ((u8 *) msg - offsetof (msgbuf_t, data));
+      vl_msg_api_handler ((void *) msg, ntohl (msgbuf->data_len));
+    }
 }
 
 u32