wireguard: add dos mitigation support
[vpp.git] / src / plugins / wireguard / wireguard_input.c
index dbdcaa0..3f546cc 100644 (file)
@@ -31,6 +31,8 @@
   _ (KEEPALIVE_SEND, "Failed while sending Keepalive")                        \
   _ (HANDSHAKE_SEND, "Failed while sending Handshake")                        \
   _ (HANDSHAKE_RECEIVE, "Failed while receiving Handshake")                   \
+  _ (COOKIE_DECRYPTION, "Failed during Cookie decryption")                    \
+  _ (COOKIE_SEND, "Failed during sending Cookie")                             \
   _ (TOO_BIG, "Packet too big")                                               \
   _ (UNDEFINED, "Undefined error")                                            \
   _ (CRYPTO_ENGINE_ERROR, "crypto engine error (packet dropped)")
@@ -172,7 +174,6 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
   u16 udp_dst_port = clib_host_to_net_u16 (uhd->dst_port);;
 
   message_header_t *header = current_b_data;
-  under_load = false;
 
   if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
     {
@@ -185,7 +186,9 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
       else
        return WG_INPUT_ERROR_PEER;
 
-      // TODO: Implement cookie_maker_consume_payload
+      if (!cookie_maker_consume_payload (
+           vm, &peer->cookie_maker, packet->nonce, packet->encrypted_cookie))
+       return WG_INPUT_ERROR_COOKIE_DECRYPTION;
 
       return WG_INPUT_ERROR_NONE;
     }
@@ -208,11 +211,13 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
       if (NULL == wg_if)
        continue;
 
+      under_load = wg_if_is_under_load (vm, wg_if);
       mac_state = cookie_checker_validate_macs (
        vm, &wg_if->cookie_checker, macs, current_b_data, len, under_load,
        &src_ip, udp_src_port);
       if (mac_state == INVALID_MAC)
        {
+         wg_if_dec_handshake_num (wg_if);
          wg_if = NULL;
          continue;
        }
@@ -238,8 +243,16 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
 
        if (packet_needs_cookie)
          {
-           // TODO: Add processing
+
+           if (!wg_send_handshake_cookie (vm, message->sender_index,
+                                          &wg_if->cookie_checker, macs,
+                                          &ip_addr_46 (&wg_if->src_ip),
+                                          wg_if->port, &src_ip, udp_src_port))
+             return WG_INPUT_ERROR_COOKIE_SEND;
+
+           return WG_INPUT_ERROR_NONE;
          }
+
        noise_remote_t *rp;
        if (noise_consume_initiation
            (vm, noise_local_get (wg_if->local_idx), &rp,
@@ -268,6 +281,18 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
     case MESSAGE_HANDSHAKE_RESPONSE:
       {
        message_handshake_response_t *resp = current_b_data;
+
+       if (packet_needs_cookie)
+         {
+           if (!wg_send_handshake_cookie (vm, resp->sender_index,
+                                          &wg_if->cookie_checker, macs,
+                                          &ip_addr_46 (&wg_if->src_ip),
+                                          wg_if->port, &src_ip, udp_src_port))
+             return WG_INPUT_ERROR_COOKIE_SEND;
+
+           return WG_INPUT_ERROR_NONE;
+         }
+
        index_t peeri = INDEX_INVALID;
        u32 *entry =
          wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
@@ -289,10 +314,6 @@ wg_handshake_process (vlib_main_t *vm, wg_main_t *wmp, vlib_buffer_t *b,
          {
            return WG_INPUT_ERROR_PEER;
          }
-       if (packet_needs_cookie)
-         {
-           // TODO: Add processing
-         }
 
        // set_peer_address (peer, ip4_src, udp_src_port);
        if (noise_remote_begin_session (vm, &peer->remote))
@@ -562,6 +583,8 @@ always_inline uword
 wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
                 vlib_frame_t *frame, u8 is_ip4, u16 async_next_node)
 {
+  vnet_main_t *vnm = vnet_get_main ();
+  vnet_interface_main_t *im = &vnm->interface_main;
   wg_main_t *wmp = &wg_main;
   wg_per_thread_data_t *ptd =
     vec_elt_at_index (wmp->per_thread_data, vm->thread_index);
@@ -802,6 +825,11 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
          last_peer_time_idx = peer_idx;
        }
 
+      vlib_increment_combined_counter (im->combined_sw_if_counters +
+                                        VNET_INTERFACE_COUNTER_RX,
+                                      vm->thread_index, peer->wg_sw_if_index,
+                                      1 /* packets */, b[0]->current_length);
+
     trace:
       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
                         (b[0]->flags & VLIB_BUFFER_IS_TRACED)))
@@ -861,6 +889,8 @@ wg_input_inline (vlib_main_t *vm, vlib_node_runtime_t *node,
 always_inline uword
 wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
 {
+  vnet_main_t *vnm = vnet_get_main ();
+  vnet_interface_main_t *im = &vnm->interface_main;
   wg_main_t *wmp = &wg_main;
   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
   u16 nexts[VLIB_FRAME_SIZE], *next = nexts;
@@ -902,9 +932,17 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
          last_rec_idx = data->receiver_index;
        }
 
-      if (PREDICT_FALSE (wg_input_post_process (vm, b[0], next, peer, data,
-                                               &is_keepalive) < 0))
-       goto trace;
+      if (PREDICT_TRUE (peer != NULL))
+       {
+         if (PREDICT_FALSE (wg_input_post_process (vm, b[0], next, peer, data,
+                                                   &is_keepalive) < 0))
+           goto trace;
+       }
+      else
+       {
+         next[0] = WG_INPUT_NEXT_PUNT;
+         goto trace;
+       }
 
       if (PREDICT_FALSE (peer_idx && (last_peer_time_idx != peer_idx)))
        {
@@ -912,6 +950,12 @@ wg_input_post (vlib_main_t *vm, vlib_node_runtime_t *node, vlib_frame_t *frame)
          wg_timers_any_authenticated_packet_traversal (peer);
          last_peer_time_idx = peer_idx;
        }
+
+      vlib_increment_combined_counter (im->combined_sw_if_counters +
+                                        VNET_INTERFACE_COUNTER_RX,
+                                      vm->thread_index, peer->wg_sw_if_index,
+                                      1 /* packets */, b[0]->current_length);
+
     trace:
       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE) &&
                         (b[0]->flags & VLIB_BUFFER_IS_TRACED)))