wireguard: add peers roaming support 97/36797/4
authorAlexander Chernavin <achernavin@netgate.com>
Thu, 4 Aug 2022 08:11:57 +0000 (08:11 +0000)
committerAlexander Chernavin <achernavin@netgate.com>
Tue, 9 Aug 2022 15:55:45 +0000 (15:55 +0000)
Type: feature

With this change, peers are able to roam between different external
endpoints. Successfully authenticated handshake or data packet that is
received from a new endpoint will cause the peer's endpoint to be
updated accordingly.

Signed-off-by: Alexander Chernavin <achernavin@netgate.com>
Change-Id: Ib4eb7dfa3403f3fb9e8bbe19ba6237c4960c764c

src/plugins/wireguard/FEATURE.yaml
src/plugins/wireguard/README.rst
src/plugins/wireguard/wireguard_cli.c
src/plugins/wireguard/wireguard_input.c
src/plugins/wireguard/wireguard_peer.c
src/plugins/wireguard/wireguard_peer.h
src/plugins/wireguard/wireguard_send.c
test/test_wireguard.py

index 4c6946d..5c0a588 100644 (file)
@@ -7,5 +7,3 @@ features:
 description: "Wireguard protocol implementation"
 state: development
 properties: [API, CLI]
-missing:
-  - Peers roaming between different external IPs
index ead4125..35dd2c4 100644 (file)
@@ -77,4 +77,3 @@ Main next steps for improving this implementation
 -------------------------------------------------
 
 1. Use all benefits of VPP-engine.
-2. Add peers roaming support
index 214e6a5..5fa6205 100644 (file)
@@ -165,7 +165,7 @@ wg_peer_add_command_fn (vlib_main_t * vm,
   u8 public_key[NOISE_PUBLIC_KEY_LEN + 1];
   fib_prefix_t allowed_ip, *allowed_ips = NULL;
   ip_prefix_t pfx;
-  ip_address_t ip;
+  ip_address_t ip = ip_address_initializer;
   u32 portDst = 0, table_id = 0;
   u32 persistent_keepalive = 0;
   u32 tun_sw_if_index = ~0;
@@ -213,6 +213,12 @@ wg_peer_add_command_fn (vlib_main_t * vm,
        }
     }
 
+  if (0 == vec_len (allowed_ips))
+    {
+      error = clib_error_return (0, "Allowed IPs are not specified");
+      goto done;
+    }
+
   rv = wg_peer_add (tun_sw_if_index, public_key, table_id, &ip_addr_46 (&ip),
                    allowed_ips, portDst, persistent_keepalive, &peer_index);
 
index b85cdc6..22850b8 100644 (file)
@@ -125,16 +125,6 @@ typedef enum
   WG_INPUT_N_NEXT,
 } wg_input_next_t;
 
-/* static void */
-/* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
-/* { */
-/*   if (peer) */
-/*     { */
-/*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
-/*       peer->dst.port = udp_port; */
-/*     } */
-/* } */
-
 static u8
 is_ip4_header (u8 *data)
 {
@@ -171,8 +161,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     }
 
   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
-  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
-  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
+  u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);
+  u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);
 
   message_header_t *header = current_b_data;
 
@@ -269,7 +259,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
            return WG_INPUT_ERROR_PEER;
          }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (rp->r_peer_idx, &src_ip, udp_src_port);
+
        if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
          {
            vlib_node_increment_counter (vm, node_idx,
@@ -318,7 +309,8 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
            return WG_INPUT_ERROR_PEER;
          }
 
-       // set_peer_address (peer, ip4_src, udp_src_port);
+       wg_peer_update_endpoint (peeri, &src_ip, udp_src_port);
+
        if (noise_remote_begin_session (vm, &peer->remote))
          {
 
@@ -582,6 +574,26 @@ error:
   return ret;
 }
 
+static_always_inline void
+wg_find_outer_addr_port (vlib_buffer_t *b, ip46_address_t *addr, u16 *port,
+                        u8 is_ip4)
+{
+  if (is_ip4)
+    {
+      ip4_udp_header_t *ip4_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip4_udp_header_t);
+      ip46_address_set_ip4 (addr, &ip4_udp_hdr->ip4.src_address);
+      *port = clib_net_to_host_u16 (ip4_udp_hdr->udp.src_port);
+    }
+  else
+    {
+      ip6_udp_header_t *ip6_udp_hdr =
+       vlib_buffer_get_current (b) - sizeof (ip6_udp_header_t);
+      ip46_address_set_ip6 (addr, &ip6_udp_hdr->ip6.src_address);
+      *port = clib_net_to_host_u16 (ip6_udp_hdr->udp.src_port);
+    }
+}
+
 always_inline uword
 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
                 vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
