wireguard: add handoff node
[vpp.git] / src / plugins / wireguard / wireguard_if.c
index c91667b..7509923 100644 (file)
@@ -5,6 +5,7 @@
 #include <wireguard/wireguard_messages.h>
 #include <wireguard/wireguard_if.h>
 #include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
 
 /* pool of interfaces */
 wg_if_t *wg_if_pool;
@@ -30,28 +31,28 @@ format_wg_if (u8 * s, va_list * args)
 {
   index_t wgii = va_arg (*args, u32);
   wg_if_t *wgi = wg_if_get (wgii);
+  noise_local_t *local = noise_local_get (wgi->local_idx);
   u8 key[NOISE_KEY_LEN_BASE64];
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, "[%d] %U src:%U port:%d",
              wgii,
              format_vnet_sw_if_index_name, vnet_get_main (),
              wgi->sw_if_index, format_ip_address, &wgi->src_ip, wgi->port);
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " private-key:%s", key);
   s =
-    format (s, " %U", format_hex_bytes, wgi->local.l_private,
+    format (s, " %U", format_hex_bytes, local->l_private,
            NOISE_PUBLIC_KEY_LEN);
 
-  key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_public, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " public-key:%s", key);
 
   s =
-    format (s, " %U", format_hex_bytes, wgi->local.l_public,
+    format (s, " %U", format_hex_bytes, local->l_public,
            NOISE_PUBLIC_KEY_LEN);
 
   s = format (s, " mac-key: %U", format_hex_bytes,
@@ -72,23 +73,28 @@ wg_if_find_by_sw_if_index (u32 sw_if_index)
   return (ti);
 }
 
+static walk_rc_t
+wg_if_find_peer_by_public_key (index_t peeri, void *data)
+{
+  uint8_t *public = data;
+  wg_peer_t *peer = wg_peer_get (peeri);
+
+  if (!memcmp (peer->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
+    return (WALK_STOP);
+  return (WALK_CONTINUE);
+}
+
 static noise_remote_t *
-wg_remote_get (uint8_t public[NOISE_PUBLIC_KEY_LEN])
+wg_remote_get (const uint8_t public[NOISE_PUBLIC_KEY_LEN])
 {
-  wg_main_t *wmp = &wg_main;
-  wg_peer_t *peer = NULL;
-  wg_peer_t *peer_iter;
-  /* *INDENT-OFF* */
-  pool_foreach (peer_iter, wmp->peers,
-  ({
-    if (!memcmp (peer_iter->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
-    {
-      peer = peer_iter;
-      break;
-    }
-  }));
-  /* *INDENT-ON* */
-  return peer ? &peer->remote : NULL;
+  index_t peeri;
+
+  peeri = wg_peer_walk (wg_if_find_peer_by_public_key, (void *) public);
+
+  if (INDEX_INVALID != peeri)
+    return &wg_peer_get (peeri)->remote;
+
+  return NULL;
 }
 
 static uint32_t
@@ -223,6 +229,7 @@ wg_if_create (u32 user_instance,
   u32 instance, hw_if_index;
   vnet_hw_interface_t *hi;
   wg_if_t *wg_if;
+  noise_local_t *local;
 
   ASSERT (sw_if_indexp);
 
@@ -236,6 +243,24 @@ wg_if_create (u32 user_instance,
   if (instance == ~0)
     return VNET_API_ERROR_INVALID_REGISTRATION;
 
+  /* *INDENT-OFF* */
+  struct noise_upcall upcall =  {
+    .u_remote_get = wg_remote_get,
+    .u_index_set = wg_index_set,
+    .u_index_drop = wg_index_drop,
+  };
+  /* *INDENT-ON* */
+
+  pool_get (noise_local_pool, local);
+
+  noise_local_init (local, &upcall);
+  if (!noise_local_set_private (local, private_key))
+    {
+      pool_put (noise_local_pool, local);
+      wg_if_instance_free (instance);
+      return VNET_API_ERROR_INVALID_REGISTRATION;
+    }
+
   pool_get (wg_if_pool, wg_if);
 
   /* tunnel index (or instance) */
@@ -251,18 +276,8 @@ wg_if_create (u32 user_instance,
   wg_if_index_by_port[port] = wg_if - wg_if_pool;
 
   wg_if->port = port;
-
-  /* *INDENT-OFF* */
-  struct noise_upcall upcall =  {
-    .u_remote_get = wg_remote_get,
-    .u_index_set = wg_index_set,
-    .u_index_drop = wg_index_drop,
-  };
-  /* *INDENT-ON* */
-
-  noise_local_init (&wg_if->local, &upcall);
-  noise_local_set_private (&wg_if->local, private_key);
-  cookie_checker_update (&wg_if->cookie_checker, wg_if->local.l_public);
+  wg_if->local_idx = local - noise_local_pool;
+  cookie_checker_update (&wg_if->cookie_checker, local->l_public);
 
   hw_if_index = vnet_register_interface (vnm,
                                         wg_if_device_class.index,
@@ -304,6 +319,7 @@ wg_if_delete (u32 sw_if_index)
   udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1);
   wg_if_index_by_port[wg_if->port] = INDEX_INVALID;
   vnet_delete_hw_interface (vnm, hw->hw_if_index);
+  pool_put_index (noise_local_pool, wg_if->local_idx);
   pool_put (wg_if_pool, wg_if);
 
   return 0;
@@ -343,7 +359,7 @@ wg_if_walk (wg_if_walk_cb_t fn, void *data)
   /* *INDENT-ON* */
 }
 
-void
+index_t
 wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
 {
   index_t peeri, val;
@@ -352,9 +368,11 @@ wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
   hash_foreach (peeri, val, wgi->peers,
   {
     if (WALK_STOP == fn(wgi, peeri, data))
-      break;
+      return peeri;
   });
   /* *INDENT-ON* */
+
+  return INDEX_INVALID;
 }