wireguard: add burst mode
[vpp.git] / src / plugins / wireguard / wireguard_noise.c
index 36de8ae..c8605f1 100644 (file)
@@ -36,7 +36,7 @@ static uint32_t noise_remote_handshake_index_get (noise_remote_t *);
 static void noise_remote_handshake_index_drop (noise_remote_t *);
 
 static uint64_t noise_counter_send (noise_counter_t *);
-static bool noise_counter_recv (noise_counter_t *, uint64_t);
+bool noise_counter_recv (noise_counter_t *, uint64_t);
 
 static void noise_kdf (uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
                       size_t, size_t, size_t, size_t,
@@ -407,6 +407,8 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
 
   /* Now we need to add_new_keypair */
   clib_rwlock_writer_lock (&r->r_keypair_lock);
+  /* Activate barrier to synchronization keys between threads */
+  vlib_worker_thread_barrier_sync (vm);
   next = r->r_next;
   current = r->r_current;
   previous = r->r_previous;
@@ -438,6 +440,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
       r->r_next = noise_remote_keypair_allocate (r);
       *r->r_next = kp;
     }
+  vlib_worker_thread_barrier_release (vm);
   clib_rwlock_writer_unlock (&r->r_keypair_lock);
 
   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
@@ -541,6 +544,41 @@ chacha20poly1305_calc (vlib_main_t * vm,
   return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
 }
 
+always_inline void
+wg_prepare_sync_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, u8 *src,
+                   u32 src_len, u8 *dst, u8 *aad, u32 aad_len, u64 nonce,
+                   vnet_crypto_op_id_t op_id,
+                   vnet_crypto_key_index_t key_index, u32 bi, u8 *iv)
+{
+  vnet_crypto_op_t _op, *op = &_op;
+  u8 src_[] = {};
+
+  clib_memset (iv, 0, 4);
+  clib_memcpy (iv + 4, &nonce, sizeof (nonce));
+
+  vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES);
+  vnet_crypto_op_init (op, op_id);
+
+  op->tag_len = NOISE_AUTHTAG_LEN;
+  if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
+    {
+      op->tag = src + src_len;
+      op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
+    }
+  else
+    op->tag = dst + src_len;
+
+  op->src = !src ? src_ : src;
+  op->len = src_len;
+
+  op->dst = dst;
+  op->key_index = key_index;
+  op->aad = aad;
+  op->aad_len = aad_len;
+  op->iv = iv;
+  op->user_data = bi;
+}
+
 enum noise_state_crypt
 noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
                      uint64_t * nonce, uint8_t * src, size_t srclen,
@@ -592,26 +630,67 @@ error:
 }
 
 enum noise_state_crypt
-noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
-                     uint64_t nonce, uint8_t * src, size_t srclen,
-                     uint8_t * dst)
+noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
+                          noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce,
+                          uint8_t *src, size_t srclen, uint8_t *dst, u32 bi,
+                          u8 *iv, f64 time)
 {
   noise_keypair_t *kp;
   enum noise_state_crypt ret = SC_FAILED;
 
-  if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
-    {
-      kp = r->r_current;
-    }
-  else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx)
-    {
-      kp = r->r_previous;
-    }
-  else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx)
-    {
-      kp = r->r_next;
-    }
-  else
+  if ((kp = r->r_current) == NULL)
+    goto error;
+
+  /* We confirm that our values are within our tolerances. We want:
+   *  - a valid keypair
+   *  - our keypair to be less than REJECT_AFTER_TIME seconds old
+   *  - our receive counter to be less than REJECT_AFTER_MESSAGES
+   *  - our send counter to be less than REJECT_AFTER_MESSAGES
+   */
+  if (!kp->kp_valid ||
+      wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
+                                   time) ||
+      kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
+      ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
+    goto error;
+
+  /* We encrypt into the same buffer, so the caller must ensure that buf
+   * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
+   * are passed back out to the caller through the provided data pointer. */
+  *r_idx = kp->kp_remote_index;
+
+  wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, *nonce,
+                     VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, kp->kp_send_index,
+                     bi, iv);
+
+  /* If our values are still within tolerances, but we are approaching
+   * the tolerances, we notify the caller with ESTALE that they should
+   * establish a new keypair. The current keypair can continue to be used
+   * until the tolerances are hit. We notify if:
+   *  - our send counter is valid and not less than REKEY_AFTER_MESSAGES
+   *  - we're the initiator and our keypair is older than
+   *    REKEY_AFTER_TIME seconds */
+  ret = SC_KEEP_KEY_FRESH;
+  if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
+      (kp->kp_is_initiator && wg_birthdate_has_expired_opt (
+                               kp->kp_birthdate, REKEY_AFTER_TIME, time)))
+    goto error;
+
+  ret = SC_OK;
+error:
+  return ret;
+}
+
+enum noise_state_crypt
+noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
+                          noise_remote_t *r, uint32_t r_idx, uint64_t nonce,
+                          uint8_t *src, size_t srclen, uint8_t *dst, u32 bi,
+                          u8 *iv, f64 time)
+{
+  noise_keypair_t *kp;
+  enum noise_state_crypt ret = SC_FAILED;
+
+  if ((kp = wg_get_active_keypair (r, r_idx)) == NULL)
     {
       goto error;
     }
@@ -620,20 +699,17 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
    * are the same as the encrypt routine.
    *
    * kp_ctr isn't locked here, we're happy to accept a racy read. */
-  if (wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
+  if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
+                                   time) ||
       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
     goto error;
 
   /* Decrypt, then validate the counter. We don't want to validate the
    * counter before decrypting as we do not know the message is authentic
    * prior to decryption. */
-  if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
-                             VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
-                             kp->kp_recv_index))
-    goto error;
-
-  if (!noise_counter_recv (&kp->kp_ctr, nonce))
-    goto error;
+  wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, nonce,
+                     VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, kp->kp_recv_index,
+                     bi, iv);
 
   /* If we've received the handshake confirming data packet then move the
    * next keypair into current. If we do slide the next keypair in, then
@@ -662,10 +738,9 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
    *    REKEY_AFTER_TIME_RECV seconds. */
   ret = SC_KEEP_KEY_FRESH;
   kp = r->r_current;
-  if (kp != NULL &&
-      kp->kp_valid &&
-      kp->kp_is_initiator &&
-      wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME_RECV))
+  if (kp != NULL && kp->kp_valid && kp->kp_is_initiator &&
+      wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV,
+                                   time))
     goto error;
 
   ret = SC_OK;
@@ -724,47 +799,6 @@ noise_counter_send (noise_counter_t * ctr)
   return ret;
 }
 
-static bool
-noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
-{
-  uint64_t i, top, index_recv, index_ctr;
-  unsigned long bit;
-  bool ret = false;
-
-  /* Check that the recv counter is valid */
-  if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
-    goto error;
-
-  /* If the packet is out of the window, invalid */
-  if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv)
-    goto error;
-
-  /* If the new counter is ahead of the current counter, we'll need to
-   * zero out the bitmap that has previously been used */
-  index_recv = recv / COUNTER_BITS;
-  index_ctr = ctr->c_recv / COUNTER_BITS;
-
-  if (recv > ctr->c_recv)
-    {
-      top = clib_min (index_recv - index_ctr, COUNTER_NUM);
-      for (i = 1; i <= top; i++)
-       ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0;
-      ctr->c_recv = recv;
-    }
-
-  index_recv %= COUNTER_NUM;
-  bit = 1ul << (recv % COUNTER_BITS);
-
-  if (ctr->c_backtrack[index_recv] & bit)
-    goto error;
-
-  ctr->c_backtrack[index_recv] |= bit;
-
-  ret = true;
-error:
-  return ret;
-}
-
 static void
 noise_kdf (uint8_t * a, uint8_t * b, uint8_t * c, const uint8_t * x,
           size_t a_len, size_t b_len, size_t c_len, size_t x_len,