wireguard: compute checksum for outer ipv6 header
[vpp.git] / src / plugins / wireguard / wireguard_output_tun.c
index d1b1d6b..f613d6c 100644 (file)
@@ -307,6 +307,22 @@ error:
   return ret;
 }
 
+static_always_inline void
+wg_calc_checksum (vlib_main_t *vm, vlib_buffer_t *b)
+{
+  int bogus = 0;
+  u8 ip_ver_out = (*((u8 *) vlib_buffer_get_current (b)) >> 4);
+
+  /* IPv6 UDP checksum is mandatory */
+  if (ip_ver_out == 6)
+    {
+      ip6_header_t *ip6 =
+       (ip6_header_t *) ((u8 *) vlib_buffer_get_current (b));
+      udp_header_t *udp = ip6_next_header (ip6);
+      udp->checksum = ip6_tcp_udp_icmp_compute_checksum (vm, b, ip6, &bogus);
+    }
+}
+
 /* is_ip4 - inner header flag */
 always_inline uword
 wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
@@ -555,6 +571,14 @@ wg_output_tun_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
       /* wg-output-process-ops */
       wg_output_process_ops (vm, node, ptd->crypto_ops, sync_bufs, nexts,
                             drop_next);
+
+      int n_left_from_sync_bufs = n_sync;
+      while (n_left_from_sync_bufs > 0)
+       {
+         n_left_from_sync_bufs--;
+         wg_calc_checksum (vm, sync_bufs[n_left_from_sync_bufs]);
+       }
+
       vlib_buffer_enqueue_to_next (vm, node, sync_bi, nexts, n_sync);
     }
   if (n_async)
@@ -627,6 +651,11 @@ wg_output_tun_post (vlib_main_t *vm, vlib_node_runtime_t *node,
       next[2] = (wg_post_data (b[2]))->next_index;
       next[3] = (wg_post_data (b[3]))->next_index;
 
+      wg_calc_checksum (vm, b[0]);
+      wg_calc_checksum (vm, b[1]);
+      wg_calc_checksum (vm, b[2]);
+      wg_calc_checksum (vm, b[3]);
+
       if (PREDICT_FALSE (node->flags & VLIB_NODE_FLAG_TRACE))
        {
          if (b[0]->flags & VLIB_BUFFER_IS_TRACED)
@@ -671,6 +700,8 @@ wg_output_tun_post (vlib_main_t *vm, vlib_node_runtime_t *node,
 
   while (n_left > 0)
     {
+      wg_calc_checksum (vm, b[0]);
+
       next[0] = (wg_post_data (b[0]))->next_index;
       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
                         (b[0]->flags & VLIB_BUFFER_IS_TRACED)))