wireguard: update ESTABLISHED flag
[vpp.git] / src / plugins / wireguard / wireguard_input.c
index b85cdc6..6b8c803 100644 (file)
@@ -125,16 +125,6 @@ typedef enum
   WG_INPUT_N_NEXT,
 } wg_input_next_t;
 
-/* static void */
-/* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
-/* { */
-/*   if (peer) */
-/*     { */
-/*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
-/*       peer->dst.port = udp_port; */
-/*     } */
-/* } */
-
 static u8
 is_ip4_header (u8 *data)
 {
@@ -171,8 +161,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     }
 
   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
-  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
-  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
+  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);
+  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);
 
   message_header_t *header = current_b_data;
 
@@ -269,16 +259,13 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
            return WG_INPUT_ERROR_PEER;
          }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (rp->r_peer_idx, &src_ip, udp_src_port);
+
        if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
          {
            vlib_node_increment_counter (vm, node_idx,
                                         WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
          }
-       else
-         {
-           wg_peer_update_flags (rp->r_peer_idx, WG_PEER_ESTABLISHED, true);
-         }
        break;
       }
     case MESSAGE_HANDSHAKE_RESPONSE:
@@ -318,7 +305,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
            return WG_INPUT_ERROR_PEER;
          }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (peeri, &src_ip, udp_src_port);
+
        if (noise_remote_begin_session (vm, &peer->remote))
          {
 
@@ -351,9 +339,11 @@ wg_input_post_process (vlib_main_t *vm, vlib_buffer_t *b, u16 *next,
                       bool *is_keepalive)
 {
   next[0] = WG_INPUT_NEXT_PUNT;
+  noise_keypair_t *kp;
 
-  noise_keypair_t *kp =
-    wg_get_active_keypair (&peer->remote, data->receiver_index);
+  if ((kp = wg_get_active_keypair (&peer->remote, data->receiver_index)) ==
+      NULL)
+    return -1;
 
   if (!noise_counter_recv (&kp->kp_ctr, data->counter))
     {
@@ -371,7 +361,7 @@ wg_input_post_process (vlib_main_t *vm, vlib_buffer_t *b, u16 *next,
   if (decr_len == 0)
     {
       *is_keepalive = true;
-      return -1;
+      return 0;
     }
 
   wg_timers_data_received (peer);
@@ -582,6 +572,26 @@ error:
   return ret;
 }
 
+static_always_inline void
+wg_find_outer_addr_port (vlib_buffer_t *b, ip46_address_t *addr, u16 *port,
+                        u8 is_ip4)
+{
+  if (is_ip4)
+    {
+      ip4_udp_header_t *ip4_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip4_udp_header_t);
+      ip46_address_set_ip4 (addr, &ip4_udp_hdr->ip4.src_address);
+      *port = clib_net_to_host_u16 (ip4_udp_hdr->udp.src_port);
+    }
+  else
+    {
+      ip6_udp_header_t *ip6_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip6_udp_header_t);
+      ip46_address_set_ip6 (addr, &ip6_udp_hdr->ip6.src_address);
+      *port = clib_net_to_host_u16 (ip6_udp_hdr->udp.src_port);
+    }
+}
+
 always_inline uword
 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
                 vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
@@ -620,6 +630,7 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 
   bool is_keepalive = false;
   u32 *peer_idx = NULL;
+  index_t peeri = INDEX_INVALID;
 
   while (n_left_from > 0)
     {
@@ -653,9 +664,15 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
                                                data->receiver_index);
              if (PREDICT_TRUE (peer_idx != NULL))
                {
-                 peer = wg_peer_get (*peer_idx);
+                 peeri = *peer_idx;
+                 peer = wg_peer_get (peeri);
+                 last_rec_idx = data->receiver_index;
+               }
+             else
+               {
+                 peer = NULL;
+                 last_rec_idx = ~0;
                }
-             last_rec_idx = data->receiver_index;
            }
 
          if (PREDICT_FALSE (!peer_idx))
@@ -727,7 +744,7 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
            }
          else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
            {
-             wg_send_handshake_from_mt (*peer_idx, false);
+             wg_send_handshake_from_mt (peeri, false);
              goto next;
            }
          else if (PREDICT_TRUE (state_cr == SC_OK))
