wireguard: add handshake rate limiting support
[vpp.git] / src / plugins / wireguard / wireguard_cookie.c
index 595b877..4ebbfa0 100644 (file)
@@ -34,6 +34,11 @@ static void cookie_checker_make_cookie (vlib_main_t *vm, cookie_checker_t *,
                                        uint8_t[COOKIE_COOKIE_SIZE],
                                        ip46_address_t *ip, u16 udp_port);
 
+static void ratelimit_init (ratelimit_t *, ratelimit_entry_t *);
+static void ratelimit_deinit (ratelimit_t *);
+static void ratelimit_gc (ratelimit_t *, bool);
+static bool ratelimit_allow (ratelimit_t *, ip46_address_t *);
+
 /* Public Functions */
 void
 cookie_maker_init (cookie_maker_t * cp, const uint8_t key[COOKIE_INPUT_SIZE])
@@ -43,6 +48,14 @@ cookie_maker_init (cookie_maker_t * cp, const uint8_t key[COOKIE_INPUT_SIZE])
   cookie_precompute_key (cp->cp_cookie_key, key, COOKIE_COOKIE_KEY_LABEL);
 }
 
+void
+cookie_checker_init (cookie_checker_t *cc, ratelimit_entry_t *pool)
+{
+  clib_memset (cc, 0, sizeof (*cc));
+  ratelimit_init (&cc->cc_ratelimit_v4, pool);
+  ratelimit_init (&cc->cc_ratelimit_v6, pool);
+}
+
 void
 cookie_checker_update (cookie_checker_t * cc, uint8_t key[COOKIE_INPUT_SIZE])
 {
@@ -58,6 +71,13 @@ cookie_checker_update (cookie_checker_t * cc, uint8_t key[COOKIE_INPUT_SIZE])
     }
 }
 
+void
+cookie_checker_deinit (cookie_checker_t *cc)
+{
+  ratelimit_deinit (&cc->cc_ratelimit_v4);
+  ratelimit_deinit (&cc->cc_ratelimit_v6);
+}
+
 void
 cookie_checker_create_payload (vlib_main_t *vm, cookie_checker_t *cc,
                               message_macs_t *cm,
@@ -146,6 +166,13 @@ cookie_checker_validate_macs (vlib_main_t *vm, cookie_checker_t *cc,
   if (clib_memcmp (our_cm.mac2, cm->mac2, COOKIE_MAC_SIZE) != 0)
     return VALID_MAC_BUT_NO_COOKIE;
 
+  /* If the mac2 is valid, we may want to rate limit the peer */
+  ratelimit_t *rl;
+  rl = ip46_address_is_ip4 (ip) ? &cc->cc_ratelimit_v4 : &cc->cc_ratelimit_v6;
+
+  if (!ratelimit_allow (rl, ip))
+    return VALID_MAC_WITH_COOKIE_BUT_RATELIMITED;
+
   return VALID_MAC_WITH_COOKIE;
 }
 
@@ -213,6 +240,126 @@ cookie_checker_make_cookie (vlib_main_t *vm, cookie_checker_t *cc,
   blake2s_final (&state, cookie, COOKIE_COOKIE_SIZE);
 }
 
+static void
+ratelimit_init (ratelimit_t *rl, ratelimit_entry_t *pool)
+{
+  rl->rl_pool = pool;
+}
+
+static void
+ratelimit_deinit (ratelimit_t *rl)
+{
+  ratelimit_gc (rl, /* force */ true);
+  hash_free (rl->rl_table);
+}
+
+static void
+ratelimit_gc (ratelimit_t *rl, bool force)
+{
+  u32 r_key;
+  u32 r_idx;
+  ratelimit_entry_t *r;
+
+  if (force)
+    {
+      /* clang-format off */
+      hash_foreach (r_key, r_idx, rl->rl_table, {
+       r = pool_elt_at_index (rl->rl_pool, r_idx);
+       pool_put (rl->rl_pool, r);
+      });
+      /* clang-format on */
+      return;
+    }
+
+  f64 now = vlib_time_now (vlib_get_main ());
+
+  if ((rl->rl_last_gc + ELEMENT_TIMEOUT) < now)
+    {
+      u32 *r_key_to_del = NULL;
+      u32 *pr_key;
+
+      rl->rl_last_gc = now;
+
+      /* clang-format off */
+      hash_foreach (r_key, r_idx, rl->rl_table, {
+       r = pool_elt_at_index (rl->rl_pool, r_idx);
+       if ((r->r_last_time + ELEMENT_TIMEOUT) < now)
+         {
+           vec_add1 (r_key_to_del, r_key);
+           pool_put (rl->rl_pool, r);
+         }
+      });
+      /* clang-format on */
+
+      vec_foreach (pr_key, r_key_to_del)
+       {
+         hash_unset (rl->rl_table, *pr_key);
+       }
+
+      vec_free (r_key_to_del);
+    }
+}
+
+static bool
+ratelimit_allow (ratelimit_t *rl, ip46_address_t *ip)
+{
+  u32 r_key;
+  uword *p;
+  u32 r_idx;
+  ratelimit_entry_t *r;
+  f64 now = vlib_time_now (vlib_get_main ());
+
+  if (ip46_address_is_ip4 (ip))
+    /* Use all 4 bytes of IPv4 address */
+    r_key = ip->ip4.as_u32;
+  else
+    /* Use top 8 bytes (/64) of IPv6 address */
+    r_key = ip->ip6.as_u32[0] ^ ip->ip6.as_u32[1];
+
+  /* Check if there is already an entry for the IP address */
+  p = hash_get (rl->rl_table, r_key);
+  if (p)
+    {
+      u64 tokens;
+      f64 diff;
+
+      r_idx = p[0];
+      r = pool_elt_at_index (rl->rl_pool, r_idx);
+
+      diff = now - r->r_last_time;
+      r->r_last_time = now;
+
+      tokens = r->r_tokens + diff * NSEC_PER_SEC;
+
+      if (tokens > TOKEN_MAX)
+       tokens = TOKEN_MAX;
+
+      if (tokens >= INITIATION_COST)
+       {
+         r->r_tokens = tokens - INITIATION_COST;
+         return true;
+       }
+
+      r->r_tokens = tokens;
+      return false;
+    }
+
+  /* No entry for the IP address */
+  ratelimit_gc (rl, /* force */ false);
+
+  if (hash_elts (rl->rl_table) >= RATELIMIT_SIZE_MAX)
+    return false;
+
+  pool_get (rl->rl_pool, r);
+  r_idx = r - rl->rl_pool;
+  hash_set (rl->rl_table, r_key, r_idx);
+
+  r->r_last_time = now;
+  r->r_tokens = TOKEN_MAX - INITIATION_COST;
+
+  return true;
+}
+
 /*
  * fd.io coding-style-patch-verification: ON
  *