misc: Purge unused pg includes
[vpp.git] / src / plugins / wireguard / wireguard_input.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <vlib/vlib.h>
18 #include <vnet/vnet.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
21
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
24
25 #define foreach_wg_input_error                          \
26   _(NONE, "No error")                                   \
27   _(HANDSHAKE_MAC, "Invalid MAC handshake")             \
28   _(PEER, "Peer error")                                 \
29   _(INTERFACE, "Interface error")                       \
30   _(DECRYPTION, "Failed during decryption")             \
31   _(KEEPALIVE_SEND, "Failed while sending Keepalive")   \
32   _(HANDSHAKE_SEND, "Failed while sending Handshake")   \
33   _(TOO_BIG, "Packet too big")                          \
34   _(UNDEFINED, "Undefined error")
35
36 typedef enum
37 {
38 #define _(sym,str) WG_INPUT_ERROR_##sym,
39   foreach_wg_input_error
40 #undef _
41     WG_INPUT_N_ERROR,
42 } wg_input_error_t;
43
44 static char *wg_input_error_strings[] = {
45 #define _(sym,string) string,
46   foreach_wg_input_error
47 #undef _
48 };
49
50 typedef struct
51 {
52   message_type_t type;
53   u16 current_length;
54   bool is_keepalive;
55   index_t peer;
56 } wg_input_trace_t;
57
58 u8 *
59 format_wg_message_type (u8 * s, va_list * args)
60 {
61   message_type_t type = va_arg (*args, message_type_t);
62
63   switch (type)
64     {
65 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
66       foreach_wg_message_type
67 #undef _
68     }
69   return (format (s, "unknown"));
70 }
71
72 /* packet trace format function */
73 static u8 *
74 format_wg_input_trace (u8 * s, va_list * args)
75 {
76   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
77   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
78
79   wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
80
81   s = format (s, "WG input: \n");
82   s = format (s, "  Type: %U\n", format_wg_message_type, t->type);
83   s = format (s, "  peer: %d\n", t->peer);
84   s = format (s, "  Length: %d\n", t->current_length);
85   s = format (s, "  Keepalive: %s", t->is_keepalive ? "true" : "false");
86
87   return s;
88 }
89
90 typedef enum
91 {
92   WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
93   WG_INPUT_NEXT_HANDOFF_DATA,
94   WG_INPUT_NEXT_IP4_INPUT,
95   WG_INPUT_NEXT_PUNT,
96   WG_INPUT_NEXT_ERROR,
97   WG_INPUT_N_NEXT,
98 } wg_input_next_t;
99
100 /* static void */
101 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
102 /* { */
103 /*   if (peer) */
104 /*     { */
105 /*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
106 /*       peer->dst.port = udp_port; */
107 /*     } */
108 /* } */
109
110 static wg_input_error_t
111 wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
112 {
113   ASSERT (vm->thread_index == 0);
114
115   enum cookie_mac_state mac_state;
116   bool packet_needs_cookie;
117   bool under_load;
118   wg_if_t *wg_if;
119   wg_peer_t *peer = NULL;
120
121   void *current_b_data = vlib_buffer_get_current (b);
122
123   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
124   ip4_header_t *iph =
125     current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
126   ip4_address_t ip4_src = iph->src_address;
127   u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
128   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
129
130   message_header_t *header = current_b_data;
131   under_load = false;
132
133   wg_if = wg_if_get_by_port (udp_dst_port);
134
135   if (NULL == wg_if)
136     return WG_INPUT_ERROR_INTERFACE;
137
138   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
139     {
140       message_handshake_cookie_t *packet =
141         (message_handshake_cookie_t *) current_b_data;
142       u32 *entry =
143         wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
144       if (entry)
145         peer = wg_peer_get (*entry);
146       else
147         return WG_INPUT_ERROR_PEER;
148
149       // TODO: Implement cookie_maker_consume_payload
150
151       return WG_INPUT_ERROR_NONE;
152     }
153
154   u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
155              sizeof (message_handshake_initiation_t) :
156              sizeof (message_handshake_response_t));
157
158   message_macs_t *macs = (message_macs_t *)
159     ((u8 *) current_b_data + len - sizeof (*macs));
160
161   mac_state =
162     cookie_checker_validate_macs (vm, &wg_if->cookie_checker, macs,
163                                   current_b_data, len, under_load, ip4_src,
164                                   udp_src_port);
165
166   if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
167       || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
168     packet_needs_cookie = false;
169   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
170     packet_needs_cookie = true;
171   else
172     return WG_INPUT_ERROR_HANDSHAKE_MAC;
173
174   switch (header->type)
175     {
176     case MESSAGE_HANDSHAKE_INITIATION:
177       {
178         message_handshake_initiation_t *message = current_b_data;
179
180         if (packet_needs_cookie)
181           {
182             // TODO: Add processing
183           }
184         noise_remote_t *rp;
185         if (noise_consume_initiation
186             (vm, noise_local_get (wg_if->local_idx), &rp,
187              message->sender_index, message->unencrypted_ephemeral,
188              message->encrypted_static, message->encrypted_timestamp))
189           {
190             peer = wg_peer_get (rp->r_peer_idx);
191           }
192         else
193           {
194             return WG_INPUT_ERROR_PEER;
195           }
196
197         // set_peer_address (peer, ip4_src, udp_src_port);
198         if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
199           {
200             vlib_node_increment_counter (vm, wg_input_node.index,
201                                          WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
202           }
203         break;
204       }
205     case MESSAGE_HANDSHAKE_RESPONSE:
206       {
207         message_handshake_response_t *resp = current_b_data;
208         u32 *entry =
209           wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
210
211         if (PREDICT_TRUE (entry != NULL))
212           {
213             peer = wg_peer_get (*entry);
214             if (peer->is_dead)
215               return WG_INPUT_ERROR_PEER;
216           }
217         else
218           return WG_INPUT_ERROR_PEER;
219
220         if (!noise_consume_response
221             (vm, &peer->remote, resp->sender_index,
222              resp->receiver_index, resp->unencrypted_ephemeral,
223              resp->encrypted_nothing))
224           {
225             return WG_INPUT_ERROR_PEER;
226           }
227         if (packet_needs_cookie)
228           {
229             // TODO: Add processing
230           }
231
232         // set_peer_address (peer, ip4_src, udp_src_port);
233         if (noise_remote_begin_session (vm, &peer->remote))
234           {
235
236             wg_timers_session_derived (peer);
237             wg_timers_handshake_complete (peer);
238             if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
239               {
240                 vlib_node_increment_counter (vm, wg_input_node.index,
241                                              WG_INPUT_ERROR_KEEPALIVE_SEND,
242                                              1);
243               }
244           }
245         break;
246       }
247     default:
248       break;
249     }
250
251   wg_timers_any_authenticated_packet_received (peer);
252   wg_timers_any_authenticated_packet_traversal (peer);
253   return WG_INPUT_ERROR_NONE;
254 }
255
256 static_always_inline bool
257 fib_prefix_is_cover_addr_4 (const fib_prefix_t * p1,
258                             const ip4_address_t * ip4)
259 {
260   switch (p1->fp_proto)
261     {
262     case FIB_PROTOCOL_IP4:
263       return (ip4_destination_matches_route (&ip4_main,
264                                              &p1->fp_addr.ip4,
265                                              ip4, p1->fp_len) != 0);
266     case FIB_PROTOCOL_IP6:
267       return (false);
268     case FIB_PROTOCOL_MPLS:
269       break;
270     }
271   return (false);
272 }
273
274 VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
275                               vlib_node_runtime_t * node,
276                               vlib_frame_t * frame)
277 {
278   message_type_t header_type;
279   u32 n_left_from;
280   u32 *from;
281   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
282   u16 nexts[VLIB_FRAME_SIZE], *next;
283   u32 thread_index = vm->thread_index;
284
285   from = vlib_frame_vector_args (frame);
286   n_left_from = frame->n_vectors;
287   b = bufs;
288   next = nexts;
289
290   vlib_get_buffers (vm, from, bufs, n_left_from);
291
292   wg_main_t *wmp = &wg_main;
293   wg_peer_t *peer = NULL;
294
295   while (n_left_from > 0)
296     {
297       bool is_keepalive = false;
298       next[0] = WG_INPUT_NEXT_PUNT;
299       header_type =
300         ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
301       u32 *peer_idx;
302
303       if (PREDICT_TRUE (header_type == MESSAGE_DATA))
304         {
305           message_data_t *data = vlib_buffer_get_current (b[0]);
306
307           peer_idx = wg_index_table_lookup (&wmp->index_table,
308                                             data->receiver_index);
309
310           if (peer_idx)
311             {
312               peer = wg_peer_get (*peer_idx);
313             }
314           else
315             {
316               next[0] = WG_INPUT_NEXT_ERROR;
317               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
318               goto out;
319             }
320
321           if (PREDICT_FALSE (~0 == peer->input_thread_index))
322             {
323               /* this is the first packet to use this peer, claim the peer
324                * for this thread.
325                */
326               clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
327                                         wg_peer_assign_thread (thread_index));
328             }
329
330           if (PREDICT_TRUE (thread_index != peer->input_thread_index))
331             {
332               next[0] = WG_INPUT_NEXT_HANDOFF_DATA;
333               goto next;
334             }
335
336           u16 encr_len = b[0]->current_length - sizeof (message_data_t);
337           u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
338           if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
339             {
340               b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
341               goto out;
342             }
343
344           u8 *decr_data = wmp->per_thread_data[thread_index].data;
345
346           enum noise_state_crypt state_cr = noise_remote_decrypt (vm,
347                                                                   &peer->remote,
348                                                                   data->receiver_index,
349                                                                   data->counter,
350                                                                   data->encrypted_data,
351                                                                   encr_len,
352                                                                   decr_data);
353
354           if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
355             {
356               wg_timers_handshake_complete (peer);
357             }
358           else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
359             {
360               wg_send_handshake_from_mt (*peer_idx, false);
361             }
362           else if (PREDICT_FALSE (state_cr == SC_FAILED))
363             {
364               next[0] = WG_INPUT_NEXT_ERROR;
365               b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
366               goto out;
367             }
368
369           clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
370           b[0]->current_length = decr_len;
371           b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
372
373           wg_timers_any_authenticated_packet_received (peer);
374           wg_timers_any_authenticated_packet_traversal (peer);
375
376           /* Keepalive packet has zero length */
377           if (decr_len == 0)
378             {
379               is_keepalive = true;
380               goto out;
381             }
382
383           wg_timers_data_received (peer);
384
385           ip4_header_t *iph = vlib_buffer_get_current (b[0]);
386
387           const wg_peer_allowed_ip_t *allowed_ip;
388           bool allowed = false;
389
390           /*
391            * we could make this into an ACL, but the expectation
392            * is that there aren't many allowed IPs and thus a linear
393            * walk is fater than an ACL
394            */
395           vec_foreach (allowed_ip, peer->allowed_ips)
396           {
397             if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
398                                             &iph->src_address))
399               {
400                 allowed = true;
401                 break;
402               }
403           }
404           if (allowed)
405             {
406               vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
407               next[0] = WG_INPUT_NEXT_IP4_INPUT;
408             }
409         }
410       else
411         {
412           peer_idx = NULL;
413
414           /* Handshake packets should be processed in main thread */
415           if (thread_index != 0)
416             {
417               next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
418               goto next;
419             }
420
421           wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
422           if (ret != WG_INPUT_ERROR_NONE)
423             {
424               next[0] = WG_INPUT_NEXT_ERROR;
425               b[0]->error = node->errors[ret];
426             }
427         }
428
429     out:
430       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
431                          && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
432         {
433           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
434           t->type = header_type;
435           t->current_length = b[0]->current_length;
436           t->is_keepalive = is_keepalive;
437           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
438         }
439     next:
440       n_left_from -= 1;
441       next += 1;
442       b += 1;
443     }
444   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
445
446   return frame->n_vectors;
447 }
448
449 /* *INDENT-OFF* */
450 VLIB_REGISTER_NODE (wg_input_node) =
451 {
452   .name = "wg-input",
453   .vector_size = sizeof (u32),
454   .format_trace = format_wg_input_trace,
455   .type = VLIB_NODE_TYPE_INTERNAL,
456   .n_errors = ARRAY_LEN (wg_input_error_strings),
457   .error_strings = wg_input_error_strings,
458   .n_next_nodes = WG_INPUT_N_NEXT,
459   /* edit / add dispositions here */
460   .next_nodes = {
461         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff",
462         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff",
463         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
464         [WG_INPUT_NEXT_PUNT] = "error-punt",
465         [WG_INPUT_NEXT_ERROR] = "error-drop",
466   },
467 };
468 /* *INDENT-ON* */
469
470 /*
471  * fd.io coding-style-patch-verification: ON
472  *
473  * Local Variables:
474  * eval: (c-set-style "gnu")
475  * End:
476  */