@@ -735,8 +752,6 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
       else
        {
-         peer_idx = NULL;
-
          /* Handshake packets should be processed in main thread */
          if (thread_index != 0)
            {
@@ -770,7 +785,7 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
          t->type = header_type;
          t->current_length = b[0]->current_length;
          t->is_keepalive = is_keepalive;
-         t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
+         t->peer = peer_idx ? peeri : INDEX_INVALID;
        }
 
     next:
@@ -808,23 +823,50 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
 
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
          peer_idx =
            wg_index_table_lookup (&wmp->index_table, data->receiver_index);
-         peer = wg_peer_get (*peer_idx);
-         last_rec_idx = data->receiver_index;
+         if (PREDICT_TRUE (peer_idx != NULL))
+           {
+             peeri = *peer_idx;
+             peer = wg_peer_get (peeri);
+             last_rec_idx = data->receiver_index;
+           }
+         else
+           {
+             peer = NULL;
+             last_rec_idx = ~0;
+           }
        }
 
-      if (PREDICT_FALSE (wg_input_post_process (vm, b[0], data_next, peer,
-                                               data, &is_keepalive) < 0))
-       goto trace;
+      if (PREDICT_TRUE (peer != NULL))
+       {
+         if (PREDICT_FALSE (wg_input_post_process (vm, b[0], data_next, peer,
+                                                   data, &is_keepalive) < 0))
+           goto trace;
+       }
+      else
+       {
+         data_next[0] = WG_INPUT_NEXT_PUNT;
+         goto trace;
+       }
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         if (PREDICT_FALSE (
+               !ip46_address_is_equal (&peer->dst.addr, &out_src_ip) ||
+               peer->dst.port != out_udp_src_port))
+           wg_peer_update_endpoint_from_mt (peeri, &out_src_ip,
+                                            out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
+         wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, true);
          last_peer_time_idx = peer_idx;
        }
 
@@ -841,7 +883,7 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
          t->type = header_type;
          t->current_length = b[0]->current_length;
          t->is_keepalive = is_keepalive;
-         t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
+         t->peer = peer_idx ? peeri : INDEX_INVALID;
        }
 
       b += 1;
@@ -890,7 +932,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 }
 
 always_inline uword
-wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
+wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame,
+              u8 is_ip4)
 {
   vnet_main_t *vnm = vnet_get_main ();
   vnet_interface_main_t *im = &vnm->interface_main;
@@ -902,6 +945,7 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
   wg_peer_t *peer = NULL;
   u32 *peer_idx = NULL;
   u32 *last_peer_time_idx = NULL;
+  index_t peeri = INDEX_INVALID;
   u32 last_rec_idx = ~0;
   f64 time = clib_time_now (&vm->clib_time) + vm->time_offset;
 
@@ -925,14 +969,27 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       bool is_keepalive = false;
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
          peer_idx =
            wg_index_table_lookup (&wmp->index_table, data->receiver_index);
 
-         peer = wg_peer_get (*peer_idx);
-         last_rec_idx = data->receiver_index;
+         if (PREDICT_TRUE (peer_idx != NULL))
+           {
+             peeri = *peer_idx;
+             peer = wg_peer_get (peeri);
+             last_rec_idx = data->receiver_index;
+           }
+         else
+           {
+             peer = NULL;
+             last_rec_idx = ~0;
+           }
        }
 
       if (PREDICT_TRUE (peer != NULL))
@@ -949,8 +1006,14 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         if (PREDICT_FALSE (
+               !ip46_address_is_equal (&peer->dst.addr, &out_src_ip) ||
+               peer->dst.port != out_udp_src_port))
+           wg_peer_update_endpoint_from_mt (peeri, &out_src_ip,
+                                            out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
+         wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, true);
          last_peer_time_idx = peer_idx;
        }
 
@@ -966,7 +1029,7 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
          wg_input_post_trace_t *t =
            vlib_add_trace (vm, node, b[0], sizeof (*t));
          t->next = next[0];
-         t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
+         t->peer = peer_idx ? peeri : INDEX_INVALID;
        }
 
       b += 1;
@@ -995,13 +1058,13 @@ VLIB_NODE_FN (wg6_input_node)
 VLIB_NODE_FN (wg4_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 1);
 }
 
 VLIB_NODE_FN (wg6_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 0);
 }
 
 /* *INDENT-OFF* */