wireguard: add handshake rate limiting support
[vpp.git] / src / plugins / wireguard / wireguard_if.c
index ff8ed35..a869df0 100644 (file)
@@ -1,3 +1,18 @@
+/*
+ * Copyright (c) 2020 Cisco and/or its affiliates.
+ * Copyright (c) 2020 Doc.ai and/or its affiliates.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 
 #include <vnet/adj/adj_midchain.h>
 #include <vnet/udp/udp.h>
@@ -5,6 +20,7 @@
 #include <wireguard/wireguard_messages.h>
 #include <wireguard/wireguard_if.h>
 #include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
 
 /* pool of interfaces */
 wg_if_t *wg_if_pool;
@@ -16,13 +32,17 @@ static uword *wg_if_instances;
 static index_t *wg_if_index_by_sw_if_index;
 
 /* vector of interfaces key'd on their UDP port (in network order) */
-index_t *wg_if_index_by_port;
+index_t **wg_if_indexes_by_port;
+
+/* pool of ratelimit entries */
+static ratelimit_entry_t *wg_ratelimit_pool;
 
 static u8 *
 format_wg_if_name (u8 * s, va_list * args)
 {
   u32 dev_instance = va_arg (*args, u32);
-  return format (s, "wg%d", dev_instance);
+  wg_if_t *wgi = wg_if_get (dev_instance);
+  return format (s, "wg%d", wgi->user_instance);
 }
 
 u8 *
@@ -30,23 +50,32 @@ format_wg_if (u8 * s, va_list * args)
 {
   index_t wgii = va_arg (*args, u32);
   wg_if_t *wgi = wg_if_get (wgii);
+  noise_local_t *local = noise_local_get (wgi->local_idx);
   u8 key[NOISE_KEY_LEN_BASE64];
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
-
   s = format (s, "[%d] %U src:%U port:%d",
              wgii,
              format_vnet_sw_if_index_name, vnet_get_main (),
              wgi->sw_if_index, format_ip_address, &wgi->src_ip, wgi->port);
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " private-key:%s", key);
+  s =
+    format (s, " %U", format_hex_bytes, local->l_private,
+           NOISE_PUBLIC_KEY_LEN);
 
-  key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_public, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " public-key:%s", key);
 
+  s =
+    format (s, " %U", format_hex_bytes, local->l_public,
+           NOISE_PUBLIC_KEY_LEN);
+
+  s = format (s, " mac-key: %U", format_hex_bytes,
+             &wgi->cookie_checker.cc_mac1_key, NOISE_PUBLIC_KEY_LEN);
+
   return (s);
 }
 
@@ -62,23 +91,28 @@ wg_if_find_by_sw_if_index (u32 sw_if_index)
   return (ti);
 }
 