@@ -735,8 +747,6 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
       else
        {
-         peer_idx = NULL;
-
          /* Handshake packets should be processed in main thread */
          if (thread_index != 0)
            {
@@ -808,6 +818,10 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
        }
 
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
@@ -823,6 +837,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip,
+                                          out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
          last_peer_time_idx = peer_idx;
@@ -890,7 +906,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 }
 
 always_inline uword
-wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
+wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame,
+              u8 is_ip4)
 {
   vnet_main_t *vnm = vnet_get_main ();
   vnet_interface_main_t *im = &vnm->interface_main;
@@ -925,6 +942,10 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       bool is_keepalive = false;
       message_data_t *data = vlib_buffer_get_current (b[0]);
+      ip46_address_t out_src_ip;
+      u16 out_udp_src_port;
+
+      wg_find_outer_addr_port (b[0], &out_src_ip, &out_udp_src_port, is_ip4);
 
       if (data->receiver_index != last_rec_idx)
        {
@@ -949,6 +970,8 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
+         wg_peer_update_endpoint_from_mt (*peer_idx, &out_src_ip,
+                                          out_udp_src_port);
          wg_timers_any_authenticated_packet_received_opt (peer, time);
          wg_timers_any_authenticated_packet_traversal (peer);
          last_peer_time_idx = peer_idx;
@@ -995,13 +1018,13 @@ VLIB_NODE_FN (wg6_input_node)
 VLIB_NODE_FN (wg4_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 1);
 }
 
 VLIB_NODE_FN (wg6_input_post_node)
 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
 {
-  return wg_input_post (vm, node, from_frame);
+  return wg_input_post (vm, node, from_frame, /* is_ip4 */ 0);
 }
 
 /* *INDENT-OFF* */
index 589f712..922ca8c 100644 (file)
@@ -16,6 +16,7 @@
 
 #include <vnet/adj/adj_midchain.h>
 #include <vnet/fib/fib_table.h>
+#include <vnet/fib/fib_entry_track.h>
 #include <wireguard/wireguard_peer.h>
 #include <wireguard/wireguard_if.h>
 #include <wireguard/wireguard_messages.h>
@@ -63,13 +64,14 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
   wg_peer_endpoint_reset (&peer->src);
   wg_peer_endpoint_reset (&peer->dst);
 
-  adj_index_t *adj_index;
-  vec_foreach (adj_index, peer->adj_indices)
+  wg_peer_adj_t *peer_adj;
+  vec_foreach (peer_adj, peer->adjs)
     {
-      if (INDEX_INVALID != *adj_index)
-       {
-         wg_peer_by_adj_index[*adj_index] = INDEX_INVALID;
-       }
+      wg_peer_by_adj_index[peer_adj->adj_index] = INDEX_INVALID;
+      if (FIB_NODE_INDEX_INVALID != peer_adj->fib_entry_index)
+       fib_entry_untrack (peer_adj->fib_entry_index, peer_adj->sibling_index);
+      if (adj_is_valid (peer_adj->adj_index))
+       adj_nbr_midchain_unstack (peer_adj->adj_index);
     }
   peer->input_thread_index = ~0;
   peer->output_thread_index = ~0;
@@ -83,8 +85,9 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
   peer->new_handshake_interval_tick = 0;
   peer->rehandshake_interval_tick = 0;
   peer->timer_need_another_keepalive = false;
+  vec_free (peer->rewrite);
   vec_free (peer->allowed_ips);
-  vec_free (peer->adj_indices);
+  vec_free (peer->adjs);
 }
 
 static void
@@ -96,17 +99,17 @@ wg_peer_init (vlib_main_t * vm, wg_peer_t * peer)
 }
 
 static void
-wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai)
+wg_peer_adj_stack (wg_peer_t *peer, wg_peer_adj_t *peer_adj)
 {
   ip_adjacency_t *adj;
   u32 sw_if_index;
   wg_if_t *wgi;
   fib_protocol_t fib_proto;
 
-  if (!adj_is_valid (ai))
+  if (!adj_is_valid (peer_adj->adj_index))
     return;
 
-  adj = adj_get (ai);
+  adj = adj_get (peer_adj->adj_index);
   sw_if_index = adj->rewrite_header.sw_if_index;
   u8 is_ip4 = ip46_address_is_ip4 (&peer->src.addr);
   fib_proto = is_ip4 ? FIB_PROTOCOL_IP4 : FIB_PROTOCOL_IP6;
@@ -116,9 +119,10 @@ wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai)
   if (!wgi)
     return;
 
-  if (!vnet_sw_interface_is_admin_up (vnet_get_main (), wgi->sw_if_index))
+  if (!vnet_sw_interface_is_admin_up (vnet_get_main (), wgi->sw_if_index) ||
+      !wg_peer_can_send (peer))
     {
-      adj_midchain_delegate_unstack (ai);
+      adj_nbr_midchain_unstack (peer_adj->adj_index);
     }
   else
     {
@@ -132,8 +136,13 @@ wg_peer_adj_stack (wg_peer_t *peer, adj_index_t ai)
       u32 fib_index;
 
       fib_index = fib_table_find (fib_proto, peer->table_id);
+      peer_adj->fib_entry_index =
+       fib_entry_track (fib_index, &dst, FIB_NODE_TYPE_ADJ,
+                        peer_adj->adj_index, &peer_adj->sibling_index);
 
-      adj_midchain_delegate_stack (ai, fib_index, &dst);
+      adj_nbr_midchain_stack_on_fib_entry (
+       peer_adj->adj_index, peer_adj->fib_entry_index,
+       fib_forw_chain_type_from_fib_proto (dst.fp_proto));
     }
 }
 
@@ -198,11 +207,11 @@ walk_rc_t
 wg_peer_if_admin_state_change (index_t peeri, void *data)
 {
   wg_peer_t *peer;
-  adj_index_t *adj_index;
+  wg_peer_adj_t *peer_adj;
   peer = wg_peer_get (peeri);
-  vec_foreach (adj_index, peer->adj_indices)
+  vec_foreach (peer_adj, peer->adjs)
     {
-      wg_peer_adj_stack (peer, *adj_index);
+      wg_peer_adj_stack (peer, peer_adj);
     }
   return (WALK_CONTINUE);
 }
@@ -215,6 +224,7 @@ wg_peer_if_adj_change (index_t peeri, void *data)
   ip_adjacency_t *adj;
   wg_peer_t *peer;
   fib_prefix_t *allowed_ip;
+  wg_peer_adj_t *peer_adj;
 
   adj = adj_get (*adj_index);
 
@@ -224,17 +234,21 @@ wg_peer_if_adj_change (index_t peeri, void *data)
       if (fib_prefix_is_cover_addr_46 (allowed_ip,
                                       &adj->sub_type.nbr.next_hop))
        {
-         vec_add1 (peer->adj_indices, *adj_index);
+         vec_add2 (peer->adjs, peer_adj, 1);
+         peer_adj->adj_index = *adj_index;
+         peer_adj->fib_entry_index = FIB_NODE_INDEX_INVALID;
+         peer_adj->sibling_index = ~0;
+
          vec_validate_init_empty (wg_peer_by_adj_index, *adj_index,
                                   INDEX_INVALID);
-         wg_peer_by_adj_index[*adj_index] = peer - wg_peer_pool;
+         wg_peer_by_adj_index[*adj_index] = peeri;
 
          fixup = wg_peer_get_fixup (peer, adj_get_link_type (*adj_index));
          adj_nbr_midchain_update_rewrite (*adj_index, fixup, NULL,
                                           ADJ_FLAG_MIDCHAIN_IP_STACK,
                                           vec_dup (peer->rewrite));
 
-         wg_peer_adj_stack (peer, *adj_index);
+         wg_peer_adj_stack (peer, peer_adj);
          return (WALK_STOP);
        }
     }
