3eba9cbf75f5bea2f64df14ce807a39e7819418b
[vpp.git] / src / plugins / wireguard / wireguard_input.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <vlib/vlib.h>
18 #include <vnet/vnet.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
21
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
24
25 #define foreach_wg_input_error                                                \
26   _ (NONE, "No error")                                                        \
27   _ (HANDSHAKE_MAC, "Invalid MAC handshake")                                  \
28   _ (PEER, "Peer error")                                                      \
29   _ (INTERFACE, "Interface error")                                            \
30   _ (DECRYPTION, "Failed during decryption")                                  \
31   _ (KEEPALIVE_SEND, "Failed while sending Keepalive")                        \
32   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
33   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
34   _ (TOO_BIG, "Packet too big")                                               \
35   _ (UNDEFINED, "Undefined error")                                            \
36   _ (CRYPTO_ENGINE_ERROR, "crypto engine error (packet dropped)")
37
38 typedef enum
39 {
40 #define _(sym,str) WG_INPUT_ERROR_##sym,
41   foreach_wg_input_error
42 #undef _
43     WG_INPUT_N_ERROR,
44 } wg_input_error_t;
45
46 static char *wg_input_error_strings[] = {
47 #define _(sym,string) string,
48   foreach_wg_input_error
49 #undef _
50 };
51
52 typedef struct
53 {
54   message_type_t type;
55   u16 current_length;
56   bool is_keepalive;
57   index_t peer;
58 } wg_input_trace_t;
59
60 typedef struct
61 {
62   index_t peer;
63   u16 next;
64 } wg_input_post_trace_t;
65
66 u8 *
67 format_wg_message_type (u8 * s, va_list * args)
68 {
69   message_type_t type = va_arg (*args, message_type_t);
70
71   switch (type)
72     {
73 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
74       foreach_wg_message_type
75 #undef _
76     }
77   return (format (s, "unknown"));
78 }
79
80 /* packet trace format function */
81 static u8 *
82 format_wg_input_trace (u8 * s, va_list * args)
83 {
84   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
85   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
86
87   wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
88
89   s = format (s, "Wireguard input: \n");
90   s = format (s, "    Type: %U\n", format_wg_message_type, t->type);
91   s = format (s, "    Peer: %d\n", t->peer);
92   s = format (s, "    Length: %d\n", t->current_length);
93   s = format (s, "    Keepalive: %s", t->is_keepalive ? "true" : "false");
94
95   return s;
96 }
97
98 /* post-node packet trace format function */
99 static u8 *
100 format_wg_input_post_trace (u8 *s, va_list *args)
101 {
102   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
103   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
104
105   wg_input_post_trace_t *t = va_arg (*args, wg_input_post_trace_t *);
106
107   s = format (s, "WG input post: \n");
108   s = format (s, "  peer: %u\n", t->peer);
109   s = format (s, "  next: %u\n", t->next);
110
111   return s;
112 }
113
114 typedef enum
115 {
116   WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
117   WG_INPUT_NEXT_HANDOFF_DATA,
118   WG_INPUT_NEXT_IP4_INPUT,
119   WG_INPUT_NEXT_IP6_INPUT,
120   WG_INPUT_NEXT_PUNT,
121   WG_INPUT_NEXT_ERROR,
122   WG_INPUT_N_NEXT,
123 } wg_input_next_t;
124
125 /* static void */
126 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
127 /* { */
128 /*   if (peer) */
129 /*     { */
130 /*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
131 /*       peer->dst.port = udp_port; */
132 /*     } */
133 /* } */
134
135 static u8
136 is_ip4_header (u8 *data)
137 {
138   return (data[0] >> 4) == 0x4;
139 }
140
141 static wg_input_error_t
142 wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
143                       u32 node_idx, u8 is_ip4)
144 {
145   ASSERT (vm->thread_index == 0);
146
147   enum cookie_mac_state mac_state;
148   bool packet_needs_cookie;
149   bool under_load;
150   index_t *wg_ifs;
151   wg_if_t *wg_if;
152   wg_peer_t *peer = NULL;
153
154   void *current_b_data = vlib_buffer_get_current (b);
155
156   ip46_address_t src_ip;
157   if (is_ip4)
158     {
159       ip4_header_t *iph4 =
160         current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
161       ip46_address_set_ip4 (&src_ip, &iph4->src_address);
162     }
163   else
164     {
165       ip6_header_t *iph6 =
166         current_b_data - sizeof (udp_header_t) - sizeof (ip6_header_t);
167       ip46_address_set_ip6 (&src_ip, &iph6->src_address);
168     }
169
170   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
171   u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
172   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
173
174   message_header_t *header = current_b_data;
175   under_load = false;
176
177   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
178     {
179       message_handshake_cookie_t *packet =
180         (message_handshake_cookie_t *) current_b_data;
181       u32 *entry =
182         wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
183       if (entry)
184         peer = wg_peer_get (*entry);
185       else
186         return WG_INPUT_ERROR_PEER;
187
188       // TODO: Implement cookie_maker_consume_payload
189
190       return WG_INPUT_ERROR_NONE;
191     }
192
193   u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
194              sizeof (message_handshake_initiation_t) :
195              sizeof (message_handshake_response_t));
196
197   message_macs_t *macs = (message_macs_t *)
198     ((u8 *) current_b_data + len - sizeof (*macs));
199
200   index_t *ii;
201   wg_ifs = wg_if_indexes_get_by_port (udp_dst_port);
202   if (NULL == wg_ifs)
203     return WG_INPUT_ERROR_INTERFACE;
204
205   vec_foreach (ii, wg_ifs)
206     {
207       wg_if = wg_if_get (*ii);
208       if (NULL == wg_if)
209         continue;
210
211       mac_state = cookie_checker_validate_macs (
212         vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load,
213         &src_ip, udp_src_port);
214       if (mac_state == INVALID_MAC)
215         {
216           wg_if = NULL;
217           continue;
218         }
219       break;
220     }
221
222   if (NULL == wg_if)
223     return WG_INPUT_ERROR_HANDSHAKE_MAC;
224
225   if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
226       || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
227     packet_needs_cookie = false;
228   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
229     packet_needs_cookie = true;
230   else
231     return WG_INPUT_ERROR_HANDSHAKE_MAC;
232
233   switch (header->type)
234     {
235     case MESSAGE_HANDSHAKE_INITIATION:
236       {
237         message_handshake_initiation_t *message = current_b_data;
238
239         if (packet_needs_cookie)
240           {
241             // TODO: Add processing
242           }
243         noise_remote_t *rp;
244         if (noise_consume_initiation
245             (vm, noise_local_get (wg_if->local_idx), &rp,
246              message->sender_index, message->unencrypted_ephemeral,
247              message->encrypted_static, message->encrypted_timestamp))
248           {
249             peer = wg_peer_get (rp->r_peer_idx);
250           }
251         else
252           {
253             return WG_INPUT_ERROR_PEER;
254           }
255
256         // set_peer_address (peer, ip4_src, udp_src_port);
257         if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
258           {
259             vlib_node_increment_counter (vm, node_idx,
260                                          WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
261           }
262         else
263           {
264             wg_peer_update_flags (rp->r_peer_idx, WG_PEER_ESTABLISHED, true);
265           }
266         break;
267       }
268     case MESSAGE_HANDSHAKE_RESPONSE:
269       {
270         message_handshake_response_t *resp = current_b_data;
271         index_t peeri = INDEX_INVALID;
272         u32 *entry =
273           wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
274
275         if (PREDICT_TRUE (entry != NULL))
276           {
277             peeri = *entry;
278             peer = wg_peer_get (peeri);
279             if (wg_peer_is_dead (peer))
280               return WG_INPUT_ERROR_PEER;
281           }
282         else
283           return WG_INPUT_ERROR_PEER;
284
285         if (!noise_consume_response
286             (vm, &peer->remote, resp->sender_index,
287              resp->receiver_index, resp->unencrypted_ephemeral,
288              resp->encrypted_nothing))
289           {
290             return WG_INPUT_ERROR_PEER;
291           }
292         if (packet_needs_cookie)
293           {
294             // TODO: Add processing
295           }
296
297         // set_peer_address (peer, ip4_src, udp_src_port);
298         if (noise_remote_begin_session (vm, &peer->remote))
299           {
300
301             wg_timers_session_derived (peer);
302             wg_timers_handshake_complete (peer);
303             if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
304               {
305                 vlib_node_increment_counter (vm, node_idx,
306                                              WG_INPUT_ERROR_KEEPALIVE_SEND, 1);
307               }
308             else
309               {
310                 wg_peer_update_flags (peeri, WG_PEER_ESTABLISHED, true);
311               }
312           }
313         break;
314       }
315     default:
316       return WG_INPUT_ERROR_HANDSHAKE_RECEIVE;
317     }
318
319   wg_timers_any_authenticated_packet_received (peer);
320   wg_timers_any_authenticated_packet_traversal (peer);
321   return WG_INPUT_ERROR_NONE;
322 }
323
324 static_always_inline int
325 wg_input_post_process (vlib_main_t *vm, vlib_buffer_t *b, u16 *next,
326                        wg_peer_t *peer, message_data_t *data,
327                        bool *is_keepalive)
328 {
329   next[0] = WG_INPUT_NEXT_PUNT;
330
331   noise_keypair_t *kp =
332     wg_get_active_keypair (&peer->remote, data->receiver_index);
333
334   if (!noise_counter_recv (&kp->kp_ctr, data->counter))
335     {
336       return -1;
337     }
338
339   u16 encr_len = b->current_length - sizeof (message_data_t);
340   u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
341
342   vlib_buffer_advance (b, sizeof (message_data_t));
343   b->current_length = decr_len;
344   vnet_buffer_offload_flags_clear (b, VNET_BUFFER_OFFLOAD_F_UDP_CKSUM);
345
346   /* Keepalive packet has zero length */
347   if (decr_len == 0)
348     {
349       *is_keepalive = true;
350       return -1;
351     }
352
353   wg_timers_data_received (peer);
354
355   ip46_address_t src_ip;
356   u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b));
357   if (is_ip4_inner)
358     {
359       ip46_address_set_ip4 (
360         &src_ip, &((ip4_header_t *) vlib_buffer_get_current (b))->src_address);
361     }
362   else
363     {
364       ip46_address_set_ip6 (
365         &src_ip, &((ip6_header_t *) vlib_buffer_get_current (b))->src_address);
366     }
367
368   const fib_prefix_t *allowed_ip;
369   bool allowed = false;
370
371   /*
372    * we could make this into an ACL, but the expectation
373    * is that there aren't many allowed IPs and thus a linear
374    * walk is faster than an ACL
375    */
376   vec_foreach (allowed_ip, peer->allowed_ips)
377     {
378       if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip))
379         {
380           allowed = true;
381           break;
382         }
383     }
384   if (allowed)
385     {
386       vnet_buffer (b)->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
387       next[0] =
388         is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : WG_INPUT_NEXT_IP6_INPUT;
389     }
390
391   return 0;
392 }
393
394 static_always_inline void
395 wg_input_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node,
396                       vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts,
397                       u16 drop_next)
398 {
399   u32 n_fail, n_ops = vec_len (ops);
400   vnet_crypto_op_t *op = ops;
401
402   if (n_ops == 0)
403     return;
404
405   n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops);
406
407   while (n_fail)
408     {
409       ASSERT (op - ops < n_ops);
410
411       if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED)
412         {
413           u32 bi = op->user_data;
414           b[bi]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
415           nexts[bi] = drop_next;
416           n_fail--;
417         }
418       op++;
419     }
420 }
421
422 always_inline void
423 wg_prepare_sync_dec_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
424                         u8 *src, u32 src_len, u8 *dst, u8 *aad, u32 aad_len,
425                         vnet_crypto_key_index_t key_index, u32 bi, u8 *iv)
426 {
427   vnet_crypto_op_t _op, *op = &_op;
428   u8 src_[] = {};
429
430   vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES);
431   vnet_crypto_op_init (op, VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC);
432
433   op->tag_len = NOISE_AUTHTAG_LEN;
434   op->tag = src + src_len;
435   op->src = !src ? src_ : src;
436   op->len = src_len;
437   op->dst = dst;
438   op->key_index = key_index;
439   op->aad = aad;
440   op->aad_len = aad_len;
441   op->iv = iv;
442   op->user_data = bi;
443   op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
444 }
445
446 static_always_inline void
447 wg_input_add_to_frame (vlib_main_t *vm, vnet_crypto_async_frame_t *f,
448                        u32 key_index, u32 crypto_len, i16 crypto_start_offset,
449                        u32 buffer_index, u16 next_node, u8 *iv, u8 *tag,
450                        u8 flags)
451 {
452   vnet_crypto_async_frame_elt_t *fe;
453   u16 index;
454
455   ASSERT (f->n_elts < VNET_CRYPTO_FRAME_SIZE);
456
457   index = f->n_elts;
458   fe = &f->elts[index];
459   f->n_elts++;
460   fe->key_index = key_index;
461   fe->crypto_total_length = crypto_len;
462   fe->crypto_start_offset = crypto_start_offset;
463   fe->iv = iv;
464   fe->tag = tag;
465   fe->flags = flags;
466   f->buffer_indices[index] = buffer_index;
467   f->next_node_index[index] = next_node;
468 }
469
470 static_always_inline enum noise_state_crypt
471 wg_input_process (vlib_main_t *vm, wg_per_thread_data_t *ptd,
472                   vnet_crypto_op_t **crypto_ops,
473                   vnet_crypto_async_frame_t **async_frame, vlib_buffer_t *b,
474                   u32 buf_idx, noise_remote_t *r, uint32_t r_idx,
475                   uint64_t nonce, uint8_t *src, size_t srclen, uint8_t *dst,
476                   u32 from_idx, u8 *iv, f64 time, u8 is_async,
477                   u16 async_next_node)
478 {
479   noise_keypair_t *kp;
480   enum noise_state_crypt ret = SC_FAILED;
481
482   if ((kp = wg_get_active_keypair (r, r_idx)) == NULL)
483     {
484       goto error;
485     }
486
487   /* We confirm that our values are within our tolerances. These values
488    * are the same as the encrypt routine.
489    *
490    * kp_ctr isn't locked here, we're happy to accept a racy read. */
491   if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
492                                     time) ||
493       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
494     goto error;
495
496   /* Decrypt, then validate the counter. We don't want to validate the
497    * counter before decrypting as we do not know the message is authentic
498    * prior to decryption. */
499
500   clib_memset (iv, 0, 4);
501   clib_memcpy (iv + 4, &nonce, sizeof (nonce));
502
503   if (is_async)
504     {
505       if (NULL == *async_frame ||
506           vnet_crypto_async_frame_is_full (*async_frame))
507         {
508           *async_frame = vnet_crypto_async_get_frame (
509             vm, VNET_CRYPTO_OP_CHACHA20_POLY1305_TAG16_AAD0_DEC);
510           /* Save the frame to the list we'll submit at the end */
511           vec_add1 (ptd->async_frames, *async_frame);
512         }
513
514       wg_input_add_to_frame (vm, *async_frame, kp->kp_recv_index, srclen,
515                              src - b->data, buf_idx, async_next_node, iv,
516                              src + srclen, VNET_CRYPTO_OP_FLAG_HMAC_CHECK);
517     }
518   else
519     {
520       wg_prepare_sync_dec_op (vm, crypto_ops, src, srclen, dst, NULL, 0,
521                               kp->kp_recv_index, from_idx, iv);
522     }
523
524   /* If we've received the handshake confirming data packet then move the
525    * next keypair into current. If we do slide the next keypair in, then
526    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
527    * data packet can't confirm a session that we are an INITIATOR of. */
528   if (kp == r->r_next)
529     {
530       clib_rwlock_writer_lock (&r->r_keypair_lock);
531       if (kp == r->r_next && kp->kp_local_index == r_idx)
532         {
533           noise_remote_keypair_free (vm, r, &r->r_previous);
534           r->r_previous = r->r_current;
535           r->r_current = r->r_next;
536           r->r_next = NULL;
537
538           ret = SC_CONN_RESET;
539           clib_rwlock_writer_unlock (&r->r_keypair_lock);
540           goto error;
541         }
542       clib_rwlock_writer_unlock (&r->r_keypair_lock);
543     }
544
545   /* Similar to when we encrypt, we want to notify the caller when we
546    * are approaching our tolerances. We notify if:
547    *  - we're the initiator and the current keypair is older than
548    *    REKEY_AFTER_TIME_RECV seconds. */
549   ret = SC_KEEP_KEY_FRESH;
550   kp = r->r_current;
551   if (kp != NULL && kp->kp_valid && kp->kp_is_initiator &&
552       wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV,
553                                     time))
554     goto error;
555
556   ret = SC_OK;
557 error:
558   return ret;
559 }
560
561 always_inline uword
562 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
563                  vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
564 {
565   vnet_main_t *vnm = vnet_get_main ();
566   vnet_interface_main_t *im = &vnm->interface_main;
567   wg_main_t *wmp = &wg_main;
568   wg_per_thread_data_t *ptd =
569     vec_elt_at_index (wmp->per_thread_data, vm->thread_index);
570   u32 *from = vlib_frame_vector_args (frame);
571   u32 n_left_from = frame->n_vectors;
572
573   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
574   u32 thread_index = vm->thread_index;
575   vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops;
576   const u16 drop_next = WG_INPUT_NEXT_PUNT;
577   message_type_t header_type;
578   vlib_buffer_t *data_bufs[VLIB_FRAME_SIZE];
579   u32 data_bi[VLIB_FRAME_SIZE];  /* buffer index for data */
580   u32 other_bi[VLIB_FRAME_SIZE]; /* buffer index for drop or handoff */
581   u16 other_nexts[VLIB_FRAME_SIZE], *other_next = other_nexts, n_other = 0;
582   u16 data_nexts[VLIB_FRAME_SIZE], *data_next = data_nexts, n_data = 0;
583   u16 n_async = 0;
584   const u8 is_async = wg_op_mode_is_set_ASYNC ();
585   vnet_crypto_async_frame_t *async_frame = NULL;
586
587   vlib_get_buffers (vm, from, bufs, n_left_from);
588   vec_reset_length (ptd->crypto_ops);
589   vec_reset_length (ptd->async_frames);
590
591   f64 time = clib_time_now (&vm->clib_time) + vm->time_offset;
592
593   wg_peer_t *peer = NULL;
594   u32 *last_peer_time_idx = NULL;
595   u32 last_rec_idx = ~0;
596
597   bool is_keepalive = false;
598   u32 *peer_idx = NULL;
599
600   while (n_left_from > 0)
601     {
602       if (n_left_from > 2)
603         {
604           u8 *p;
605           vlib_prefetch_buffer_header (b[2], LOAD);
606           p = vlib_buffer_get_current (b[1]);
607           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
608           CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES,
609                          LOAD);
610         }
611
612       other_next[n_other] = WG_INPUT_NEXT_PUNT;
613       data_nexts[n_data] = WG_INPUT_N_NEXT;
614
615       header_type =
616         ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
617
618       if (PREDICT_TRUE (header_type == MESSAGE_DATA))
619         {
620           message_data_t *data = vlib_buffer_get_current (b[0]);
621           u8 *iv_data = b[0]->pre_data;
622           u32 buf_idx = from[b - bufs];
623           peer_idx = wg_index_table_lookup (&wmp->index_table,
624                                             data->receiver_index);
625
626           if (data->receiver_index != last_rec_idx)
627             {
628               peer_idx = wg_index_table_lookup (&wmp->index_table,
629                                                 data->receiver_index);
630               if (PREDICT_TRUE (peer_idx != NULL))
631                 {
632                   peer = wg_peer_get (*peer_idx);
633                 }
634               last_rec_idx = data->receiver_index;
635             }
636
637           if (PREDICT_FALSE (!peer_idx))
638             {
639               other_next[n_other] = WG_INPUT_NEXT_ERROR;
640               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
641               other_bi[n_other] = buf_idx;
642               n_other += 1;
643               goto out;
644             }
645
646           if (PREDICT_FALSE (~0 == peer->input_thread_index))
647             {
648               /* this is the first packet to use this peer, claim the peer
649                * for this thread.
650                */
651               clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
652                                         wg_peer_assign_thread (thread_index));
653             }
654
655           if (PREDICT_TRUE (thread_index != peer->input_thread_index))
656             {
657               other_next[n_other] = WG_INPUT_NEXT_HANDOFF_DATA;
658               other_bi[n_other] = buf_idx;
659               n_other += 1;
660               goto next;
661             }
662
663           u16 encr_len = b[0]->current_length - sizeof (message_data_t);
664           u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
665           if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
666             {
667               b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
668               other_bi[n_other] = buf_idx;
669               n_other += 1;
670               goto out;
671             }
672
673           enum noise_state_crypt state_cr = wg_input_process (
674             vm, ptd, crypto_ops, &async_frame, b[0], buf_idx, &peer->remote,
675             data->receiver_index, data->counter, data->encrypted_data,
676             decr_len, data->encrypted_data, n_data, iv_data, time, is_async,
677             async_next_node);
678
679           if (PREDICT_FALSE (state_cr == SC_FAILED))
680             {
681               wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, false);
682               other_next[n_other] = WG_INPUT_NEXT_ERROR;
683               b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
684               other_bi[n_other] = buf_idx;
685               n_other += 1;
686               goto out;
687             }
688           if (!is_async)
689             {
690               data_bufs[n_data] = b[0];
691               data_bi[n_data] = buf_idx;
692               n_data += 1;
693             }
694           else
695             {
696               n_async += 1;
697             }
698
699           if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
700             {
701               wg_timers_handshake_complete (peer);
702               goto next;
703             }
704           else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
705             {
706               wg_send_handshake_from_mt (*peer_idx, false);
707               goto next;
708             }
709           else if (PREDICT_TRUE (state_cr == SC_OK))
710             goto next;
711         }
712       else
713         {
714           peer_idx = NULL;
715
716           /* Handshake packets should be processed in main thread */
717           if (thread_index != 0)
718             {
719               other_next[n_other] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
720               other_bi[n_other] = from[b - bufs];
721               n_other += 1;
722               goto next;
723             }
724
725           wg_input_error_t ret =
726             wg_handshake_process (vm, wmp, b[0], node->node_index, is_ip4);
727           if (ret != WG_INPUT_ERROR_NONE)
728             {
729               other_next[n_other] = WG_INPUT_NEXT_ERROR;
730               b[0]->error = node->errors[ret];
731               other_bi[n_other] = from[b - bufs];
732               n_other += 1;
733             }
734           else
735             {
736               other_bi[n_other] = from[b - bufs];
737               n_other += 1;
738             }
739         }
740
741     out:
742       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
743                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
744         {
745           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
746           t->type = header_type;
747           t->current_length = b[0]->current_length;
748           t->is_keepalive = is_keepalive;
749           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
750         }
751
752     next:
753       n_left_from -= 1;
754       b += 1;
755     }
756
757   /* decrypt packets */
758   wg_input_process_ops (vm, node, ptd->crypto_ops, data_bufs, data_nexts,
759                         drop_next);
760
761   /* process after decryption */
762   b = data_bufs;
763   n_left_from = n_data;
764   last_rec_idx = ~0;
765   last_peer_time_idx = NULL;
766
767   while (n_left_from > 0)
768     {
769       bool is_keepalive = false;
770       u32 *peer_idx = NULL;
771
772       if (PREDICT_FALSE (data_next[0] == WG_INPUT_NEXT_PUNT))
773         {
774           goto trace;
775         }
776       if (n_left_from > 2)
777         {
778           u8 *p;
779           vlib_prefetch_buffer_header (b[2], LOAD);
780           p = vlib_buffer_get_current (b[1]);
781           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
782           CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES,
783                          LOAD);
784         }
785
786       message_data_t *data = vlib_buffer_get_current (b[0]);
787
788       if (data->receiver_index != last_rec_idx)
789         {
790           peer_idx =
791             wg_index_table_lookup (&wmp->index_table, data->receiver_index);
792           peer = wg_peer_get (*peer_idx);
793           last_rec_idx = data->receiver_index;
794         }
795
796       if (PREDICT_FALSE (wg_input_post_process (vm, b[0], data_next, peer,
797                                                 data, &is_keepalive) < 0))
798         goto trace;
799
800       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
801         {
802           wg_timers_any_authenticated_packet_received_opt (peer, time);
803           wg_timers_any_authenticated_packet_traversal (peer);
804           last_peer_time_idx = peer_idx;
805         }
806
807       vlib_increment_combined_counter (im->combined_sw_if_counters +
808                                          VNET_INTERFACE_COUNTER_RX,
809                                        vm->thread_index, peer->wg_sw_if_index,
810                                        1 /* packets */, b[0]->current_length);
811
812     trace:
813       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
814                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
815         {
816           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
817           t->type = header_type;
818           t->current_length = b[0]->current_length;
819           t->is_keepalive = is_keepalive;
820           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
821         }
822
823       b += 1;
824       n_left_from -= 1;
825       data_next += 1;
826     }
827
828   if (n_async)
829     {
830       /* submit all of the open frames */
831       vnet_crypto_async_frame_t **async_frame;
832       vec_foreach (async_frame, ptd->async_frames)
833         {
834           if (PREDICT_FALSE (
835                 vnet_crypto_async_submit_open_frame (vm, *async_frame) < 0))
836             {
837               u32 n_drop = (*async_frame)->n_elts;
838               u32 *bi = (*async_frame)->buffer_indices;
839               u16 index = n_other;
840               while (n_drop--)
841                 {
842                   other_bi[index] = bi[0];
843                   vlib_buffer_t *b = vlib_get_buffer (vm, bi[0]);
844                   other_nexts[index] = drop_next;
845                   b->error = node->errors[WG_INPUT_ERROR_CRYPTO_ENGINE_ERROR];
846                   bi++;
847                   index++;
848                 }
849               n_other += (*async_frame)->n_elts;
850
851               vnet_crypto_async_reset_frame (*async_frame);
852               vnet_crypto_async_free_frame (vm, *async_frame);
853             }
854         }
855     }
856
857   /* enqueue other bufs */
858   if (n_other)
859     vlib_buffer_enqueue_to_next (vm, node, other_bi, other_next, n_other);
860
861   /* enqueue data bufs */
862   if (n_data)
863     vlib_buffer_enqueue_to_next (vm, node, data_bi, data_nexts, n_data);
864
865   return frame->n_vectors;
866 }
867
868 always_inline uword
869 wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
870 {
871   vnet_main_t *vnm = vnet_get_main ();
872   vnet_interface_main_t *im = &vnm->interface_main;
873   wg_main_t *wmp = &wg_main;
874   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
875   u16 nexts[VLIB_FRAME_SIZE], *next = nexts;
876   u32 *from = vlib_frame_vector_args (frame);
877   u32 n_left = frame->n_vectors;
878   wg_peer_t *peer = NULL;
879   u32 *peer_idx = NULL;
880   u32 *last_peer_time_idx = NULL;
881   u32 last_rec_idx = ~0;
882   f64 time = clib_time_now (&vm->clib_time) + vm->time_offset;
883
884   vlib_get_buffers (vm, from, b, n_left);
885
886   if (n_left >= 2)
887     {
888       vlib_prefetch_buffer_header (b[0], LOAD);
889       vlib_prefetch_buffer_header (b[1], LOAD);
890     }
891
892   while (n_left > 0)
893     {
894       if (n_left > 2)
895         {
896           u8 *p;
897           vlib_prefetch_buffer_header (b[2], LOAD);
898           p = vlib_buffer_get_current (b[1]);
899           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
900         }
901
902       bool is_keepalive = false;
903       message_data_t *data = vlib_buffer_get_current (b[0]);
904
905       if (data->receiver_index != last_rec_idx)
906         {
907           peer_idx =
908             wg_index_table_lookup (&wmp->index_table, data->receiver_index);
909
910           peer = wg_peer_get (*peer_idx);
911           last_rec_idx = data->receiver_index;
912         }
913
914       if (PREDICT_TRUE (peer != NULL))
915         {
916           if (PREDICT_FALSE (wg_input_post_process (vm, b[0], next, peer, data,
917                                                     &is_keepalive) < 0))
918             goto trace;
919         }
920       else
921         {
922           next[0] = WG_INPUT_NEXT_PUNT;
923           goto trace;
924         }
925
926       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
927         {
928           wg_timers_any_authenticated_packet_received_opt (peer, time);
929           wg_timers_any_authenticated_packet_traversal (peer);
930           last_peer_time_idx = peer_idx;
931         }
932
933       vlib_increment_combined_counter (im->combined_sw_if_counters +
934                                          VNET_INTERFACE_COUNTER_RX,
935                                        vm->thread_index, peer->wg_sw_if_index,
936                                        1 /* packets */, b[0]->current_length);
937
938     trace:
939       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
940                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
941         {
942           wg_input_post_trace_t *t =
943             vlib_add_trace (vm, node, b[0], sizeof (*t));
944           t->next = next[0];
945           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
946         }
947
948       b += 1;
949       next += 1;
950       n_left -= 1;
951     }
952
953   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
954   return frame->n_vectors;
955 }
956
957 VLIB_NODE_FN (wg4_input_node)
958 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
959 {
960   return wg_input_inline (vm, node, frame, /* is_ip4 */ 1,
961                           wg_decrypt_async_next.wg4_post_next);
962 }
963
964 VLIB_NODE_FN (wg6_input_node)
965 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
966 {
967   return wg_input_inline (vm, node, frame, /* is_ip4 */ 0,
968                           wg_decrypt_async_next.wg6_post_next);
969 }
970
971 VLIB_NODE_FN (wg4_input_post_node)
972 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
973 {
974   return wg_input_post (vm, node, from_frame);
975 }
976
977 VLIB_NODE_FN (wg6_input_post_node)
978 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
979 {
980   return wg_input_post (vm, node, from_frame);
981 }
982
983 /* *INDENT-OFF* */
984 VLIB_REGISTER_NODE (wg4_input_node) =
985 {
986   .name = "wg4-input",
987   .vector_size = sizeof (u32),
988   .format_trace = format_wg_input_trace,
989   .type = VLIB_NODE_TYPE_INTERNAL,
990   .n_errors = ARRAY_LEN (wg_input_error_strings),
991   .error_strings = wg_input_error_strings,
992   .n_next_nodes = WG_INPUT_N_NEXT,
993   /* edit / add dispositions here */
994   .next_nodes = {
995         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg4-handshake-handoff",
996         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg4-input-data-handoff",
997         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
998         [WG_INPUT_NEXT_IP6_INPUT] = "ip6-input",
999         [WG_INPUT_NEXT_PUNT] = "error-punt",
1000         [WG_INPUT_NEXT_ERROR] = "error-drop",
1001   },
1002 };
1003
1004 VLIB_REGISTER_NODE (wg6_input_node) =
1005 {
1006   .name = "wg6-input",
1007   .vector_size = sizeof (u32),
1008   .format_trace = format_wg_input_trace,
1009   .type = VLIB_NODE_TYPE_INTERNAL,
1010   .n_errors = ARRAY_LEN (wg_input_error_strings),
1011   .error_strings = wg_input_error_strings,
1012   .n_next_nodes = WG_INPUT_N_NEXT,
1013   /* edit / add dispositions here */
1014   .next_nodes = {
1015         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg6-handshake-handoff",
1016         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg6-input-data-handoff",
1017         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
1018         [WG_INPUT_NEXT_IP6_INPUT] = "ip6-input",
1019         [WG_INPUT_NEXT_PUNT] = "error-punt",
1020         [WG_INPUT_NEXT_ERROR] = "error-drop",
1021   },
1022 };
1023
1024 VLIB_REGISTER_NODE (wg4_input_post_node) = {
1025   .name = "wg4-input-post-node",
1026   .vector_size = sizeof (u32),
1027   .format_trace = format_wg_input_post_trace,
1028   .type = VLIB_NODE_TYPE_INTERNAL,
1029   .sibling_of = "wg4-input",
1030
1031   .n_errors = ARRAY_LEN (wg_input_error_strings),
1032   .error_strings = wg_input_error_strings,
1033 };
1034
1035 VLIB_REGISTER_NODE (wg6_input_post_node) = {
1036   .name = "wg6-input-post-node",
1037   .vector_size = sizeof (u32),
1038   .format_trace = format_wg_input_post_trace,
1039   .type = VLIB_NODE_TYPE_INTERNAL,
1040   .sibling_of = "wg6-input",
1041
1042   .n_errors = ARRAY_LEN (wg_input_error_strings),
1043   .error_strings = wg_input_error_strings,
1044 };
1045
1046 /* *INDENT-ON* */
1047
1048 /*
1049  * fd.io coding-style-patch-verification: ON
1050  *
1051  * Local Variables:
1052  * eval: (c-set-style "gnu")
1053  * End:
1054  */