wireguard: compute checksum for outer ipv6 header
[vpp.git] / src / plugins / wireguard / wireguard_send.c
index 509fe70..72fa110 100644 (file)
@@ -41,7 +41,8 @@ ip46_enqueue_packet (vlib_main_t *vm, u32 bi0, int is_ip4)
 }
 
 static void
-wg_buffer_prepend_rewrite (vlib_buffer_t *b0, const u8 *rewrite, u8 is_ip4)
+wg_buffer_prepend_rewrite (vlib_main_t *vm, vlib_buffer_t *b0,
+                          const u8 *rewrite, u8 is_ip4)
 {
   if (is_ip4)
     {
@@ -70,10 +71,15 @@ wg_buffer_prepend_rewrite (vlib_buffer_t *b0, const u8 *rewrite, u8 is_ip4)
       /* copy only ip6 and udp header; wireguard header not needed */
       clib_memcpy (hdr6, rewrite, sizeof (ip6_udp_header_t));
 
-      hdr6->udp.length =
+      hdr6->ip6.payload_length = hdr6->udp.length =
        clib_host_to_net_u16 (b0->current_length - sizeof (ip6_header_t));
 
-      hdr6->ip6.payload_length = clib_host_to_net_u16 (b0->current_length);
+      /* IPv6 UDP checksum is mandatory */
+      int bogus = 0;
+      ip6_header_t *ip6_0 = &(hdr6->ip6);
+      hdr6->udp.checksum =
+       ip6_tcp_udp_icmp_compute_checksum (vm, b0, ip6_0, &bogus);
+      ASSERT (bogus == 0);
     }
 }
 
@@ -95,7 +101,7 @@ wg_create_buffer (vlib_main_t *vm, const u8 *rewrite, const u8 *packet,
 
   b0->current_length = packet_len;
 
-  wg_buffer_prepend_rewrite (b0, rewrite, is_ip4);
+  wg_buffer_prepend_rewrite (vm, b0, rewrite, is_ip4);
 
   return true;
 }
@@ -104,6 +110,9 @@ u8 *
 wg_build_rewrite (ip46_address_t *src_addr, u16 src_port,
                  ip46_address_t *dst_addr, u16 dst_port, u8 is_ip4)
 {
+  if (ip46_address_is_zero (dst_addr) || 0 == dst_port)
+    return NULL;
+
   u8 *rewrite = NULL;
   if (is_ip4)
     {
@@ -151,6 +160,9 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
 {
   ASSERT (vm->thread_index == 0);
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   message_handshake_initiation_t packet;
 
   if (!is_retry)
@@ -224,6 +236,9 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 {
   ASSERT (vm->thread_index == 0);
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   u32 size_of_packet = message_data_len (0);
   message_data_t *packet =
     (message_data_t *) wg_main.per_thread_data[vm->thread_index].data;
@@ -278,6 +293,9 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
 {
   message_handshake_response_t packet;
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   if (noise_create_response (vm,
                             &peer->remote,
                             &packet.sender_index,
@@ -294,7 +312,6 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
          wg_timers_session_derived (peer);
          wg_timers_any_authenticated_packet_sent (peer);
          wg_timers_any_authenticated_packet_traversal (peer);
-         peer->last_sent_handshake = vlib_time_now (vm);
 
          u32 bi0 = 0;
          u8 is_ip4 = ip46_address_is_ip4 (&peer->dst.addr);
@@ -329,10 +346,14 @@ wg_send_handshake_cookie (vlib_main_t *vm, u32 sender_index,
 
   u32 bi0 = 0;
   u8 is_ip4 = ip46_address_is_ip4 (remote_addr);
+  bool ret;
   rewrite = wg_build_rewrite (wg_if_addr, wg_if_port, remote_addr, remote_port,
                              is_ip4);
-  if (!wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0,
-                        is_ip4))
+
+  ret = wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0,
+                         is_ip4);
+  vec_free (rewrite);
+  if (!ret)
     return false;
 
   ip46_enqueue_packet (vm, bi0, is_ip4);