wireguard: notify key changes to crypto engine
[vpp.git] / src / plugins / wireguard / wireguard_noise.c
index c9d8e31..5fe2e44 100644 (file)
@@ -33,8 +33,10 @@ noise_local_t *noise_local_pool;
 static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
 static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
                                       noise_keypair_t **);
-static uint32_t noise_remote_handshake_index_get (noise_remote_t *);
-static void noise_remote_handshake_index_drop (noise_remote_t *);
+static uint32_t noise_remote_handshake_index_get (vlib_main_t *vm,
+                                                 noise_remote_t *);
+static void noise_remote_handshake_index_drop (vlib_main_t *vm,
+                                              noise_remote_t *);
 
 static uint64_t noise_counter_send (noise_counter_t *);
 bool noise_counter_recv (noise_counter_t *, uint64_t);
@@ -86,7 +88,7 @@ noise_local_set_private (noise_local_t * l,
 }
 
 void
-noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
+noise_remote_init (vlib_main_t *vm, noise_remote_t *r, uint32_t peer_pool_idx,
                   const uint8_t public[NOISE_PUBLIC_KEY_LEN],
                   u32 noise_local_idx)
 {
@@ -97,18 +99,18 @@ noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
   r->r_local_idx = noise_local_idx;
   r->r_handshake.hs_state = HS_ZEROED;
 
-  noise_remote_precompute (r);
+  noise_remote_precompute (vm, r);
 }
 
 void
-noise_remote_precompute (noise_remote_t * r)
+noise_remote_precompute (vlib_main_t *vm, noise_remote_t *r)
 {
   noise_local_t *l = noise_local_get (r->r_local_idx);
 
   if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
 
-  noise_remote_handshake_index_drop (r);
+  noise_remote_handshake_index_drop (vm, r);
   wg_secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 }
 
@@ -142,6 +144,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
   /* es */
   if (!noise_mix_dh (hs->hs_ck, key, hs->hs_e, r->r_public))
     goto error;
+  vnet_crypto_key_update (vm, key_idx);
 
   /* s */
   noise_msg_encrypt (vm, es, l->l_public, NOISE_PUBLIC_KEY_LEN, key_idx,
@@ -150,13 +153,14 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
   /* ss */
   if (!noise_mix_ss (hs->hs_ck, key, r->r_ss))
     goto error;
+  vnet_crypto_key_update (vm, key_idx);
 
   /* {t} */
   noise_tai64n_now (ets);
   noise_msg_encrypt (vm, ets, ets, NOISE_TIMESTAMP_LEN, key_idx, hs->hs_hash);
-  noise_remote_handshake_index_drop (r);
+  noise_remote_handshake_index_drop (vm, r);
   hs->hs_state = CREATED_INITIATION;
-  hs->hs_local_index = noise_remote_handshake_index_get (r);
+  hs->hs_local_index = noise_remote_handshake_index_get (vm, r);
   *s_idx = hs->hs_local_index;
   ret = true;
 error:
@@ -196,6 +200,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
   /* es */
   if (!noise_mix_dh (hs.hs_ck, key, l->l_private, ue))
     goto error;
+  vnet_crypto_key_update (vm, key_idx);
 
   /* s */
 
@@ -211,6 +216,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
   /* ss */
   if (!noise_mix_ss (hs.hs_ck, key, r->r_ss))
     goto error;
+  vnet_crypto_key_update (vm, key_idx);
 
   /* {t} */
   if (!noise_msg_decrypt (vm, timestamp, ets,
@@ -237,7 +243,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
     goto error;
 
   /* Ok, we're happy to accept this initiation now */
-  noise_remote_handshake_index_drop (r);
+  noise_remote_handshake_index_drop (vm, r);
   r->r_handshake = hs;
   *rp = r;
   ret = true;
@@ -285,13 +291,14 @@ noise_create_response (vlib_main_t * vm, noise_remote_t * r, uint32_t * s_idx,
 
   /* psk */
   noise_mix_psk (hs->hs_ck, hs->hs_hash, key, r->r_psk);
+  vnet_crypto_key_update (vm, key_idx);
 
   /* {} */
   noise_msg_encrypt (vm, en, NULL, 0, key_idx, hs->hs_hash);
 
 
   hs->hs_state = CREATED_RESPONSE;
-  hs->hs_local_index = noise_remote_handshake_index_get (r);
+  hs->hs_local_index = noise_remote_handshake_index_get (vm, r);
   *r_idx = hs->hs_remote_index;
   *s_idx = hs->hs_local_index;
   ret = true;
@@ -339,6 +346,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
 
   /* psk */
   noise_mix_psk (hs.hs_ck, hs.hs_hash, key, preshared_key);
+  vnet_crypto_key_update (vm, key_idx);
 
   /* {} */
 
@@ -451,7 +459,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
 void
 noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
 {
-  noise_remote_handshake_index_drop (r);
+  noise_remote_handshake_index_drop (vm, r);
   wg_secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 
   clib_rwlock_writer_lock (&r->r_keypair_lock);
@@ -555,21 +563,21 @@ noise_remote_keypair_allocate (noise_remote_t * r)
 }
 
 static uint32_t
-noise_remote_handshake_index_get (noise_remote_t * r)
+noise_remote_handshake_index_get (vlib_main_t *vm, noise_remote_t *r)
 {
   noise_local_t *local = noise_local_get (r->r_local_idx);
   struct noise_upcall *u = &local->l_upcall;
-  return u->u_index_set (r);
+  return u->u_index_set (vm, r);
 }
 
 static void
-noise_remote_handshake_index_drop (noise_remote_t * r)
+noise_remote_handshake_index_drop (vlib_main_t *vm, noise_remote_t *r)
 {
   noise_handshake_t *hs = &r->r_handshake;
   noise_local_t *local = noise_local_get (r->r_local_idx);
   struct noise_upcall *u = &local->l_upcall;
   if (hs->hs_state != HS_ZEROED)
-    u->u_index_drop (hs->hs_local_index);
+    u->u_index_drop (vm, hs->hs_local_index);
 }
 
 static void