wireguard: Fix for tunnel encap
[vpp.git] / src / plugins / wireguard / wireguard_noise.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>.
4  * Copyright (c) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>.
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at:
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include <openssl/hmac.h>
19 #include <wireguard/wireguard.h>
20
21 /* This implements Noise_IKpsk2:
22  *
23  * <- s
24  * ******
25  * -> e, es, s, ss, {t}
26  * <- e, ee, se, psk, {}
27  */
28
29 /* Private functions */
30 static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
31 static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
32                                        noise_keypair_t **);
33 static uint32_t noise_remote_handshake_index_get (noise_remote_t *);
34 static void noise_remote_handshake_index_drop (noise_remote_t *);
35
36 static uint64_t noise_counter_send (noise_counter_t *);
37 static bool noise_counter_recv (noise_counter_t *, uint64_t);
38
39 static void noise_kdf (uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
40                        size_t, size_t, size_t, size_t,
41                        const uint8_t[NOISE_HASH_LEN]);
42 static bool noise_mix_dh (uint8_t[NOISE_HASH_LEN],
43                           uint8_t[NOISE_SYMMETRIC_KEY_LEN],
44                           const uint8_t[NOISE_PUBLIC_KEY_LEN],
45                           const uint8_t[NOISE_PUBLIC_KEY_LEN]);
46 static bool noise_mix_ss (uint8_t ck[NOISE_HASH_LEN],
47                           uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
48                           const uint8_t ss[NOISE_PUBLIC_KEY_LEN]);
49 static void noise_mix_hash (uint8_t[NOISE_HASH_LEN], const uint8_t *, size_t);
50 static void noise_mix_psk (uint8_t[NOISE_HASH_LEN],
51                            uint8_t[NOISE_HASH_LEN],
52                            uint8_t[NOISE_SYMMETRIC_KEY_LEN],
53                            const uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
54 static void noise_param_init (uint8_t[NOISE_HASH_LEN],
55                               uint8_t[NOISE_HASH_LEN],
56                               const uint8_t[NOISE_PUBLIC_KEY_LEN]);
57
58 static void noise_msg_encrypt (vlib_main_t * vm, uint8_t *, uint8_t *, size_t,
59                                uint32_t key_idx, uint8_t[NOISE_HASH_LEN]);
60 static bool noise_msg_decrypt (vlib_main_t * vm, uint8_t *, uint8_t *, size_t,
61                                uint32_t key_idx, uint8_t[NOISE_HASH_LEN]);
62 static void noise_msg_ephemeral (uint8_t[NOISE_HASH_LEN],
63                                  uint8_t[NOISE_HASH_LEN],
64                                  const uint8_t src[NOISE_PUBLIC_KEY_LEN]);
65
66 static void noise_tai64n_now (uint8_t[NOISE_TIMESTAMP_LEN]);
67
68 static void secure_zero_memory (void *v, size_t n);
69
70 /* Set/Get noise parameters */
71 void
72 noise_local_init (noise_local_t * l, struct noise_upcall *upcall)
73 {
74   clib_memset (l, 0, sizeof (*l));
75   l->l_upcall = *upcall;
76 }
77
78 bool
79 noise_local_set_private (noise_local_t * l,
80                          const uint8_t private[NOISE_PUBLIC_KEY_LEN])
81 {
82   clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
83   l->l_has_identity = curve25519_gen_public (l->l_public, private);
84
85   return l->l_has_identity;
86 }
87
88 bool
89 noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
90                   uint8_t private[NOISE_PUBLIC_KEY_LEN])
91 {
92   if (l->l_has_identity)
93     {
94       if (public != NULL)
95         clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN);
96       if (private != NULL)
97         clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN);
98     }
99   else
100     {
101       return false;
102     }
103   return true;
104 }
105
106 void
107 noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
108                    const uint8_t public[NOISE_PUBLIC_KEY_LEN],
109                    noise_local_t * l)
110 {
111   clib_memset (r, 0, sizeof (*r));
112   clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
113   r->r_peer_idx = peer_pool_idx;
114
115   ASSERT (l != NULL);
116   r->r_local = l;
117   r->r_handshake.hs_state = HS_ZEROED;
118   noise_remote_precompute (r);
119 }
120
121 bool
122 noise_remote_set_psk (noise_remote_t * r,
123                       uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
124 {
125   int same;
126   same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
127   if (!same)
128     {
129       clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
130     }
131   return same == 0;
132 }
133
134 bool
135 noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
136                    uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
137 {
138   static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
139   int ret;
140
141   if (public != NULL)
142     clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN);
143
144   if (psk != NULL)
145     clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
146   ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
147
148   return ret;
149 }
150
151 void
152 noise_remote_precompute (noise_remote_t * r)
153 {
154   noise_local_t *l = r->r_local;
155   if (!l->l_has_identity)
156     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
157   else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
158     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
159
160   noise_remote_handshake_index_drop (r);
161   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
162 }
163
164 /* Handshake functions */
165 bool
166 noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
167                          uint32_t * s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
168                          uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
169                          uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
170 {
171   noise_handshake_t *hs = &r->r_handshake;
172   noise_local_t *l = r->r_local;
173   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
174   uint32_t key_idx;
175   uint8_t *key;
176   int ret = false;
177
178   key_idx =
179     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
180                          NOISE_SYMMETRIC_KEY_LEN);
181   key = vnet_crypto_get_key (key_idx)->data;
182
183   if (!l->l_has_identity)
184     goto error;
185   noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
186
187   /* e */
188   curve25519_gen_secret (hs->hs_e);
189   if (!curve25519_gen_public (ue, hs->hs_e))
190     goto error;
191   noise_msg_ephemeral (hs->hs_ck, hs->hs_hash, ue);
192
193   /* es */
194   if (!noise_mix_dh (hs->hs_ck, key, hs->hs_e, r->r_public))
195     goto error;
196
197   /* s */
198   noise_msg_encrypt (vm, es, l->l_public, NOISE_PUBLIC_KEY_LEN, key_idx,
199                      hs->hs_hash);
200
201   /* ss */
202   if (!noise_mix_ss (hs->hs_ck, key, r->r_ss))
203     goto error;
204
205   /* {t} */
206   noise_tai64n_now (ets);
207   noise_msg_encrypt (vm, ets, ets, NOISE_TIMESTAMP_LEN, key_idx, hs->hs_hash);
208   noise_remote_handshake_index_drop (r);
209   hs->hs_state = CREATED_INITIATION;
210   hs->hs_local_index = noise_remote_handshake_index_get (r);
211   *s_idx = hs->hs_local_index;
212   ret = true;
213 error:
214   vnet_crypto_key_del (vm, key_idx);
215   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
216   return ret;
217 }
218
219 bool
220 noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
221                           noise_remote_t ** rp, uint32_t s_idx,
222                           uint8_t ue[NOISE_PUBLIC_KEY_LEN],
223                           uint8_t es[NOISE_PUBLIC_KEY_LEN +
224                                      NOISE_AUTHTAG_LEN],
225                           uint8_t ets[NOISE_TIMESTAMP_LEN +
226                                       NOISE_AUTHTAG_LEN])
227 {
228   noise_remote_t *r;
229   noise_handshake_t hs;
230   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
231   uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
232   uint8_t timestamp[NOISE_TIMESTAMP_LEN];
233   u32 key_idx;
234   uint8_t *key;
235   int ret = false;
236
237   key_idx =
238     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
239                          NOISE_SYMMETRIC_KEY_LEN);
240   key = vnet_crypto_get_key (key_idx)->data;
241
242   if (!l->l_has_identity)
243     goto error;
244   noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
245
246   /* e */
247   noise_msg_ephemeral (hs.hs_ck, hs.hs_hash, ue);
248
249   /* es */
250   if (!noise_mix_dh (hs.hs_ck, key, l->l_private, ue))
251     goto error;
252
253   /* s */
254
255   if (!noise_msg_decrypt (vm, r_public, es,
256                           NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key_idx,
257                           hs.hs_hash))
258     goto error;
259
260   /* Lookup the remote we received from */
261   if ((r = l->l_upcall.u_remote_get (r_public)) == NULL)
262     goto error;
263
264   /* ss */
265   if (!noise_mix_ss (hs.hs_ck, key, r->r_ss))
266     goto error;
267
268   /* {t} */
269   if (!noise_msg_decrypt (vm, timestamp, ets,
270                           NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key_idx,
271                           hs.hs_hash))
272     goto error;
273   ;
274
275   hs.hs_state = CONSUMED_INITIATION;
276   hs.hs_local_index = 0;
277   hs.hs_remote_index = s_idx;
278   clib_memcpy (hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
279
280   /* Replay */
281   if (clib_memcmp (timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
282     clib_memcpy (r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
283   else
284     goto error;
285
286   /* Flood attack */
287   if (wg_birthdate_has_expired (r->r_last_init, REJECT_INTERVAL))
288     r->r_last_init = vlib_time_now (vm);
289   else
290     goto error;
291
292   /* Ok, we're happy to accept this initiation now */
293   noise_remote_handshake_index_drop (r);
294   r->r_handshake = hs;
295   *rp = r;
296   ret = true;
297 error:
298   vnet_crypto_key_del (vm, key_idx);
299   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
300   secure_zero_memory (&hs, sizeof (hs));
301   return ret;
302 }
303
304 bool
305 noise_create_response (vlib_main_t * vm, noise_remote_t * r, uint32_t * s_idx,
306                        uint32_t * r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
307                        uint8_t en[0 + NOISE_AUTHTAG_LEN])
308 {
309   noise_handshake_t *hs = &r->r_handshake;
310   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
311   uint8_t e[NOISE_PUBLIC_KEY_LEN];
312   uint32_t key_idx;
313   uint8_t *key;
314   int ret = false;
315
316   key_idx =
317     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
318                          NOISE_SYMMETRIC_KEY_LEN);
319   key = vnet_crypto_get_key (key_idx)->data;
320
321   if (hs->hs_state != CONSUMED_INITIATION)
322     goto error;
323
324   /* e */
325   curve25519_gen_secret (e);
326   if (!curve25519_gen_public (ue, e))
327     goto error;
328   noise_msg_ephemeral (hs->hs_ck, hs->hs_hash, ue);
329
330   /* ee */
331   if (!noise_mix_dh (hs->hs_ck, NULL, e, hs->hs_e))
332     goto error;
333
334   /* se */
335   if (!noise_mix_dh (hs->hs_ck, NULL, e, r->r_public))
336     goto error;
337
338   /* psk */
339   noise_mix_psk (hs->hs_ck, hs->hs_hash, key, r->r_psk);
340
341   /* {} */
342   noise_msg_encrypt (vm, en, NULL, 0, key_idx, hs->hs_hash);
343
344
345   hs->hs_state = CREATED_RESPONSE;
346   hs->hs_local_index = noise_remote_handshake_index_get (r);
347   *r_idx = hs->hs_remote_index;
348   *s_idx = hs->hs_local_index;
349   ret = true;
350 error:
351   vnet_crypto_key_del (vm, key_idx);
352   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
353   secure_zero_memory (e, NOISE_PUBLIC_KEY_LEN);
354   return ret;
355 }
356
357 bool
358 noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
359                         uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
360                         uint8_t en[0 + NOISE_AUTHTAG_LEN])
361 {
362   noise_local_t *l = r->r_local;
363   noise_handshake_t hs;
364   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
365   uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
366   uint32_t key_idx;
367   uint8_t *key;
368   int ret = false;
369
370   key_idx =
371     vnet_crypto_key_add (vm, VNET_CRYPTO_ALG_CHACHA20_POLY1305, _key,
372                          NOISE_SYMMETRIC_KEY_LEN);
373   key = vnet_crypto_get_key (key_idx)->data;
374
375   if (!l->l_has_identity)
376     goto error;
377
378   hs = r->r_handshake;
379   clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
380
381   if (hs.hs_state != CREATED_INITIATION || hs.hs_local_index != r_idx)
382     goto error;
383
384   /* e */
385   noise_msg_ephemeral (hs.hs_ck, hs.hs_hash, ue);
386
387   /* ee */
388   if (!noise_mix_dh (hs.hs_ck, NULL, hs.hs_e, ue))
389     goto error;
390
391   /* se */
392   if (!noise_mix_dh (hs.hs_ck, NULL, l->l_private, ue))
393     goto error;
394
395   /* psk */
396   noise_mix_psk (hs.hs_ck, hs.hs_hash, key, preshared_key);
397
398   /* {} */
399
400   if (!noise_msg_decrypt
401       (vm, NULL, en, 0 + NOISE_AUTHTAG_LEN, key_idx, hs.hs_hash))
402     goto error;
403
404
405   hs.hs_remote_index = s_idx;
406
407   if (r->r_handshake.hs_state == hs.hs_state &&
408       r->r_handshake.hs_local_index == hs.hs_local_index)
409     {
410       r->r_handshake = hs;
411       r->r_handshake.hs_state = CONSUMED_RESPONSE;
412       ret = true;
413     }
414 error:
415   vnet_crypto_key_del (vm, key_idx);
416   secure_zero_memory (&hs, sizeof (hs));
417   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
418   return ret;
419 }
420
421 bool
422 noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
423 {
424   noise_handshake_t *hs = &r->r_handshake;
425   noise_keypair_t kp, *next, *current, *previous;
426
427   uint8_t key_send[NOISE_SYMMETRIC_KEY_LEN];
428   uint8_t key_recv[NOISE_SYMMETRIC_KEY_LEN];
429
430   /* We now derive the keypair from the handshake */
431   if (hs->hs_state == CONSUMED_RESPONSE)
432     {
433       kp.kp_is_initiator = 1;
434       noise_kdf (key_send, key_recv, NULL, NULL,
435                  NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
436                  hs->hs_ck);
437     }
438   else if (hs->hs_state == CREATED_RESPONSE)
439     {
440       kp.kp_is_initiator = 0;
441       noise_kdf (key_recv, key_send, NULL, NULL,
442                  NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
443                  hs->hs_ck);
444     }
445   else
446     {
447       return false;
448     }
449
450   kp.kp_valid = 1;
451   kp.kp_send_index = vnet_crypto_key_add (vm,
452                                           VNET_CRYPTO_ALG_CHACHA20_POLY1305,
453                                           key_send, NOISE_SYMMETRIC_KEY_LEN);
454   kp.kp_recv_index = vnet_crypto_key_add (vm,
455                                           VNET_CRYPTO_ALG_CHACHA20_POLY1305,
456                                           key_recv, NOISE_SYMMETRIC_KEY_LEN);
457   kp.kp_local_index = hs->hs_local_index;
458   kp.kp_remote_index = hs->hs_remote_index;
459   kp.kp_birthdate = vlib_time_now (vm);
460   clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
461
462   /* Now we need to add_new_keypair */
463   next = r->r_next;
464   current = r->r_current;
465   previous = r->r_previous;
466
467   if (kp.kp_is_initiator)
468     {
469       if (next != NULL)
470         {
471           r->r_next = NULL;
472           r->r_previous = next;
473           noise_remote_keypair_free (vm, r, &current);
474         }
475       else
476         {
477           r->r_previous = current;
478         }
479
480       noise_remote_keypair_free (vm, r, &previous);
481
482       r->r_current = noise_remote_keypair_allocate (r);
483       *r->r_current = kp;
484     }
485   else
486     {
487       noise_remote_keypair_free (vm, r, &next);
488       r->r_previous = NULL;
489       noise_remote_keypair_free (vm, r, &previous);
490
491       r->r_next = noise_remote_keypair_allocate (r);
492       *r->r_next = kp;
493     }
494   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
495   secure_zero_memory (&kp, sizeof (kp));
496   return true;
497 }
498
499 void
500 noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
501 {
502   noise_remote_handshake_index_drop (r);
503   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
504
505   noise_remote_keypair_free (vm, r, &r->r_next);
506   noise_remote_keypair_free (vm, r, &r->r_current);
507   noise_remote_keypair_free (vm, r, &r->r_previous);
508   r->r_next = NULL;
509   r->r_current = NULL;
510   r->r_previous = NULL;
511 }
512
513 void
514 noise_remote_expire_current (noise_remote_t * r)
515 {
516   if (r->r_next != NULL)
517     r->r_next->kp_valid = 0;
518   if (r->r_current != NULL)
519     r->r_current->kp_valid = 0;
520 }
521
522 bool
523 noise_remote_ready (noise_remote_t * r)
524 {
525   noise_keypair_t *kp;
526   int ret;
527
528   if ((kp = r->r_current) == NULL ||
529       !kp->kp_valid ||
530       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
531       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
532       kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES)
533     ret = false;
534   else
535     ret = true;
536   return ret;
537 }
538
539 static bool
540 chacha20poly1305_calc (vlib_main_t * vm,
541                        u8 * src,
542                        u32 src_len,
543                        u8 * dst,
544                        u8 * aad,
545                        u32 aad_len,
546                        u64 nonce,
547                        vnet_crypto_op_id_t op_id,
548                        vnet_crypto_key_index_t key_index)
549 {
550   vnet_crypto_op_t _op, *op = &_op;
551   u8 iv[12];
552   u8 tag_[NOISE_AUTHTAG_LEN] = { };
553   u8 src_[] = { };
554
555   clib_memset (iv, 0, 12);
556   clib_memcpy (iv + 4, &nonce, sizeof (nonce));
557
558   vnet_crypto_op_init (op, op_id);
559
560   op->tag_len = NOISE_AUTHTAG_LEN;
561   if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC)
562     {
563       op->tag = src + src_len - NOISE_AUTHTAG_LEN;
564       src_len -= NOISE_AUTHTAG_LEN;
565     }
566   else
567     op->tag = tag_;
568
569   op->src = !src ? src_ : src;
570   op->len = src_len;
571
572   op->dst = dst;
573   op->key_index = key_index;
574   op->aad = aad;
575   op->aad_len = aad_len;
576   op->iv = iv;
577
578   vnet_crypto_process_ops (vm, op, 1);
579   if (op_id == VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC)
580     {
581       clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
582     }
583
584   return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
585 }
586
587 enum noise_state_crypt
588 noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
589                       uint64_t * nonce, uint8_t * src, size_t srclen,
590                       uint8_t * dst)
591 {
592   noise_keypair_t *kp;
593   enum noise_state_crypt ret = SC_FAILED;
594
595   if ((kp = r->r_current) == NULL)
596     goto error;
597
598   /* We confirm that our values are within our tolerances. We want:
599    *  - a valid keypair
600    *  - our keypair to be less than REJECT_AFTER_TIME seconds old
601    *  - our receive counter to be less than REJECT_AFTER_MESSAGES
602    *  - our send counter to be less than REJECT_AFTER_MESSAGES
603    */
604   if (!kp->kp_valid ||
605       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
606       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
607       ((*nonce = noise_counter_send (&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
608     goto error;
609
610   /* We encrypt into the same buffer, so the caller must ensure that buf
611    * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
612    * are passed back out to the caller through the provided data pointer. */
613   *r_idx = kp->kp_remote_index;
614
615   chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, *nonce,
616                          VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC,
617                          kp->kp_send_index);
618
619   /* If our values are still within tolerances, but we are approaching
620    * the tolerances, we notify the caller with ESTALE that they should
621    * establish a new keypair. The current keypair can continue to be used
622    * until the tolerances are hit. We notify if:
623    *  - our send counter is valid and not less than REKEY_AFTER_MESSAGES
624    *  - we're the initiator and our keypair is older than
625    *    REKEY_AFTER_TIME seconds */
626   ret = SC_KEEP_KEY_FRESH;
627   if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
628       (kp->kp_is_initiator &&
629        wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME)))
630     goto error;
631
632   ret = SC_OK;
633 error:
634   return ret;
635 }
636
637 enum noise_state_crypt
638 noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
639                       uint64_t nonce, uint8_t * src, size_t srclen,
640                       uint8_t * dst)
641 {
642   noise_keypair_t *kp;
643   enum noise_state_crypt ret = SC_FAILED;
644
645   if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
646     {
647       kp = r->r_current;
648     }
649   else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx)
650     {
651       kp = r->r_previous;
652     }
653   else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx)
654     {
655       kp = r->r_next;
656     }
657   else
658     {
659       goto error;
660     }
661
662   /* We confirm that our values are within our tolerances. These values
663    * are the same as the encrypt routine.
664    *
665    * kp_ctr isn't locked here, we're happy to accept a racy read. */
666   if (wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
667       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
668     goto error;
669
670   /* Decrypt, then validate the counter. We don't want to validate the
671    * counter before decrypting as we do not know the message is authentic
672    * prior to decryption. */
673   if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
674                               VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
675                               kp->kp_recv_index))
676     goto error;
677
678   if (!noise_counter_recv (&kp->kp_ctr, nonce))
679     goto error;
680
681   /* If we've received the handshake confirming data packet then move the
682    * next keypair into current. If we do slide the next keypair in, then
683    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
684    * data packet can't confirm a session that we are an INITIATOR of. */
685   if (kp == r->r_next && kp->kp_local_index == r_idx)
686     {
687       noise_remote_keypair_free (vm, r, &r->r_previous);
688       r->r_previous = r->r_current;
689       r->r_current = r->r_next;
690       r->r_next = NULL;
691
692       ret = SC_CONN_RESET;
693       goto error;
694     }
695
696
697   /* Similar to when we encrypt, we want to notify the caller when we
698    * are approaching our tolerances. We notify if:
699    *  - we're the initiator and the current keypair is older than
700    *    REKEY_AFTER_TIME_RECV seconds. */
701   ret = SC_KEEP_KEY_FRESH;
702   kp = r->r_current;
703   if (kp != NULL &&
704       kp->kp_valid &&
705       kp->kp_is_initiator &&
706       wg_birthdate_has_expired (kp->kp_birthdate, REKEY_AFTER_TIME_RECV))
707     goto error;
708
709   ret = SC_OK;
710 error:
711   return ret;
712 }
713
714 /* Private functions - these should not be called outside this file under any
715  * circumstances. */
716 static noise_keypair_t *
717 noise_remote_keypair_allocate (noise_remote_t * r)
718 {
719   noise_keypair_t *kp;
720   kp = clib_mem_alloc (sizeof (*kp));
721   return kp;
722 }
723
724 static void
725 noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
726                            noise_keypair_t ** kp)
727 {
728   struct noise_upcall *u = &r->r_local->l_upcall;
729   if (*kp)
730     {
731       u->u_index_drop ((*kp)->kp_local_index);
732       vnet_crypto_key_del (vm, (*kp)->kp_send_index);
733       vnet_crypto_key_del (vm, (*kp)->kp_recv_index);
734       clib_mem_free (*kp);
735     }
736 }
737
738 static uint32_t
739 noise_remote_handshake_index_get (noise_remote_t * r)
740 {
741   struct noise_upcall *u = &r->r_local->l_upcall;
742   return u->u_index_set (r);
743 }
744
745 static void
746 noise_remote_handshake_index_drop (noise_remote_t * r)
747 {
748   noise_handshake_t *hs = &r->r_handshake;
749   struct noise_upcall *u = &r->r_local->l_upcall;
750   if (hs->hs_state != HS_ZEROED)
751     u->u_index_drop (hs->hs_local_index);
752 }
753
754 static uint64_t
755 noise_counter_send (noise_counter_t * ctr)
756 {
757   uint64_t ret = ctr->c_send++;
758   return ret;
759 }
760
761 static bool
762 noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
763 {
764   uint64_t i, top, index_recv, index_ctr;
765   unsigned long bit;
766   bool ret = false;
767
768
769   /* Check that the recv counter is valid */
770   if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
771     goto error;
772
773   /* If the packet is out of the window, invalid */
774   if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv)
775     goto error;
776
777   /* If the new counter is ahead of the current counter, we'll need to
778    * zero out the bitmap that has previously been used */
779   index_recv = recv / COUNTER_BITS;
780   index_ctr = ctr->c_recv / COUNTER_BITS;
781
782   if (recv > ctr->c_recv)
783     {
784       top = clib_min (index_recv - index_ctr, COUNTER_NUM);
785       for (i = 1; i <= top; i++)
786         ctr->c_backtrack[(i + index_ctr) & (COUNTER_NUM - 1)] = 0;
787       ctr->c_recv = recv;
788     }
789
790   index_recv %= COUNTER_NUM;
791   bit = 1ul << (recv % COUNTER_BITS);
792
793   if (ctr->c_backtrack[index_recv] & bit)
794     goto error;
795
796   ctr->c_backtrack[index_recv] |= bit;
797
798   ret = true;
799 error:
800   return ret;
801 }
802
803 static void
804 noise_kdf (uint8_t * a, uint8_t * b, uint8_t * c, const uint8_t * x,
805            size_t a_len, size_t b_len, size_t c_len, size_t x_len,
806            const uint8_t ck[NOISE_HASH_LEN])
807 {
808   uint8_t out[BLAKE2S_HASH_SIZE + 1];
809   uint8_t sec[BLAKE2S_HASH_SIZE];
810
811   /* Extract entropy from "x" into sec */
812   u32 l = 0;
813   HMAC (EVP_blake2s256 (), ck, NOISE_HASH_LEN, x, x_len, sec, &l);
814   ASSERT (l == BLAKE2S_HASH_SIZE);
815   if (a == NULL || a_len == 0)
816     goto out;
817
818   /* Expand first key: key = sec, data = 0x1 */
819   out[0] = 1;
820   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, 1, out, &l);
821   ASSERT (l == BLAKE2S_HASH_SIZE);
822   clib_memcpy (a, out, a_len);
823
824   if (b == NULL || b_len == 0)
825     goto out;
826
827   /* Expand second key: key = sec, data = "a" || 0x2 */
828   out[BLAKE2S_HASH_SIZE] = 2;
829   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, BLAKE2S_HASH_SIZE + 1,
830         out, &l);
831   ASSERT (l == BLAKE2S_HASH_SIZE);
832   clib_memcpy (b, out, b_len);
833
834   if (c == NULL || c_len == 0)
835     goto out;
836
837   /* Expand third key: key = sec, data = "b" || 0x3 */
838   out[BLAKE2S_HASH_SIZE] = 3;
839   HMAC (EVP_blake2s256 (), sec, BLAKE2S_HASH_SIZE, out, BLAKE2S_HASH_SIZE + 1,
840         out, &l);
841   ASSERT (l == BLAKE2S_HASH_SIZE);
842
843   clib_memcpy (c, out, c_len);
844
845 out:
846   /* Clear sensitive data from stack */
847   secure_zero_memory (sec, BLAKE2S_HASH_SIZE);
848   secure_zero_memory (out, BLAKE2S_HASH_SIZE + 1);
849 }
850
851 static bool
852 noise_mix_dh (uint8_t ck[NOISE_HASH_LEN],
853               uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
854               const uint8_t private[NOISE_PUBLIC_KEY_LEN],
855               const uint8_t public[NOISE_PUBLIC_KEY_LEN])
856 {
857   uint8_t dh[NOISE_PUBLIC_KEY_LEN];
858   if (!curve25519_gen_shared (dh, private, public))
859     return false;
860   noise_kdf (ck, key, NULL, dh,
861              NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
862              ck);
863   secure_zero_memory (dh, NOISE_PUBLIC_KEY_LEN);
864   return true;
865 }
866
867 static bool
868 noise_mix_ss (uint8_t ck[NOISE_HASH_LEN],
869               uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
870               const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
871 {
872   static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
873   if (clib_memcmp (ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
874     return false;
875   noise_kdf (ck, key, NULL, ss,
876              NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
877              ck);
878   return true;
879 }
880
881 static void
882 noise_mix_hash (uint8_t hash[NOISE_HASH_LEN], const uint8_t * src,
883                 size_t src_len)
884 {
885   blake2s_state_t blake;
886
887   blake2s_init (&blake, NOISE_HASH_LEN);
888   blake2s_update (&blake, hash, NOISE_HASH_LEN);
889   blake2s_update (&blake, src, src_len);
890   blake2s_final (&blake, hash, NOISE_HASH_LEN);
891 }
892
893 static void
894 noise_mix_psk (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
895                uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
896                const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
897 {
898   uint8_t tmp[NOISE_HASH_LEN];
899
900   noise_kdf (ck, tmp, key, psk,
901              NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
902              NOISE_SYMMETRIC_KEY_LEN, ck);
903   noise_mix_hash (hash, tmp, NOISE_HASH_LEN);
904   secure_zero_memory (tmp, NOISE_HASH_LEN);
905 }
906
907 static void
908 noise_param_init (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
909                   const uint8_t s[NOISE_PUBLIC_KEY_LEN])
910 {
911   blake2s_state_t blake;
912
913   blake2s (ck, NOISE_HASH_LEN, (uint8_t *) NOISE_HANDSHAKE_NAME,
914            strlen (NOISE_HANDSHAKE_NAME), NULL, 0);
915
916   blake2s_init (&blake, NOISE_HASH_LEN);
917   blake2s_update (&blake, ck, NOISE_HASH_LEN);
918   blake2s_update (&blake, (uint8_t *) NOISE_IDENTIFIER_NAME,
919                   strlen (NOISE_IDENTIFIER_NAME));
920   blake2s_final (&blake, hash, NOISE_HASH_LEN);
921
922   noise_mix_hash (hash, s, NOISE_PUBLIC_KEY_LEN);
923 }
924
925 static void
926 noise_msg_encrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
927                    size_t src_len, uint32_t key_idx,
928                    uint8_t hash[NOISE_HASH_LEN])
929 {
930   /* Nonce always zero for Noise_IK */
931   chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
932                          VNET_CRYPTO_OP_CHACHA20_POLY1305_ENC, key_idx);
933   noise_mix_hash (hash, dst, src_len + NOISE_AUTHTAG_LEN);
934 }
935
936 static bool
937 noise_msg_decrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
938                    size_t src_len, uint32_t key_idx,
939                    uint8_t hash[NOISE_HASH_LEN])
940 {
941   /* Nonce always zero for Noise_IK */
942   if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
943                               VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx))
944     return false;
945   noise_mix_hash (hash, src, src_len);
946   return true;
947 }
948
949 static void
950 noise_msg_ephemeral (uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
951                      const uint8_t src[NOISE_PUBLIC_KEY_LEN])
952 {
953   noise_mix_hash (hash, src, NOISE_PUBLIC_KEY_LEN);
954   noise_kdf (ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
955              NOISE_PUBLIC_KEY_LEN, ck);
956 }
957
958 static void
959 noise_tai64n_now (uint8_t output[NOISE_TIMESTAMP_LEN])
960 {
961   uint32_t unix_sec;
962   uint32_t unix_nanosec;
963
964   uint64_t sec;
965   uint32_t nsec;
966
967   unix_time_now_nsec_fraction (&unix_sec, &unix_nanosec);
968
969   /* Round down the nsec counter to limit precise timing leak. */
970   unix_nanosec &= REJECT_INTERVAL_MASK;
971
972   /* https://cr.yp.to/libtai/tai64.html */
973   sec = htobe64 (0x400000000000000aULL + unix_sec);
974   nsec = htobe32 (unix_nanosec);
975
976   /* memcpy to output buffer, assuming output could be unaligned. */
977   clib_memcpy (output, &sec, sizeof (sec));
978   clib_memcpy (output + sizeof (sec), &nsec, sizeof (nsec));
979 }
980
981 static void
982 secure_zero_memory (void *v, size_t n)
983 {
984   static void *(*const volatile memset_v) (void *, int, size_t) = &memset;
985   memset_v (v, 0, n);
986 }
987
988 /*
989  * fd.io coding-style-patch-verification: ON
990  *
991  * Local Variables:
992  * eval: (c-set-style "gnu")
993  * End:
994  */