wireguard: Fix for tunnel encap
[vpp.git] / src / plugins / wireguard / wireguard_output_tun.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #include <vlib/vlib.h>
17 #include <vnet/vnet.h>
18 #include <vnet/pg/pg.h>
19 #include <vnet/fib/ip6_fib.h>
20 #include <vnet/fib/ip4_fib.h>
21 #include <vnet/fib/fib_entry.h>
22 #include <vppinfra/error.h>
23
24 #include <wireguard/wireguard.h>
25 #include <wireguard/wireguard_send.h>
26
27 #define foreach_wg_output_error                                         \
28  _(NONE, "No error")                                                    \
29  _(PEER, "Peer error")                                                  \
30  _(KEYPAIR, "Keypair error")                                            \
31  _(HANDSHAKE_SEND, "Handshake sending failed")                          \
32  _(TOO_BIG, "packet too big")                                           \
33
34 #define WG_OUTPUT_SCRATCH_SIZE 2048
35
36 typedef struct wg_output_scratch_t_
37 {
38   u8 scratch[WG_OUTPUT_SCRATCH_SIZE];
39 } wg_output_scratch_t;
40
41 /* Cache line aligned per-thread scratch space */
42 static wg_output_scratch_t *wg_output_scratchs;
43
44 typedef enum
45 {
46 #define _(sym,str) WG_OUTPUT_ERROR_##sym,
47   foreach_wg_output_error
48 #undef _
49     WG_OUTPUT_N_ERROR,
50 } wg_output_error_t;
51
52 static char *wg_output_error_strings[] = {
53 #define _(sym,string) string,
54   foreach_wg_output_error
55 #undef _
56 };
57
58 typedef enum
59 {
60   WG_OUTPUT_NEXT_ERROR,
61   WG_OUTPUT_NEXT_INTERFACE_OUTPUT,
62   WG_OUTPUT_N_NEXT,
63 } wg_output_next_t;
64
65 typedef struct
66 {
67   ip4_udp_header_t hdr;
68 } wg_output_tun_trace_t;
69
70 u8 *
71 format_ip4_udp_header (u8 * s, va_list * args)
72 {
73   ip4_udp_header_t *hdr = va_arg (*args, ip4_udp_header_t *);
74
75   s = format (s, "%U:$U",
76               format_ip4_header, &hdr->ip4, format_udp_header, &hdr->udp);
77
78   return (s);
79 }
80
81 /* packet trace format function */
82 static u8 *
83 format_wg_output_tun_trace (u8 * s, va_list * args)
84 {
85   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
86   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
87
88   wg_output_tun_trace_t *t = va_arg (*args, wg_output_tun_trace_t *);
89
90   s = format (s, "Encrypted packet: %U\n", format_ip4_udp_header, &t->hdr);
91   return s;
92 }
93
94 VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
95                                    vlib_node_runtime_t * node,
96                                    vlib_frame_t * frame)
97 {
98   u32 n_left_from;
99   u32 *from;
100   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
101   u16 nexts[VLIB_FRAME_SIZE], *next;
102   u32 thread_index = vm->thread_index;
103
104   from = vlib_frame_vector_args (frame);
105   n_left_from = frame->n_vectors;
106   b = bufs;
107   next = nexts;
108
109   vlib_get_buffers (vm, from, bufs, n_left_from);
110
111   wg_main_t *wmp = &wg_main;
112   u32 handsh_fails = 0;
113   wg_peer_t *peer = NULL;
114
115   while (n_left_from > 0)
116     {
117       ip4_udp_header_t *hdr = vlib_buffer_get_current (b[0]);
118       u8 *plain_data = (vlib_buffer_get_current (b[0]) +
119                         sizeof (ip4_udp_header_t));
120       u16 plain_data_len =
121         clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length);
122
123       next[0] = WG_OUTPUT_NEXT_ERROR;
124
125       peer =
126         wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]);
127
128       if (!peer || peer->is_dead)
129         {
130           b[0]->error = node->errors[WG_OUTPUT_ERROR_PEER];
131           goto out;
132         }
133
134       if (PREDICT_FALSE (!peer->remote.r_current))
135         {
136           if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
137             handsh_fails++;
138           b[0]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR];
139           goto out;
140         }
141
142       size_t encrypted_packet_len = message_data_len (plain_data_len);
143
144       /*
145        * Ensure there is enough space to write the encrypted data
146        * into the packet
147        */
148       if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) ||
149           PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >=
150                          vlib_buffer_get_default_data_size (vm)))
151         {
152           b[0]->error = node->errors[WG_OUTPUT_ERROR_TOO_BIG];
153           goto out;
154         }
155
156       message_data_t *encrypted_packet =
157         (message_data_t *) wg_output_scratchs[thread_index].scratch;
158
159       enum noise_state_crypt state;
160       state =
161         noise_remote_encrypt (wmp->vlib_main,
162                               &peer->remote,
163                               &encrypted_packet->receiver_index,
164                               &encrypted_packet->counter, plain_data,
165                               plain_data_len,
166                               encrypted_packet->encrypted_data);
167       switch (state)
168         {
169         case SC_OK:
170           break;
171         case SC_KEEP_KEY_FRESH:
172           if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
173             handsh_fails++;
174           break;
175         case SC_FAILED:
176           //TODO: Maybe wrong
177           if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
178             handsh_fails++;
179           clib_mem_free (encrypted_packet);
180           goto out;
181         default:
182           break;
183         }
184
185       // Here we are sure that can send packet to next node.
186       next[0] = WG_OUTPUT_NEXT_INTERFACE_OUTPUT;
187       encrypted_packet->header.type = MESSAGE_DATA;
188
189       clib_memcpy (plain_data, (u8 *) encrypted_packet, encrypted_packet_len);
190
191       hdr->udp.length = clib_host_to_net_u16 (encrypted_packet_len +
192                                               sizeof (udp_header_t));
193       b[0]->current_length = (encrypted_packet_len +
194                               sizeof (ip4_header_t) + sizeof (udp_header_t));
195       ip4_header_set_len_w_chksum
196         (&hdr->ip4, clib_host_to_net_u16 (b[0]->current_length));
197
198       wg_timers_any_authenticated_packet_traversal (peer);
199       wg_timers_any_authenticated_packet_sent (peer);
200       wg_timers_data_sent (peer);
201
202     out:
203       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
204                          && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
205         {
206           wg_output_tun_trace_t *t =
207             vlib_add_trace (vm, node, b[0], sizeof (*t));
208           t->hdr = *hdr;
209         }
210       n_left_from -= 1;
211       next += 1;
212       b += 1;
213     }
214
215   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
216
217   vlib_node_increment_counter (vm, node->node_index,
218                                WG_OUTPUT_ERROR_HANDSHAKE_SEND, handsh_fails);
219
220   return frame->n_vectors;
221 }
222
223 /* *INDENT-OFF* */
224 VLIB_REGISTER_NODE (wg_output_tun_node) =
225 {
226   .name = "wg-output-tun",
227   .vector_size = sizeof (u32),
228   .format_trace = format_wg_output_tun_trace,
229   .type = VLIB_NODE_TYPE_INTERNAL,
230   .n_errors = ARRAY_LEN (wg_output_error_strings),
231   .error_strings = wg_output_error_strings,
232   .n_next_nodes = WG_OUTPUT_N_NEXT,
233   .next_nodes = {
234         [WG_OUTPUT_NEXT_INTERFACE_OUTPUT] = "adj-midchain-tx",
235         [WG_OUTPUT_NEXT_ERROR] = "error-drop",
236   },
237 };
238 /* *INDENT-ON* */
239
240 static clib_error_t *
241 wireguard_output_module_init (vlib_main_t * vm)
242 {
243   vlib_thread_main_t *tm = vlib_get_thread_main ();
244
245   vec_validate_aligned (wg_output_scratchs, tm->n_vlib_mains,
246                         CLIB_CACHE_LINE_BYTES);
247   return (NULL);
248 }
249
250 VLIB_INIT_FUNCTION (wireguard_output_module_init);
251
252 /*
253  * fd.io coding-style-patch-verification: ON
254  *
255  * Local Variables:
256  * eval: (c-set-style "gnu")
257  * End:
258  */