wireguard: add handoff node
[vpp.git] / src / plugins / wireguard / wireguard_output_tun.c
1 /*
2  * Copyright (c) 2020 Doc.ai and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #include <vlib/vlib.h>
17 #include <vnet/vnet.h>
18 #include <vppinfra/error.h>
19
20 #include <wireguard/wireguard.h>
21 #include <wireguard/wireguard_send.h>
22
23 #define foreach_wg_output_error                                         \
24  _(NONE, "No error")                                                    \
25  _(PEER, "Peer error")                                                  \
26  _(KEYPAIR, "Keypair error")                                            \
27  _(TOO_BIG, "packet too big")                                           \
28
29 typedef enum
30 {
31 #define _(sym,str) WG_OUTPUT_ERROR_##sym,
32   foreach_wg_output_error
33 #undef _
34     WG_OUTPUT_N_ERROR,
35 } wg_output_error_t;
36
37 static char *wg_output_error_strings[] = {
38 #define _(sym,string) string,
39   foreach_wg_output_error
40 #undef _
41 };
42
43 typedef enum
44 {
45   WG_OUTPUT_NEXT_ERROR,
46   WG_OUTPUT_NEXT_HANDOFF,
47   WG_OUTPUT_NEXT_INTERFACE_OUTPUT,
48   WG_OUTPUT_N_NEXT,
49 } wg_output_next_t;
50
51 typedef struct
52 {
53   ip4_udp_header_t hdr;
54   index_t peer;
55 } wg_output_tun_trace_t;
56
57 u8 *
58 format_ip4_udp_header (u8 * s, va_list * args)
59 {
60   ip4_udp_header_t *hdr = va_arg (*args, ip4_udp_header_t *);
61
62   s = format (s, "%U:$U",
63               format_ip4_header, &hdr->ip4, format_udp_header, &hdr->udp);
64
65   return (s);
66 }
67
68 /* packet trace format function */
69 static u8 *
70 format_wg_output_tun_trace (u8 * s, va_list * args)
71 {
72   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
73   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
74
75   wg_output_tun_trace_t *t = va_arg (*args, wg_output_tun_trace_t *);
76
77   s = format (s, "peer: %d\n", t->peer);
78   s = format (s, "  Encrypted packet: %U", format_ip4_udp_header, &t->hdr);
79   return s;
80 }
81
82 VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
83                                    vlib_node_runtime_t * node,
84                                    vlib_frame_t * frame)
85 {
86   u32 n_left_from;
87   u32 *from;
88   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
89   u16 nexts[VLIB_FRAME_SIZE], *next;
90   u32 thread_index = vm->thread_index;
91
92   from = vlib_frame_vector_args (frame);
93   n_left_from = frame->n_vectors;
94   b = bufs;
95   next = nexts;
96
97   vlib_get_buffers (vm, from, bufs, n_left_from);
98
99   wg_main_t *wmp = &wg_main;
100   wg_peer_t *peer = NULL;
101
102   while (n_left_from > 0)
103     {
104       ip4_udp_header_t *hdr = vlib_buffer_get_current (b[0]);
105       u8 *plain_data = (vlib_buffer_get_current (b[0]) +
106                         sizeof (ip4_udp_header_t));
107       u16 plain_data_len =
108         clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length);
109       index_t peeri;
110
111       next[0] = WG_OUTPUT_NEXT_ERROR;
112       peeri =
113         wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]);
114       peer = wg_peer_get (peeri);
115
116       if (!peer || peer->is_dead)
117         {
118           b[0]->error = node->errors[WG_OUTPUT_ERROR_PEER];
119           goto out;
120         }
121
122       if (PREDICT_FALSE (~0 == peer->output_thread_index))
123         {
124           /* this is the first packet to use this peer, claim the peer
125            * for this thread.
126            */
127           clib_atomic_cmp_and_swap (&peer->output_thread_index, ~0,
128                                     wg_peer_assign_thread (thread_index));
129         }
130
131       if (PREDICT_TRUE (thread_index != peer->output_thread_index))
132         {
133           next[0] = WG_OUTPUT_NEXT_HANDOFF;
134           goto next;
135         }
136
137       if (PREDICT_FALSE (!peer->remote.r_current))
138         {
139           wg_send_handshake_from_mt (peeri, false);
140           b[0]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR];
141           goto out;
142         }
143       size_t encrypted_packet_len = message_data_len (plain_data_len);
144
145       /*
146        * Ensure there is enough space to write the encrypted data
147        * into the packet
148        */
149       if (PREDICT_FALSE (encrypted_packet_len >= WG_DEFAULT_DATA_SIZE) ||
150           PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >=
151                          vlib_buffer_get_default_data_size (vm)))
152         {
153           b[0]->error = node->errors[WG_OUTPUT_ERROR_TOO_BIG];
154           goto out;
155         }
156
157       message_data_t *encrypted_packet =
158         (message_data_t *) wmp->per_thread_data[thread_index].data;
159
160       enum noise_state_crypt state;
161       state =
162         noise_remote_encrypt (vm,
163                               &peer->remote,
164                               &encrypted_packet->receiver_index,
165                               &encrypted_packet->counter, plain_data,
166                               plain_data_len,
167                               encrypted_packet->encrypted_data);
168
169       if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH))
170         {
171           wg_send_handshake_from_mt (peeri, false);
172         }
173       else if (PREDICT_FALSE (state == SC_FAILED))
174         {
175           //TODO: Maybe wrong
176           wg_send_handshake_from_mt (peeri, false);
177           goto out;
178         }
179
180       /* Here we are sure that can send packet to next node */
181       next[0] = WG_OUTPUT_NEXT_INTERFACE_OUTPUT;
182       encrypted_packet->header.type = MESSAGE_DATA;
183
184       clib_memcpy (plain_data, (u8 *) encrypted_packet, encrypted_packet_len);
185
186       hdr->udp.length = clib_host_to_net_u16 (encrypted_packet_len +
187                                               sizeof (udp_header_t));
188       b[0]->current_length = (encrypted_packet_len +
189                               sizeof (ip4_header_t) + sizeof (udp_header_t));
190       ip4_header_set_len_w_chksum
191         (&hdr->ip4, clib_host_to_net_u16 (b[0]->current_length));
192
193       wg_timers_any_authenticated_packet_sent (peer);
194       wg_timers_data_sent (peer);
195       wg_timers_any_authenticated_packet_traversal (peer);
196
197     out:
198       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
199                          && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
200         {
201           wg_output_tun_trace_t *t =
202             vlib_add_trace (vm, node, b[0], sizeof (*t));
203           t->hdr = *hdr;
204           t->peer = peeri;
205         }
206     next:
207       n_left_from -= 1;
208       next += 1;
209       b += 1;
210     }
211
212   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
213   return frame->n_vectors;
214 }
215
216 /* *INDENT-OFF* */
217 VLIB_REGISTER_NODE (wg_output_tun_node) =
218 {
219   .name = "wg-output-tun",
220   .vector_size = sizeof (u32),
221   .format_trace = format_wg_output_tun_trace,
222   .type = VLIB_NODE_TYPE_INTERNAL,
223   .n_errors = ARRAY_LEN (wg_output_error_strings),
224   .error_strings = wg_output_error_strings,
225   .n_next_nodes = WG_OUTPUT_N_NEXT,
226   .next_nodes = {
227         [WG_OUTPUT_NEXT_HANDOFF] = "wg-output-tun-handoff",
228         [WG_OUTPUT_NEXT_INTERFACE_OUTPUT] = "adj-midchain-tx",
229         [WG_OUTPUT_NEXT_ERROR] = "error-drop",
230   },
231 };
232 /* *INDENT-ON* */
233
234 /*
235  * fd.io coding-style-patch-verification: ON
236  *
237  * Local Variables:
238  * eval: (c-set-style "gnu")
239  * End:
240  */