wireguard: add peers roaming support
[vpp.git] / src / plugins / wireguard / wireguard_input.c
index ef60d50..22850b8 100644 (file)
@@ -25,6 +25,7 @@
 #define foreach_wg_input_error                                                \
   _ (NONE, "No error")                                                        \
   _ (HANDSHAKE_MAC, "Invalid MAC handshake")                                  \
+  _ (HANDSHAKE_RATELIMITED, "Handshake ratelimited")                          \
   _ (PEER, "Peer error")                                                      \
   _ (INTERFACE, "Interface error")                                            \
   _ (DECRYPTION, "Failed during decryption")                                  \
@@ -32,6 +33,7 @@
   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
   _ (COOKIE_DECRYPTION, "Failed during Cookie decryption")                    \
+  _ (COOKIE_SEND, "Failed during sending Cookie")                             \
   _ (TOO_BIG, "Packet too big")                                               \
   _ (UNDEFINED, "Undefined error")                                            \
   _ (CRYPTO_ENGINE_ERROR, "crypto engine error (packet dropped)")
@@ -123,16 +125,6 @@ typedef enum
   WG_INPUT_N_NEXT,
 } wg_input_next_t;
 
-/* static void */
-/* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
-/* { */
-/*   if (peer) */
-/*     { */
-/*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
-/*       peer->dst.port = udp_port; */
-/*     } */
-/* } */
-
 static u8
 is_ip4_header (u8 *data)
 {
@@ -169,11 +161,10 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     }
 
   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
-  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
-  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
+  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);
+  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);
 
   message_header_t *header = current_b_data;
-  under_load = false;
 
   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
     {
@@ -211,11 +202,13 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
       if (NULL == wg_if)
        continue;
 
+      under_load = wg_if_is_under_load (vm, wg_if);
       mac_state = cookie_checker_validate_macs (
        vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load,
        &src_ip, udp_src_port);
       if (mac_state == INVALID_MAC)
        {
+         wg_if_dec_handshake_num (wg_if);
          wg_if = NULL;
          continue;
        }
@@ -230,6 +223,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     packet_needs_cookie = false;
   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
     packet_needs_cookie = true;
+  else if (mac_state == VALID_MAC_WITH_COOKIE_BUT_RATELIMITED)
+    return WG_INPUT_ERROR_HANDSHAKE_RATELIMITED;
   else
     return WG_INPUT_ERROR_HANDSHAKE_MAC;
 
@@ -241,8 +236,16 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
 
        if (packet_needs_cookie)
          {
-           // TODO: Add processing
+
+           if (!wg_send_handshake_cookie (vm, message->sender_index,
+                                          &wg_if->cookie_checker, macs,
+                                          &ip_addr_46 (&wg_if->src_ip),
+                                          wg_if->port, &src_ip, udp_src_port))
+             return WG_INPUT_ERROR_COOKIE_SEND;
+
+           return WG_INPUT_ERROR_NONE;
          }
+
        noise_remote_t *rp;
        if (noise_consume_initiation
            (vm, noise_local_get (wg_if->local_idx), &rp,
@@ -256,7 +259,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
            return WG_INPUT_ERROR_PEER;
          }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (rp->r_peer_idx, &src_ip, udp_src_port);
+
        if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
          {
            vlib_node_increment_counter (vm, node_idx,
@@ -271,6 +275,18 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     case MESSAGE_HANDSHAKE_RESPONSE:
       {
        message_handshake_response_t *resp = current_b_data;
+
+       if (packet_needs_cookie)
+         {
+           if (!wg_send_handshake_cookie (vm, resp->sender_index,
+                                          &wg_if->cookie_checker, macs,
+                                          &ip_addr_46 (&wg_if->src_ip),
+                                          wg_if->port, &src_ip, udp_src_port))
+             return WG_INPUT_ERROR_COOKIE_SEND;
+
+           return WG_INPUT_ERROR_NONE;
+         }
+
        index_t peeri = INDEX_INVALID;
        u32 *entry =
          wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
@@ -292,12 +308,9 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
          {
            return WG_INPUT_ERROR_PEER;
          }
-       if (packet_needs_cookie)
-         {
-           // TODO: Add processing
-         }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (peeri, &src_ip, udp_src_port);
+
        if (noise_remote_begin_session (vm, &peer->remote))
          {
 
@@ -561,6 +574,26 @@ error:
   return ret;
 }
 
+static_always_inline void
+wg_find_outer_addr_port (vlib_buffer_t *b, ip46_address_t *addr, u16 *port,
+                        u8 is_ip4)
+{
+  if (is_ip4)
+    {
+      ip4_udp_header_t *ip4_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip4_udp_header_t);
+      ip46_address_set_ip4 (addr, &ip4_udp_hdr->ip4.src_address);
+      *port = clib_net_to_host_u16 (ip4_udp_hdr->udp.src_port);
+    }
+  else
+    {
+      ip6_udp_header_t *ip6_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip6_udp_header_t);
+      ip46_address_set_ip6 (addr, &ip6_udp_hdr->ip6.src_address);
+      *port = clib_net_to_host_u16 (ip6_udp_hdr->udp.src_port);
+    }
+}
+
 always_inline uword
 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
                 vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
@@ -714,8 +747,6 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
       else
        {
-         peer_idx = NULL;
-
          /* Handshake packets should be processed in main thread */
          if (thread_index != 0)
            {
@@ -787,6 +818,10 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
 
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
@@ -802,6 +837,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip,
+                                          out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
          last_peer_time_idx = peer_idx;
@@ -869,7 +906,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 }
 
 always_inline uword
-wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
+wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame,
+              u8 is_ip4)
 {
   vnet_main_t *vnm = vnet_get_main ();
   vnet_interface_main_t *im = &vnm->interface_main;
@@ -904,6 +942,10 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       bool is_keepalive = false;
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
@@ -928,6 +970,8 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip,
+                                          out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
          last_peer_time_idx = peer_idx;
@@ -974,13 +1018,13 @@ VLIB_NODE_FN (wg6_input_node)
 VLIB_NODE_FN (wg4_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 1);
 }
 
 VLIB_NODE_FN (wg6_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 0);
 }
 
 /* *INDENT-OFF* */