wireguard: use the same udp-port for multi-tunnel
[vpp.git] / src / plugins / wireguard / wireguard_input.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <vlib/vlib.h>
18 #include <vnet/vnet.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
21
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
24
25 #define foreach_wg_input_error                                                \
26   _ (NONE, "No error")                                                        \
27   _ (HANDSHAKE_MAC, "Invalid MAC handshake")                                  \
28   _ (PEER, "Peer error")                                                      \
29   _ (INTERFACE, "Interface error")                                            \
30   _ (DECRYPTION, "Failed during decryption")                                  \
31   _ (KEEPALIVE_SEND, "Failed while sending Keepalive")                        \
32   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
33   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
34   _ (TOO_BIG, "Packet too big")                                               \
35   _ (UNDEFINED, "Undefined error")
36
37 typedef enum
38 {
39 #define _(sym,str) WG_INPUT_ERROR_##sym,
40   foreach_wg_input_error
41 #undef _
42     WG_INPUT_N_ERROR,
43 } wg_input_error_t;
44
45 static char *wg_input_error_strings[] = {
46 #define _(sym,string) string,
47   foreach_wg_input_error
48 #undef _
49 };
50
51 typedef struct
52 {
53   message_type_t type;
54   u16 current_length;
55   bool is_keepalive;
56   index_t peer;
57 } wg_input_trace_t;
58
59 u8 *
60 format_wg_message_type (u8 * s, va_list * args)
61 {
62   message_type_t type = va_arg (*args, message_type_t);
63
64   switch (type)
65     {
66 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
67       foreach_wg_message_type
68 #undef _
69     }
70   return (format (s, "unknown"));
71 }
72
73 /* packet trace format function */
74 static u8 *
75 format_wg_input_trace (u8 * s, va_list * args)
76 {
77   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
78   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
79
80   wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
81
82   s = format (s, "WG input: \n");
83   s = format (s, "  Type: %U\n", format_wg_message_type, t->type);
84   s = format (s, "  peer: %d\n", t->peer);
85   s = format (s, "  Length: %d\n", t->current_length);
86   s = format (s, "  Keepalive: %s", t->is_keepalive ? "true" : "false");
87
88   return s;
89 }
90
91 typedef enum
92 {
93   WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
94   WG_INPUT_NEXT_HANDOFF_DATA,
95   WG_INPUT_NEXT_IP4_INPUT,
96   WG_INPUT_NEXT_PUNT,
97   WG_INPUT_NEXT_ERROR,
98   WG_INPUT_N_NEXT,
99 } wg_input_next_t;
100
101 /* static void */
102 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
103 /* { */
104 /*   if (peer) */
105 /*     { */
106 /*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
107 /*       peer->dst.port = udp_port; */
108 /*     } */
109 /* } */
110
111 static wg_input_error_t
112 wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
113 {
114   ASSERT (vm->thread_index == 0);
115
116   enum cookie_mac_state mac_state;
117   bool packet_needs_cookie;
118   bool under_load;
119   index_t *wg_ifs;
120   wg_if_t *wg_if;
121   wg_peer_t *peer = NULL;
122
123   void *current_b_data = vlib_buffer_get_current (b);
124
125   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
126   ip4_header_t *iph =
127     current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
128   ip4_address_t ip4_src = iph->src_address;
129   u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
130   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
131
132   message_header_t *header = current_b_data;
133   under_load = false;
134
135   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
136     {
137       message_handshake_cookie_t *packet =
138         (message_handshake_cookie_t *) current_b_data;
139       u32 *entry =
140         wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
141       if (entry)
142         peer = wg_peer_get (*entry);
143       else
144         return WG_INPUT_ERROR_PEER;
145
146       // TODO: Implement cookie_maker_consume_payload
147
148       return WG_INPUT_ERROR_NONE;
149     }
150
151   u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
152              sizeof (message_handshake_initiation_t) :
153              sizeof (message_handshake_response_t));
154
155   message_macs_t *macs = (message_macs_t *)
156     ((u8 *) current_b_data + len - sizeof (*macs));
157
158   index_t *ii;
159   wg_ifs = wg_if_indexes_get_by_port (udp_dst_port);
160   if (NULL == wg_ifs)
161     return WG_INPUT_ERROR_INTERFACE;
162
163   vec_foreach (ii, wg_ifs)
164     {
165       wg_if = wg_if_get (*ii);
166       if (NULL == wg_if)
167         continue;
168
169       mac_state = cookie_checker_validate_macs (
170         vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load,
171         ip4_src, udp_src_port);
172       if (mac_state == INVALID_MAC)
173         {
174           wg_if = NULL;
175           continue;
176         }
177       break;
178     }
179
180   if (NULL == wg_if)
181     return WG_INPUT_ERROR_HANDSHAKE_MAC;
182
183   if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
184       || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
185     packet_needs_cookie = false;
186   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
187     packet_needs_cookie = true;
188   else
189     return WG_INPUT_ERROR_HANDSHAKE_MAC;
190
191   switch (header->type)
192     {
193     case MESSAGE_HANDSHAKE_INITIATION:
194       {
195         message_handshake_initiation_t *message = current_b_data;
196
197         if (packet_needs_cookie)
198           {
199             // TODO: Add processing
200           }
201         noise_remote_t *rp;
202         if (noise_consume_initiation
203             (vm, noise_local_get (wg_if->local_idx), &rp,
204              message->sender_index, message->unencrypted_ephemeral,
205              message->encrypted_static, message->encrypted_timestamp))
206           {
207             peer = wg_peer_get (rp->r_peer_idx);
208           }
209         else
210           {
211             return WG_INPUT_ERROR_PEER;
212           }
213
214         // set_peer_address (peer, ip4_src, udp_src_port);
215         if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
216           {
217             vlib_node_increment_counter (vm, wg_input_node.index,
218                                          WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
219           }
220         break;
221       }
222     case MESSAGE_HANDSHAKE_RESPONSE:
223       {
224         message_handshake_response_t *resp = current_b_data;
225         u32 *entry =
226           wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
227
228         if (PREDICT_TRUE (entry != NULL))
229           {
230             peer = wg_peer_get (*entry);
231             if (peer->is_dead)
232               return WG_INPUT_ERROR_PEER;
233           }
234         else
235           return WG_INPUT_ERROR_PEER;
236
237         if (!noise_consume_response
238             (vm, &peer->remote, resp->sender_index,
239              resp->receiver_index, resp->unencrypted_ephemeral,
240              resp->encrypted_nothing))
241           {
242             return WG_INPUT_ERROR_PEER;
243           }
244         if (packet_needs_cookie)
245           {
246             // TODO: Add processing
247           }
248
249         // set_peer_address (peer, ip4_src, udp_src_port);
250         if (noise_remote_begin_session (vm, &peer->remote))
251           {
252
253             wg_timers_session_derived (peer);
254             wg_timers_handshake_complete (peer);
255             if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
256               {
257                 vlib_node_increment_counter (vm, wg_input_node.index,
258                                              WG_INPUT_ERROR_KEEPALIVE_SEND,
259                                              1);
260               }
261           }
262         break;
263       }
264     default:
265       return WG_INPUT_ERROR_HANDSHAKE_RECEIVE;
266     }
267
268   wg_timers_any_authenticated_packet_received (peer);
269   wg_timers_any_authenticated_packet_traversal (peer);
270   return WG_INPUT_ERROR_NONE;
271 }
272
273 VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
274                               vlib_node_runtime_t * node,
275                               vlib_frame_t * frame)
276 {
277   message_type_t header_type;
278   u32 n_left_from;
279   u32 *from;
280   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
281   u16 nexts[VLIB_FRAME_SIZE], *next;
282   u32 thread_index = vm->thread_index;
283
284   from = vlib_frame_vector_args (frame);
285   n_left_from = frame->n_vectors;
286   b = bufs;
287   next = nexts;
288
289   vlib_get_buffers (vm, from, bufs, n_left_from);
290
291   wg_main_t *wmp = &wg_main;
292   wg_peer_t *peer = NULL;
293
294   while (n_left_from > 0)
295     {
296       bool is_keepalive = false;
297       next[0] = WG_INPUT_NEXT_PUNT;
298       header_type =
299         ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
300       u32 *peer_idx;
301
302       if (PREDICT_TRUE (header_type == MESSAGE_DATA))
303         {
304           message_data_t *data = vlib_buffer_get_current (b[0]);
305
306           peer_idx = wg_index_table_lookup (&wmp->index_table,
307                                             data->receiver_index);
308
309           if (peer_idx)
310             {
311               peer = wg_peer_get (*peer_idx);
312             }
313           else
314             {
315               next[0] = WG_INPUT_NEXT_ERROR;
316               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
317               goto out;
318             }
319
320           if (PREDICT_FALSE (~0 == peer->input_thread_index))
321             {
322               /* this is the first packet to use this peer, claim the peer
323                * for this thread.
324                */
325               clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
326                                         wg_peer_assign_thread (thread_index));
327             }
328
329           if (PREDICT_TRUE (thread_index != peer->input_thread_index))
330             {
331               next[0] = WG_INPUT_NEXT_HANDOFF_DATA;
332               goto next;
333             }
334
335           u16 encr_len = b[0]->current_length - sizeof (message_data_t);
336           u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
337           if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
338             {
339               b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
340               goto out;
341             }
342
343           u8 *decr_data = wmp->per_thread_data[thread_index].data;
344
345           enum noise_state_crypt state_cr = noise_remote_decrypt (vm,
346                                                                   &peer->remote,
347                                                                   data->receiver_index,
348                                                                   data->counter,
349                                                                   data->encrypted_data,
350                                                                   encr_len,
351                                                                   decr_data);
352
353           if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
354             {
355               wg_timers_handshake_complete (peer);
356             }
357           else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
358             {
359               wg_send_handshake_from_mt (*peer_idx, false);
360             }
361           else if (PREDICT_FALSE (state_cr == SC_FAILED))
362             {
363               next[0] = WG_INPUT_NEXT_ERROR;
364               b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
365               goto out;
366             }
367
368           clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
369           b[0]->current_length = decr_len;
370           vnet_buffer_offload_flags_clear (b[0],
371                                            VNET_BUFFER_OFFLOAD_F_UDP_CKSUM);
372
373           wg_timers_any_authenticated_packet_received (peer);
374           wg_timers_any_authenticated_packet_traversal (peer);
375
376           /* Keepalive packet has zero length */
377           if (decr_len == 0)
378             {
379               is_keepalive = true;
380               goto out;
381             }
382
383           wg_timers_data_received (peer);
384
385           ip4_header_t *iph = vlib_buffer_get_current (b[0]);
386
387           const fib_prefix_t *allowed_ip;
388           bool allowed = false;
389
390           /*
391            * we could make this into an ACL, but the expectation
392            * is that there aren't many allowed IPs and thus a linear
393            * walk is fater than an ACL
394            */
395           vec_foreach (allowed_ip, peer->allowed_ips)
396           {
397             if (fib_prefix_is_cover_addr_4 (allowed_ip, &iph->src_address))
398               {
399                 allowed = true;
400                 break;
401               }
402           }
403           if (allowed)
404             {
405               vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
406               next[0] = WG_INPUT_NEXT_IP4_INPUT;
407             }
408         }
409       else
410         {
411           peer_idx = NULL;
412
413           /* Handshake packets should be processed in main thread */
414           if (thread_index != 0)
415             {
416               next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
417               goto next;
418             }
419
420           wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
421           if (ret != WG_INPUT_ERROR_NONE)
422             {
423               next[0] = WG_INPUT_NEXT_ERROR;
424               b[0]->error = node->errors[ret];
425             }
426         }
427
428     out:
429       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
430                          && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
431         {
432           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
433           t->type = header_type;
434           t->current_length = b[0]->current_length;
435           t->is_keepalive = is_keepalive;
436           t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
437         }
438     next:
439       n_left_from -= 1;
440       next += 1;
441       b += 1;
442     }
443   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
444
445   return frame->n_vectors;
446 }
447
448 /* *INDENT-OFF* */
449 VLIB_REGISTER_NODE (wg_input_node) =
450 {
451   .name = "wg-input",
452   .vector_size = sizeof (u32),
453   .format_trace = format_wg_input_trace,
454   .type = VLIB_NODE_TYPE_INTERNAL,
455   .n_errors = ARRAY_LEN (wg_input_error_strings),
456   .error_strings = wg_input_error_strings,
457   .n_next_nodes = WG_INPUT_N_NEXT,
458   /* edit / add dispositions here */
459   .next_nodes = {
460         [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff",
461         [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff",
462         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
463         [WG_INPUT_NEXT_PUNT] = "error-punt",
464         [WG_INPUT_NEXT_ERROR] = "error-drop",
465   },
466 };
467 /* *INDENT-ON* */
468
469 /*
470  * fd.io coding-style-patch-verification: ON
471  *
472  * Local Variables:
473  * eval: (c-set-style "gnu")
474  * End:
475  */