wireguard: add handshake rate limiting support
[vpp.git] / src / plugins / wireguard / wireguard_input.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <vlib/vlib.h>
18 #include <vnet/vnet.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
21
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
24
25 #define foreach_wg_input_error                                                \
26   _ (NONE, "No error")                                                        \
27   _ (HANDSHAKE_MAC, "Invalid MAC handshake")                                  \
28   _ (HANDSHAKE_RATELIMITED, "Handshake ratelimited")                          \
29   _ (PEER, "Peer error")                                                      \
30   _ (INTERFACE, "Interface error")                                            \
31   _ (DECRYPTION, "Failed during decryption")                                  \
32   _ (KEEPALIVE_SEND, "Failed while sending Keepalive")                        \
33   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
34   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
35   _ (COOKIE_DECRYPTION, "Failed during Cookie decryption")                    \
36   _ (COOKIE_SEND, "Failed during sending Cookie")                             \
37   _ (TOO_BIG, "Packet too big")                                               \
38   _ (UNDEFINED, "Undefined error")                                            \
39   _ (CRYPTO_ENGINE_ERROR, "crypto engine error (packet dropped)")
40
41 typedef enum
42 {
43 #define _(sym,str) WG_INPUT_ERROR_##sym,
44   foreach_wg_input_error
45 #undef _
46     WG_INPUT_N_ERROR,
47 } wg_input_error_t;
48
49 static char *wg_input_error_strings[] = {
50 #define _(sym,string) string,
51   foreach_wg_input_error
52 #undef _
53 };
54
55 typedef struct
56 {
57   message_type_t type;
58   u16 current_length;
59   bool is_keepalive;
60   index_t peer;
61 } wg_input_trace_t;
62
63 typedef struct
64 {
65   index_t peer;
66   u16 next;
67 } wg_input_post_trace_t;
68
69 u8 *
70 format_wg_message_type (u8 * s, va_list * args)
71 {
72   message_type_t type = va_arg (*args, message_type_t);
73
74   switch (type)
75     {
76 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
77       foreach_wg_message_type
78 #undef _
79     }
80   return (format (s, "unknown"));
81 }
82
83 /* packet trace format function */
84 static u8 *
85 format_wg_input_trace (u8 * s, va_list * args)
86 {
87   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
88   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
89
90   wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
91
92   s = format (s, "Wireguard input: \n");
93   s = format (s, "    Type: %U\n", format_wg_message_type, t->type);
94   s = format (s, "    Peer: %d\n", t->peer);
95   s = format (s, "    Length: %d\n", t->current_length);
96   s = format (s, "    Keepalive: %s", t->is_keepalive ? "true" : "false");
97
98   return s;
99 }
100
101 /* post-node packet trace format function */
102 static u8 *
103 format_wg_input_post_trace (u8 *s, va_list *args)
104 {
105   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
106   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
107
108   wg_input_post_trace_t *t = va_arg (*args, wg_input_post_trace_t *);
109
110   s = format (s, "WG input post: \n");
111   s = format (s, "  peer: %u\n", t->peer);
112   s = format (s, "  next: %u\n", t->next);
113
114   return s;
115 }
116
117 typedef enum
118 {
119   WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
120   WG_INPUT_NEXT_HANDOFF_DATA,
121   WG_INPUT_NEXT_IP4_INPUT,
122   WG_INPUT_NEXT_IP6_INPUT,
123   WG_INPUT_NEXT_PUNT,
124   WG_INPUT_NEXT_ERROR,
125   WG_INPUT_N_NEXT,
126 } wg_input_next_t;
127
128 /* static void */
129 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
130 /* { */
131 /*   if (peer) */
132 /*     { */
133 /*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
134 /*       peer->dst.port = udp_port; */
135 /*     } */
136 /* } */
137
138 static u8
139 is_ip4_header (u8 *data)
140 {
141   return (data[0] >> 4) == 0x4;
142 }
143
144 static wg_input_error_t
145 wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
146                       u32 node_idx, u8 is_ip4)
147 {
148   ASSERT (vm->thread_index == 0);
149
150   enum cookie_mac_state mac_state;
151   bool packet_needs_cookie;
152   bool under_load;
153   index_t *wg_ifs;
154   wg_if_t *wg_if;
155   wg_peer_t *peer = NULL;
156
157   void *current_b_data = vlib_buffer_get_current (b);
158
159   ip46_address_t src_ip;
160   if (is_ip4)
161     {
162       ip4_header_t *iph4 =
163         current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
164       ip46_address_set_ip4 (&src_ip, &iph4->src_address);
165     }
166   else
167     {
168       ip6_header_t *iph6 =
169         current_b_data - sizeof (udp_header_t) - sizeof (ip6_header_t);
170       ip46_address_set_ip6 (&src_ip, &iph6->src_address);
171     }
172
173   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
174   u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
175   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
176
177   message_header_t *header = current_b_data;
178
179   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
180     {
181       message_handshake_cookie_t *packet =
182         (message_handshake_cookie_t *) current_b_data;
183       u32 *entry =
184         wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
185       if (entry)
186         peer = wg_peer_get (*entry);
187       else
188         return WG_INPUT_ERROR_PEER;
189
190       if (!cookie_maker_consume_payload (
191             vm, &peer->cookie_maker, packet->nonce, packet->encrypted_cookie))
192         return WG_INPUT_ERROR_COOKIE_DECRYPTION;
193
194       return WG_INPUT_ERROR_NONE;
195     }
196
197   u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
198              sizeof (message_handshake_initiation_t) :
199              sizeof (message_handshake_response_t));
200
201   message_macs_t *macs = (message_macs_t *)
202     ((u8 *) current_b_data + len - sizeof (*macs));
203
204   index_t *ii;
205   wg_ifs = wg_if_indexes_get_by_port (udp_dst_port);
206   if (NULL == wg_ifs)
207     return WG_INPUT_ERROR_INTERFACE;
208
209   vec_foreach (ii, wg_ifs)
210     {
211       wg_if = wg_if_get (*ii);
212       if (NULL == wg_if)
213         continue;
214
215       under_load = wg_if_is_under_load (vm, wg_if);
216       mac_state = cookie_checker_validate_macs (
217         vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load,
218         &src_ip, udp_src_port);
219       if (mac_state == INVALID_MAC)
220         {
221           wg_if_dec_handshake_num (wg_if);
222           wg_if = NULL;
223           continue;
224         }
225       break;
226     }
227
228   if (NULL == wg_if)
229     return WG_INPUT_ERROR_HANDSHAKE_MAC;
230
231   if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
232       || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
233     packet_needs_cookie = false;
234   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
235     packet_needs_cookie = true;
236   else if (mac_state == VALID_MAC_WITH_COOKIE_BUT_RATELIMITED)
237     return WG_INPUT_ERROR_HANDSHAKE_RATELIMITED;
238   else
239     return WG_INPUT_ERROR_HANDSHAKE_MAC;
240
241   switch (header->type)
242     {
243     case MESSAGE_HANDSHAKE_INITIATION:
244       {
245         message_handshake_initiation_t *message = current_b_data;
246
247         if (packet_needs_cookie)
248           {
249
250             if (!wg_send_handshake_cookie (vm, message->sender_index,
251                                            &wg_if->cookie_checker, macs,
252                                            &ip_addr_46 (&wg_if->src_ip),
253                                            wg_if->port, &src_ip, udp_src_port))
254               return WG_INPUT_ERROR_COOKIE_SEND;
255
256             return WG_INPUT_ERROR_NONE;
257           }
258
259         noise_remote_t *rp;
260         if (noise_consume_initiation
261             (vm, noise_local_get (wg_if->local_idx), &rp,
262              message->sender_index, message->unencrypted_ephemeral,
263              message->encrypted_static, message->encrypted_timestamp))
264           {
265             peer = wg_peer_get (rp->r_peer_idx);
266           }
267         else
268           {
269             return WG_INPUT_ERROR_PEER;
270           }
271
272         // set_peer_address (peer, ip4_src, udp_src_port);
273         if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
274           {
275             vlib_node_increment_counter (vm, node_idx,
276                                          WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
277           }
278         else
279           {
280             wg_peer_update_flags (rp->r_peer_idx, WG_PEER_ESTABLISHED, true);
281           }
282         break;
283       }
284     case MESSAGE_HANDSHAKE_RESPONSE:
285       {
286         message_handshake_response_t *resp = current_b_data;
287
288         if (packet_needs_cookie)
289           {
290             if (!wg_send_handshake_cookie (vm, resp->sender_index,
291                                            &wg_if->cookie_checker, macs,
292                                            &ip_addr_46 (&wg_if->src_ip),
293                                            wg_if->port, &src_ip, udp_src_port))
294               return WG_INPUT_ERROR_COOKIE_SEND;
295
296             return WG_INPUT_ERROR_NONE;
297           }
298
299         index_t peeri = INDEX_INVALID;
300         u32 *entry =
301           wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
302
303         if (PREDICT_TRUE (entry != NULL))
304           {
305             peeri = *entry;
306             peer = wg_peer_get (peeri);
307             if (wg_peer_is_dead (peer))
308               return WG_INPUT_ERROR_PEER;
309           }
310         else
311           return WG_INPUT_ERROR_PEER;
312
313         if (!noise_consume_response
314             (vm, &peer->remote, resp->sender_index,
315              resp->receiver_index, resp->unencrypted_ephemeral,
316              resp->encrypted_nothing))
317           {
318             return WG_INPUT_ERROR_PEER;
319           }
320
321         // set_peer_address (peer, ip4_src, udp_src_port);
322         if (noise_remote_begin_session (vm, &peer->remote))
323           {
324
325             wg_timers_session_derived (peer);
326             wg_timers_handshake_complete (peer);
327             if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
328               {
329                 vlib_node_increment_counter (vm, node_idx,
330                                              WG_INPUT_ERROR_KEEPALIVE_SEND, 1);
331               }
332             else
333               {
334                 wg_peer_update_flags (peeri, WG_PEER_ESTABLISHED, true);
335               }
336           }
337         break;
338       }
339     default:
340       return WG_INPUT_ERROR_HANDSHAKE_RECEIVE;
341     }
342
343   wg_timers_any_authenticated_packet_received (peer);
344   wg_timers_any_authenticated_packet_traversal (peer);
345   return WG_INPUT_ERROR_NONE;
346 }
347
348 static_always_inline int
349 wg_input_post_process (vlib_main_t *vm, vlib_buffer_t *b, u16 *next,
350                        wg_peer_t *peer, message_data_t *data,
351                        bool *is_keepalive)
352 {
353   next[0] = WG_INPUT_NEXT_PUNT;
354
355   noise_keypair_t *kp =
356     wg_get_active_keypair (&peer->remote, data->receiver_index);
357
358   if (!noise_counter_recv (&kp->kp_ctr, data->counter))
359     {
360       return -1;
361     }
362
363   u16 encr_len = b->current_length - sizeof (message_data_t);
364   u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
365
366   vlib_buffer_advance (b, sizeof (message_data_t));
367   b->current_length = decr_len;
368   vnet_buffer_offload_flags_clear (b, VNET_BUFFER_OFFLOAD_F_UDP_CKSUM);
369
370   /* Keepalive packet has zero length */
371   if (decr_len == 0)
372     {
373       *is_keepalive = true;
374       return -1;
375     }
376
377   wg_timers_data_received (peer);
378
379   ip46_address_t src_ip;
380   u8 is_ip4_inner = is_ip4_header (vlib_buffer_get_current (b));
381   if (is_ip4_inner)
382     {
383       ip46_address_set_ip4 (
384         &src_ip, &((ip4_header_t *) vlib_buffer_get_current (b))->src_address);
385     }
386   else
387     {
388       ip46_address_set_ip6 (
389         &src_ip, &((ip6_header_t *) vlib_buffer_get_current (b))->src_address);
390     }
391
392   const fib_prefix_t *allowed_ip;
393   bool allowed = false;
394
395   /*
396    * we could make this into an ACL, but the expectation
397    * is that there aren't many allowed IPs and thus a linear
398    * walk is faster than an ACL
399    */
400   vec_foreach (allowed_ip, peer->allowed_ips)
401     {
402       if (fib_prefix_is_cover_addr_46 (allowed_ip, &src_ip))
403         {
404           allowed = true;
405           break;
406         }
407     }
408   if (allowed)
409     {
410       vnet_buffer (b)->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
411       next[0] =
412         is_ip4_inner ? WG_INPUT_NEXT_IP4_INPUT : WG_INPUT_NEXT_IP6_INPUT;
413     }
414
415   return 0;
416 }
417
418 static_always_inline void
419 wg_input_process_ops (vlib_main_t *vm, vlib_node_runtime_t *node,
420                       vnet_crypto_op_t *ops, vlib_buffer_t *b[], u16 *nexts,
421                       u16 drop_next)
422 {
423   u32 n_fail, n_ops = vec_len (ops);
424   vnet_crypto_op_t *op = ops;
425
426   if (n_ops == 0)
427     return;
428
429   n_fail = n_ops - vnet_crypto_process_ops (vm, op, n_ops);
430
431   while (n_fail)
432     {
433       ASSERT (op - ops < n_ops);
434
435       if (op->status != VNET_CRYPTO_OP_STATUS_COMPLETED)
436         {
437           u32 bi = op->user_data;
438           b[bi]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
439           nexts[bi] = drop_next;
440           n_fail--;
441         }
442       op++;
443     }
444 }
445
446 always_inline void
447 wg_prepare_sync_dec_op (vlib_main_t *vm, vnet_crypto_op_t **crypto_ops,
448                         u8 *src, u32 src_len, u8 *dst, u8 *aad, u32 aad_len,
449                         vnet_crypto_key_index_t key_index, u32 bi, u8 *iv)
450 {
451   vnet_crypto_op_t _op, *op = &_op;
452   u8 src_[] = {};
453
454   vec_add2_aligned (crypto_ops[0], op, 1, CLIB_CACHE_LINE_BYTES);
455   vnet_crypto_op_init (op, VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC);
456
457   op->tag_len = NOISE_AUTHTAG_LEN;
458   op->tag = src + src_len;
459   op->src = !src ? src_ : src;
460   op->len = src_len;
461   op->dst = dst;
462   op->key_index = key_index;
463   op->aad = aad;
464   op->aad_len = aad_len;
465   op->iv = iv;
466   op->user_data = bi;
467   op->flags |= VNET_CRYPTO_OP_FLAG_HMAC_CHECK;
468 }
469
470 static_always_inline void
471 wg_input_add_to_frame (vlib_main_t *vm, vnet_crypto_async_frame_t *f,
472                        u32 key_index, u32 crypto_len, i16 crypto_start_offset,
473                        u32 buffer_index, u16 next_node, u8 *iv, u8 *tag,
474                        u8 flags)
475 {
476   vnet_crypto_async_frame_elt_t *fe;
477   u16 index;
478
479   ASSERT (f->n_elts < VNET_CRYPTO_FRAME_SIZE);
480
481   index = f->n_elts;
482   fe = &f->elts[index];
483   f->n_elts++;
484   fe->key_index = key_index;
485   fe->crypto_total_length = crypto_len;
486   fe->crypto_start_offset = crypto_start_offset;
487   fe->iv = iv;
488   fe->tag = tag;
489   fe->flags = flags;
490   f->buffer_indices[index] = buffer_index;
491   f->next_node_index[index] = next_node;
492 }
493
494 static_always_inline enum noise_state_crypt
495 wg_input_process (vlib_main_t *vm, wg_per_thread_data_t *ptd,
496                   vnet_crypto_op_t **crypto_ops,
497                   vnet_crypto_async_frame_t **async_frame, vlib_buffer_t *b,
498                   u32 buf_idx, noise_remote_t *r, uint32_t r_idx,
499                   uint64_t nonce, uint8_t *src, size_t srclen, uint8_t *dst,
500                   u32 from_idx, u8 *iv, f64 time, u8 is_async,
501                   u16 async_next_node)
502 {
503   noise_keypair_t *kp;
504   enum noise_state_crypt ret = SC_FAILED;
505
506   if ((kp = wg_get_active_keypair (r, r_idx)) == NULL)
507     {
508       goto error;
509     }
510
511   /* We confirm that our values are within our tolerances. These values
512    * are the same as the encrypt routine.
513    *
514    * kp_ctr isn't locked here, we're happy to accept a racy read. */
515   if (wg_birthdate_has_expired_opt (kp->kp_birthdate, REJECT_AFTER_TIME,
516                                     time) ||
517       kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
518     goto error;
519
520   /* Decrypt, then validate the counter. We don't want to validate the
521    * counter before decrypting as we do not know the message is authentic
522    * prior to decryption. */
523
524   clib_memset (iv, 0, 4);
525   clib_memcpy (iv + 4, &nonce, sizeof (nonce));
526
527   if (is_async)
528     {
529       if (NULL == *async_frame ||
530           vnet_crypto_async_frame_is_full (*async_frame))
531         {
532           *async_frame = vnet_crypto_async_get_frame (
533             vm, VNET_CRYPTO_OP_CHACHA20_POLY1305_TAG16_AAD0_DEC);
534           /* Save the frame to the list we'll submit at the end */
535           vec_add1 (ptd->async_frames, *async_frame);
536         }
537
538       wg_input_add_to_frame (vm, *async_frame, kp->kp_recv_index, srclen,
539                              src - b->data, buf_idx, async_next_node, iv,
540                              src + srclen, VNET_CRYPTO_OP_FLAG_HMAC_CHECK);
541     }
542   else
543     {
544       wg_prepare_sync_dec_op (vm, crypto_ops, src, srclen, dst, NULL, 0,
545                               kp->kp_recv_index, from_idx, iv);
546     }
547
548   /* If we've received the handshake confirming data packet then move the
549    * next keypair into current. If we do slide the next keypair in, then
550    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
551    * data packet can't confirm a session that we are an INITIATOR of. */
552   if (kp == r->r_next)
553     {
554       clib_rwlock_writer_lock (&r->r_keypair_lock);
555       if (kp == r->r_next && kp->kp_local_index == r_idx)
556         {
557           noise_remote_keypair_free (vm, r, &r->r_previous);
558           r->r_previous = r->r_current;
559           r->r_current = r->r_next;
560           r->r_next = NULL;
561
562           ret = SC_CONN_RESET;
563           clib_rwlock_writer_unlock (&r->r_keypair_lock);
564           goto error;
565         }
566       clib_rwlock_writer_unlock (&r->r_keypair_lock);
567     }
568
569   /* Similar to when we encrypt, we want to notify the caller when we
570    * are approaching our tolerances. We notify if:
571    *  - we're the initiator and the current keypair is older than
572    *    REKEY_AFTER_TIME_RECV seconds. */
573   ret = SC_KEEP_KEY_FRESH;
574   kp = r->r_current;
575   if (kp != NULL && kp->kp_valid && kp->kp_is_initiator &&
576       wg_birthdate_has_expired_opt (kp->kp_birthdate, REKEY_AFTER_TIME_RECV,
577                                     time))
578     goto error;
579
580   ret = SC_OK;
581 error:
582   return ret;
583 }
584
585 always_inline uword
586 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
587                  vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
588 {
589   vnet_main_t *vnm = vnet_get_main ();
590   vnet_interface_main_t *im = &vnm->interface_main;
591   wg_main_t *wmp = &wg_main;
592   wg_per_thread_data_t *ptd =
593     vec_elt_at_index (wmp->per_thread_data, vm->thread_index);
594   u32 *from = vlib_frame_vector_args (frame);
595   u32 n_left_from = frame->n_vectors;
596
597   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
598   u32 thread_index = vm->thread_index;
599   vnet_crypto_op_t **crypto_ops = &ptd->crypto_ops;
600   const u16 drop_next = WG_INPUT_NEXT_PUNT;
601   message_type_t header_type;
602   vlib_buffer_t *data_bufs[VLIB_FRAME_SIZE];
603   u32 data_bi[VLIB_FRAME_SIZE];  /* buffer index for data */
604   u32 other_bi[VLIB_FRAME_SIZE]; /* buffer index for drop or handoff */
605   u16 other_nexts[VLIB_FRAME_SIZE], *other_next = other_nexts, n_other = 0;
606   u16 data_nexts[VLIB_FRAME_SIZE], *data_next = data_nexts, n_data = 0;
607   u16 n_async = 0;
608   const u8 is_async = wg_op_mode_is_set_ASYNC ();
609   vnet_crypto_async_frame_t *async_frame = NULL;
610
611   vlib_get_buffers (vm, from, bufs, n_left_from);
612   vec_reset_length (ptd->crypto_ops);
613   vec_reset_length (ptd->async_frames);
614
615   f64 time = clib_time_now (&vm->clib_time) + vm->time_offset;
616
617   wg_peer_t *peer = NULL;
618   u32 *last_peer_time_idx = NULL;
619   u32 last_rec_idx = ~0;
620
621   bool is_keepalive = false;
622   u32 *peer_idx = NULL;
623
624   while (n_left_from > 0)
625     {
626       if (n_left_from > 2)
627         {
628           u8 *p;
629           vlib_prefetch_buffer_header (b[2], LOAD);
630           p = vlib_buffer_get_current (b[1]);
631           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
632           CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES,
633                          LOAD);
634         }
635
636       other_next[n_other] = WG_INPUT_NEXT_PUNT;
637       data_nexts[n_data] = WG_INPUT_N_NEXT;
638
639       header_type =
640         ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
641
642       if (PREDICT_TRUE (header_type == MESSAGE_DATA))
643         {
644           message_data_t *data = vlib_buffer_get_current (b[0]);
645           u8 *iv_data = b[0]->pre_data;
646           u32 buf_idx = from[b - bufs];
647           peer_idx = wg_index_table_lookup (&wmp->index_table,
648                                             data->receiver_index);
649
650           if (data->receiver_index != last_rec_idx)
651             {
652               peer_idx = wg_index_table_lookup (&wmp->index_table,
653                                                 data->receiver_index);
654               if (PREDICT_TRUE (peer_idx != NULL))
655                 {
656                   peer = wg_peer_get (*peer_idx);
657                 }
658               last_rec_idx = data->receiver_index;
659             }
660
661           if (PREDICT_FALSE (!peer_idx))
662             {
663               other_next[n_other] = WG_INPUT_NEXT_ERROR;
664               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
665               other_bi[n_other] = buf_idx;
666               n_other += 1;
667               goto out;
668             }
669
670           if (PREDICT_FALSE (~0 == peer->input_thread_index))
671             {
672               /* this is the first packet to use this peer, claim the peer
673                * for this thread.
674                */
675               clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
676                                         wg_peer_assign_thread (thread_index));
677             }
678
679           if (PREDICT_TRUE (thread_index != peer->input_thread_index))
680             {
681               other_next[n_other] = WG_INPUT_NEXT_HANDOFF_DATA;
682               other_bi[n_other] = buf_idx;
683               n_other += 1;
684               goto next;
685             }
686
687           u16 encr_len = b[0]->current_length - sizeof (message_data_t);
688           u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
689           if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
690             {
691               b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
692               other_bi[n_other] = buf_idx;
693               n_other += 1;
694               goto out;
695             }
696
697           enum noise_state_crypt state_cr = wg_input_process (
698             vm, ptd, crypto_ops, &async_frame, b[0], buf_idx, &peer->remote,
699             data->receiver_index, data->counter, data->encrypted_data,
700             decr_len, data->encrypted_data, n_data, iv_data, time, is_async,
701             async_next_node);
702
703           if (PREDICT_FALSE (state_cr == SC_FAILED))
704             {
705               wg_peer_update_flags (*peer_idx, WG_PEER_ESTABLISHED, false);
706               other_next[n_other] = WG_INPUT_NEXT_ERROR;
707               b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
708               other_bi[n_other] = buf_idx;
709               n_other += 1;
710               goto out;
711             }
712           if (!is_async)
713             {
714               data_bufs[n_data] = b[0];
715               data_bi[n_data] = buf_idx;
716               n_data += 1;
717             }
718           else
719             {
720               n_async += 1;
721             }
722
723           if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
724             {
725               wg_timers_handshake_complete (peer);
726               goto next;
727             }
728           else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
729             {
730               wg_send_handshake_from_mt (*peer_idx, false);
731               goto next;
732             }
733           else if (PREDICT_TRUE (state_cr == SC_OK))
734             goto next;
735         }
736       else
737         {
738           peer_idx = NULL;
739
740           /* Handshake packets should be processed in main thread */
741           if (thread_index != 0)
742             {
743               other_next[n_other] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
744               other_bi[n_other] = from[b - bufs];
745               n_other += 1;
746               goto next;
747             }
748
749           wg_input_error_t ret =
750             wg_handshake_process (vm, wmp, b[0], node->node_index, is_ip4);
751           if (ret != WG_INPUT_ERROR_NONE)
752             {
753               other_next[n_other] = WG_INPUT_NEXT_ERROR;
754               b[0]->error = node->errors[ret];
755               other_bi[n_other] = from[b - bufs];
756               n_other += 1;
757             }
758           else
759             {
760               other_bi[n_other] = from[b - bufs];
761               n_other += 1;
762             }
763         }
764
765     out:
766       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
767                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
768         {
769           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
770           t->type = header_type;
771           t->current_length = b[0]->current_length;
772           t->is_keepalive = is_keepalive;
773           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
774         }
775
776     next:
777       n_left_from -= 1;
778       b += 1;
779     }
780
781   /* decrypt packets */
782   wg_input_process_ops (vm, node, ptd->crypto_ops, data_bufs, data_nexts,
783                         drop_next);
784
785   /* process after decryption */
786   b = data_bufs;
787   n_left_from = n_data;
788   last_rec_idx = ~0;
789   last_peer_time_idx = NULL;
790
791   while (n_left_from > 0)
792     {
793       bool is_keepalive = false;
794       u32 *peer_idx = NULL;
795
796       if (PREDICT_FALSE (data_next[0] == WG_INPUT_NEXT_PUNT))
797         {
798           goto trace;
799         }
800       if (n_left_from > 2)
801         {
802           u8 *p;
803           vlib_prefetch_buffer_header (b[2], LOAD);
804           p = vlib_buffer_get_current (b[1]);
805           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
806           CLIB_PREFETCH (vlib_buffer_get_tail (b[1]), CLIB_CACHE_LINE_BYTES,
807                          LOAD);
808         }
809
810       message_data_t *data = vlib_buffer_get_current (b[0]);
811
812       if (data->receiver_index != last_rec_idx)
813         {
814           peer_idx =
815             wg_index_table_lookup (&wmp->index_table, data->receiver_index);
816           peer = wg_peer_get (*peer_idx);
817           last_rec_idx = data->receiver_index;
818         }
819
820       if (PREDICT_FALSE (wg_input_post_process (vm, b[0], data_next, peer,
821                                                 data, &is_keepalive) < 0))
822         goto trace;
823
824       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
825         {
826           wg_timers_any_authenticated_packet_received_opt (peer, time);
827           wg_timers_any_authenticated_packet_traversal (peer);
828           last_peer_time_idx = peer_idx;
829         }
830
831       vlib_increment_combined_counter (im->combined_sw_if_counters +
832                                          VNET_INTERFACE_COUNTER_RX,
833                                        vm->thread_index, peer->wg_sw_if_index,
834                                        1 /* packets */, b[0]->current_length);
835
836     trace:
837       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
838                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
839         {
840           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
841           t->type = header_type;
842           t->current_length = b[0]->current_length;
843           t->is_keepalive = is_keepalive;
844           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
845         }
846
847       b += 1;
848       n_left_from -= 1;
849       data_next += 1;
850     }
851
852   if (n_async)
853     {
854       /* submit all of the open frames */
855       vnet_crypto_async_frame_t **async_frame;
856       vec_foreach (async_frame, ptd->async_frames)
857         {
858           if (PREDICT_FALSE (
859                 vnet_crypto_async_submit_open_frame (vm, *async_frame) < 0))
860             {
861               u32 n_drop = (*async_frame)->n_elts;
862               u32 *bi = (*async_frame)->buffer_indices;
863               u16 index = n_other;
864               while (n_drop--)
865                 {
866                   other_bi[index] = bi[0];
867                   vlib_buffer_t *b = vlib_get_buffer (vm, bi[0]);
868                   other_nexts[index] = drop_next;
869                   b->error = node->errors[WG_INPUT_ERROR_CRYPTO_ENGINE_ERROR];
870                   bi++;
871                   index++;
872                 }
873               n_other += (*async_frame)->n_elts;
874
875               vnet_crypto_async_reset_frame (*async_frame);
876               vnet_crypto_async_free_frame (vm, *async_frame);
877             }
878         }
879     }
880
881   /* enqueue other bufs */
882   if (n_other)
883     vlib_buffer_enqueue_to_next (vm, node, other_bi, other_next, n_other);
884
885   /* enqueue data bufs */
886   if (n_data)
887     vlib_buffer_enqueue_to_next (vm, node, data_bi, data_nexts, n_data);
888
889   return frame->n_vectors;
890 }
891
892 always_inline uword
893 wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
894 {
895   vnet_main_t *vnm = vnet_get_main ();
896   vnet_interface_main_t *im = &vnm->interface_main;
897   wg_main_t *wmp = &wg_main;
898   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
899   u16 nexts[VLIB_FRAME_SIZE], *next = nexts;
900   u32 *from = vlib_frame_vector_args (frame);
901   u32 n_left = frame->n_vectors;
902   wg_peer_t *peer = NULL;
903   u32 *peer_idx = NULL;
904   u32 *last_peer_time_idx = NULL;
905   u32 last_rec_idx = ~0;
906   f64 time = clib_time_now (&vm->clib_time) + vm->time_offset;
907
908   vlib_get_buffers (vm, from, b, n_left);
909
910   if (n_left >= 2)
911     {
912       vlib_prefetch_buffer_header (b[0], LOAD);
913       vlib_prefetch_buffer_header (b[1], LOAD);
914     }
915
916   while (n_left > 0)
917     {
918       if (n_left > 2)
919         {
920           u8 *p;
921           vlib_prefetch_buffer_header (b[2], LOAD);
922           p = vlib_buffer_get_current (b[1]);
923           CLIB_PREFETCH (p, CLIB_CACHE_LINE_BYTES, LOAD);
924         }
925
926       bool is_keepalive = false;
927       message_data_t *data = vlib_buffer_get_current (b[0]);
928
929       if (data->receiver_index != last_rec_idx)
930         {
931           peer_idx =
932             wg_index_table_lookup (&wmp->index_table, data->receiver_index);
933
934           peer = wg_peer_get (*peer_idx);
935           last_rec_idx = data->receiver_index;
936         }
937
938       if (PREDICT_TRUE (peer != NULL))
939         {
940           if (PREDICT_FALSE (wg_input_post_process (vm, b[0], next, peer, data,
941                                                     &is_keepalive) < 0))
942             goto trace;
943         }
944       else
945         {
946           next[0] = WG_INPUT_NEXT_PUNT;
947           goto trace;
948         }
949
950       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
951         {
952           wg_timers_any_authenticated_packet_received_opt (peer, time);
953           wg_timers_any_authenticated_packet_traversal (peer);
954           last_peer_time_idx = peer_idx;
955         }
956
957       vlib_increment_combined_counter (im->combined_sw_if_counters +
958                                          VNET_INTERFACE_COUNTER_RX,
959                                        vm->thread_index, peer->wg_sw_if_index,
960                                        1 /* packets */, b[0]->current_length);
961
962     trace:
963       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
964                          (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
965         {
966           wg_input_post_trace_t *t =
967             vlib_add_trace (vm, node, b[0], sizeof (*t));
968           t->next = next[0];
969           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
970         }
971
972       b += 1;
973       next += 1;
974       n_left -= 1;
975     }
976
977   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
978   return frame->n_vectors;
979 }
980
981 VLIB_NODE_FN (wg4_input_node)
982 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
983 {
984   return wg_input_inline (vm, node, frame, /* is_ip4 */ 1,
985                           wg_decrypt_async_next.wg4_post_next);
986 }
987
988 VLIB_NODE_FN (wg6_input_node)
989 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
990 {
991   return wg_input_inline (vm, node, frame, /* is_ip4 */ 0,
992                           wg_decrypt_async_next.wg6_post_next);
993 }
994
995 VLIB_NODE_FN (wg4_input_post_node)
996 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
997 {
998   return wg_input_post (vm, node, from_frame);
999 }
1000
1001 VLIB_NODE_FN (wg6_input_post_node)
1002 (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *from_frame)
1003 {
1004   return wg_input_post (vm, node, from_frame);
1005 }
1006
1007 /* *INDENT-OFF* */
1008 VLIB_REGISTER_NODE (wg4_input_node) =
1009 {
1010   .name = "wg4-input",
1011   .vector_size = sizeof (u32),
1012   .format_trace = format_wg_input_trace,
1013   .type = VLIB_NODE_TYPE_INTERNAL,
1014   .n_errors = ARRAY_LEN (wg_input_error_strings),
1015   .error_strings = wg_input_error_strings,
1016   .n_next_nodes = WG_INPUT_N_NEXT,
1017   /* edit / add dispositions here */
1018   .next_nodes = {
1019         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg4-handshake-handoff",
1020         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg4-input-data-handoff",
1021         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
1022         [WG_INPUT_NEXT_IP6_INPUT] = "ip6-input",
1023         [WG_INPUT_NEXT_PUNT] = "error-punt",
1024         [WG_INPUT_NEXT_ERROR] = "error-drop",
1025   },
1026 };
1027
1028 VLIB_REGISTER_NODE (wg6_input_node) =
1029 {
1030   .name = "wg6-input",
1031   .vector_size = sizeof (u32),
1032   .format_trace = format_wg_input_trace,
1033   .type = VLIB_NODE_TYPE_INTERNAL,
1034   .n_errors = ARRAY_LEN (wg_input_error_strings),
1035   .error_strings = wg_input_error_strings,
1036   .n_next_nodes = WG_INPUT_N_NEXT,
1037   /* edit / add dispositions here */
1038   .next_nodes = {
1039         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg6-handshake-handoff",
1040         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg6-input-data-handoff",
1041         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
1042         [WG_INPUT_NEXT_IP6_INPUT] = "ip6-input",
1043         [WG_INPUT_NEXT_PUNT] = "error-punt",
1044         [WG_INPUT_NEXT_ERROR] = "error-drop",
1045   },
1046 };
1047
1048 VLIB_REGISTER_NODE (wg4_input_post_node) = {
1049   .name = "wg4-input-post-node",
1050   .vector_size = sizeof (u32),
1051   .format_trace = format_wg_input_post_trace,
1052   .type = VLIB_NODE_TYPE_INTERNAL,
1053   .sibling_of = "wg4-input",
1054
1055   .n_errors = ARRAY_LEN (wg_input_error_strings),
1056   .error_strings = wg_input_error_strings,
1057 };
1058
1059 VLIB_REGISTER_NODE (wg6_input_post_node) = {
1060   .name = "wg6-input-post-node",
1061   .vector_size = sizeof (u32),
1062   .format_trace = format_wg_input_post_trace,
1063   .type = VLIB_NODE_TYPE_INTERNAL,
1064   .sibling_of = "wg6-input",
1065
1066   .n_errors = ARRAY_LEN (wg_input_error_strings),
1067   .error_strings = wg_input_error_strings,
1068 };
1069
1070 /* *INDENT-ON* */
1071
1072 /*
1073  * fd.io coding-style-patch-verification: ON
1074  *
1075  * Local Variables:
1076  * eval: (c-set-style "gnu")
1077  * End:
1078  */