+static walk_rc_t
+wg_if_find_peer_by_public_key (index_t peeri, void *data)
+{
+  uint8_t *public = data;
+  wg_peer_t *peer = wg_peer_get (peeri);
+
+  if (!memcmp (peer->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
+    return (WALK_STOP);
+  return (WALK_CONTINUE);
+}
+
 static noise_remote_t *
-wg_remote_get (uint8_t public[NOISE_PUBLIC_KEY_LEN])
+wg_remote_get (const uint8_t public[NOISE_PUBLIC_KEY_LEN])
 {
-  wg_main_t *wmp = &wg_main;
-  wg_peer_t *peer = NULL;
-  wg_peer_t *peer_iter;
-  /* *INDENT-OFF* */
-  pool_foreach (peer_iter, wmp->peers,
-  ({
-    if (!memcmp (peer_iter->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
-    {
-      peer = peer_iter;
-      break;
-    }
-  }));
-  /* *INDENT-ON* */
-  return peer ? &peer->remote : NULL;
+  index_t peeri;
+
+  peeri = wg_peer_walk (wg_if_find_peer_by_public_key, (void *) public);
+
+  if (INDEX_INVALID != peeri)
+    return &wg_peer_get (peeri)->remote;
+
+  return NULL;
 }
 
 static uint32_t
@@ -120,7 +154,18 @@ wg_if_admin_up_down (vnet_main_t * vnm, u32 hw_if_index, u32 flags)
 void
 wg_if_update_adj (vnet_main_t * vnm, u32 sw_if_index, adj_index_t ai)
 {
-  /* The peers manage the adjacencies */
+  index_t wgii;
+
+  /* Convert any neighbour adjacency that has a next-hop reachable through
+   * the wg interface into a midchain. This is to avoid sending ARP/ND to
+   * resolve the next-hop address via the wg interface. Then, if one of the
+   * peers has matching prefix among allowed prefixes, the midchain will be
+   * updated to the corresponding one.
+   */
+  adj_nbr_midchain_update_rewrite (ai, NULL, NULL, ADJ_FLAG_NONE, NULL);
+
+  wgii = wg_if_find_by_sw_if_index (sw_if_index);
+  wg_if_peer_walk (wg_if_get (wgii), wg_peer_if_adj_change, &ai);
 }
 
 
@@ -213,6 +258,7 @@ wg_if_create (u32 user_instance,
   u32 instance, hw_if_index;
   vnet_hw_interface_t *hi;
   wg_if_t *wg_if;
+  noise_local_t *local;
 
   ASSERT (sw_if_indexp);
 
@@ -226,7 +272,25 @@ wg_if_create (u32 user_instance,
   if (instance == ~0)
     return VNET_API_ERROR_INVALID_REGISTRATION;
 
-  pool_get (wg_if_pool, wg_if);
+  /* *INDENT-OFF* */
+  struct noise_upcall upcall =  {
+    .u_remote_get = wg_remote_get,
+    .u_index_set = wg_index_set,
+    .u_index_drop = wg_index_drop,
+  };
+  /* *INDENT-ON* */
+
+  pool_get (noise_local_pool, local);
+
+  noise_local_init (local, &upcall);
+  if (!noise_local_set_private (local, private_key))
+    {
+      pool_put (noise_local_pool, local);
+      wg_if_instance_free (instance);
+      return VNET_API_ERROR_INVALID_REGISTRATION;
+    }
+
+  pool_get_zero (wg_if_pool, wg_if);
 
   /* tunnel index (or instance) */
   u32 t_idx = wg_if - wg_if_pool;
@@ -235,23 +299,21 @@ wg_if_create (u32 user_instance,
   if (~0 == wg_if->user_instance)
     wg_if->user_instance = t_idx;
 
-  udp_dst_port_info_t *pi = udp_get_dst_port_info (&udp_main, port, UDP_IP4);
-  if (pi)
-    return (VNET_API_ERROR_VALUE_EXIST);
-  udp_register_dst_port (vlib_get_main (), port, wg_input_node.index, 1);
+  vec_validate_init_empty (wg_if_indexes_by_port, port, NULL);
+  if (vec_len (wg_if_indexes_by_port[port]) == 0)
+    {
+      udp_register_dst_port (vlib_get_main (), port, wg4_input_node.index,
+                            UDP_IP4);
+      udp_register_dst_port (vlib_get_main (), port, wg6_input_node.index,
+                            UDP_IP6);
+    }
 
-  vec_validate_init_empty (wg_if_index_by_port, port, INDEX_INVALID);
-  wg_if_index_by_port[port] = wg_if - wg_if_pool;
+  vec_add1 (wg_if_indexes_by_port[port], t_idx);
 
   wg_if->port = port;
-  struct noise_upcall upcall;
-  upcall.u_remote_get = wg_remote_get;
-  upcall.u_index_set = wg_index_set;
-  upcall.u_index_drop = wg_index_drop;
-
-  noise_local_init (&wg_if->local, &upcall);
-  noise_local_set_private (&wg_if->local, private_key);
-  cookie_checker_update (&wg_if->cookie_checker, wg_if->local.l_public);
+  wg_if->local_idx = local - noise_local_pool;
+  cookie_checker_init (&wg_if->cookie_checker, wg_ratelimit_pool);
+  cookie_checker_update (&wg_if->cookie_checker, local->l_public);
 
   hw_if_index = vnet_register_interface (vnm,
                                         wg_if_device_class.index,
@@ -266,6 +328,8 @@ wg_if_create (u32 user_instance,
 
   ip_address_copy (&wg_if->src_ip, src_ip);
   wg_if->sw_if_index = *sw_if_indexp = hi->sw_if_index;
+  vnet_set_interface_l3_output_node (vnm->vlib_main, hi->sw_if_index,
+                                    (u8 *) "tunnel-output");
 
   return 0;
 }
@@ -280,18 +344,43 @@ wg_if_delete (u32 sw_if_index)
 
   vnet_hw_interface_t *hw = vnet_get_sup_hw_interface (vnm, sw_if_index);
   if (hw == 0 || hw->dev_class_index != wg_if_device_class.index)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+    return VNET_API_ERROR_INVALID_VALUE;
 
   wg_if_t *wg_if;
-  wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index));
+  index_t wgii = wg_if_find_by_sw_if_index (sw_if_index);
+  wg_if = wg_if_get (wgii);
   if (NULL == wg_if)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+    return VNET_API_ERROR_INVALID_SW_IF_INDEX_2;
 
-  if (wg_if_instance_free (hw->dev_instance) < 0)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+  if (wg_if_instance_free (wg_if->user_instance) < 0)
+    return VNET_API_ERROR_INVALID_VALUE_2;
+
+  // Remove peers before interface deletion
+  wg_if_peer_walk (wg_if, wg_peer_if_delete, NULL);
+
+  hash_free (wg_if->peers);
+
+  index_t *ii;
+  index_t *ifs = wg_if_indexes_get_by_port (wg_if->port);
+  vec_foreach (ii, ifs)
+    {
+      if (*ii == wgii)
+       {
+         vec_del1 (ifs, ifs - ii);
+         break;
+       }
+    }
+  if (vec_len (ifs) == 0)
+    {
+      udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1);
+      udp_unregister_dst_port (vlib_get_main (), wg_if->port, 0);
+    }
 
-  wg_if_index_by_port[wg_if->port] = INDEX_INVALID;
+  cookie_checker_deinit (&wg_if->cookie_checker);
+
+  vnet_reset_interface_l3_output_node (vnm->vlib_main, sw_if_index);
   vnet_delete_hw_interface (vnm, hw->hw_if_index);
+  pool_put_index (noise_local_pool, wg_if->local_idx);
   pool_put (wg_if_pool, wg_if);
 
   return 0;
@@ -303,8 +392,12 @@ wg_if_peer_add (wg_if_t * wgi, index_t peeri)
   hash_set (wgi->peers, peeri, peeri);
 
   if (1 == hash_elts (wgi->peers))
-    vnet_feature_enable_disable ("ip4-output", "wg-output-tun",
-                                wgi->sw_if_index, 1, 0, 0);
+    {
+      vnet_feature_enable_disable ("ip4-output", "wg4-output-tun",
+                                  wgi->sw_if_index, 1, 0, 0);
+      vnet_feature_enable_disable ("ip6-output", "wg6-output-tun",
+                                  wgi->sw_if_index, 1, 0, 0);
+    }
 }
 
 void
@@ -313,8 +406,12 @@ wg_if_peer_remove (wg_if_t * wgi, index_t peeri)
   hash_unset (wgi->peers, peeri);
 
   if (0 == hash_elts (wgi->peers))
-    vnet_feature_enable_disable ("ip4-output", "wg-output-tun",
-                                wgi->sw_if_index, 0, 0, 0);
+    {
+      vnet_feature_enable_disable ("ip4-output", "wg4-output-tun",
+                                  wgi->sw_if_index, 0, 0, 0);
+      vnet_feature_enable_disable ("ip6-output", "wg6-output-tun",
+                                  wgi->sw_if_index, 0, 0, 0);
+    }
 }
 
 void
@@ -323,96 +420,29 @@ wg_if_walk (wg_if_walk_cb_t fn, void *data)
   index_t wgii;
 
   /* *INDENT-OFF* */
-  pool_foreach_index (wgii, wg_if_pool,
+  pool_foreach_index (wgii, wg_if_pool)
   {
     if (WALK_STOP == fn(wgii, data))
       break;
-  });
+  }
   /* *INDENT-ON* */
 }
 
-void
+index_t
 wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
 {
   index_t peeri, val;
 
   /* *INDENT-OFF* */
-  hash_foreach (peeri, val, wgi->peers,
-  {
-    if (WALK_STOP == fn(wgi, peeri, data))
-      break;
+  hash_foreach (peeri, val, wgi->peers, {
+    if (WALK_STOP == fn (peeri, data))
+      return peeri;
   });
   /* *INDENT-ON* */
-}
 
-
-static void
-wg_if_table_bind_v4 (ip4_main_t * im,
-                    uword opaque,
-                    u32 sw_if_index, u32 new_fib_index, u32 old_fib_index)
-{
-  wg_if_t *wg_if;
-
-  wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index));
-  if (NULL == wg_if)
-    return;
-
-  wg_peer_table_bind_ctx_t ctx = {
-    .af = AF_IP4,
-    .old_fib_index = old_fib_index,
-    .new_fib_index = new_fib_index,
-  };
-
-  wg_if_peer_walk (wg_if, wg_peer_if_table_change, &ctx);
+  return INDEX_INVALID;
 }
 
-static void
-wg_if_table_bind_v6 (ip6_main_t * im,
-                    uword opaque,
-                    u32 sw_if_index, u32 new_fib_index, u32 old_fib_index)
-{
-  wg_if_t *wg_if;
-
-  wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index));
-  if (NULL == wg_if)
-    return;
-
-  wg_peer_table_bind_ctx_t ctx = {
-    .af = AF_IP6,
-    .old_fib_index = old_fib_index,
-    .new_fib_index = new_fib_index,
-  };
-
-  wg_if_peer_walk (wg_if, wg_peer_if_table_change, &ctx);
-}
-
-static clib_error_t *
-wg_if_module_init (vlib_main_t * vm)
-{
-  {
-    ip4_table_bind_callback_t cb = {
-      .function = wg_if_table_bind_v4,
-    };
-    vec_add1 (ip4_main.table_bind_callbacks, cb);
-  }
-  {
-    ip6_table_bind_callback_t cb = {
-      .function = wg_if_table_bind_v6,
-    };
-    vec_add1 (ip6_main.table_bind_callbacks, cb);
-  }
-
-  return (NULL);
-}
-
-/* *INDENT-OFF* */
-VLIB_INIT_FUNCTION (wg_if_module_init) =
-{
-  .runs_after = VLIB_INITS("ip_main_init"),
-};
-/* *INDENT-ON* */
-
-
 /*
  * fd.io coding-style-patch-verification: ON
  *