wireguard: add flag to check hmac for decryption
[vpp.git] / src / plugins / wireguard / wireguard_noise.c
index dc7d506..7b4c019 100755 (executable)
@@ -26,6 +26,8 @@
  * <- e, ee, se, psk, {}
  */
 
+noise_local_t *noise_local_pool;
+
 /* Private functions */
 static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
 static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
@@ -80,81 +82,31 @@ noise_local_set_private (noise_local_t * l,
                         const uint8_t private[NOISE_PUBLIC_KEY_LEN])
 {
   clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
-  l->l_has_identity = curve25519_gen_public (l->l_public, private);
 
-  return l->l_has_identity;
-}
-
-bool
-noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                 uint8_t private[NOISE_PUBLIC_KEY_LEN])
-{
-  if (l->l_has_identity)
-    {
-      if (public != NULL)
-       clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN);
-      if (private != NULL)
-       clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN);
-    }
-  else
-    {
-      return false;
-    }
-  return true;
+  return curve25519_gen_public (l->l_public, private);
 }
 
 void
 noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
                   const uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                  noise_local_t * l)
+                  u32 noise_local_idx)
 {
   clib_memset (r, 0, sizeof (*r));
   clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+  clib_rwlock_init (&r->r_keypair_lock);
   r->r_peer_idx = peer_pool_idx;
-
-  ASSERT (l != NULL);
-  r->r_local = l;
+  r->r_local_idx = noise_local_idx;
   r->r_handshake.hs_state = HS_ZEROED;
-  noise_remote_precompute (r);
-}
 
-bool
-noise_remote_set_psk (noise_remote_t * r,
-                     uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
-  int same;
-  same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
-  if (!same)
-    {
-      clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
-    }
-  return same == 0;
-}
-
-bool
-noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                  uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
-  static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
-  int ret;
-
-  if (public != NULL)
-    clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN);
-
-  if (psk != NULL)
-    clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
-  ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
-
-  return ret;
+  noise_remote_precompute (r);
 }
 
 void
 noise_remote_precompute (noise_remote_t * r)
 {
-  noise_local_t *l = r->r_local;
-  if (!l->l_has_identity)
-    clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
-  else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
+  noise_local_t *l = noise_local_get (r->r_local_idx);
+
+  if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
 
   noise_remote_handshake_index_drop (r);
@@ -169,7 +121,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
                         uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
 {
   noise_handshake_t *hs = &r->r_handshake;
-  noise_local_t *l = r->r_local;
+  noise_local_t *l = noise_local_get (r->r_local_idx);
   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
   uint32_t key_idx;
   uint8_t *key;
@@ -180,8 +132,6 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
   noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
 
   /* e */
@@ -211,8 +161,8 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
   *s_idx = hs->hs_local_index;
   ret = true;
 error:
-  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  vnet_crypto_key_del (vm, key_idx);
   return ret;
 }
 
@@ -239,8 +189,6 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
   noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
 
   /* e */
@@ -294,9 +242,10 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
   r->r_handshake = hs;
   *rp = r;
   ret = true;
+
 error:
-  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (&hs, sizeof (hs));
   return ret;
 }
@@ -348,8 +297,8 @@ noise_create_response (vlib_main_t * vm, noise_remote_t * r, uint32_t * s_idx,
   *s_idx = hs->hs_local_index;
   ret = true;
 error:
-  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (e, NOISE_PUBLIC_KEY_LEN);
   return ret;
 }
@@ -359,7 +308,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
                        uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
                        uint8_t en[0 + NOISE_AUTHTAG_LEN])
 {
-  noise_local_t *l = r->r_local;
+  noise_local_t *l = noise_local_get (r->r_local_idx);
   noise_handshake_t hs;
   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
   uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
@@ -372,9 +321,6 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
-
   hs = r->r_handshake;
   clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
 
@@ -412,9 +358,9 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
       ret = true;
     }
 error:
-  vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (&hs, sizeof (hs));
   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
+  vnet_crypto_key_del (vm, key_idx);
   return ret;
 }
 
@@ -460,6 +406,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
   clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
 
   /* Now we need to add_new_keypair */
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   next = r->r_next;
   current = r->r_current;
   previous = r->r_previous;
@@ -491,7 +438,10 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
       r->r_next = noise_remote_keypair_allocate (r);
       *r->r_next = kp;
     }
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
+
   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+
   secure_zero_memory (&kp, sizeof (kp));
   return true;
 }
@@ -502,21 +452,25 @@ noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
   noise_remote_handshake_index_drop (r);
   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   noise_remote_keypair_free (vm, r, &r->r_next);
   noise_remote_keypair_free (vm, r, &r->r_current);
   noise_remote_keypair_free (vm, r, &r->r_previous);
   r->r_next = NULL;
   r->r_current = NULL;
   r->r_previous = NULL;
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
 }
 
 void
 noise_remote_expire_current (noise_remote_t * r)
 {
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   if (r->r_next != NULL)
     r->r_next->kp_valid = 0;
   if (r->r_current != NULL)
     r->r_current->kp_valid = 0;
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
 }
 
 bool
