2 * Copyright (c) 2020 Doc.ai and/or its affiliates.
3 * Copyright (c) 2020 Cisco and/or its affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <vlib/vlib.h>
18 #include <vnet/vnet.h>
19 #include <vppinfra/error.h>
20 #include <wireguard/wireguard.h>
22 #include <wireguard/wireguard_send.h>
23 #include <wireguard/wireguard_if.h>
25 #define foreach_wg_input_error \
27 _(HANDSHAKE_MAC, "Invalid MAC handshake") \
28 _(PEER, "Peer error") \
29 _(INTERFACE, "Interface error") \
30 _(DECRYPTION, "Failed during decryption") \
31 _(KEEPALIVE_SEND, "Failed while sending Keepalive") \
32 _(HANDSHAKE_SEND, "Failed while sending Handshake") \
33 _(TOO_BIG, "Packet too big") \
34 _(UNDEFINED, "Undefined error")
38 #define _(sym,str) WG_INPUT_ERROR_##sym,
39 foreach_wg_input_error
44 static char *wg_input_error_strings[] = {
45 #define _(sym,string) string,
46 foreach_wg_input_error
59 format_wg_message_type (u8 * s, va_list * args)
61 message_type_t type = va_arg (*args, message_type_t);
65 #define _(v,a) case MESSAGE_##v: return (format (s, "%s", a));
66 foreach_wg_message_type
69 return (format (s, "unknown"));
72 /* packet trace format function */
74 format_wg_input_trace (u8 * s, va_list * args)
76 CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
77 CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
79 wg_input_trace_t *t = va_arg (*args, wg_input_trace_t *);
81 s = format (s, "WG input: \n");
82 s = format (s, " Type: %U\n", format_wg_message_type, t->type);
83 s = format (s, " peer: %d\n", t->peer);
84 s = format (s, " Length: %d\n", t->current_length);
85 s = format (s, " Keepalive: %s", t->is_keepalive ? "true" : "false");
92 WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
93 WG_INPUT_NEXT_HANDOFF_DATA,
94 WG_INPUT_NEXT_IP4_INPUT,
101 /* set_peer_address (wg_peer_t * peer, ip4_address_t ip4, u16 udp_port) */
105 /* ip46_address_set_ip4 (&peer->dst.addr, &ip4); */
106 /* peer->dst.port = udp_port; */
110 static wg_input_error_t
111 wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
113 ASSERT (vm->thread_index == 0);
115 enum cookie_mac_state mac_state;
116 bool packet_needs_cookie;
119 wg_peer_t *peer = NULL;
121 void *current_b_data = vlib_buffer_get_current (b);
123 udp_header_t *uhd = current_b_data - sizeof (udp_header_t);
125 current_b_data - sizeof (udp_header_t) - sizeof (ip4_header_t);
126 ip4_address_t ip4_src = iph->src_address;
127 u16 udp_src_port = clib_host_to_net_u16 (uhd->src_port);;
128 u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
130 message_header_t *header = current_b_data;
133 wg_if = wg_if_get_by_port (udp_dst_port);
136 return WG_INPUT_ERROR_INTERFACE;
138 if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
140 message_handshake_cookie_t *packet =
141 (message_handshake_cookie_t *) current_b_data;
143 wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
145 peer = wg_peer_get (*entry);
147 return WG_INPUT_ERROR_PEER;
149 // TODO: Implement cookie_maker_consume_payload
151 return WG_INPUT_ERROR_NONE;
154 u32 len = (header->type == MESSAGE_HANDSHAKE_INITIATION ?
155 sizeof (message_handshake_initiation_t) :
156 sizeof (message_handshake_response_t));
158 message_macs_t *macs = (message_macs_t *)
159 ((u8 *) current_b_data + len - sizeof (*macs));
162 cookie_checker_validate_macs (vm, &wg_if->cookie_checker, macs,
163 current_b_data, len, under_load, ip4_src,
166 if ((under_load && mac_state == VALID_MAC_WITH_COOKIE)
167 || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
168 packet_needs_cookie = false;
169 else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
170 packet_needs_cookie = true;
172 return WG_INPUT_ERROR_HANDSHAKE_MAC;
174 switch (header->type)
176 case MESSAGE_HANDSHAKE_INITIATION:
178 message_handshake_initiation_t *message = current_b_data;
180 if (packet_needs_cookie)
182 // TODO: Add processing
185 if (noise_consume_initiation
186 (vm, noise_local_get (wg_if->local_idx), &rp,
187 message->sender_index, message->unencrypted_ephemeral,
188 message->encrypted_static, message->encrypted_timestamp))
190 peer = wg_peer_get (rp->r_peer_idx);
194 return WG_INPUT_ERROR_PEER;
197 // set_peer_address (peer, ip4_src, udp_src_port);
198 if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
200 vlib_node_increment_counter (vm, wg_input_node.index,
201 WG_INPUT_ERROR_HANDSHAKE_SEND, 1);
205 case MESSAGE_HANDSHAKE_RESPONSE:
207 message_handshake_response_t *resp = current_b_data;
209 wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
211 if (PREDICT_TRUE (entry != NULL))
213 peer = wg_peer_get (*entry);
215 return WG_INPUT_ERROR_PEER;
218 return WG_INPUT_ERROR_PEER;
220 if (!noise_consume_response
221 (vm, &peer->remote, resp->sender_index,
222 resp->receiver_index, resp->unencrypted_ephemeral,
223 resp->encrypted_nothing))
225 return WG_INPUT_ERROR_PEER;
227 if (packet_needs_cookie)
229 // TODO: Add processing
232 // set_peer_address (peer, ip4_src, udp_src_port);
233 if (noise_remote_begin_session (vm, &peer->remote))
236 wg_timers_session_derived (peer);
237 wg_timers_handshake_complete (peer);
238 if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
240 vlib_node_increment_counter (vm, wg_input_node.index,
241 WG_INPUT_ERROR_KEEPALIVE_SEND,
251 wg_timers_any_authenticated_packet_received (peer);
252 wg_timers_any_authenticated_packet_traversal (peer);
253 return WG_INPUT_ERROR_NONE;
256 static_always_inline bool
257 fib_prefix_is_cover_addr_4 (const fib_prefix_t * p1,
258 const ip4_address_t * ip4)
260 switch (p1->fp_proto)
262 case FIB_PROTOCOL_IP4:
263 return (ip4_destination_matches_route (&ip4_main,
265 ip4, p1->fp_len) != 0);
266 case FIB_PROTOCOL_IP6:
268 case FIB_PROTOCOL_MPLS:
274 VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
275 vlib_node_runtime_t * node,
276 vlib_frame_t * frame)
278 message_type_t header_type;
281 vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
282 u16 nexts[VLIB_FRAME_SIZE], *next;
283 u32 thread_index = vm->thread_index;
285 from = vlib_frame_vector_args (frame);
286 n_left_from = frame->n_vectors;
290 vlib_get_buffers (vm, from, bufs, n_left_from);
292 wg_main_t *wmp = &wg_main;
293 wg_peer_t *peer = NULL;
295 while (n_left_from > 0)
297 bool is_keepalive = false;
298 next[0] = WG_INPUT_NEXT_PUNT;
300 ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
303 if (PREDICT_TRUE (header_type == MESSAGE_DATA))
305 message_data_t *data = vlib_buffer_get_current (b[0]);
307 peer_idx = wg_index_table_lookup (&wmp->index_table,
308 data->receiver_index);
312 peer = wg_peer_get (*peer_idx);
316 next[0] = WG_INPUT_NEXT_ERROR;
317 b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
321 if (PREDICT_FALSE (~0 == peer->input_thread_index))
323 /* this is the first packet to use this peer, claim the peer
326 clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
327 wg_peer_assign_thread (thread_index));
330 if (PREDICT_TRUE (thread_index != peer->input_thread_index))
332 next[0] = WG_INPUT_NEXT_HANDOFF_DATA;
336 u16 encr_len = b[0]->current_length - sizeof (message_data_t);
337 u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
338 if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
340 b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
344 u8 *decr_data = wmp->per_thread_data[thread_index].data;
346 enum noise_state_crypt state_cr = noise_remote_decrypt (vm,
348 data->receiver_index,
350 data->encrypted_data,
354 if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
356 wg_timers_handshake_complete (peer);
358 else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
360 wg_send_handshake_from_mt (*peer_idx, false);
362 else if (PREDICT_FALSE (state_cr == SC_FAILED))
364 next[0] = WG_INPUT_NEXT_ERROR;
365 b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
369 clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
370 b[0]->current_length = decr_len;
371 b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
373 wg_timers_any_authenticated_packet_received (peer);
374 wg_timers_any_authenticated_packet_traversal (peer);
376 /* Keepalive packet has zero length */
383 wg_timers_data_received (peer);
385 ip4_header_t *iph = vlib_buffer_get_current (b[0]);
387 const wg_peer_allowed_ip_t *allowed_ip;
388 bool allowed = false;
391 * we could make this into an ACL, but the expectation
392 * is that there aren't many allowed IPs and thus a linear
393 * walk is fater than an ACL
395 vec_foreach (allowed_ip, peer->allowed_ips)
397 if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
406 vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
407 next[0] = WG_INPUT_NEXT_IP4_INPUT;
414 /* Handshake packets should be processed in main thread */
415 if (thread_index != 0)
417 next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
421 wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
422 if (ret != WG_INPUT_ERROR_NONE)
424 next[0] = WG_INPUT_NEXT_ERROR;
425 b[0]->error = node->errors[ret];
430 if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
431 && (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
433 wg_input_trace_t *t = vlib_add_trace (vm, node, b[0], sizeof (*t));
434 t->type = header_type;
435 t->current_length = b[0]->current_length;
436 t->is_keepalive = is_keepalive;
437 t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
444 vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
446 return frame->n_vectors;
450 VLIB_REGISTER_NODE (wg_input_node) =
453 .vector_size = sizeof (u32),
454 .format_trace = format_wg_input_trace,
455 .type = VLIB_NODE_TYPE_INTERNAL,
456 .n_errors = ARRAY_LEN (wg_input_error_strings),
457 .error_strings = wg_input_error_strings,
458 .n_next_nodes = WG_INPUT_N_NEXT,
459 /* edit / add dispositions here */
461 [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff",
462 [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff",
463 [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
464 [WG_INPUT_NEXT_PUNT] = "error-punt",
465 [WG_INPUT_NEXT_ERROR] = "error-drop",
471 * fd.io coding-style-patch-verification: ON
474 * eval: (c-set-style "gnu")