c8605f117cddefef087e92145b5556138a2ad600
[vpp.git] / src / plugins / wireguard / wireguard_noise.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>.
4  * Copyright (c) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>.
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at:
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include <openssl/hmac.h>
19 #include <wireguard/wireguard.h>
20
21 /* This implements Noise_IKpsk2:
22  *
23  * <- s
24  * ******
25  * -> e, es, s, ss, {t}
26  * <- e, ee, se, psk, {}
27  */
28
29 noise_local_t *noise_local_pool;
30
31 /* Private functions */
32 static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
33 static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
34                                        noise_keypair_t **);
35 static uint32_t noise_remote_handshake_index_get (noise_remote_t *);
36 static void noise_remote_handshake_index_drop (noise_remote_t *);
37
38 static uint64_t noise_counter_send (noise_counter_t *);
39 bool noise_counter_recv (noise_counter_t *, uint64_t);
40
41 static void noise_kdf (uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
42                        size_t, size_t, size_t, size_t,
43                        const uint8_t[NOISE_HASH_LEN]);
44 static bool noise_mix_dh (uint8_t[NOISE_HASH_LEN],
45                           uint8_t[NOISE_SYMMETRIC_KEY_LEN],
46                           const uint8_t[NOISE_PUBLIC_KEY_LEN],
47                           const uint8_t[NOISE_PUBLIC_KEY_LEN]);
48 static bool noise_mix_ss (uint8_t ck[NOISE_HASH_LEN],
49                           uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
50                           const uint8_t ss[NOISE_PUBLIC_KEY_LEN]);
51 static void noise_mix_hash (uint8_t[NOISE_HASH_LEN], const uint8_t *, size_t);
52 static void noise_mix_psk (uint8_t[NOISE_HASH_LEN],
53                            uint8_t[NOISE_HASH_LEN],
54                            uint8_t[NOISE_SYMMETRIC_KEY_LEN],
55                            const uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
56 static void noise_param_init (uint8_t[NOISE_HASH_LEN],
57                               uint8_t[NOISE_HASH_LEN],
58                               const uint8_t[NOISE_PUBLIC_KEY_LEN]);
59
60 static void noise_msg_encrypt (vlib_main_t * vm, uint8_t *, uint8_t *, size_t,
61                                uint32_t key_idx, uint8_t[NOISE_HASH_LEN]);
62 static bool noise_msg_decrypt (vlib_main_t * vm, uint8_t *, uint8_t *, size_t,
63                                uint32_t key_idx, uint8_t[NOISE_HASH_LEN]);
64 static void noise_msg_ephemeral (uint8_t[NOISE_HASH_LEN],
65                                  uint8_t[NOISE_HASH_LEN],
66                                  const uint8_t src[NOISE_PUBLIC_KEY_LEN]);
67
68 static void noise_tai64n_now (uint8_t[NOISE_TIMESTAMP_LEN]);
69
70 static void secure_zero_memory (void *v, size_t n);
71
72 /* Set/Get noise parameters */
73 void
74 noise_local_init (noise_local_t * l, struct noise_upcall *upcall)
75 {
76   clib_memset (l, 0, sizeof (*l));
77   l->l_upcall = *upcall;
78 }
79
80 bool
81 noise_local_set_private (noise_local_t * l,
82                          const uint8_t private[NOISE_PUBLIC_KEY_LEN])
83 {
84   clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
85
86   return curve25519_gen_public (l->l_public, private);
87 }
88
89 void
90 noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
91                    const uint8_t public[NOISE_PUBLIC_KEY_LEN],
92                    u32 noise_local_idx)
93 {
94   clib_memset (r, 0, sizeof (*r));
95   clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
96   clib_rwlock_init (&r->r_keypair_lock);
97   r->r_peer_idx = peer_pool_idx;
98   r->r_local_idx = noise_local_idx;
99   r->r_handshake.hs_state = HS_ZEROED;
100
101   noise_remote_precompute (r);
102 }
103
104 void
105 noise_remote_precompute (noise_remote_t * r)
106 {
107   noise_local_t *l = noise_local_get (r->r_local_idx);
108
109   if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
110     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
111
112   noise_remote_handshake_index_drop (r);
113   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
114 }
115
116 /* Handshake functions */
117 bool
118 noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
119                          uint32_t * s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
120                          uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
121                          uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
122 {
123   noise_handshake_t *hs = &r->r_handshake;
124   noise_local_t *l = noise_local_get (r->r_local_idx);
125   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
126   uint32_t key_idx;
127   uint8_t *key;
128   int ret = false;
129
130   key_idx =
131     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
132                          NOISE_SYMMETRIC_KEY_LEN);
133   key = vnet_crypto_get_key (key_idx)->data;
134
135   noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
136
137   /* e */
138   curve25519_gen_secret (hs->hs_e);
139   if (!curve25519_gen_public (ue, hs->hs_e))
140     goto error;
141   noise_msg_ephemeral (hs->hs_ck, hs->hs_hash, ue);
142
143   /* es */
144   if (!noise_mix_dh (hs->hs_ck, key, hs->hs_e, r->r_public))
145     goto error;
146
147   /* s */
148   noise_msg_encrypt (vm, es, l->l_public, NOISE_PUBLIC_KEY_LEN, key_idx,
149                      hs->hs_hash);
150
151   /* ss */
152   if (!noise_mix_ss (hs->hs_ck, key, r->r_ss))
153     goto error;
154
155   /* {t} */
156   noise_tai64n_now (ets);
157   noise_msg_encrypt (vm, ets, ets, NOISE_TIMESTAMP_LEN, key_idx, hs->hs_hash);
158   noise_remote_handshake_index_drop (r);
159   hs->hs_state = CREATED_INITIATION;
160   hs->hs_local_index = noise_remote_handshake_index_get (r);
161   *s_idx = hs->hs_local_index;
162   ret = true;
163 error:
164   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
165   vnet_crypto_key_del (vm, key_idx);
166   return ret;
167 }
168
169 bool
170 noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
171                           noise_remote_t ** rp, uint32_t s_idx,
172                           uint8_t ue[NOISE_PUBLIC_KEY_LEN],
173                           uint8_t es[NOISE_PUBLIC_KEY_LEN +
174                                      NOISE_AUTHTAG_LEN],
175                           uint8_t ets[NOISE_TIMESTAMP_LEN +
176                                       NOISE_AUTHTAG_LEN])
177 {
178   noise_remote_t *r;
179   noise_handshake_t hs;
180   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
181   uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
182   uint8_t timestamp[NOISE_TIMESTAMP_LEN];
183   u32 key_idx;
184   uint8_t *key;
185   int ret = false;
186
187   key_idx =
188     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
189                          NOISE_SYMMETRIC_KEY_LEN);
190   key = vnet_crypto_get_key (key_idx)->data;
191
192   noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
193
194   /* e */
195   noise_msg_ephemeral (hs.hs_ck, hs.hs_hash, ue);
196
197   /* es */
198   if (!noise_mix_dh (hs.hs_ck, key, l->l_private, ue))
199     goto error;
200
201   /* s */
202
203   if (!noise_msg_decrypt (vm, r_public, es,
204                           NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key_idx,
205                           hs.hs_hash))
206     goto error;
207
208   /* Lookup the remote we received from */
209   if ((r = l->l_upcall.u_remote_get (r_public)) == NULL)
210     goto error;
211
212   /* ss */
213   if (!noise_mix_ss (hs.hs_ck, key, r->r_ss))
214     goto error;
215
216   /* {t} */
217   if (!noise_msg_decrypt (vm, timestamp, ets,
218                           NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key_idx,
219                           hs.hs_hash))
220     goto error;
221   ;
222
223   hs.hs_state = CONSUMED_INITIATION;
224   hs.hs_local_index = 0;
225   hs.hs_remote_index = s_idx;
226   clib_memcpy (hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
227
228   /* Replay */
229   if (clib_memcmp (timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
230     clib_memcpy (r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
231   else
232     goto error;
233
234   /* Flood attack */
235   if (wg_birthdate_has_expired (r->r_last_init, REJECT_INTERVAL))
236     r->r_last_init = vlib_time_now (vm);
237   else
238     goto error;
239
240   /* Ok, we're happy to accept this initiation now */
241   noise_remote_handshake_index_drop (r);
242   r->r_handshake = hs;
243   *rp = r;
244   ret = true;
245
246 error:
247   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
248   vnet_crypto_key_del (vm, key_idx);
249   secure_zero_memory (&hs, sizeof (hs));
250   return ret;
251 }
252
253 bool
254 noise_create_response (vlib_main_t * vm, noise_remote_t * r, uint32_t * s_idx,
255                        uint32_t * r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
256                        uint8_t en[0 + NOISE_AUTHTAG_LEN])
257 {
258   noise_handshake_t *hs = &r->r_handshake;
259   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
260   uint8_t e[NOISE_PUBLIC_KEY_LEN];
261   uint32_t key_idx;
262   uint8_t *key;
263   int ret = false;
264
265   key_idx =
266     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
267                          NOISE_SYMMETRIC_KEY_LEN);
268   key = vnet_crypto_get_key (key_idx)->data;
269
270   if (hs->hs_state != CONSUMED_INITIATION)
271     goto error;
272
273   /* e */
274   curve25519_gen_secret (e);
275   if (!curve25519_gen_public (ue, e))
276     goto error;
277   noise_msg_ephemeral (hs->hs_ck, hs->hs_hash, ue);
278
279   /* ee */
280   if (!noise_mix_dh (hs->hs_ck, NULL, e, hs->hs_e))
281     goto error;
282
283   /* se */
284   if (!noise_mix_dh (hs->hs_ck, NULL, e, r->r_public))
285     goto error;
286
287   /* psk */
288   noise_mix_psk (hs->hs_ck, hs->hs_hash, key, r->r_psk);
289
290   /* {} */
291   noise_msg_encrypt (vm, en, NULL, 0, key_idx, hs->hs_hash);
292
293
294   hs->hs_state = CREATED_RESPONSE;
295   hs->hs_local_index = noise_remote_handshake_index_get (r);
296   *r_idx = hs->hs_remote_index;
297   *s_idx = hs->hs_local_index;
298   ret = true;
299 error:
300   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
301   vnet_crypto_key_del (vm, key_idx);
302   secure_zero_memory (e, NOISE_PUBLIC_KEY_LEN);
303   return ret;
304 }
305
306 bool
307 noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
308                         uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
309                         uint8_t en[0 + NOISE_AUTHTAG_LEN])
310 {
311   noise_local_t *l = noise_local_get (r->r_local_idx);
312   noise_handshake_t hs;
313   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
314   uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
315   uint32_t key_idx;
316   uint8_t *key;
317   int ret = false;
318
319   key_idx =
320     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
321                          NOISE_SYMMETRIC_KEY_LEN);
322   key = vnet_crypto_get_key (key_idx)->data;
323
324   hs = r->r_handshake;
325   clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
326
327   if (hs.hs_state != CREATED_INITIATION || hs.hs_local_index != r_idx)
328     goto error;
329
330   /* e */
331   noise_msg_ephemeral (hs.hs_ck, hs.hs_hash, ue);
332
333   /* ee */
334   if (!noise_mix_dh (hs.hs_ck, NULL, hs.hs_e, ue))
335     goto error;
336
337   /* se */
338   if (!noise_mix_dh (hs.hs_ck, NULL, l->l_private, ue))
339     goto error;
340
341   /* psk */
342   noise_mix_psk (hs.hs_ck, hs.hs_hash, key, preshared_key);
343
344   /* {} */
345
346   if (!noise_msg_decrypt
347       (vm, NULL, en, 0 + NOISE_AUTHTAG_LEN, key_idx, hs.hs_hash))
348     goto error;
349
350
351   hs.hs_remote_index = s_idx;
352
353   if (r->r_handshake.hs_state == hs.hs_state &&
354       r->r_handshake.hs_local_index == hs.hs_local_index)
355     {
356       r->r_handshake = hs;
357       r->r_handshake.hs_state = CONSUMED_RESPONSE;
358       ret = true;
359     }
360 error:
361   secure_zero_memory (&hs, sizeof (hs));
362   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
363   vnet_crypto_key_del (vm, key_idx);
364   return ret;
365 }
366
367 bool
368 noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
369 {
370   noise_handshake_t *hs = &r->r_handshake;
371   noise_keypair_t kp, *next, *current, *previous;
372
373   uint8_t key_send[NOISE_SYMMETRIC_KEY_LEN];
374   uint8_t key_recv[NOISE_SYMMETRIC_KEY_LEN];
375
376   /* We now derive the keypair from the handshake */
377   if (hs->hs_state == CONSUMED_RESPONSE)
378     {
379       kp.kp_is_initiator = 1;
380       noise_kdf (key_send, key_recv, NULL, NULL,
381                  NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
382                  hs->hs_ck);
383     }
384   else if (hs->hs_state == CREATED_RESPONSE)
385     {
386       kp.kp_is_initiator = 0;
387       noise_kdf (key_recv, key_send, NULL, NULL,
388                  NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
389                  hs->hs_ck);
390     }
391   else
392     {
393       return false;
394     }
395
396   kp.kp_valid = 1;
397   kp.kp_send_index = vnet_crypto_key_add (vm,
398                                           VNET_CRYPTO_ALG_CHACHA20_POLY1305,
399                                           key_send, NOISE_SYMMETRIC_KEY_LEN);
400   kp.kp_recv_index = vnet_crypto_key_add (vm,
401                                           VNET_CRYPTO_ALG_CHACHA20_POLY1305,
402                                           key_recv, NOISE_SYMMETRIC_KEY_LEN);
403   kp.kp_local_index = hs->hs_local_index;
404   kp.kp_remote_index = hs->hs_remote_index;
405   kp.kp_birthdate = vlib_time_now (vm);
406   clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
407
408   /* Now we need to add_new_keypair */
409   clib_rwlock_writer_lock (&r->r_keypair_lock);
410   /* Activate barrier to synchronization keys between threads */
411   vlib_worker_thread_barrier_sync (vm);
412   next = r->r_next;
413   current = r->r_current;
414   previous = r->r_previous;
415
416   if (kp.kp_is_initiator)
417     {
418       if (next != NULL)
419         {
420           r->r_next = NULL;
421           r->r_previous = next;
422           noise_remote_keypair_free (vm, r, &current);
423         }
424       else
425         {
426           r->r_previous = current;
427         }
428
429       noise_remote_keypair_free (vm, r, &previous);
430
431       r->r_current = noise_remote_keypair_allocate (r);
432       *r->r_current = kp;
433     }
434   else
435     {
436       noise_remote_keypair_free (vm, r, &next);
437       r->r_previous = NULL;
438       noise_remote_keypair_free (vm, r, &previous);
439
440       r->r_next = noise_remote_keypair_allocate (r);
441       *r->r_next = kp;
442     }
443   vlib_worker_thread_barrier_release (vm);
444   clib_rwlock_writer_unlock (&r->r_keypair_lock);
445
446   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
447
448   secure_zero_memory (&kp, sizeof (kp));
449   return true;
450 }
451
452 void
453 noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
454 {
455   noise_remote_handshake_index_drop (r);
456   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
457
458   clib_rwlock_writer_lock (&r->r_keypair_lock);
459   noise_remote_keypair_free (vm, r, &r->r_next);
460   noise_remote_keypair_free (vm, r, &r->r_current);
461   noise_remote_keypair_free (vm, r, &r->r_previous);
462   r->r_next = NULL;
463   r->r_current = NULL;
464   r->r_previous = NULL;
465   clib_rwlock_writer_unlock (&r->r_keypair_lock);
466 }
467
468 void
469 noise_remote_expire_current (noise_remote_t * r)
470 {
471   clib_rwlock_writer_lock (&r->r_keypair_lock);
472   if (r->r_next != NULL)
473     r->r_next->kp_valid = 0;
474   if (r->r_current != NULL)
475     r->r_current->kp_valid = 0;
476   clib_rwlock_writer_unlock (&r->r_keypair_lock);
477 }
478
479 bool
480 noise_remote_ready (noise_remote_t * r)
481 {
482   noise_keypair_t *kp;
483   int ret;
484
485   clib_rwlock_reader_lock (&r->r_keypair_lock);
486   if ((kp = r->r_current) == NULL ||
487       !kp->kp_valid ||
488       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
489       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
490       kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES)
491     ret = false;
492   else
493     ret = true;
494   clib_rwlock_reader_unlock (&r->r_keypair_lock);
495   return ret;
496 }
497
498 static bool
499 chacha20poly1305_calc (vlib_main_t * vm,
500                        u8 * src,
501                        u32 src_len,
502                        u8 * dst,
503                        u8 * aad,
504                        u32 aad_len,
505                        u64 nonce,
506                        vnet_crypto_op_id_t op_id,
507                        vnet_crypto_key_index_t key_index)
508 {
509   vnet_crypto_op_t _op, *op = &_op;
510   u8 iv[12];
511   u8 tag_[NOISE_AUTHTAG_LEN] = { };
512   u8 src_[] = { };
513
514   clib_memset (iv, 0, 12);
515   clib_memcpy (iv + 4, &nonce, sizeof (nonce));
516
517   vnet_crypto_op_init (op, op_id);
518
519   op->tag_len = NOISE_AUTHTAG_LEN;
520   if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
521     {
522       op->tag = src + src_len - NOISE_AUTHTAG_LEN;
523       src_len -= NOISE_AUTHTAG_LEN;
524       op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
525     }
526   else
527     op->tag = tag_;
528
529   op->src = !src ? src_ : src;
530   op->len = src_len;
531
532   op->dst = dst;
533   op->key_index = key_index;
534   op->aad = aad;
535   op->aad_len = aad_len;
536   op->iv = iv;
537
538   vnet_crypto_process_ops (vm, op, 1);
539   if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC)
540     {
541       clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
542     }
543
544   return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
545 }
546
547 always_inline void
548 wg_prepare_sync_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops, u8 *src,
549                     u32 src_len, u8 *dst, u8 *aad, u32 aad_len, u64 nonce,
550                     vnet_crypto_op_id_t op_id,
551                     vnet_crypto_key_index_t key_index, u32 bi, u8 *iv)
552 {
553   vnet_crypto_op_t _op, *op = &_op;
554   u8 src_[] = {};
555
556   clib_memset (iv, 0, 4);
557   clib_memcpy (iv + 4, &nonce, sizeof (nonce));
558
559   vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES);
560   vnet_crypto_op_init (op, op_id);
561
562   op->tag_len = NOISE_AUTHTAG_LEN;
563   if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
564     {
565       op->tag = src + src_len;
566       op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
567     }
568   else
569     op->tag = dst + src_len;
570
571   op->src = !src ? src_ : src;
572   op->len = src_len;
573
574   op->dst = dst;
575   op->key_index = key_index;
576   op->aad = aad;
577   op->aad_len = aad_len;
578   op->iv = iv;
579   op->user_data = bi;
580 }
581
582 enum noise_state_crypt
583 noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
584                       uint64_t * nonce, uint8_t * src, size_t srclen,
585                       uint8_t * dst)
586 {
587   noise_keypair_t *kp;
588   enum noise_state_crypt ret = SC_FAILED;
589
590   if ((kp = r->r_current) == NULL)
591     goto error;
592
593   /* We confirm that our values are within our tolerances. We want:
594    *  - a valid keypair
595    *  - our keypair to be less than REJECT_AFTER_TIME seconds old
596    *  - our receive counter to be less than REJECT_AFTER_MESSAGES
597    *  - our send counter to be less than REJECT_AFTER_MESSAGES
598    */
599   if (!kp->kp_valid ||
600       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
601       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
602       ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
603     goto error;
604
605   /* We encrypt into the same buffer, so the caller must ensure that buf
606    * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
607    * are passed back out to the caller through the provided data pointer. */
608   *r_idx = kp->kp_remote_index;
609
610   chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, *nonce,
611                          VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC,
612                          kp->kp_send_index);
613
614   /* If our values are still within tolerances, but we are approaching
615    * the tolerances, we notify the caller with ESTALE that they should
616    * establish a new keypair. The current keypair can continue to be used
617    * until the tolerances are hit. We notify if:
618    *  - our send counter is valid and not less than REKEY_AFTER_MESSAGES
619    *  - we're the initiator and our keypair is older than
620    *    REKEY_AFTER_TIME seconds */
621   ret = SC_KEEP_KEY_FRESH;
622   if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
623       (kp->kp_is_initiator &&
624        wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME)))
625     goto error;
626
627   ret = SC_OK;
628 error:
629   return ret;
630 }
631
632 enum noise_state_crypt
633 noise_sync_remote_encrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
634                            noise_remote_t *r, uint32_t *r_idx, uint64_t *nonce,
635                            uint8_t *src, size_t srclen, uint8_t *dst, u32 bi,
636                            u8 *iv, f64 time)
637 {
638   noise_keypair_t *kp;
639   enum noise_state_crypt ret = SC_FAILED;
640
641   if ((kp = r->r_current) == NULL)
642     goto error;
643
644   /* We confirm that our values are within our tolerances. We want:
645    *  - a valid keypair
646    *  - our keypair to be less than REJECT_AFTER_TIME seconds old
647    *  - our receive counter to be less than REJECT_AFTER_MESSAGES
648    *  - our send counter to be less than REJECT_AFTER_MESSAGES
649    */
650   if (!kp->kp_valid ||
651       wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
652                                     time) ||
653       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
654       ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
655     goto error;
656
657   /* We encrypt into the same buffer, so the caller must ensure that buf
658    * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
659    * are passed back out to the caller through the provided data pointer. */
660   *r_idx = kp->kp_remote_index;
661
662   wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, *nonce,
663                       VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, kp->kp_send_index,
664                       bi, iv);
665
666   /* If our values are still within tolerances, but we are approaching
667    * the tolerances, we notify the caller with ESTALE that they should
668    * establish a new keypair. The current keypair can continue to be used
669    * until the tolerances are hit. We notify if:
670    *  - our send counter is valid and not less than REKEY_AFTER_MESSAGES
671    *  - we're the initiator and our keypair is older than
672    *    REKEY_AFTER_TIME seconds */
673   ret = SC_KEEP_KEY_FRESH;
674   if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
675       (kp->kp_is_initiator && wg_birthdate_has_expired_opt (
676                                 kp->kp_birthdate, REKEY_AFTER_TIME, time)))
677     goto error;
678
679   ret = SC_OK;
680 error:
681   return ret;
682 }
683
684 enum noise_state_crypt
685 noise_sync_remote_decrypt (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
686                            noise_remote_t *r, uint32_t r_idx, uint64_t nonce,
687                            uint8_t *src, size_t srclen, uint8_t *dst, u32 bi,
688                            u8 *iv, f64 time)
689 {
690   noise_keypair_t *kp;
691   enum noise_state_crypt ret = SC_FAILED;
692
693   if ((kp = wg_get_active_keypair (r, r_idx)) == NULL)
694     {
695       goto error;
696     }
697
698   /* We confirm that our values are within our tolerances. These values
699    * are the same as the encrypt routine.
700    *
701    * kp_ctr isn't locked here, we're happy to accept a racy read. */
702   if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
703                                     time) ||
704       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
705     goto error;
706
707   /* Decrypt, then validate the counter. We don't want to validate the
708    * counter before decrypting as we do not know the message is authentic
709    * prior to decryption. */
710   wg_prepare_sync_op (vm, crypto_ops, src, srclen, dst, NULL, 0, nonce,
711                       VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, kp->kp_recv_index,
712                       bi, iv);
713
714   /* If we've received the handshake confirming data packet then move the
715    * next keypair into current. If we do slide the next keypair in, then
716    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
717    * data packet can't confirm a session that we are an INITIATOR of. */
718   if (kp == r->r_next)
719     {
720       clib_rwlock_writer_lock (&r->r_keypair_lock);
721       if (kp == r->r_next && kp->kp_local_index == r_idx)
722         {
723           noise_remote_keypair_free (vm, r, &r->r_previous);
724           r->r_previous = r->r_current;
725           r->r_current = r->r_next;
726           r->r_next = NULL;
727
728           ret = SC_CONN_RESET;
729           clib_rwlock_writer_unlock (&r->r_keypair_lock);
730           goto error;
731         }
732       clib_rwlock_writer_unlock (&r->r_keypair_lock);
733     }
734
735   /* Similar to when we encrypt, we want to notify the caller when we
736    * are approaching our tolerances. We notify if:
737    *  - we're the initiator and the current keypair is older than
738    *    REKEY_AFTER_TIME_RECV seconds. */
739   ret = SC_KEEP_KEY_FRESH;
740   kp = r->r_current;
741   if (kp != NULL && kp->kp_valid && kp->kp_is_initiator &&
742       wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV,
743                                     time))
744     goto error;
745
746   ret = SC_OK;
747 error:
748   return ret;
749 }
750
751 /* Private functions - these should not be called outside this file under any
752  * circumstances. */
753 static noise_keypair_t *
754 noise_remote_keypair_allocate (noise_remote_t * r)
755 {
756   noise_keypair_t *kp;
757   kp = clib_mem_alloc (sizeof (*kp));
758   return kp;
759 }
760
761 static void
762 noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
763                            noise_keypair_t ** kp)
764 {
765   noise_local_t *local = noise_local_get (r->r_local_idx);
766   struct noise_upcall *u = &local->l_upcall;
767   if (*kp)
768     {
769       u->u_index_drop ((*kp)->kp_local_index);
770       vnet_crypto_key_del (vm, (*kp)->kp_send_index);
771       vnet_crypto_key_del (vm, (*kp)->kp_recv_index);
772       clib_mem_free (*kp);
773     }
774 }
775
776 static uint32_t
777 noise_remote_handshake_index_get (noise_remote_t * r)
778 {
779   noise_local_t *local = noise_local_get (r->r_local_idx);
780   struct noise_upcall *u = &local->l_upcall;
781   return u->u_index_set (r);
782 }
783
784 static void
785 noise_remote_handshake_index_drop (noise_remote_t * r)
786 {
787   noise_handshake_t *hs = &r->r_handshake;
788   noise_local_t *local = noise_local_get (r->r_local_idx);
789   struct noise_upcall *u = &local->l_upcall;
790   if (hs->hs_state != HS_ZEROED)
791     u->u_index_drop (hs->hs_local_index);
792 }
793
794 static uint64_t
795 noise_counter_send (noise_counter_t * ctr)
796 {
797   uint64_t ret;
798   ret = ctr->c_send++;
799   return ret;
800 }
801
802 static void
803 noise_kdf (uint8_t * a, uint8_t * b, uint8_t * c, const uint8_t * x,
804            size_t a_len, size_t b_len, size_t c_len, size_t x_len,
805            const uint8_t ck[NOISE_HASH_LEN])
806 {
807   uint8_t out[BLAKE2S_HASH_SIZE + 1];
808   uint8_t sec[BLAKE2S_HASH_SIZE];
809
810   /* Extract entropy from "x" into sec */
811   u32 l = 0;
812   HMAC (EVP_blake2s256 (), ck, NOISE_HASH_LEN, x, x_len, sec, &l);
813   ASSERT (l == BLAKE2S_HASH_SIZE);
814   if (a == NULL || a_len == 0)
815     goto out;
816
817   /* Expand first key: key = sec, data = 0x1 */
818   out[0] = 1;
819   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, 1, out, &l);
820   ASSERT (l == BLAKE2S_HASH_SIZE);
821   clib_memcpy (a, out, a_len);
822
823   if (b == NULL || b_len == 0)
824     goto out;
825
826   /* Expand second key: key = sec, data = "a" || 0x2 */
827   out[BLAKE2S_HASH_SIZE] = 2;
828   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, BLAKE2S_HASH_SIZE + 1,
829         out, &l);
830   ASSERT (l == BLAKE2S_HASH_SIZE);
831   clib_memcpy (b, out, b_len);
832
833   if (c == NULL || c_len == 0)
834     goto out;
835
836   /* Expand third key: key = sec, data = "b" || 0x3 */
837   out[BLAKE2S_HASH_SIZE] = 3;
838   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, BLAKE2S_HASH_SIZE + 1,
839         out, &l);
840   ASSERT (l == BLAKE2S_HASH_SIZE);
841
842   clib_memcpy (c, out, c_len);
843
844 out:
845   /* Clear sensitive data from stack */
846   secure_zero_memory (sec, BLAKE2S_HASH_SIZE);
847   secure_zero_memory (out, BLAKE2S_HASH_SIZE + 1);
848 }
849
850 static bool
851 noise_mix_dh (uint8_t ck[NOISE_HASH_LEN],
852               uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
853               const uint8_t private[NOISE_PUBLIC_KEY_LEN],
854               const uint8_t public[NOISE_PUBLIC_KEY_LEN])
855 {
856   uint8_t dh[NOISE_PUBLIC_KEY_LEN];
857   if (!curve25519_gen_shared (dh, private, public))
858     return false;
859   noise_kdf (ck, key, NULL, dh,
860              NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
861              ck);
862   secure_zero_memory (dh, NOISE_PUBLIC_KEY_LEN);
863   return true;
864 }
865
866 static bool
867 noise_mix_ss (uint8_t ck[NOISE_HASH_LEN],
868               uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
869               const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
870 {
871   static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
872   if (clib_memcmp (ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
873     return false;
874   noise_kdf (ck, key, NULL, ss,
875              NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
876              ck);
877   return true;
878 }
879
880 static void
881 noise_mix_hash (uint8_t hash[NOISE_HASH_LEN], const uint8_t * src,
882                 size_t src_len)
883 {
884   blake2s_state_t blake;
885
886   blake2s_init (&blake, NOISE_HASH_LEN);
887   blake2s_update (&blake, hash, NOISE_HASH_LEN);
888   blake2s_update (&blake, src, src_len);
889   blake2s_final (&blake, hash, NOISE_HASH_LEN);
890 }
891
892 static void
893 noise_mix_psk (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
894                uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
895                const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
896 {
897   uint8_t tmp[NOISE_HASH_LEN];
898
899   noise_kdf (ck, tmp, key, psk,
900              NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
901              NOISE_SYMMETRIC_KEY_LEN, ck);
902   noise_mix_hash (hash, tmp, NOISE_HASH_LEN);
903   secure_zero_memory (tmp, NOISE_HASH_LEN);
904 }
905
906 static void
907 noise_param_init (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
908                   const uint8_t s[NOISE_PUBLIC_KEY_LEN])
909 {
910   blake2s_state_t blake;
911
912   blake2s (ck, NOISE_HASH_LEN, (uint8_t *) NOISE_HANDSHAKE_NAME,
913            strlen (NOISE_HANDSHAKE_NAME), NULL, 0);
914
915   blake2s_init (&blake, NOISE_HASH_LEN);
916   blake2s_update (&blake, ck, NOISE_HASH_LEN);
917   blake2s_update (&blake, (uint8_t *) NOISE_IDENTIFIER_NAME,
918                   strlen (NOISE_IDENTIFIER_NAME));
919   blake2s_final (&blake, hash, NOISE_HASH_LEN);
920
921   noise_mix_hash (hash, s, NOISE_PUBLIC_KEY_LEN);
922 }
923
924 static void
925 noise_msg_encrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
926                    size_t src_len, uint32_t key_idx,
927                    uint8_t hash[NOISE_HASH_LEN])
928 {
929   /* Nonce always zero for Noise_IK */
930   chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
931                          VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, key_idx);
932   noise_mix_hash (hash, dst, src_len + NOISE_AUTHTAG_LEN);
933 }
934
935 static bool
936 noise_msg_decrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
937                    size_t src_len, uint32_t key_idx,
938                    uint8_t hash[NOISE_HASH_LEN])
939 {
940   /* Nonce always zero for Noise_IK */
941   if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
942                               VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx))
943     return false;
944   noise_mix_hash (hash, src, src_len);
945   return true;
946 }
947
948 static void
949 noise_msg_ephemeral (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
950                      const uint8_t src[NOISE_PUBLIC_KEY_LEN])
951 {
952   noise_mix_hash (hash, src, NOISE_PUBLIC_KEY_LEN);
953   noise_kdf (ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
954              NOISE_PUBLIC_KEY_LEN, ck);
955 }
956
957 static void
958 noise_tai64n_now (uint8_t output[NOISE_TIMESTAMP_LEN])
959 {
960   uint32_t unix_sec;
961   uint32_t unix_nanosec;
962
963   uint64_t sec;
964   uint32_t nsec;
965
966   unix_time_now_nsec_fraction (&unix_sec, &unix_nanosec);
967
968   /* Round down the nsec counter to limit precise timing leak. */
969   unix_nanosec &= REJECT_INTERVAL_MASK;
970
971   /* https://cr.yp.to/libtai/tai64.html */
972   sec = htobe64 (0x400000000000000aULL + unix_sec);
973   nsec = htobe32 (unix_nanosec);
974
975   /* memcpy to output buffer, assuming output could be unaligned. */
976   clib_memcpy (output, &sec, sizeof (sec));
977   clib_memcpy (output + sizeof (sec), &nsec, sizeof (nsec));
978 }
979
980 static void
981 secure_zero_memory (void *v, size_t n)
982 {
983   static void *(*const volatile memset_v) (void *, int, size_t) = &memset;
984   memset_v (v, 0, n);
985 }
986
987 /*
988  * fd.io coding-style-patch-verification: ON
989  *
990  * Local Variables:
991  * eval: (c-set-style "gnu")
992  * End:
993  */