@@ -313,6 +327,71 @@ wg_peer_update_flags (index_t peeri, wg_peer_flags flag, bool add_del)
   wg_api_peer_event (peeri, peer->flags);
 }
 
+void
+wg_peer_update_endpoint (index_t peeri, const ip46_address_t *addr, u16 port)
+{
+  wg_peer_t *peer = wg_peer_get (peeri);
+
+  if (ip46_address_is_equal (&peer->dst.addr, addr) && peer->dst.port == port)
+    return;
+
+  wg_peer_endpoint_init (&peer->dst, addr, port);
+
+  u8 is_ip4 = ip46_address_is_ip4 (&peer->dst.addr);
+  vec_free (peer->rewrite);
+  peer->rewrite = wg_build_rewrite (&peer->src.addr, peer->src.port,
+                                   &peer->dst.addr, peer->dst.port, is_ip4);
+
+  wg_peer_adj_t *peer_adj;
+  vec_foreach (peer_adj, peer->adjs)
+    {
+      if (FIB_NODE_INDEX_INVALID != peer_adj->fib_entry_index)
+       {
+         fib_entry_untrack (peer_adj->fib_entry_index,
+                            peer_adj->sibling_index);
+         peer_adj->fib_entry_index = FIB_NODE_INDEX_INVALID;
+         peer_adj->sibling_index = ~0;
+       }
+
+      if (adj_is_valid (peer_adj->adj_index))
+       {
+         adj_midchain_fixup_t fixup =
+           wg_peer_get_fixup (peer, adj_get_link_type (peer_adj->adj_index));
+         adj_nbr_midchain_update_rewrite (peer_adj->adj_index, fixup, NULL,
+                                          ADJ_FLAG_MIDCHAIN_IP_STACK,
+                                          vec_dup (peer->rewrite));
+         wg_peer_adj_stack (peer, peer_adj);
+       }
+    }
+}
+
+typedef struct wg_peer_upd_ep_args_t_
+{
+  index_t peeri;
+  ip46_address_t addr;
+  u16 port;
+} wg_peer_upd_ep_args_t;
+
+static void
+wg_peer_update_endpoint_thread_fn (wg_peer_upd_ep_args_t *args)
+{
+  wg_peer_update_endpoint (args->peeri, &args->addr, args->port);
+}
+
+void
+wg_peer_update_endpoint_from_mt (index_t peeri, const ip46_address_t *addr,
+                                u16 port)
+{
+  wg_peer_upd_ep_args_t args = {
+    .peeri = peeri,
+    .port = port,
+  };
+
+  ip46_address_copy (&args.addr, addr);
+  vlib_rpc_call_main_thread (wg_peer_update_endpoint_thread_fn, (u8 *) &args,
+                            sizeof (args));
+}
+
 int
 wg_peer_add (u32 tun_sw_if_index, const u8 public_key[NOISE_PUBLIC_KEY_LEN],
             u32 table_id, const ip46_address_t *endpoint,
@@ -345,7 +424,7 @@ wg_peer_add (u32 tun_sw_if_index, const u8 public_key[NOISE_PUBLIC_KEY_LEN],
   if (pool_elts (wg_peer_pool) > MAX_PEERS)
     return (VNET_API_ERROR_LIMIT_EXCEEDED);
 
-  pool_get (wg_peer_pool, peer);
+  pool_get_zero (wg_peer_pool, peer);
 
   wg_peer_init (vm, peer);
 
@@ -428,9 +507,9 @@ format_wg_peer (u8 * s, va_list * va)
 {
   index_t peeri = va_arg (*va, index_t);
   fib_prefix_t *allowed_ip;
-  adj_index_t *adj_index;
   u8 key[NOISE_KEY_LEN_BASE64];
   wg_peer_t *peer;
+  wg_peer_adj_t *peer_adj;
 
   peer = wg_peer_get (peeri);
   key_to_base64 (peer->remote.r_public, NOISE_PUBLIC_KEY_LEN, key);
@@ -443,9 +522,9 @@ format_wg_peer (u8 * s, va_list * va)
     peer->wg_sw_if_index, peer->persistent_keepalive_interval, peer->flags,
     pool_elts (peer->api_clients));
   s = format (s, "\n  adj:");
-  vec_foreach (adj_index, peer->adj_indices)
+  vec_foreach (peer_adj, peer->adjs)
     {
-      s = format (s, " %d", *adj_index);
+      s = format (s, " %d", peer_adj->adj_index);
     }
   s = format (s, "\n  key:%=s %U", key, format_hex_bytes,
              peer->remote.r_public, NOISE_PUBLIC_KEY_LEN);
index a14f269..c07ea89 100644 (file)
@@ -68,6 +68,13 @@ typedef enum
   WG_PEER_ESTABLISHED = 0x2,
 } wg_peer_flags;
 
+typedef struct wg_peer_adj_t_
+{
+  adj_index_t adj_index;
+  fib_node_index_t fib_entry_index;
+  u32 sibling_index;
+} wg_peer_adj_t;
+
 typedef struct wg_peer
 {
   noise_remote_t remote;
@@ -80,7 +87,7 @@ typedef struct wg_peer
   wg_peer_endpoint_t dst;
   wg_peer_endpoint_t src;
   u32 table_id;
-  adj_index_t *adj_indices;
+  wg_peer_adj_t *adjs;
 
   /* rewrite built from address information */
   u8 *rewrite;
@@ -144,6 +151,10 @@ adj_walk_rc_t wg_peer_adj_walk (adj_index_t ai, void *data);
 
 void wg_api_peer_event (index_t peeri, wg_peer_flags flags);
 void wg_peer_update_flags (index_t peeri, wg_peer_flags flag, bool add_del);
+void wg_peer_update_endpoint (index_t peeri, const ip46_address_t *addr,
+                             u16 port);
+void wg_peer_update_endpoint_from_mt (index_t peeri,
+                                     const ip46_address_t *addr, u16 port);
 
 static inline bool
 wg_peer_is_dead (wg_peer_t *peer)
@@ -200,6 +211,12 @@ fib_prefix_is_cover_addr_46 (const fib_prefix_t *p1, const ip46_address_t *ip)
   return (false);
 }
 
+static inline bool
+wg_peer_can_send (wg_peer_t *peer)
+{
+  return peer && peer->rewrite;
+}
+
 #endif // __included_wg_peer_h__
 
 /*
index 509fe70..93e808a 100644 (file)
@@ -104,6 +104,9 @@ u8 *
 wg_build_rewrite (ip46_address_t *src_addr, u16 src_port,
                  ip46_address_t *dst_addr, u16 dst_port, u8 is_ip4)
 {
+  if (ip46_address_is_zero (dst_addr) || 0 == dst_port)
+    return NULL;
+
   u8 *rewrite = NULL;
   if (is_ip4)
     {
@@ -151,6 +154,9 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
 {
   ASSERT (vm->thread_index == 0);
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   message_handshake_initiation_t packet;
 
   if (!is_retry)
@@ -224,6 +230,9 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 {
   ASSERT (vm->thread_index == 0);
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   u32 size_of_packet = message_data_len (0);
   message_data_t *packet =
     (message_data_t *) wg_main.per_thread_data[vm->thread_index].data;
@@ -278,6 +287,9 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
 {
   message_handshake_response_t packet;
 
+  if (!wg_peer_can_send (peer))
+    return false;
+
   if (noise_create_response (vm,
                             &peer->remote,
                             &packet.sender_index,
@@ -329,10 +341,14 @@ wg_send_handshake_cookie (vlib_main_t *vm, u32 sender_index,
 
   u32 bi0 = 0;
   u8 is_ip4 = ip46_address_is_ip4 (remote_addr);
+  bool ret;
   rewrite = wg_build_rewrite (wg_if_addr, wg_if_port, remote_addr, remote_port,
                              is_ip4);
-  if (!wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0,
-                        is_ip4))
+
+  ret = wg_create_buffer (vm, rewrite, (u8 *) &packet, sizeof (packet), &bi0,
+                         is_ip4);
+  vec_free (rewrite);
+  if (!ret)
     return false;
 
   ip46_enqueue_packet (vm, bi0, is_ip4);
index b8c5d2a..95cfe68 100644 (file)
@@ -137,15 +137,6 @@ class VppWgInterface(VppInterface):
         return "wireguard-%d" % self._sw_if_index
 
 
-def find_route(test, prefix, is_ip6, table_id=0):
-    routes = test.vapi.ip_route_dump(table_id, is_ip6)
-
-    for e in routes:
-        if table_id == e.route.table_id and str(e.route.prefix) == str(prefix):
-            return True
-    return False
-
-
 NOISE_HANDSHAKE_NAME = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
 NOISE_IDENTIFIER_NAME = b"WireGuard v1 zx2c4 Jason@zx2c4.com"
 
@@ -176,6 +167,10 @@ class VppWgPeer(VppObject):
 
         self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME)
 
+    def change_endpoint(self, endpoint, port):
+        self.endpoint = endpoint
+        self.port = port
+
     def add_vpp_config(self, is_ip6=False):
         rv = self._test.vapi.wireguard_peer_add(
             peer={
@@ -206,10 +201,12 @@ class VppWgPeer(VppObject):
         peers = self._test.vapi.wireguard_peers_dump()
 
         for p in peers:
+            # "::" endpoint will be returned as "0.0.0.0" in peer's details
+            endpoint = "0.0.0.0" if self.endpoint == "::" else self.endpoint
             if (
                 p.peer.public_key == self.public_key_bytes()
                 and p.peer.port == self.port
-                and str(p.peer.endpoint) == self.endpoint
+                and str(p.peer.endpoint) == endpoint
                 and p.peer.sw_if_index == self.itf.sw_if_index
                 and len(self.allowed_ips) == p.peer.n_allowed_ips
             ):
@@ -470,17 +467,17 @@ class VppWgPeer(VppObject):
     def validate_encapped(self, rxs, tx, is_ip6=False):
         for rx in rxs:
             if is_ip6 is False:
-                rx = IP(self.decrypt_transport(rx))
+                rx = IP(self.decrypt_transport(rx, is_ip6=is_ip6))
 
-                # chech the oringial packet is present
+                # check the original packet is present
                 self._test.assertEqual(rx[IP].dst, tx[IP].dst)
                 self._test.assertEqual(rx[IP].ttl, tx[IP].ttl - 1)
             else:
-                rx = IPv6(self.decrypt_transport(rx))
+                rx = IPv6(self.decrypt_transport(rx, is_ip6=is_ip6))
 
-                # chech the oringial packet is present
+                # check the original packet is present
                 self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst)
-                self._test.assertEqual(rx[IPv6].ttl, tx[IPv6].ttl - 1)
+                self._test.assertEqual(rx[IPv6].hlim, tx[IPv6].hlim - 1)
 
     def want_events(self):
         self._test.vapi.want_wireguard_peer_events(
@@ -997,6 +994,237 @@ class TestWg(VppTestCase):
         peer_2.remove_vpp_config()
         wg0.remove_vpp_config()
 
+    def _test_wg_peer_roaming_on_handshake_tmpl(self, is_endpoint_set, is_resp, is_ip6):
+        port = 12323
+
+        # create wg interface
+        if is_ip6:
+            wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip6()
+        else:
+            wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create more remote hosts
+        NUM_REMOTE_HOSTS = 2
+        self.pg1.generate_remote_hosts(NUM_REMOTE_HOSTS)
+        if is_ip6:
+            self.pg1.configure_ipv6_neighbors()
+        else:
+            self.pg1.configure_ipv4_neighbors()
+
+        # create a peer
+        if is_ip6:
+            peer_1 = VppWgPeer(
+                test=self,
+                itf=wg0,
+                endpoint=self.pg1.remote_hosts[0].ip6 if is_endpoint_set else "::",
+                port=port + 1 if is_endpoint_set else 0,
+                allowed_ips=["1::3:0/112"],
+            ).add_vpp_config()
+        else:
+            peer_1 = VppWgPeer(
+                test=self,
+                itf=wg0,
+                endpoint=self.pg1.remote_hosts[0].ip4 if is_endpoint_set else "0.0.0.0",
+                port=port + 1 if is_endpoint_set else 0,
+                allowed_ips=["10.11.3.0/24"],
+            ).add_vpp_config()
+        self.assertTrue(peer_1.query_vpp_config())
+
+        if is_resp:
+            # wait for the peer to send a handshake initiation
+            rxs = self.pg1.get_capture(1, timeout=2)
+            # prepare a handshake response
+            resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6)
+            # change endpoint
+            if is_ip6:
+                peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100)
+                resp[IPv6].src, resp[UDP].sport = peer_1.endpoint, peer_1.port
+            else:
+                peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100)
+                resp[IP].src, resp[UDP].sport = peer_1.endpoint, peer_1.port
+            # send the handshake response
+            # expect a keepalive message sent to the new endpoint
+            rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+            # verify the keepalive message
+            b = peer_1.decrypt_transport(rxs[0], is_ip6=is_ip6)
+            self.assertEqual(0, len(b))
+        else:
+            # change endpoint
+            if is_ip6:
+                peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100)
+            else:
+                peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100)
+            # prepare and send a handshake initiation
+            # expect a handshake response sent to the new endpoint
+            init = peer_1.mk_handshake(self.pg1, is_ip6=is_ip6)
+            rxs = self.send_and_expect(self.pg1, [init], self.pg1)
+            # verify the response
+            peer_1.consume_response(rxs[0], is_ip6=is_ip6)
+        self.assertTrue(peer_1.query_vpp_config())
+
+        # remove configs
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_peer_roaming_on_init_v4(self):
+        """Peer roaming on handshake initiation (v4)"""
+        self._test_wg_peer_roaming_on_handshake_tmpl(
+            is_endpoint_set=False, is_resp=False, is_ip6=False
+        )
+
+    def test_wg_peer_roaming_on_init_v6(self):
+        """Peer roaming on handshake initiation (v6)"""
+        self._test_wg_peer_roaming_on_handshake_tmpl(
+            is_endpoint_set=False, is_resp=False, is_ip6=True
+        )
+
+    def test_wg_peer_roaming_on_resp_v4(self):
+        """Peer roaming on handshake response (v4)"""
+        self._test_wg_peer_roaming_on_handshake_tmpl(
+            is_endpoint_set=True, is_resp=True, is_ip6=False
+        )
+
+    def test_wg_peer_roaming_on_resp_v6(self):
+        """Peer roaming on handshake response (v6)"""
+        self._test_wg_peer_roaming_on_handshake_tmpl(
+            is_endpoint_set=True, is_resp=True, is_ip6=True
+        )
+
+    def _test_wg_peer_roaming_on_data_tmpl(self, is_async, is_ip6):
+        self.vapi.wg_set_async_mode(is_async)
+        port = 12323
+
+        # create wg interface
+        if is_ip6:
+            wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip6()
+        else:
+            wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create more remote hosts
+        NUM_REMOTE_HOSTS = 2
+        self.pg1.generate_remote_hosts(NUM_REMOTE_HOSTS)
+        if is_ip6:
+            self.pg1.configure_ipv6_neighbors()
+        else:
+            self.pg1.configure_ipv4_neighbors()
+
+        # create a peer
+        if is_ip6:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_hosts[0].ip6, port + 1, ["1::3:0/112"]
+            ).add_vpp_config()
+        else:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_hosts[0].ip4, port + 1, ["10.11.3.0/24"]
+            ).add_vpp_config()
+        self.assertTrue(peer_1.query_vpp_config())
+
+        # create a route to rewrite traffic into the wg interface
+        if is_ip6:
+            r1 = VppIpRoute(
+                self, "1::3:0", 112, [VppRoutePath("1::3:1", wg0.sw_if_index)]
+            ).add_vpp_config()
+        else:
+            r1 = VppIpRoute(
+                self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)]
+            ).add_vpp_config()
+
+        # wait for the peer to send a handshake initiation
+        rxs = self.pg1.get_capture(1, timeout=2)
+
+        # prepare and send a handshake response
+        # expect a keepalive message
+        resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6)
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        # verify the keepalive message
+        b = peer_1.decrypt_transport(rxs[0], is_ip6=is_ip6)
+        self.assertEqual(0, len(b))
+
+        # change endpoint
+        if is_ip6:
+            peer_1.change_endpoint(self.pg1.remote_hosts[1].ip6, port + 100)
+        else:
+            peer_1.change_endpoint(self.pg1.remote_hosts[1].ip4, port + 100)
+
+        # prepare and send a data packet
+        # expect endpoint change
+        if is_ip6:
+            ip_header = IPv6(src="1::3:1", dst=self.pg0.remote_ip6, hlim=20)
+        else:
+            ip_header = IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20)
+        data = (
+            peer_1.mk_tunnel_header(self.pg1, is_ip6=is_ip6)
+            / Wireguard(message_type=4, reserved_zero=0)
+            / WireguardTransport(
+                receiver_index=peer_1.sender,
+                counter=0,
+                encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                    ip_header / UDP(sport=222, dport=223) / Raw()
+                ),
+            )
+        )
+        rxs = self.send_and_expect(self.pg1, [data], self.pg0)
+        if is_ip6:
+            self.assertEqual(rxs[0][IPv6].dst, self.pg0.remote_ip6)
+            self.assertEqual(rxs[0][IPv6].hlim, 19)
+        else:
+            self.assertEqual(rxs[0][IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rxs[0][IP].ttl, 19)
+        self.assertTrue(peer_1.query_vpp_config())
+
+        # prepare and send a packet that will be rewritten into the wg interface
+        # expect a data packet sent to the new endpoint
+        if is_ip6:
+            ip_header = IPv6(src=self.pg0.remote_ip6, dst="1::3:2")
+        else:
+            ip_header = IP(src=self.pg0.remote_ip4, dst="10.11.3.2")
+        p = (
+            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
+            / ip_header
+            / UDP(sport=555, dport=556)
+            / Raw()
+        )
+        rxs = self.send_and_expect(self.pg0, [p], self.pg1)
+
+        # verify the data packet
+        peer_1.validate_encapped(rxs, p, is_ip6=is_ip6)
+
+        # remove configs
+        r1.remove_vpp_config()
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_peer_roaming_on_data_v4_sync(self):
+        """Peer roaming on data packet (v4, sync)"""
+        self._test_wg_peer_roaming_on_data_tmpl(is_async=False, is_ip6=False)
+
+    def test_wg_peer_roaming_on_data_v6_sync(self):
+        """Peer roaming on data packet (v6, sync)"""
+        self._test_wg_peer_roaming_on_data_tmpl(is_async=False, is_ip6=True)
+
+    def test_wg_peer_roaming_on_data_v4_async(self):
+        """Peer roaming on data packet (v4, async)"""
+        self._test_wg_peer_roaming_on_data_tmpl(is_async=True, is_ip6=False)
+
+    def test_wg_peer_roaming_on_data_v6_async(self):
+        """Peer roaming on data packet (v6, async)"""
+        self._test_wg_peer_roaming_on_data_tmpl(is_async=True, is_ip6=True)
+
     def test_wg_peer_resp(self):
         """Send handshake response"""
         port = 12323
@@ -1197,7 +1425,7 @@ class TestWg(VppTestCase):
         for rx in rxs:
             rx = IP(peer_1.decrypt_transport(rx))
 
-            # chech the oringial packet is present
+            # check the original packet is present
             self.assertEqual(rx[IP].dst, p[IP].dst)
             self.assertEqual(rx[IP].ttl, p[IP].ttl - 1)
 
@@ -1358,7 +1586,7 @@ class TestWg(VppTestCase):
         for rx in rxs:
             rx = IPv6(peer_1.decrypt_transport(rx, True))
 
-            # chech the oringial packet is present
+            # check the original packet is present
             self.assertEqual(rx[IPv6].dst, p[IPv6].dst)
             self.assertEqual(rx[IPv6].hlim, p[IPv6].hlim - 1)
 
@@ -1499,7 +1727,7 @@ class TestWg(VppTestCase):
         for rx in rxs:
             rx = IPv6(peer_1.decrypt_transport(rx))
 
-            # chech the oringial packet is present
+            # check the original packet is present
             self.assertEqual(rx[IPv6].dst, p[IPv6].dst)
             self.assertEqual(rx[IPv6].hlim, p[IPv6].hlim - 1)
 
@@ -1638,7 +1866,7 @@ class TestWg(VppTestCase):
         for rx in rxs:
             rx = IP(peer_1.decrypt_transport(rx, True))
 
-            # chech the oringial packet is present
+            # check the original packet is present
             self.assertEqual(rx[IP].dst, p[IP].dst)
             self.assertEqual(rx[IP].ttl, p[IP].ttl - 1)