@@ -525,6 +479,7 @@ noise_remote_ready (noise_remote_t * r)
   noise_keypair_t *kp;
   int ret;
 
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
   if ((kp = r->r_current) == NULL ||
       !kp->kp_valid ||
       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
@@ -533,10 +488,11 @@ noise_remote_ready (noise_remote_t * r)
     ret = false;
   else
     ret = true;
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
-static void
+static bool
 chacha20poly1305_calc (vlib_main_t * vm,
                       u8 * src,
                       u32 src_len,
@@ -562,6 +518,7 @@ chacha20poly1305_calc (vlib_main_t * vm,
     {
       op->tag = src + src_len - NOISE_AUTHTAG_LEN;
       src_len -= NOISE_AUTHTAG_LEN;
+      op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
     }
   else
     op->tag = tag_;
@@ -580,6 +537,8 @@ chacha20poly1305_calc (vlib_main_t * vm,
     {
       clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
     }
+
+  return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
 }
 
 enum noise_state_crypt
@@ -590,6 +549,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
   noise_keypair_t *kp;
   enum noise_state_crypt ret = SC_FAILED;
 
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
   if ((kp = r->r_current) == NULL)
     goto error;
 
@@ -629,6 +589,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
 
   ret = SC_OK;
 error:
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
@@ -639,6 +600,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
 {
   noise_keypair_t *kp;
   enum noise_state_crypt ret = SC_FAILED;
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
 
   if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
     {
@@ -668,9 +630,10 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
   /* Decrypt, then validate the counter. We don't want to validate the
    * counter before decrypting as we do not know the message is authentic
    * prior to decryption. */
-  chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
-                        VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
-                        kp->kp_recv_index);
+  if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
+                             VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
+                             kp->kp_recv_index))
+    goto error;
 
   if (!noise_counter_recv (&kp->kp_ctr, nonce))
     goto error;
@@ -679,18 +642,26 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
    * next keypair into current. If we do slide the next keypair in, then
    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
    * data packet can't confirm a session that we are an INITIATOR of. */
-  if (kp == r->r_next && kp->kp_local_index == r_idx)
+  if (kp == r->r_next)
     {
-      noise_remote_keypair_free (vm, r, &r->r_previous);
-      r->r_previous = r->r_current;
-      r->r_current = r->r_next;
-      r->r_next = NULL;
+      clib_rwlock_reader_unlock (&r->r_keypair_lock);
+      clib_rwlock_writer_lock (&r->r_keypair_lock);
+      if (kp == r->r_next && kp->kp_local_index == r_idx)
+       {
+         noise_remote_keypair_free (vm, r, &r->r_previous);
+         r->r_previous = r->r_current;
+         r->r_current = r->r_next;
+         r->r_next = NULL;
 
-      ret = SC_CONN_RESET;
-      goto error;
+         ret = SC_CONN_RESET;
+         clib_rwlock_writer_unlock (&r->r_keypair_lock);
+         clib_rwlock_reader_lock (&r->r_keypair_lock);
+         goto error;
+       }
+      clib_rwlock_writer_unlock (&r->r_keypair_lock);
+      clib_rwlock_reader_lock (&r->r_keypair_lock);
     }
 
-
   /* Similar to when we encrypt, we want to notify the caller when we
    * are approaching our tolerances. We notify if:
    *  - we're the initiator and the current keypair is older than
@@ -705,6 +676,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
 
   ret = SC_OK;
 error:
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
@@ -722,7 +694,8 @@ static void
 noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
                           noise_keypair_t ** kp)
 {
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   if (*kp)
     {
       u->u_index_drop ((*kp)->kp_local_index);
@@ -735,7 +708,8 @@ noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
 static uint32_t
 noise_remote_handshake_index_get (noise_remote_t * r)
 {
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   return u->u_index_set (r);
 }
 
@@ -743,7 +717,8 @@ static void
 noise_remote_handshake_index_drop (noise_remote_t * r)
 {
   noise_handshake_t *hs = &r->r_handshake;
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   if (hs->hs_state != HS_ZEROED)
     u->u_index_drop (hs->hs_local_index);
 }
@@ -751,7 +726,8 @@ noise_remote_handshake_index_drop (noise_remote_t * r)
 static uint64_t
 noise_counter_send (noise_counter_t * ctr)
 {
-  uint64_t ret = ctr->c_send++;
+  uint64_t ret;
+  ret = ctr->c_send++;
   return ret;
 }
 
@@ -762,7 +738,6 @@ noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
   unsigned long bit;
   bool ret = false;
 
-
   /* Check that the recv counter is valid */
   if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
     goto error;
@@ -936,8 +911,9 @@ noise_msg_decrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
                   uint8_t hash[NOISE_HASH_LEN])
 {
   /* Nonce always zero for Noise_IK */
-  chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
-                        VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx);
+  if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
+                             VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx))
+    return false;
   noise_mix_hash (hash, src, src_len);
   return true;
 }