wireguard: initial implementation of wireguard protocol
[vpp.git] / src / plugins / wireguard / wireguard_input.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #include <vlib/vlib.h>
17 #include <vnet/vnet.h>
18 #include <vnet/pg/pg.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
21
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
24
25 #define foreach_wg_input_error                          \
26   _(NONE, "No error")                                   \
27   _(HANDSHAKE_MAC, "Invalid MAC handshake")             \
28   _(PEER, "Peer error")                                 \
29   _(INTERFACE, "Interface error")                       \
30   _(DECRYPTION, "Failed during decryption")             \
31   _(KEEPALIVE_SEND, "Failed while sending Keepalive")   \
32   _(HANDSHAKE_SEND, "Failed while sending Handshake")   \
33   _(UNDEFINED, "Undefined error")
34
35 typedef enum
36 {
37 #define _(sym,str) WG_INPUT_ERROR_##sym,
38   foreach_wg_input_error
39 #undef _
40     WG_INPUT_N_ERROR,
41 } wg_input_error_t;
42
43 static char *wg_input_error_strings[] = {
44 #define _(sym,string) string,
45   foreach_wg_input_error
46 #undef _
47 };
48
49 typedef struct
50 {
51   message_type_t type;
52   u16 current_length;
53   bool is_keepalive;
54
55 } wg_input_trace_t;
56
57 u8 *
58 format_wg_message_type (u8 * s, va_list * args)
59 {
60   message_type_t type = va_arg (*args, message_type_t);
61
62   switch (type)
63     {
64 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
65       foreach_wg_message_type
66 #undef _
67     }
68   return (format (s, "unknown"));
69 }
70
71 /* packet trace format function */
72 static u8 *
73 format_wg_input_trace (u8 * s, va_list * args)
74 {
75   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
76   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
77
78   wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
79
80   s = format (s, "WG input: \n");
81   s = format (s, "  Type: %U\n", format_wg_message_type, t->type);
82   s = format (s, "  Length: %d\n", t->current_length);
83   s = format (s, "  Keepalive: %s", t->is_keepalive ? "true" : "false");
84
85   return s;
86 }
87
88 typedef enum
89 {
90   WG_INPUT_NEXT_IP4_INPUT,
91   WG_INPUT_NEXT_PUNT,
92   WG_INPUT_NEXT_ERROR,
93   WG_INPUT_N_NEXT,
94 } wg_input_next_t;
95
96 /* static void */
97 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
98 /* { */
99 /*   if (peer) */
100 /*     { */
101 /*       ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
102 /*       peer->dst.port = udp_port; */
103 /*     } */
104 /* } */
105
106 static wg_input_error_t
107 wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
108 {
109   enum cookie_mac_state mac_state;
110   bool packet_needs_cookie;
111   bool under_load;
112   wg_if_t *wg_if;
113   wg_peer_t *peer = NULL;
114
115   void *current_b_data = vlib_buffer_get_current (b);
116
117   udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
118   ip4_header_t *iph =
119     current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
120   ip4_address_t ip4_src = iph->src_address;
121   u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
122   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
123
124   message_header_t *header = current_b_data;
125   under_load = false;
126
127   wg_if = wg_if_get_by_port (udp_dst_port);
128
129   if (NULL == wg_if)
130     return WG_INPUT_ERROR_INTERFACE;
131
132   if (header->type == MESSAGE_HANDSHAKE_COOKIE)
133     {
134       message_handshake_cookie_t *packet =
135         (message_handshake_cookie_t *) current_b_data;
136       u32 *entry =
137         wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
138       if (entry)
139         {
140           peer = pool_elt_at_index (wmp->peers, *entry);
141         }
142       if (!peer)
143         return WG_INPUT_ERROR_PEER;
144
145       // TODO: Implement cookie_maker_consume_payload
146
147       return WG_INPUT_ERROR_NONE;
148     }
149
150   u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
151              sizeof (message_handshake_initiation_t) :
152              sizeof (message_handshake_response_t));
153
154   message_macs_t *macs = (message_macs_t *)
155     ((u8 *) current_b_data + len - sizeof (*macs));
156
157   mac_state =
158     cookie_checker_validate_macs (vm, &wg_if->cookie_checker, macs,
159                                   current_b_data, len, under_load, ip4_src,
160                                   udp_src_port);
161
162   if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
163       || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
164     packet_needs_cookie = false;
165   else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
166     packet_needs_cookie = true;
167   else
168     return WG_INPUT_ERROR_HANDSHAKE_MAC;
169
170   switch (header->type)
171     {
172     case MESSAGE_HANDSHAKE_INITIATION:
173       {
174         message_handshake_initiation_t *message = current_b_data;
175
176         if (packet_needs_cookie)
177           {
178             // TODO: Add processing
179           }
180         noise_remote_t *rp;
181
182         if (noise_consume_initiation
183             (wmp->vlib_main, &wg_if->local, &rp, message->sender_index,
184              message->unencrypted_ephemeral, message->encrypted_static,
185              message->encrypted_timestamp))
186           {
187             peer = pool_elt_at_index (wmp->peers, rp->r_peer_idx);
188           }
189
190         if (!peer)
191           return WG_INPUT_ERROR_PEER;
192
193         // set_peer_address (peer, ip4_src, udp_src_port);
194         if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
195           {
196             vlib_node_increment_counter (vm, wg_input_node.index,
197                                          WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
198           }
199         break;
200       }
201     case MESSAGE_HANDSHAKE_RESPONSE:
202       {
203         message_handshake_response_t *resp = current_b_data;
204         u32 *entry =
205           wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
206         if (entry)
207           {
208             peer = pool_elt_at_index (wmp->peers, *entry);
209             if (!peer || peer->is_dead)
210               return WG_INPUT_ERROR_PEER;
211           }
212
213         if (!noise_consume_response
214             (wmp->vlib_main, &peer->remote, resp->sender_index,
215              resp->receiver_index, resp->unencrypted_ephemeral,
216              resp->encrypted_nothing))
217           {
218             return WG_INPUT_ERROR_PEER;
219           }
220         if (packet_needs_cookie)
221           {
222             // TODO: Add processing
223           }
224
225         // set_peer_address (peer, ip4_src, udp_src_port);
226         if (noise_remote_begin_session (wmp->vlib_main, &peer->remote))
227           {
228             wg_timers_session_derived (peer);
229             wg_timers_handshake_complete (peer);
230             if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
231               {
232                 vlib_node_increment_counter (vm, wg_input_node.index,
233                                              WG_INPUT_ERROR_KEEPALIVE_SEND,
234                                              1);
235               }
236           }
237         break;
238       }
239     default:
240       break;
241     }
242
243   wg_timers_any_authenticated_packet_received (peer);
244   wg_timers_any_authenticated_packet_traversal (peer);
245   return WG_INPUT_ERROR_NONE;
246 }
247
248 static_always_inline bool
249 fib_prefix_is_cover_addr_4 (const fib_prefix_t * p1,
250                             const ip4_address_t * ip4)
251 {
252   switch (p1->fp_proto)
253     {
254     case FIB_PROTOCOL_IP4:
255       return (ip4_destination_matches_route (&ip4_main,
256                                              &p1->fp_addr.ip4,
257                                              ip4, p1->fp_len) != 0);
258     case FIB_PROTOCOL_IP6:
259       return (false);
260     case FIB_PROTOCOL_MPLS:
261       break;
262     }
263   return (false);
264 }
265
266 VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
267                               vlib_node_runtime_t * node,
268                               vlib_frame_t * frame)
269 {
270   message_type_t header_type;
271   u32 n_left_from;
272   u32 *from;
273   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
274   u16 nexts[VLIB_FRAME_SIZE], *next;
275
276   from = vlib_frame_vector_args (frame);
277   n_left_from = frame->n_vectors;
278   b = bufs;
279   next = nexts;
280
281   vlib_get_buffers (vm, from, bufs, n_left_from);
282
283   wg_main_t *wmp = &wg_main;
284   wg_peer_t *peer = NULL;
285
286   while (n_left_from > 0)
287     {
288       bool is_keepalive = false;
289       next[0] = WG_INPUT_NEXT_PUNT;
290       header_type =
291         ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
292
293       switch (header_type)
294         {
295         case MESSAGE_HANDSHAKE_INITIATION:
296         case MESSAGE_HANDSHAKE_RESPONSE:
297         case MESSAGE_HANDSHAKE_COOKIE:
298           {
299             wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
300             if (ret != WG_INPUT_ERROR_NONE)
301               {
302                 next[0] = WG_INPUT_NEXT_ERROR;
303                 b[0]->error = node->errors[ret];
304               }
305             break;
306           }
307         case MESSAGE_DATA:
308           {
309             message_data_t *data = vlib_buffer_get_current (b[0]);
310             u32 *entry =
311               wg_index_table_lookup (&wmp->index_table, data->receiver_index);
312
313             if (entry)
314               {
315                 peer = pool_elt_at_index (wmp->peers, *entry);
316                 if (!peer)
317                   {
318                     next[0] = WG_INPUT_NEXT_ERROR;
319                     b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
320                     goto out;
321                   }
322               }
323
324             u16 encr_len = b[0]->current_length - sizeof (message_data_t);
325             u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
326             u8 *decr_data = clib_mem_alloc (decr_len);
327
328             enum noise_state_crypt state_cr =
329               noise_remote_decrypt (wmp->vlib_main,
330                                     &peer->remote,
331                                     data->receiver_index,
332                                     data->counter,
333                                     data->encrypted_data,
334                                     encr_len,
335                                     decr_data);
336
337             switch (state_cr)
338               {
339               case SC_OK:
340                 break;
341               case SC_CONN_RESET:
342                 wg_timers_handshake_complete (peer);
343                 break;
344               case SC_KEEP_KEY_FRESH:
345                 if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
346                   {
347                     vlib_node_increment_counter (vm, wg_input_node.index,
348                                                  WG_INPUT_ERROR_HANDSHAKE_SEND,
349                                                  1);
350                   }
351                 break;
352               case SC_FAILED:
353                 next[0] = WG_INPUT_NEXT_ERROR;
354                 b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
355                 goto out;
356               default:
357                 break;
358               }
359
360             clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
361             b[0]->current_length = decr_len;
362             b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
363
364             clib_mem_free (decr_data);
365
366             wg_timers_any_authenticated_packet_received (peer);
367             wg_timers_any_authenticated_packet_traversal (peer);
368
369             if (decr_len == 0)
370               {
371                 is_keepalive = true;
372                 goto out;
373               }
374
375             wg_timers_data_received (peer);
376
377             ip4_header_t *iph = vlib_buffer_get_current (b[0]);
378
379             const wg_peer_allowed_ip_t *allowed_ip;
380             bool allowed = false;
381
382             /*
383              * we could make this into an ACL, but the expectation
384              * is that there aren't many allowed IPs and thus a linear
385              * walk is fater than an ACL
386              */
387             vec_foreach (allowed_ip, peer->allowed_ips)
388             {
389               if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
390                                               &iph->src_address))
391                 {
392                   allowed = true;
393                   break;
394                 }
395             }
396             if (allowed)
397               {
398                 vnet_buffer (b[0])->sw_if_index[VLIB_RX] =
399                   peer->wg_sw_if_index;
400                 next[0] = WG_INPUT_NEXT_IP4_INPUT;
401               }
402             break;
403           }
404         default:
405           break;
406         }
407
408     out:
409       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
410                          && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
411         {
412           wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
413           t->type = header_type;
414           t->current_length = b[0]->current_length;
415           t->is_keepalive = is_keepalive;
416         }
417       n_left_from -= 1;
418       next += 1;
419       b += 1;
420     }
421   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
422
423   return frame->n_vectors;
424 }
425
426 /* *INDENT-OFF* */
427 VLIB_REGISTER_NODE (wg_input_node) =
428 {
429   .name = "wg-input",
430   .vector_size = sizeof (u32),
431   .format_trace = format_wg_input_trace,
432   .type = VLIB_NODE_TYPE_INTERNAL,
433   .n_errors = ARRAY_LEN (wg_input_error_strings),
434   .error_strings = wg_input_error_strings,
435   .n_next_nodes = WG_INPUT_N_NEXT,
436   /* edit / add dispositions here */
437   .next_nodes = {
438         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
439         [WG_INPUT_NEXT_PUNT] = "error-punt",
440         [WG_INPUT_NEXT_ERROR] = "error-drop",
441   },
442 };
443 /* *INDENT-ON* */
444
445 /*
446  * fd.io coding-style-patch-verification: ON
447  *
448  * Local Variables:
449  * eval: (c-set-style "gnu")
450  * End:
451  */