wireguard: add handoff node 48/28848/11
authorArtem Glazychev <artem.glazychev@xored.com>
Mon, 14 Sep 2020 04:36:01 +0000 (11:36 +0700)
committerNeale Ranns <nranns@cisco.com>
Wed, 23 Sep 2020 10:11:13 +0000 (10:11 +0000)
All timer and control plane functions happen from main thread

Type: fix

Change-Id: I4fc333c644485cd17e6f426493feef91688d9b24
Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
18 files changed:
src/plugins/wireguard/CMakeLists.txt
src/plugins/wireguard/test/test_wireguard.py
src/plugins/wireguard/wireguard.c
src/plugins/wireguard/wireguard.h
src/plugins/wireguard/wireguard_api.c
src/plugins/wireguard/wireguard_handoff.c [new file with mode: 0644]
src/plugins/wireguard/wireguard_if.c
src/plugins/wireguard/wireguard_if.h
src/plugins/wireguard/wireguard_input.c
src/plugins/wireguard/wireguard_noise.c
src/plugins/wireguard/wireguard_noise.h
src/plugins/wireguard/wireguard_output_tun.c
src/plugins/wireguard/wireguard_peer.c
src/plugins/wireguard/wireguard_peer.h
src/plugins/wireguard/wireguard_send.c
src/plugins/wireguard/wireguard_send.h
src/plugins/wireguard/wireguard_timer.c
src/plugins/wireguard/wireguard_timer.h

index db5bb2d..db74f9c 100755 (executable)
@@ -30,6 +30,7 @@ add_vpp_plugin(wireguard
   wireguard_if.h
   wireguard_input.c
   wireguard_output_tun.c
+  wireguard_handoff.c
   wireguard_key.c
   wireguard_key.h
   wireguard_cli.c
index 7734939..cee1e93 100755 (executable)
@@ -327,6 +327,14 @@ class VppWgPeer(VppObject):
     def encrypt_transport(self, p):
         return self.noise.encrypt(bytes(p))
 
+    def validate_encapped(self, rxs, tx):
+        for rx in rxs:
+            rx = IP(self.decrypt_transport(rx))
+
+            # chech the oringial packet is present
+            self._test.assertEqual(rx[IP].dst, tx[IP].dst)
+            self._test.assertEqual(rx[IP].ttl, tx[IP].ttl-1)
+
 
 class TestWg(VppTestCase):
     """ Wireguard Test Case """
@@ -455,11 +463,7 @@ class TestWg(VppTestCase):
 
         rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
 
-        for rx in rxs:
-            rx = IP(peer_1.decrypt_transport(rx))
-            # chech the oringial packet is present
-            self.assertEqual(rx[IP].dst, p[IP].dst)
-            self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
+        peer_1.validate_encapped(rxs, p)
 
         # send packets into the tunnel, expect to receive them on
         # the other side
@@ -655,3 +659,90 @@ class TestWg(VppTestCase):
 
         wg0.remove_vpp_config()
         wg1.remove_vpp_config()
+
+
+class WireguardHandoffTests(TestWg):
+    """ Wireguard Tests in multi worker setup """
+    worker_config = "workers 2"
+
+    def test_wg_peer_init(self):
+        """ Handoff """
+        wg_output_node_name = '/err/wg-output-tun/'
+        wg_input_node_name = '/err/wg-input/'
+
+        port = 12323
+
+        # Create interfaces
+        wg0 = VppWgInterface(self,
+                             self.pg1.local_ip4,
+                             port).add_vpp_config()
+        wg0.admin_up()
+        wg0.config_ip4()
+
+        peer_1 = VppWgPeer(self,
+                           wg0,
+                           self.pg1.remote_ip4,
+                           port+1,
+                           ["10.11.2.0/24",
+                            "10.11.3.0/24"]).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        # send a valid handsake init for which we expect a response
+        p = peer_1.mk_handshake(self.pg1)
+
+        rx = self.send_and_expect(self.pg1, [p], self.pg1)
+
+        peer_1.consume_response(rx[0])
+
+        # send a data packet from the peer through the tunnel
+        # this completes the handshake and pins the peer to worker 0
+        p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+             UDP(sport=222, dport=223) /
+             Raw())
+        d = peer_1.encrypt_transport(p)
+        p = (peer_1.mk_tunnel_header(self.pg1) /
+             (Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(receiver_index=peer_1.sender,
+                                 counter=0,
+                                 encrypted_encapsulated_packet=d)))
+        rxs = self.send_and_expect(self.pg1, [p], self.pg0,
+                                   worker=0)
+
+        for rx in rxs:
+            self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rx[IP].ttl, 19)
+
+        # send a packets that are routed into the tunnel
+        # and pins the peer tp worker 1
+        pe = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+              IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+              UDP(sport=555, dport=556) /
+              Raw(b'\x00' * 80))
+        rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=1)
+        peer_1.validate_encapped(rxs, pe)
+
+        # send packets into the tunnel, from the other worker
+        p = [(peer_1.mk_tunnel_header(self.pg1) /
+              Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(
+                  receiver_index=peer_1.sender,
+                  counter=ii+1,
+                  encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                      (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+                       UDP(sport=222, dport=223) /
+                       Raw())))) for ii in range(255)]
+
+        rxs = self.send_and_expect(self.pg1, p, self.pg0, worker=1)
+
+        for rx in rxs:
+            self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rx[IP].ttl, 19)
+
+        # send a packets that are routed into the tunnel
+        # from owrker 0
+        rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=0)
+
+        peer_1.validate_encapped(rxs, pe)
+
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
index 0092181..9510a0a 100755 (executable)
@@ -32,7 +32,17 @@ wg_init (vlib_main_t * vm)
   wg_main_t *wmp = &wg_main;
 
   wmp->vlib_main = vm;
-  wmp->peers = 0;
+
+  wmp->in_fq_index = vlib_frame_queue_main_init (wg_input_node.index, 0);
+  wmp->out_fq_index =
+    vlib_frame_queue_main_init (wg_output_tun_node.index, 0);
+
+  vlib_thread_main_t *tm = vlib_get_thread_main ();
+
+  vec_validate_aligned (wmp->per_thread_data, tm->n_vlib_mains,
+                       CLIB_CACHE_LINE_BYTES);
+
+  wg_timer_wheel_init ();
 
   return (NULL);
 }
index 70a692e..2c892a3 100755 (executable)
 
 #include <wireguard/wireguard_index_table.h>
 #include <wireguard/wireguard_messages.h>
-#include <wireguard/wireguard_peer.h>
+#include <wireguard/wireguard_timer.h>
+
+#define WG_DEFAULT_DATA_SIZE 2048
 
 extern vlib_node_registration_t wg_input_node;
 extern vlib_node_registration_t wg_output_tun_node;
 
-
-
+typedef struct wg_per_thread_data_t_
+{
+  u8 data[WG_DEFAULT_DATA_SIZE];
+} wg_per_thread_data_t;
 typedef struct
 {
   /* convenience */
@@ -31,10 +35,14 @@ typedef struct
 
   u16 msg_id_base;
 
-  // Peers pool
-  wg_peer_t *peers;
   wg_index_table_t index_table;
 
+  u32 in_fq_index;
+  u32 out_fq_index;
+
+  wg_per_thread_data_t *per_thread_data;
+
+  tw_timer_wheel_16t_2w_512sl_t timer_wheel;
 } wg_main_t;
 
 extern wg_main_t wg_main;
index 8bbacdd..27ed6ea 100755 (executable)
@@ -97,15 +97,17 @@ wireguard_if_send_details (index_t wgii, void *data)
   vl_api_wireguard_interface_details_t *rmp;
   wg_deatils_walk_t *ctx = data;
   const wg_if_t *wgi;
+  const noise_local_t *local;
 
   wgi = wg_if_get (wgii);
+  local = noise_local_get (wgi->local_idx);
 
   rmp = vl_msg_api_alloc_zero (sizeof (*rmp));
   rmp->_vl_msg_id = htons (VL_API_WIREGUARD_INTERFACE_DETAILS +
                           wg_main.msg_id_base);
 
   clib_memcpy (rmp->interface.private_key,
-              wgi->local.l_private, NOISE_PUBLIC_KEY_LEN);
+              local->l_private, NOISE_PUBLIC_KEY_LEN);
   rmp->interface.sw_if_index = htonl (wgi->sw_if_index);
   rmp->interface.port = htons (wgi->port);
   ip_address_encode2 (&wgi->src_ip, &rmp->interface.src_ip);
diff --git a/src/plugins/wireguard/wireguard_handoff.c b/src/plugins/wireguard/wireguard_handoff.c
new file mode 100644 (file)
index 0000000..b0b7422
--- /dev/null
@@ -0,0 +1,197 @@
+/*
+ * Copyright (c) 2020 Doc.ai and/or its affiliates.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at:
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
+
+#define foreach_wg_handoff_error  \
+_(CONGESTION_DROP, "congestion drop")
+
+typedef enum
+{
+#define _(sym,str) WG_HANDOFF_ERROR_##sym,
+  foreach_wg_handoff_error
+#undef _
+    HANDOFF_N_ERROR,
+} ipsec_handoff_error_t;
+
+static char *wg_handoff_error_strings[] = {
+#define _(sym,string) string,
+  foreach_wg_handoff_error
+#undef _
+};
+
+typedef enum
+{
+  WG_HANDOFF_HANDSHAKE,
+  WG_HANDOFF_INP_DATA,
+  WG_HANDOFF_OUT_TUN,
+} wg_handoff_mode_t;
+
+typedef struct wg_handoff_trace_t_
+{
+  u32 next_worker_index;
+  index_t peer;
+} wg_handoff_trace_t;
+
+static u8 *
+format_wg_handoff_trace (u8 * s, va_list * args)
+{
+  CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
+  CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
+  wg_handoff_trace_t *t = va_arg (*args, wg_handoff_trace_t *);
+
+  s = format (s, "next-worker %d peer %d", t->next_worker_index, t->peer);
+
+  return s;
+}
+
+static_always_inline uword
+wg_handoff (vlib_main_t * vm,
+           vlib_node_runtime_t * node,
+           vlib_frame_t * frame, u32 fq_index, wg_handoff_mode_t mode)
+{
+  vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
+  u16 thread_indices[VLIB_FRAME_SIZE], *ti;
+  u32 n_enq, n_left_from, *from;
+  wg_main_t *wmp;
+
+  wmp = &wg_main;
+  from = vlib_frame_vector_args (frame);
+  n_left_from = frame->n_vectors;
+  vlib_get_buffers (vm, from, bufs, n_left_from);
+
+  b = bufs;
+  ti = thread_indices;
+
+  while (n_left_from > 0)
+    {
+      const wg_peer_t *peer;
+      index_t peeri;
+
+      if (PREDICT_FALSE (mode == WG_HANDOFF_HANDSHAKE))
+       {
+         ti[0] = 0;
+       }
+      else if (mode == WG_HANDOFF_INP_DATA)
+       {
+         message_data_t *data = vlib_buffer_get_current (b[0]);
+         u32 *entry =
+           wg_index_table_lookup (&wmp->index_table, data->receiver_index);
+         peeri = *entry;
+         peer = wg_peer_get (peeri);
+
+         ti[0] = peer->input_thread_index;
+       }
+      else
+       {
+         peeri =
+           wg_peer_get_by_adj_index (vnet_buffer (b[0])->
+                                     ip.adj_index[VLIB_TX]);
+         peer = wg_peer_get (peeri);
+         ti[0] = peer->output_thread_index;
+       }
+
+      if (PREDICT_FALSE (b[0]->flags & VLIB_BUFFER_IS_TRACED))
+       {
+         wg_handoff_trace_t *t =
+           vlib_add_trace (vm, node, b[0], sizeof (*t));
+         t->next_worker_index = ti[0];
+         t->peer = peeri;
+       }
+
+      n_left_from -= 1;
+      ti += 1;
+      b += 1;
+    }
+
+  n_enq = vlib_buffer_enqueue_to_thread (vm, fq_index, from,
+                                        thread_indices, frame->n_vectors, 1);
+
+  if (n_enq < frame->n_vectors)
+    vlib_node_increment_counter (vm, node->node_index,
+                                WG_HANDOFF_ERROR_CONGESTION_DROP,
+                                frame->n_vectors - n_enq);
+
+  return n_enq;
+}
+
+VLIB_NODE_FN (wg_handshake_handoff) (vlib_main_t * vm,
+                                    vlib_node_runtime_t * node,
+                                    vlib_frame_t * from_frame)
+{
+  wg_main_t *wmp = &wg_main;
+
+  return wg_handoff (vm, node, from_frame, wmp->in_fq_index,
+                    WG_HANDOFF_HANDSHAKE);
+}
+
+VLIB_NODE_FN (wg_input_data_handoff) (vlib_main_t * vm,
+                                     vlib_node_runtime_t * node,
+                                     vlib_frame_t * from_frame)
+{
+  wg_main_t *wmp = &wg_main;
+
+  return wg_handoff (vm, node, from_frame, wmp->in_fq_index,
+                    WG_HANDOFF_INP_DATA);
+}
+
+VLIB_NODE_FN (wg_output_tun_handoff) (vlib_main_t * vm,
+                                     vlib_node_runtime_t * node,
+                                     vlib_frame_t * from_frame)
+{
+  wg_main_t *wmp = &wg_main;
+
+  return wg_handoff (vm, node, from_frame, wmp->out_fq_index,
+                    WG_HANDOFF_OUT_TUN);
+}
+
+VLIB_REGISTER_NODE (wg_handshake_handoff) =
+{
+  .name = "wg-handshake-handoff",.vector_size = sizeof (u32),.format_trace =
+    format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+    ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+    wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+  {
+  [0] = "error-drop",}
+,};
+
+VLIB_REGISTER_NODE (wg_input_data_handoff) =
+{
+  .name = "wg-input-data-handoff",.vector_size = sizeof (u32),.format_trace =
+    format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+    ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+    wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+  {
+  [0] = "error-drop",}
+,};
+
+VLIB_REGISTER_NODE (wg_output_tun_handoff) =
+{
+  .name = "wg-output-tun-handoff",.vector_size = sizeof (u32),.format_trace =
+    format_wg_handoff_trace,.type = VLIB_NODE_TYPE_INTERNAL,.n_errors =
+    ARRAY_LEN (wg_handoff_error_strings),.error_strings =
+    wg_handoff_error_strings,.n_next_nodes = 1,.next_nodes =
+  {
+  [0] = "error-drop",}
+,};
+
+/*
+ * fd.io coding-style-patch-verification: ON
+ *
+ * Local Variables:
+ * eval: (c-set-style "gnu")
+ * End:
+ */
index c91667b..7509923 100644 (file)
@@ -5,6 +5,7 @@
 #include <wireguard/wireguard_messages.h>
 #include <wireguard/wireguard_if.h>
 #include <wireguard/wireguard.h>
+#include <wireguard/wireguard_peer.h>
 
 /* pool of interfaces */
 wg_if_t *wg_if_pool;
@@ -30,28 +31,28 @@ format_wg_if (u8 * s, va_list * args)
 {
   index_t wgii = va_arg (*args, u32);
   wg_if_t *wgi = wg_if_get (wgii);
+  noise_local_t *local = noise_local_get (wgi->local_idx);
   u8 key[NOISE_KEY_LEN_BASE64];
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, "[%d] %U src:%U port:%d",
              wgii,
              format_vnet_sw_if_index_name, vnet_get_main (),
              wgi->sw_if_index, format_ip_address, &wgi->src_ip, wgi->port);
 
-  key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " private-key:%s", key);
   s =
-    format (s, " %U", format_hex_bytes, wgi->local.l_private,
+    format (s, " %U", format_hex_bytes, local->l_private,
            NOISE_PUBLIC_KEY_LEN);
 
-  key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key);
+  key_to_base64 (local->l_public, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " public-key:%s", key);
 
   s =
-    format (s, " %U", format_hex_bytes, wgi->local.l_public,
+    format (s, " %U", format_hex_bytes, local->l_public,
            NOISE_PUBLIC_KEY_LEN);
 
   s = format (s, " mac-key: %U", format_hex_bytes,
@@ -72,23 +73,28 @@ wg_if_find_by_sw_if_index (u32 sw_if_index)
   return (ti);
 }
 
+static walk_rc_t
+wg_if_find_peer_by_public_key (index_t peeri, void *data)
+{
+  uint8_t *public = data;
+  wg_peer_t *peer = wg_peer_get (peeri);
+
+  if (!memcmp (peer->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
+    return (WALK_STOP);
+  return (WALK_CONTINUE);
+}
+
 static noise_remote_t *
-wg_remote_get (uint8_t public[NOISE_PUBLIC_KEY_LEN])
+wg_remote_get (const uint8_t public[NOISE_PUBLIC_KEY_LEN])
 {
-  wg_main_t *wmp = &wg_main;
-  wg_peer_t *peer = NULL;
-  wg_peer_t *peer_iter;
-  /* *INDENT-OFF* */
-  pool_foreach (peer_iter, wmp->peers,
-  ({
-    if (!memcmp (peer_iter->remote.r_public, public, NOISE_PUBLIC_KEY_LEN))
-    {
-      peer = peer_iter;
-      break;
-    }
-  }));
-  /* *INDENT-ON* */
-  return peer ? &peer->remote : NULL;
+  index_t peeri;
+
+  peeri = wg_peer_walk (wg_if_find_peer_by_public_key, (void *) public);
+
+  if (INDEX_INVALID != peeri)
+    return &wg_peer_get (peeri)->remote;
+
+  return NULL;
 }
 
 static uint32_t
@@ -223,6 +229,7 @@ wg_if_create (u32 user_instance,
   u32 instance, hw_if_index;
   vnet_hw_interface_t *hi;
   wg_if_t *wg_if;
+  noise_local_t *local;
 
   ASSERT (sw_if_indexp);
 
@@ -236,6 +243,24 @@ wg_if_create (u32 user_instance,
   if (instance == ~0)
     return VNET_API_ERROR_INVALID_REGISTRATION;
 
+  /* *INDENT-OFF* */
+  struct noise_upcall upcall =  {
+    .u_remote_get = wg_remote_get,
+    .u_index_set = wg_index_set,
+    .u_index_drop = wg_index_drop,
+  };
+  /* *INDENT-ON* */
+
+  pool_get (noise_local_pool, local);
+
+  noise_local_init (local, &upcall);
+  if (!noise_local_set_private (local, private_key))
+    {
+      pool_put (noise_local_pool, local);
+      wg_if_instance_free (instance);
+      return VNET_API_ERROR_INVALID_REGISTRATION;
+    }
+
   pool_get (wg_if_pool, wg_if);
 
   /* tunnel index (or instance) */
@@ -251,18 +276,8 @@ wg_if_create (u32 user_instance,
   wg_if_index_by_port[port] = wg_if - wg_if_pool;
 
   wg_if->port = port;
-
-  /* *INDENT-OFF* */
-  struct noise_upcall upcall =  {
-    .u_remote_get = wg_remote_get,
-    .u_index_set = wg_index_set,
-    .u_index_drop = wg_index_drop,
-  };
-  /* *INDENT-ON* */
-
-  noise_local_init (&wg_if->local, &upcall);
-  noise_local_set_private (&wg_if->local, private_key);
-  cookie_checker_update (&wg_if->cookie_checker, wg_if->local.l_public);
+  wg_if->local_idx = local - noise_local_pool;
+  cookie_checker_update (&wg_if->cookie_checker, local->l_public);
 
   hw_if_index = vnet_register_interface (vnm,
                                         wg_if_device_class.index,
@@ -304,6 +319,7 @@ wg_if_delete (u32 sw_if_index)
   udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1);
   wg_if_index_by_port[wg_if->port] = INDEX_INVALID;
   vnet_delete_hw_interface (vnm, hw->hw_if_index);
+  pool_put_index (noise_local_pool, wg_if->local_idx);
   pool_put (wg_if_pool, wg_if);
 
   return 0;
@@ -343,7 +359,7 @@ wg_if_walk (wg_if_walk_cb_t fn, void *data)
   /* *INDENT-ON* */
 }
 
-void
+index_t
 wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
 {
   index_t peeri, val;
@@ -352,9 +368,11 @@ wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data)
   hash_foreach (peeri, val, wgi->peers,
   {
     if (WALK_STOP == fn(wgi, peeri, data))
-      break;
+      return peeri;
   });
   /* *INDENT-ON* */
+
+  return INDEX_INVALID;
 }
 
 
index 9e6b619..d8c2a87 100644 (file)
@@ -25,7 +25,8 @@ typedef struct wg_if_t_
   u32 sw_if_index;
 
   // Interface params
-  noise_local_t local;
+  /* noise_local_pool elt index */
+  u32 local_idx;
   cookie_checker_t cookie_checker;
   u16 port;
 
@@ -52,7 +53,7 @@ void wg_if_walk (wg_if_walk_cb_t fn, void *data);
 
 typedef walk_rc_t (*wg_if_peer_walk_cb_t) (wg_if_t * wgi, index_t peeri,
                                           void *data);
-void wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data);
+index_t wg_if_peer_walk (wg_if_t * wgi, wg_if_peer_walk_cb_t fn, void *data);
 
 void wg_if_peer_add (wg_if_t * wgi, index_t peeri);
 void wg_if_peer_remove (wg_if_t * wgi, index_t peeri);
index cdd65f8..b15c265 100755 (executable)
@@ -30,6 +30,7 @@
   _(DECRYPTION, "Failed during decryption")             \
   _(KEEPALIVE_SEND, "Failed while sending Keepalive")   \
   _(HANDSHAKE_SEND, "Failed while sending Handshake")   \
+  _(TOO_BIG, "Packet too big")                          \
   _(UNDEFINED, "Undefined error")
 
 typedef enum
@@ -51,7 +52,7 @@ typedef struct
   message_type_t type;
   u16 current_length;
   bool is_keepalive;
-
+  index_t peer;
 } wg_input_trace_t;
 
 u8 *
@@ -79,6 +80,7 @@ format_wg_input_trace (u8 * s, va_list * args)
 
   s = format (s, "WG input: \n");
   s = format (s, "  Type: %U\n", format_wg_message_type, t->type);
+  s = format (s, "  peer: %d\n", t->peer);
   s = format (s, "  Length: %d\n", t->current_length);
   s = format (s, "  Keepalive: %s", t->is_keepalive ? "true" : "false");
 
@@ -87,6 +89,8 @@ format_wg_input_trace (u8 * s, va_list * args)
 
 typedef enum
 {
+  WG_INPUT_NEXT_HANDOFF_HANDSHAKE,
+  WG_INPUT_NEXT_HANDOFF_DATA,
   WG_INPUT_NEXT_IP4_INPUT,
   WG_INPUT_NEXT_PUNT,
   WG_INPUT_NEXT_ERROR,
@@ -106,6 +110,8 @@ typedef enum
 static wg_input_error_t
 wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
 {
+  ASSERT (vm->thread_index == 0);
+
   enum cookie_mac_state mac_state;
   bool packet_needs_cookie;
   bool under_load;
@@ -129,17 +135,15 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
   if (NULL == wg_if)
     return WG_INPUT_ERROR_INTERFACE;
 
-  if (header->type == MESSAGE_HANDSHAKE_COOKIE)
+  if (PREDICT_FALSE (header->type == MESSAGE_HANDSHAKE_COOKIE))
     {
       message_handshake_cookie_t *packet =
        (message_handshake_cookie_t *) current_b_data;
       u32 *entry =
        wg_index_table_lookup (&wmp->index_table, packet->receiver_index);
       if (entry)
-       {
-         peer = pool_elt_at_index (wmp->peers, *entry);
-       }
-      if (!peer)
+       peer = wg_peer_get (*entry);
+      else
        return WG_INPUT_ERROR_PEER;
 
       // TODO: Implement cookie_maker_consume_payload
@@ -178,17 +182,17 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
            // TODO: Add processing
          }
        noise_remote_t *rp;
-
        if (noise_consume_initiation
-           (wmp->vlib_main, &wg_if->local, &rp, message->sender_index,
-            message->unencrypted_ephemeral, message->encrypted_static,
-            message->encrypted_timestamp))
+           (vm, noise_local_get (wg_if->local_idx), &rp,
+            message->sender_index, message->unencrypted_ephemeral,
+            message->encrypted_static, message->encrypted_timestamp))
          {
-           peer = pool_elt_at_index (wmp->peers, rp->r_peer_idx);
+           peer = wg_peer_get (rp->r_peer_idx);
+         }
+       else
+         {
+           return WG_INPUT_ERROR_PEER;
          }
-
-       if (!peer)
-         return WG_INPUT_ERROR_PEER;
 
        // set_peer_address (peer, ip4_src, udp_src_port);
        if (PREDICT_FALSE (!wg_send_handshake_response (vm, peer)))
@@ -203,15 +207,18 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
        message_handshake_response_t *resp = current_b_data;
        u32 *entry =
          wg_index_table_lookup (&wmp->index_table, resp->receiver_index);
-       if (entry)
+
+       if (PREDICT_TRUE (entry != NULL))
          {
-           peer = pool_elt_at_index (wmp->peers, *entry);
-           if (!peer || peer->is_dead)
+           peer = wg_peer_get (*entry);
+           if (peer->is_dead)
              return WG_INPUT_ERROR_PEER;
          }
+       else
+         return WG_INPUT_ERROR_PEER;
 
        if (!noise_consume_response
-           (wmp->vlib_main, &peer->remote, resp->sender_index,
+           (vm, &peer->remote, resp->sender_index,
             resp->receiver_index, resp->unencrypted_ephemeral,
             resp->encrypted_nothing))
          {
@@ -223,8 +230,9 @@ wg_handshake_process (vlib_main_t * vm, wg_main_t * wmp, vlib_buffer_t * b)
          }
 
        // set_peer_address (peer, ip4_src, udp_src_port);
-       if (noise_remote_begin_session (wmp->vlib_main, &peer->remote))
+       if (noise_remote_begin_session (vm, &peer->remote))
          {
+
            wg_timers_session_derived (peer);
            wg_timers_handshake_complete (peer);
            if (PREDICT_FALSE (!wg_send_keepalive (vm, peer)))
@@ -272,6 +280,7 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
   u32 *from;
   vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b;
   u16 nexts[VLIB_FRAME_SIZE], *next;
+  u32 thread_index = vm->thread_index;
 
   from = vlib_frame_vector_args (frame);
   n_left_from = frame->n_vectors;
@@ -289,120 +298,132 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
       next[0] = WG_INPUT_NEXT_PUNT;
       header_type =
        ((message_header_t *) vlib_buffer_get_current (b[0]))->type;
+      u32 *peer_idx;
 
-      switch (header_type)
+      if (PREDICT_TRUE (header_type == MESSAGE_DATA))
        {
-       case MESSAGE_HANDSHAKE_INITIATION:
-       case MESSAGE_HANDSHAKE_RESPONSE:
-       case MESSAGE_HANDSHAKE_COOKIE:
-         {
-           wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
-           if (ret != WG_INPUT_ERROR_NONE)
-             {
-               next[0] = WG_INPUT_NEXT_ERROR;
-               b[0]->error = node->errors[ret];
-             }
-           break;
-         }
-       case MESSAGE_DATA:
-         {
-           message_data_t *data = vlib_buffer_get_current (b[0]);
-           u32 *entry =
-             wg_index_table_lookup (&wmp->index_table, data->receiver_index);
-
-           if (entry)
-             {
-               peer = pool_elt_at_index (wmp->peers, *entry);
-             }
-           else
-             {
-               next[0] = WG_INPUT_NEXT_ERROR;
-               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
-               goto out;
-             }
+         message_data_t *data = vlib_buffer_get_current (b[0]);
 
-           u16 encr_len = b[0]->current_length - sizeof (message_data_t);
-           u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
-           u8 *decr_data = clib_mem_alloc (decr_len);
+         peer_idx = wg_index_table_lookup (&wmp->index_table,
+                                           data->receiver_index);
 
-           enum noise_state_crypt state_cr =
-             noise_remote_decrypt (wmp->vlib_main,
-                                   &peer->remote,
-                                   data->receiver_index,
-                                   data->counter,
-                                   data->encrypted_data,
-                                   encr_len,
-                                   decr_data);
+         if (peer_idx)
+           {
+             peer = wg_peer_get (*peer_idx);
+           }
+         else
+           {
+             next[0] = WG_INPUT_NEXT_ERROR;
+             b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
+             goto out;
+           }
 
-           switch (state_cr)
-             {
-             case SC_OK:
-               break;
-             case SC_CONN_RESET:
-               wg_timers_handshake_complete (peer);
-               break;
-             case SC_KEEP_KEY_FRESH:
-               if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
-                 {
-                   vlib_node_increment_counter (vm, wg_input_node.index,
-                                                WG_INPUT_ERROR_HANDSHAKE_SEND,
-                                                1);
-                 }
-               break;
-             case SC_FAILED:
-               next[0] = WG_INPUT_NEXT_ERROR;
-               b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
-               goto out;
-             default:
-               break;
-             }
+         if (PREDICT_FALSE (~0 == peer->input_thread_index))
+           {
+             /* this is the first packet to use this peer, claim the peer
+              * for this thread.
+              */
+             clib_atomic_cmp_and_swap (&peer->input_thread_index, ~0,
+                                       wg_peer_assign_thread (thread_index));
+           }
 
-           clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
-           b[0]->current_length = decr_len;
-           b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
+         if (PREDICT_TRUE (thread_index != peer->input_thread_index))
+           {
+             next[0] = WG_INPUT_NEXT_HANDOFF_DATA;
+             goto next;
+           }
 
-           clib_mem_free (decr_data);
+         u16 encr_len = b[0]->current_length - sizeof (message_data_t);
+         u16 decr_len = encr_len - NOISE_AUTHTAG_LEN;
+         if (PREDICT_FALSE (decr_len >= WG_DEFAULT_DATA_SIZE))
+           {
+             b[0]->error = node->errors[WG_INPUT_ERROR_TOO_BIG];
+             goto out;
+           }
 
-           wg_timers_any_authenticated_packet_received (peer);
-           wg_timers_any_authenticated_packet_traversal (peer);
+         u8 *decr_data = wmp->per_thread_data[thread_index].data;
 
-           if (decr_len == 0)
-             {
-               is_keepalive = true;
-               goto out;
-             }
+         enum noise_state_crypt state_cr = noise_remote_decrypt (vm,
+                                                                 &peer->remote,
+                                                                 data->receiver_index,
+                                                                 data->counter,
+                                                                 data->encrypted_data,
+                                                                 encr_len,
+                                                                 decr_data);
 
-           wg_timers_data_received (peer);
+         if (PREDICT_FALSE (state_cr == SC_CONN_RESET))
+           {
+             wg_timers_handshake_complete (peer);
+           }
+         else if (PREDICT_FALSE (state_cr == SC_KEEP_KEY_FRESH))
+           {
+             wg_send_handshake_from_mt (*peer_idx, false);
+           }
+         else if (PREDICT_FALSE (state_cr == SC_FAILED))
+           {
+             next[0] = WG_INPUT_NEXT_ERROR;
+             b[0]->error = node->errors[WG_INPUT_ERROR_DECRYPTION];
+             goto out;
+           }
 
-           ip4_header_t *iph = vlib_buffer_get_current (b[0]);
+         clib_memcpy (vlib_buffer_get_current (b[0]), decr_data, decr_len);
+         b[0]->current_length = decr_len;
+         b[0]->flags &= ~VNET_BUFFER_F_OFFLOAD_UDP_CKSUM;
 
-           const wg_peer_allowed_ip_t *allowed_ip;
-           bool allowed = false;
+         wg_timers_any_authenticated_packet_received (peer);
+         wg_timers_any_authenticated_packet_traversal (peer);
 
-           /*
-            * we could make this into an ACL, but the expectation
-            * is that there aren't many allowed IPs and thus a linear
-            * walk is fater than an ACL
-            */
-           vec_foreach (allowed_ip, peer->allowed_ips)
+         /* Keepalive packet has zero length */
+         if (decr_len == 0)
            {
-             if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
-                                             &iph->src_address))
-               {
-                 allowed = true;
-                 break;
-               }
+             is_keepalive = true;
+             goto out;
            }
-           if (allowed)
+
+         wg_timers_data_received (peer);
+
+         ip4_header_t *iph = vlib_buffer_get_current (b[0]);
+
+         const wg_peer_allowed_ip_t *allowed_ip;
+         bool allowed = false;
+
+         /*
+          * we could make this into an ACL, but the expectation
+          * is that there aren't many allowed IPs and thus a linear
+          * walk is fater than an ACL
+          */
+         vec_foreach (allowed_ip, peer->allowed_ips)
+         {
+           if (fib_prefix_is_cover_addr_4 (&allowed_ip->prefix,
+                                           &iph->src_address))
              {
-               vnet_buffer (b[0])->sw_if_index[VLIB_RX] =
-                 peer->wg_sw_if_index;
-               next[0] = WG_INPUT_NEXT_IP4_INPUT;
+               allowed = true;
+               break;
              }
-           break;
          }
-       default:
-         break;
+         if (allowed)
+           {
+             vnet_buffer (b[0])->sw_if_index[VLIB_RX] = peer->wg_sw_if_index;
+             next[0] = WG_INPUT_NEXT_IP4_INPUT;
+           }
+       }
+      else
+       {
+         peer_idx = NULL;
+
+         /* Handshake packets should be processed in main thread */
+         if (thread_index != 0)
+           {
+             next[0] = WG_INPUT_NEXT_HANDOFF_HANDSHAKE;
+             goto next;
+           }
+
+         wg_input_error_t ret = wg_handshake_process (vm, wmp, b[0]);
+         if (ret != WG_INPUT_ERROR_NONE)
+           {
+             next[0] = WG_INPUT_NEXT_ERROR;
+             b[0]->error = node->errors[ret];
+           }
        }
 
     out:
@@ -413,7 +434,9 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
          t->type = header_type;
          t->current_length = b[0]->current_length;
          t->is_keepalive = is_keepalive;
+         t->peer = peer_idx ? *peer_idx : INDEX_INVALID;
        }
+    next:
       n_left_from -= 1;
       next += 1;
       b += 1;
@@ -435,6 +458,8 @@ VLIB_REGISTER_NODE (wg_input_node) =
   .n_next_nodes = WG_INPUT_N_NEXT,
   /* edit / add dispositions here */
   .next_nodes = {
+        [WG_INPUT_NEXT_HANDOFF_HANDSHAKE] = "wg-handshake-handoff",
+        [WG_INPUT_NEXT_HANDOFF_DATA] = "wg-input-data-handoff",
         [WG_INPUT_NEXT_IP4_INPUT] = "ip4-input-no-checksum",
         [WG_INPUT_NEXT_PUNT] = "error-punt",
         [WG_INPUT_NEXT_ERROR] = "error-drop",
index b47bb57..00b6710 100755 (executable)
@@ -26,6 +26,8 @@
  * <- e, ee, se, psk, {}
  */
 
+noise_local_t *noise_local_pool;
+
 /* Private functions */
 static noise_keypair_t *noise_remote_keypair_allocate (noise_remote_t *);
 static void noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t *,
@@ -80,81 +82,31 @@ noise_local_set_private (noise_local_t * l,
                         const uint8_t private[NOISE_PUBLIC_KEY_LEN])
 {
   clib_memcpy (l->l_private, private, NOISE_PUBLIC_KEY_LEN);
-  l->l_has_identity = curve25519_gen_public (l->l_public, private);
-
-  return l->l_has_identity;
-}
 
-bool
-noise_local_keys (noise_local_t * l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                 uint8_t private[NOISE_PUBLIC_KEY_LEN])
-{
-  if (l->l_has_identity)
-    {
-      if (public != NULL)
-       clib_memcpy (public, l->l_public, NOISE_PUBLIC_KEY_LEN);
-      if (private != NULL)
-       clib_memcpy (private, l->l_private, NOISE_PUBLIC_KEY_LEN);
-    }
-  else
-    {
-      return false;
-    }
-  return true;
+  return curve25519_gen_public (l->l_public, private);
 }
 
 void
 noise_remote_init (noise_remote_t * r, uint32_t peer_pool_idx,
                   const uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                  noise_local_t * l)
+                  u32 noise_local_idx)
 {
   clib_memset (r, 0, sizeof (*r));
   clib_memcpy (r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+  clib_rwlock_init (&r->r_keypair_lock);
   r->r_peer_idx = peer_pool_idx;
-
-  ASSERT (l != NULL);
-  r->r_local = l;
+  r->r_local_idx = noise_local_idx;
   r->r_handshake.hs_state = HS_ZEROED;
-  noise_remote_precompute (r);
-}
-
-bool
-noise_remote_set_psk (noise_remote_t * r,
-                     uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
-  int same;
-  same = !clib_memcmp (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
-  if (!same)
-    {
-      clib_memcpy (r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
-    }
-  return same == 0;
-}
-
-bool
-noise_remote_keys (noise_remote_t * r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
-                  uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
-{
-  static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
-  int ret;
 
-  if (public != NULL)
-    clib_memcpy (public, r->r_public, NOISE_PUBLIC_KEY_LEN);
-
-  if (psk != NULL)
-    clib_memcpy (psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
-  ret = clib_memcmp (r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
-
-  return ret;
+  noise_remote_precompute (r);
 }
 
 void
 noise_remote_precompute (noise_remote_t * r)
 {
-  noise_local_t *l = r->r_local;
-  if (!l->l_has_identity)
-    clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
-  else if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
+  noise_local_t *l = noise_local_get (r->r_local_idx);
+
+  if (!curve25519_gen_shared (r->r_ss, l->l_private, r->r_public))
     clib_memset (r->r_ss, 0, NOISE_PUBLIC_KEY_LEN);
 
   noise_remote_handshake_index_drop (r);
@@ -169,7 +121,7 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
                         uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
 {
   noise_handshake_t *hs = &r->r_handshake;
-  noise_local_t *l = r->r_local;
+  noise_local_t *l = noise_local_get (r->r_local_idx);
   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
   uint32_t key_idx;
   uint8_t *key;
@@ -180,8 +132,6 @@ noise_create_initiation (vlib_main_t * vm, noise_remote_t * r,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
   noise_param_init (hs->hs_ck, hs->hs_hash, r->r_public);
 
   /* e */
@@ -239,8 +189,6 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
   noise_param_init (hs.hs_ck, hs.hs_hash, l->l_public);
 
   /* e */
@@ -294,6 +242,7 @@ noise_consume_initiation (vlib_main_t * vm, noise_local_t * l,
   r->r_handshake = hs;
   *rp = r;
   ret = true;
+
 error:
   vnet_crypto_key_del (vm, key_idx);
   secure_zero_memory (key, NOISE_SYMMETRIC_KEY_LEN);
@@ -359,7 +308,7 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
                        uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
                        uint8_t en[0 + NOISE_AUTHTAG_LEN])
 {
-  noise_local_t *l = r->r_local;
+  noise_local_t *l = noise_local_get (r->r_local_idx);
   noise_handshake_t hs;
   uint8_t _key[NOISE_SYMMETRIC_KEY_LEN];
   uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
@@ -372,9 +321,6 @@ noise_consume_response (vlib_main_t * vm, noise_remote_t * r, uint32_t s_idx,
                         NOISE_SYMMETRIC_KEY_LEN);
   key = vnet_crypto_get_key (key_idx)->data;
 
-  if (!l->l_has_identity)
-    goto error;
-
   hs = r->r_handshake;
   clib_memcpy (preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
 
@@ -460,6 +406,7 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
   clib_memset (&kp.kp_ctr, 0, sizeof (kp.kp_ctr));
 
   /* Now we need to add_new_keypair */
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   next = r->r_next;
   current = r->r_current;
   previous = r->r_previous;
@@ -491,7 +438,10 @@ noise_remote_begin_session (vlib_main_t * vm, noise_remote_t * r)
       r->r_next = noise_remote_keypair_allocate (r);
       *r->r_next = kp;
     }
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
+
   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
+
   secure_zero_memory (&kp, sizeof (kp));
   return true;
 }
@@ -502,21 +452,25 @@ noise_remote_clear (vlib_main_t * vm, noise_remote_t * r)
   noise_remote_handshake_index_drop (r);
   secure_zero_memory (&r->r_handshake, sizeof (r->r_handshake));
 
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   noise_remote_keypair_free (vm, r, &r->r_next);
   noise_remote_keypair_free (vm, r, &r->r_current);
   noise_remote_keypair_free (vm, r, &r->r_previous);
   r->r_next = NULL;
   r->r_current = NULL;
   r->r_previous = NULL;
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
 }
 
 void
 noise_remote_expire_current (noise_remote_t * r)
 {
+  clib_rwlock_writer_lock (&r->r_keypair_lock);
   if (r->r_next != NULL)
     r->r_next->kp_valid = 0;
   if (r->r_current != NULL)
     r->r_current->kp_valid = 0;
+  clib_rwlock_writer_unlock (&r->r_keypair_lock);
 }
 
 bool
@@ -525,6 +479,7 @@ noise_remote_ready (noise_remote_t * r)
   noise_keypair_t *kp;
   int ret;
 
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
   if ((kp = r->r_current) == NULL ||
       !kp->kp_valid ||
       wg_birthdate_has_expired (kp->kp_birthdate, REJECT_AFTER_TIME) ||
@@ -533,6 +488,7 @@ noise_remote_ready (noise_remote_t * r)
     ret = false;
   else
     ret = true;
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
@@ -592,6 +548,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
   noise_keypair_t *kp;
   enum noise_state_crypt ret = SC_FAILED;
 
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
   if ((kp = r->r_current) == NULL)
     goto error;
 
@@ -631,6 +588,7 @@ noise_remote_encrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t * r_idx,
 
   ret = SC_OK;
 error:
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
@@ -641,6 +599,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
 {
   noise_keypair_t *kp;
   enum noise_state_crypt ret = SC_FAILED;
+  clib_rwlock_reader_lock (&r->r_keypair_lock);
 
   if (r->r_current != NULL && r->r_current->kp_local_index == r_idx)
     {
@@ -682,18 +641,26 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
    * next keypair into current. If we do slide the next keypair in, then
    * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
    * data packet can't confirm a session that we are an INITIATOR of. */
-  if (kp == r->r_next && kp->kp_local_index == r_idx)
+  if (kp == r->r_next)
     {
-      noise_remote_keypair_free (vm, r, &r->r_previous);
-      r->r_previous = r->r_current;
-      r->r_current = r->r_next;
-      r->r_next = NULL;
+      clib_rwlock_reader_unlock (&r->r_keypair_lock);
+      clib_rwlock_writer_lock (&r->r_keypair_lock);
+      if (kp == r->r_next && kp->kp_local_index == r_idx)
+       {
+         noise_remote_keypair_free (vm, r, &r->r_previous);
+         r->r_previous = r->r_current;
+         r->r_current = r->r_next;
+         r->r_next = NULL;
 
-      ret = SC_CONN_RESET;
-      goto error;
+         ret = SC_CONN_RESET;
+         clib_rwlock_writer_unlock (&r->r_keypair_lock);
+         clib_rwlock_reader_lock (&r->r_keypair_lock);
+         goto error;
+       }
+      clib_rwlock_writer_unlock (&r->r_keypair_lock);
+      clib_rwlock_reader_lock (&r->r_keypair_lock);
     }
 
-
   /* Similar to when we encrypt, we want to notify the caller when we
    * are approaching our tolerances. We notify if:
    *  - we're the initiator and the current keypair is older than
@@ -708,6 +675,7 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
 
   ret = SC_OK;
 error:
+  clib_rwlock_reader_unlock (&r->r_keypair_lock);
   return ret;
 }
 
@@ -725,7 +693,8 @@ static void
 noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
                           noise_keypair_t ** kp)
 {
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   if (*kp)
     {
       u->u_index_drop ((*kp)->kp_local_index);
@@ -738,7 +707,8 @@ noise_remote_keypair_free (vlib_main_t * vm, noise_remote_t * r,
 static uint32_t
 noise_remote_handshake_index_get (noise_remote_t * r)
 {
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   return u->u_index_set (r);
 }
 
@@ -746,7 +716,8 @@ static void
 noise_remote_handshake_index_drop (noise_remote_t * r)
 {
   noise_handshake_t *hs = &r->r_handshake;
-  struct noise_upcall *u = &r->r_local->l_upcall;
+  noise_local_t *local = noise_local_get (r->r_local_idx);
+  struct noise_upcall *u = &local->l_upcall;
   if (hs->hs_state != HS_ZEROED)
     u->u_index_drop (hs->hs_local_index);
 }
@@ -754,7 +725,8 @@ noise_remote_handshake_index_drop (noise_remote_t * r)
 static uint64_t
 noise_counter_send (noise_counter_t * ctr)
 {
-  uint64_t ret = ctr->c_send++;
+  uint64_t ret;
+  ret = ctr->c_send++;
   return ret;
 }
 
@@ -765,7 +737,6 @@ noise_counter_recv (noise_counter_t * ctr, uint64_t recv)
   unsigned long bit;
   bool ret = false;
 
-
   /* Check that the recv counter is valid */
   if (ctr->c_recv >= REJECT_AFTER_MESSAGES || recv >= REJECT_AFTER_MESSAGES)
     goto error;
index 1f6804c..5b5a88f 100755 (executable)
@@ -100,7 +100,7 @@ typedef struct noise_remote
 {
   uint32_t r_peer_idx;
   uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
-  noise_local_t *r_local;
+  uint32_t r_local_idx;
   uint8_t r_ss[NOISE_PUBLIC_KEY_LEN];
 
   noise_handshake_t r_handshake;
@@ -108,37 +108,40 @@ typedef struct noise_remote
   uint8_t r_timestamp[NOISE_TIMESTAMP_LEN];
   f64 r_last_init;
 
+  clib_rwlock_t r_keypair_lock;
   noise_keypair_t *r_next, *r_current, *r_previous;
 } noise_remote_t;
 
 typedef struct noise_local
 {
-  bool l_has_identity;
   uint8_t l_public[NOISE_PUBLIC_KEY_LEN];
   uint8_t l_private[NOISE_PUBLIC_KEY_LEN];
 
   struct noise_upcall
   {
     void *u_arg;
-    noise_remote_t *(*u_remote_get) (uint8_t[NOISE_PUBLIC_KEY_LEN]);
+    noise_remote_t *(*u_remote_get) (const uint8_t[NOISE_PUBLIC_KEY_LEN]);
       uint32_t (*u_index_set) (noise_remote_t *);
     void (*u_index_drop) (uint32_t);
   } l_upcall;
 } noise_local_t;
 
+/* pool of noise_local */
+extern noise_local_t *noise_local_pool;
+
 /* Set/Get noise parameters */
+static_always_inline noise_local_t *
+noise_local_get (uint32_t locali)
+{
+  return (pool_elt_at_index (noise_local_pool, locali));
+}
+
 void noise_local_init (noise_local_t *, struct noise_upcall *);
 bool noise_local_set_private (noise_local_t *,
                              const uint8_t[NOISE_PUBLIC_KEY_LEN]);
-bool noise_local_keys (noise_local_t *, uint8_t[NOISE_PUBLIC_KEY_LEN],
-                      uint8_t[NOISE_PUBLIC_KEY_LEN]);
 
 void noise_remote_init (noise_remote_t *, uint32_t,
-                       const uint8_t[NOISE_PUBLIC_KEY_LEN], noise_local_t *);
-bool noise_remote_set_psk (noise_remote_t *,
-                          uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
-bool noise_remote_keys (noise_remote_t *, uint8_t[NOISE_PUBLIC_KEY_LEN],
-                       uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+                       const uint8_t[NOISE_PUBLIC_KEY_LEN], uint32_t);
 
 /* Should be called anytime noise_local_set_private is called */
 void noise_remote_precompute (noise_remote_t *);
index cdfd9d7..9a8710b 100755 (executable)
 
 #include <vlib/vlib.h>
 #include <vnet/vnet.h>
-#include <vnet/pg/pg.h>
-#include <vnet/fib/ip6_fib.h>
-#include <vnet/fib/ip4_fib.h>
-#include <vnet/fib/fib_entry.h>
 #include <vppinfra/error.h>
 
 #include <wireguard/wireguard.h>
  _(NONE, "No error")                                                   \
  _(PEER, "Peer error")                                                  \
  _(KEYPAIR, "Keypair error")                                            \
- _(HANDSHAKE_SEND, "Handshake sending failed")                          \
  _(TOO_BIG, "packet too big")                                           \
 
-#define WG_OUTPUT_SCRATCH_SIZE 2048
-
-typedef struct wg_output_scratch_t_
-{
-  u8 scratch[WG_OUTPUT_SCRATCH_SIZE];
-} wg_output_scratch_t;
-
-/* Cache line aligned per-thread scratch space */
-static wg_output_scratch_t *wg_output_scratchs;
-
 typedef enum
 {
 #define _(sym,str) WG_OUTPUT_ERROR_##sym,
@@ -58,6 +43,7 @@ static char *wg_output_error_strings[] = {
 typedef enum
 {
   WG_OUTPUT_NEXT_ERROR,
+  WG_OUTPUT_NEXT_HANDOFF,
   WG_OUTPUT_NEXT_INTERFACE_OUTPUT,
   WG_OUTPUT_N_NEXT,
 } wg_output_next_t;
@@ -65,6 +51,7 @@ typedef enum
 typedef struct
 {
   ip4_udp_header_t hdr;
+  index_t peer;
 } wg_output_tun_trace_t;
 
 u8 *
@@ -87,7 +74,8 @@ format_wg_output_tun_trace (u8 * s, va_list * args)
 
   wg_output_tun_trace_t *t = va_arg (*args, wg_output_tun_trace_t *);
 
-  s = format (s, "Encrypted packet: %U\n", format_ip4_udp_header, &t->hdr);
+  s = format (s, "peer: %d\n", t->peer);
+  s = format (s, "  Encrypted packet: %U", format_ip4_udp_header, &t->hdr);
   return s;
 }
 
@@ -109,7 +97,6 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
   vlib_get_buffers (vm, from, bufs, n_left_from);
 
   wg_main_t *wmp = &wg_main;
-  u32 handsh_fails = 0;
   wg_peer_t *peer = NULL;
 
   while (n_left_from > 0)
@@ -119,11 +106,12 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
                        sizeof (ip4_udp_header_t));
       u16 plain_data_len =
        clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length);
+      index_t peeri;
 
       next[0] = WG_OUTPUT_NEXT_ERROR;
-
-      peer =
+      peeri =
        wg_peer_get_by_adj_index (vnet_buffer (b[0])->ip.adj_index[VLIB_TX]);
+      peer = wg_peer_get (peeri);
 
       if (!peer || peer->is_dead)
        {
@@ -131,21 +119,34 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
          goto out;
        }
 
+      if (PREDICT_FALSE (~0 == peer->output_thread_index))
+       {
+         /* this is the first packet to use this peer, claim the peer
+          * for this thread.
+          */
+         clib_atomic_cmp_and_swap (&peer->output_thread_index, ~0,
+                                   wg_peer_assign_thread (thread_index));
+       }
+
+      if (PREDICT_TRUE (thread_index != peer->output_thread_index))
+       {
+         next[0] = WG_OUTPUT_NEXT_HANDOFF;
+         goto next;
+       }
+
       if (PREDICT_FALSE (!peer->remote.r_current))
        {
-         if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
-           handsh_fails++;
+         wg_send_handshake_from_mt (peeri, false);
          b[0]->error = node->errors[WG_OUTPUT_ERROR_KEYPAIR];
          goto out;
        }
-
       size_t encrypted_packet_len = message_data_len (plain_data_len);
 
       /*
        * Ensure there is enough space to write the encrypted data
        * into the packet
        */
-      if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) ||
+      if (PREDICT_FALSE (encrypted_packet_len >= WG_DEFAULT_DATA_SIZE) ||
          PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >=
                         vlib_buffer_get_default_data_size (vm)))
        {
@@ -154,35 +155,29 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
        }
 
       message_data_t *encrypted_packet =
-       (message_data_t *) wg_output_scratchs[thread_index].scratch;
+       (message_data_t *) wmp->per_thread_data[thread_index].data;
 
       enum noise_state_crypt state;
       state =
-       noise_remote_encrypt (wmp->vlib_main,
+       noise_remote_encrypt (vm,
                              &peer->remote,
                              &encrypted_packet->receiver_index,
                              &encrypted_packet->counter, plain_data,
                              plain_data_len,
                              encrypted_packet->encrypted_data);
-      switch (state)
+
+      if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH))
+       {
+         wg_send_handshake_from_mt (peeri, false);
+       }
+      else if (PREDICT_FALSE (state == SC_FAILED))
        {
-       case SC_OK:
-         break;
-       case SC_KEEP_KEY_FRESH:
-         if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
-           handsh_fails++;
-         break;
-       case SC_FAILED:
          //TODO: Maybe wrong
-         if (PREDICT_FALSE (!wg_send_handshake (vm, peer, false)))
-           handsh_fails++;
-         clib_mem_free (encrypted_packet);
+         wg_send_handshake_from_mt (peeri, false);
          goto out;
-       default:
-         break;
        }
 
-      // Here we are sure that can send packet to next node.
+      /* Here we are sure that can send packet to next node */
       next[0] = WG_OUTPUT_NEXT_INTERFACE_OUTPUT;
       encrypted_packet->header.type = MESSAGE_DATA;
 
@@ -195,9 +190,9 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
       ip4_header_set_len_w_chksum
        (&hdr->ip4, clib_host_to_net_u16 (b[0]->current_length));
 
-      wg_timers_any_authenticated_packet_traversal (peer);
       wg_timers_any_authenticated_packet_sent (peer);
       wg_timers_data_sent (peer);
+      wg_timers_any_authenticated_packet_traversal (peer);
 
     out:
       if (PREDICT_FALSE ((node->flags & VLIB_NODE_FLAG_TRACE)
@@ -206,17 +201,15 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
          wg_output_tun_trace_t *t =
            vlib_add_trace (vm, node, b[0], sizeof (*t));
          t->hdr = *hdr;
+         t->peer = peeri;
        }
+    next:
       n_left_from -= 1;
       next += 1;
       b += 1;
     }
 
   vlib_buffer_enqueue_to_next (vm, node, from, nexts, frame->n_vectors);
-
-  vlib_node_increment_counter (vm, node->node_index,
-                              WG_OUTPUT_ERROR_HANDSHAKE_SEND, handsh_fails);
-
   return frame->n_vectors;
 }
 
@@ -231,24 +224,13 @@ VLIB_REGISTER_NODE (wg_output_tun_node) =
   .error_strings = wg_output_error_strings,
   .n_next_nodes = WG_OUTPUT_N_NEXT,
   .next_nodes = {
+        [WG_OUTPUT_NEXT_HANDOFF] = "wg-output-tun-handoff",
         [WG_OUTPUT_NEXT_INTERFACE_OUTPUT] = "adj-midchain-tx",
         [WG_OUTPUT_NEXT_ERROR] = "error-drop",
   },
 };
 /* *INDENT-ON* */
 
-static clib_error_t *
-wireguard_output_module_init (vlib_main_t * vm)
-{
-  vlib_thread_main_t *tm = vlib_get_thread_main ();
-
-  vec_validate_aligned (wg_output_scratchs, tm->n_vlib_mains,
-                       CLIB_CACHE_LINE_BYTES);
-  return (NULL);
-}
-
-VLIB_INIT_FUNCTION (wireguard_output_module_init);
-
 /*
  * fd.io coding-style-patch-verification: ON
  *
index 30adea8..b41118f 100755 (executable)
 #include <wireguard/wireguard.h>
 
 static fib_source_t wg_fib_source;
+wg_peer_t *wg_peer_pool;
 
 index_t *wg_peer_by_adj_index;
 
-wg_peer_t *
-wg_peer_get (index_t peeri)
-{
-  return (pool_elt_at_index (wg_main.peers, peeri));
-}
-
 static void
 wg_peer_endpoint_reset (wg_peer_endpoint_t * ep)
 {
@@ -82,7 +77,11 @@ static void
 wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
 {
   wg_timers_stop (peer);
-  noise_remote_clear (vm, &peer->remote);
+  for (int i = 0; i < WG_N_TIMERS; i++)
+    {
+      peer->timers[i] = ~0;
+    }
+
   peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1);
 
   clib_memset (&peer->cookie_maker, 0, sizeof (peer->cookie_maker));
@@ -97,9 +96,18 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
     }
   wg_peer_fib_flush (peer);
 
+  peer->input_thread_index = ~0;
+  peer->output_thread_index = ~0;
   peer->adj_index = INDEX_INVALID;
+  peer->timer_wheel = 0;
   peer->persistent_keepalive_interval = 0;
   peer->timer_handshake_attempts = 0;
+  peer->last_sent_packet = 0;
+  peer->last_received_packet = 0;
+  peer->session_derived = 0;
+  peer->rehandshake_started = 0;
+  peer->new_handshake_interval_tick = 0;
+  peer->rehandshake_interval_tick = 0;
   peer->timer_need_another_keepalive = false;
   peer->is_dead = true;
   vec_free (peer->allowed_ips);
@@ -108,7 +116,7 @@ wg_peer_clear (vlib_main_t * vm, wg_peer_t * peer)
 static void
 wg_peer_init (vlib_main_t * vm, wg_peer_t * peer)
 {
-  wg_timers_init (peer, vlib_time_now (vm));
+  peer->adj_index = INDEX_INVALID;
   wg_peer_clear (vm, peer);
 }
 
@@ -205,8 +213,9 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer,
   wg_peer_endpoint_init (&peer->dst, dst, port);
 
   peer->table_id = table_id;
-  peer->persistent_keepalive_interval = persistent_keepalive_interval;
   peer->wg_sw_if_index = wg_sw_if_index;
+  peer->timer_wheel = &wg_main.timer_wheel;
+  peer->persistent_keepalive_interval = persistent_keepalive_interval;
   peer->last_sent_handshake = vlib_time_now (vm) - (REKEY_TIMEOUT + 1);
   peer->is_dead = false;
 
@@ -230,7 +239,7 @@ wg_peer_fill (vlib_main_t * vm, wg_peer_t * peer,
 
   vec_validate_init_empty (wg_peer_by_adj_index,
                           peer->adj_index, INDEX_INVALID);
-  wg_peer_by_adj_index[peer->adj_index] = peer - wg_main.peers;
+  wg_peer_by_adj_index[peer->adj_index] = peer - wg_peer_pool;
 
   adj_nbr_midchain_update_rewrite (peer->adj_index,
                                   NULL,
@@ -280,7 +289,7 @@ wg_peer_add (u32 tun_sw_if_index,
     return (VNET_API_ERROR_INVALID_SW_IF_INDEX);
 
   /* *INDENT-OFF* */
-  pool_foreach (peer, wg_main.peers,
+  pool_foreach (peer, wg_peer_pool,
   ({
     if (!memcmp (peer->remote.r_public, public_key, NOISE_PUBLIC_KEY_LEN))
     {
@@ -289,10 +298,10 @@ wg_peer_add (u32 tun_sw_if_index,
   }));
   /* *INDENT-ON* */
 
-  if (pool_elts (wg_main.peers) > MAX_PEERS)
+  if (pool_elts (wg_peer_pool) > MAX_PEERS)
     return (VNET_API_ERROR_LIMIT_EXCEEDED);
 
-  pool_get (wg_main.peers, peer);
+  pool_get (wg_peer_pool, peer);
 
   wg_peer_init (vm, peer);
 
@@ -302,12 +311,12 @@ wg_peer_add (u32 tun_sw_if_index,
   if (rv)
     {
       wg_peer_clear (vm, peer);
-      pool_put (wg_main.peers, peer);
+      pool_put (wg_peer_pool, peer);
       return (rv);
     }
 
-  noise_remote_init (&peer->remote, peer - wg_main.peers, public_key,
-                    &wg_if->local);
+  noise_remote_init (&peer->remote, peer - wg_peer_pool, public_key,
+                    wg_if->local_idx);
   cookie_maker_init (&peer->cookie_maker, public_key);
 
   if (peer->persistent_keepalive_interval != 0)
@@ -315,7 +324,7 @@ wg_peer_add (u32 tun_sw_if_index,
       wg_send_keepalive (vm, peer);
     }
 
-  *peer_index = peer - wg_main.peers;
+  *peer_index = peer - wg_peer_pool;
   wg_if_peer_add (wg_if, *peer_index);
 
   return (0);
@@ -328,34 +337,37 @@ wg_peer_remove (index_t peeri)
   wg_peer_t *peer = NULL;
   wg_if_t *wgi;
 
-  if (pool_is_free_index (wmp->peers, peeri))
+  if (pool_is_free_index (wg_peer_pool, peeri))
     return VNET_API_ERROR_NO_SUCH_ENTRY;
 
-  peer = pool_elt_at_index (wmp->peers, peeri);
+  peer = pool_elt_at_index (wg_peer_pool, peeri);
 
   wgi = wg_if_get (wg_if_find_by_sw_if_index (peer->wg_sw_if_index));
   wg_if_peer_remove (wgi, peeri);
 
   vnet_feature_enable_disable ("ip4-output", "wg-output-tun",
                               peer->wg_sw_if_index, 0, 0, 0);
+
+  noise_remote_clear (wmp->vlib_main, &peer->remote);
   wg_peer_clear (wmp->vlib_main, peer);
-  pool_put (wmp->peers, peer);
+  pool_put (wg_peer_pool, peer);
 
   return (0);
 }
 
-void
+index_t
 wg_peer_walk (wg_peer_walk_cb_t fn, void *data)
 {
   index_t peeri;
 
   /* *INDENT-OFF* */
-  pool_foreach_index(peeri, wg_main.peers,
+  pool_foreach_index(peeri, wg_peer_pool,
   {
     if (WALK_STOP == fn(peeri, data))
-      break;
+      return peeri;
   });
   /* *INDENT-ON* */
+  return INDEX_INVALID;
 }
 
 static u8 *
index 99c73f3..009a6f6 100755 (executable)
@@ -49,6 +49,9 @@ typedef struct wg_peer
   noise_remote_t remote;
   cookie_maker_t cookie_maker;
 
+  u32 input_thread_index;
+  u32 output_thread_index;
+
   /* Peer addresses */
   wg_peer_endpoint_t dst;
   wg_peer_endpoint_t src;
@@ -65,11 +68,22 @@ typedef struct wg_peer
   u32 wg_sw_if_index;
 
   /* Timers */
-  tw_timer_wheel_16t_2w_512sl_t timer_wheel;
+  tw_timer_wheel_16t_2w_512sl_t *timer_wheel;
   u32 timers[WG_N_TIMERS];
   u32 timer_handshake_attempts;
   u16 persistent_keepalive_interval;
+
+  /* Timestamps */
   f64 last_sent_handshake;
+  f64 last_sent_packet;
+  f64 last_received_packet;
+  f64 session_derived;
+  f64 rehandshake_started;
+
+  /* Variable intervals */
+  u32 new_handshake_interval_tick;
+  u32 rehandshake_interval_tick;
+
   bool timer_need_another_keepalive;
 
   bool is_dead;
@@ -91,10 +105,9 @@ int wg_peer_add (u32 tun_sw_if_index,
 int wg_peer_remove (u32 peer_index);
 
 typedef walk_rc_t (*wg_peer_walk_cb_t) (index_t peeri, void *arg);
-void wg_peer_walk (wg_peer_walk_cb_t fn, void *data);
+index_t wg_peer_walk (wg_peer_walk_cb_t fn, void *data);
 
 u8 *format_wg_peer (u8 * s, va_list * va);
-wg_peer_t *wg_peer_get (index_t peeri);
 
 walk_rc_t wg_peer_if_admin_state_change (wg_if_t * wgi, index_t peeri,
                                         void *data);
@@ -104,11 +117,30 @@ walk_rc_t wg_peer_if_table_change (wg_if_t * wgi, index_t peeri, void *data);
  * Expoed for the data-plane
  */
 extern index_t *wg_peer_by_adj_index;
+extern wg_peer_t *wg_peer_pool;
 
 static inline wg_peer_t *
+wg_peer_get (index_t peeri)
+{
+  return (pool_elt_at_index (wg_peer_pool, peeri));
+}
+
+static inline index_t
 wg_peer_get_by_adj_index (index_t ai)
 {
-  return wg_peer_get (wg_peer_by_adj_index[ai]);
+  return (wg_peer_by_adj_index[ai]);
+}
+
+/*
+ * Makes choice for thread_id should be assigned.
+*/
+static inline u32
+wg_peer_assign_thread (u32 thread_id)
+{
+  return ((thread_id) ? thread_id
+         : (vlib_num_workers ()?
+            ((unix_time_now_nsec () % vlib_num_workers ()) +
+             1) : thread_id));
 }
 
 #endif // __included_wg_peer_h__
index a5d8aaf..2e29a9b 100755 (executable)
  */
 
 #include <vnet/vnet.h>
-#include <vnet/fib/ip6_fib.h>
-#include <vnet/fib/ip4_fib.h>
-#include <vnet/fib/fib_entry.h>
 #include <vnet/ip/ip6_link.h>
 #include <vnet/pg/pg.h>
 #include <vnet/udp/udp.h>
 #include <vppinfra/error.h>
+#include <vlibmemory/api.h>
 #include <wireguard/wireguard.h>
 #include <wireguard/wireguard_send.h>
 
@@ -86,7 +84,8 @@ wg_create_buffer (vlib_main_t * vm,
 bool
 wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
 {
-  wg_main_t *wmp = &wg_main;
+  ASSERT (vm->thread_index == 0);
+
   message_handshake_initiation_t packet;
 
   if (!is_retry)
@@ -94,41 +93,73 @@ wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry)
 
   if (!wg_birthdate_has_expired (peer->last_sent_handshake,
                                 REKEY_TIMEOUT) || peer->is_dead)
-    {
-      return true;
-    }
-  if (noise_create_initiation (wmp->vlib_main,
+    return true;
+
+  if (noise_create_initiation (vm,
                               &peer->remote,
                               &packet.sender_index,
                               packet.unencrypted_ephemeral,
                               packet.encrypted_static,
                               packet.encrypted_timestamp))
     {
-      f64 now = vlib_time_now (vm);
       packet.header.type = MESSAGE_HANDSHAKE_INITIATION;
       cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet,
                        sizeof (packet));
-      wg_timers_any_authenticated_packet_traversal (peer);
       wg_timers_any_authenticated_packet_sent (peer);
-      peer->last_sent_handshake = now;
       wg_timers_handshake_initiated (peer);
+      wg_timers_any_authenticated_packet_traversal (peer);
+
+      peer->last_sent_handshake = vlib_time_now (vm);
     }
   else
     return false;
+
   u32 bi0 = 0;
   if (!wg_create_buffer (vm, peer, (u8 *) & packet, sizeof (packet), &bi0))
     return false;
-  ip46_enqueue_packet (vm, bi0, false);
 
+  ip46_enqueue_packet (vm, bi0, false);
   return true;
 }
 
+typedef struct
+{
+  u32 peer_idx;
+  bool is_retry;
+} wg_send_args_t;
+
+static void *
+wg_send_handshake_thread_fn (void *arg)
+{
+  wg_send_args_t *a = arg;
+
+  wg_main_t *wmp = &wg_main;
+  wg_peer_t *peer = wg_peer_get (a->peer_idx);
+
+  wg_send_handshake (wmp->vlib_main, peer, a->is_retry);
+  return 0;
+}
+
+void
+wg_send_handshake_from_mt (u32 peer_idx, bool is_retry)
+{
+  wg_send_args_t a = {
+    .peer_idx = peer_idx,
+    .is_retry = is_retry,
+  };
+
+  vl_api_rpc_call_main_thread (wg_send_handshake_thread_fn,
+                              (u8 *) & a, sizeof (a));
+}
+
 bool
 wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 {
-  wg_main_t *wmp = &wg_main;
+  ASSERT (vm->thread_index == 0);
+
   u32 size_of_packet = message_data_len (0);
-  message_data_t *packet = clib_mem_alloc (size_of_packet);
+  message_data_t *packet =
+    (message_data_t *) wg_main.per_thread_data[vm->thread_index].data;
   u32 bi0 = 0;
   bool ret = true;
   enum noise_state_crypt state;
@@ -140,23 +171,21 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
     }
 
   state =
-    noise_remote_encrypt (wmp->vlib_main,
+    noise_remote_encrypt (vm,
                          &peer->remote,
                          &packet->receiver_index,
                          &packet->counter, NULL, 0, packet->encrypted_data);
-  switch (state)
+
+  if (PREDICT_FALSE (state == SC_KEEP_KEY_FRESH))
     {
-    case SC_OK:
-      break;
-    case SC_KEEP_KEY_FRESH:
       wg_send_handshake (vm, peer, false);
-      break;
-    case SC_FAILED:
+    }
+  else if (PREDICT_FALSE (state == SC_FAILED))
+    {
       ret = false;
       goto out;
-    default:
-      break;
     }
+
   packet->header.type = MESSAGE_DATA;
 
   if (!wg_create_buffer (vm, peer, (u8 *) packet, size_of_packet, &bi0))
@@ -166,22 +195,19 @@ wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
     }
 
   ip46_enqueue_packet (vm, bi0, false);
-  wg_timers_any_authenticated_packet_traversal (peer);
+
   wg_timers_any_authenticated_packet_sent (peer);
+  wg_timers_any_authenticated_packet_traversal (peer);
 
 out:
-  clib_mem_free (packet);
   return ret;
 }
 
 bool
 wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
 {
-  wg_main_t *wmp = &wg_main;
   message_handshake_response_t packet;
 
-  peer->last_sent_handshake = vlib_time_now (vm);
-
   if (noise_create_response (vm,
                             &peer->remote,
                             &packet.sender_index,
@@ -189,17 +215,16 @@ wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer)
                             packet.unencrypted_ephemeral,
                             packet.encrypted_nothing))
     {
-      f64 now = vlib_time_now (vm);
       packet.header.type = MESSAGE_HANDSHAKE_RESPONSE;
       cookie_maker_mac (&peer->cookie_maker, &packet.macs, &packet,
                        sizeof (packet));
 
-      if (noise_remote_begin_session (wmp->vlib_main, &peer->remote))
+      if (noise_remote_begin_session (vm, &peer->remote))
        {
          wg_timers_session_derived (peer);
-         wg_timers_any_authenticated_packet_traversal (peer);
          wg_timers_any_authenticated_packet_sent (peer);
-         peer->last_sent_handshake = now;
+         wg_timers_any_authenticated_packet_traversal (peer);
+         peer->last_sent_handshake = vlib_time_now (vm);
 
          u32 bi0 = 0;
          if (!wg_create_buffer (vm, peer, (u8 *) & packet,
index 4ea1f6e..efe4194 100755 (executable)
@@ -20,6 +20,7 @@
 
 bool wg_send_keepalive (vlib_main_t * vm, wg_peer_t * peer);
 bool wg_send_handshake (vlib_main_t * vm, wg_peer_t * peer, bool is_retry);
+void wg_send_handshake_from_mt (u32 peer_index, bool is_retry);
 bool wg_send_handshake_response (vlib_main_t * vm, wg_peer_t * peer);
 
 always_inline void
index e4d4030..b7fd689 100755 (executable)
@@ -13,6 +13,7 @@
  * limitations under the License.
  */
 
+#include <vlibmemory/api.h>
 #include <wireguard/wireguard.h>
 #include <wireguard/wireguard_send.h>
 #include <wireguard/wireguard_timer.h>
@@ -30,31 +31,77 @@ stop_timer (wg_peer_t * peer, u32 timer_id)
 {
   if (peer->timers[timer_id] != ~0)
     {
-      tw_timer_stop_16t_2w_512sl (&peer->timer_wheel, peer->timers[timer_id]);
+      tw_timer_stop_16t_2w_512sl (peer->timer_wheel, peer->timers[timer_id]);
       peer->timers[timer_id] = ~0;
     }
 }
 
 static void
-start_or_update_timer (wg_peer_t * peer, u32 timer_id, u32 interval)
+start_timer (wg_peer_t * peer, u32 timer_id, u32 interval_ticks)
 {
+  ASSERT (vlib_get_thread_index () == 0);
+
   if (peer->timers[timer_id] == ~0)
     {
-      wg_main_t *wmp = &wg_main;
       peer->timers[timer_id] =
-       tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
-                                    timer_id, interval);
-    }
-  else
-    {
-      tw_timer_update_16t_2w_512sl (&peer->timer_wheel,
-                                   peer->timers[timer_id], interval);
+       tw_timer_start_16t_2w_512sl (peer->timer_wheel, peer - wg_peer_pool,
+                                    timer_id, interval_ticks);
     }
 }
 
+typedef struct
+{
+  u32 peer_idx;
+  u32 timer_id;
+  u32 interval_ticks;
+
+} wg_timers_args;
+
+static void *
+start_timer_thread_fn (void *arg)
+{
+  wg_timers_args *a = arg;
+  wg_peer_t *peer = wg_peer_get (a->peer_idx);
+
+  start_timer (peer, a->timer_id, a->interval_ticks);
+  return 0;
+}
+
+static void
+start_timer_from_mt (u32 peer_idx, u32 timer_id, u32 interval_ticks)
+{
+  wg_timers_args a = {
+    .peer_idx = peer_idx,
+    .timer_id = timer_id,
+    .interval_ticks = interval_ticks,
+  };
+
+  vl_api_rpc_call_main_thread (start_timer_thread_fn, (u8 *) & a, sizeof (a));
+}
+
+static inline u32
+timer_ticks_left (vlib_main_t * vm, f64 init_time_sec, u32 interval_ticks)
+{
+  static const int32_t rounding = (int32_t) (WHZ / 2);
+  int32_t ticks_remain;
+
+  ticks_remain = (init_time_sec - vlib_time_now (vm)) * WHZ + interval_ticks;
+  return (ticks_remain > rounding) ? (u32) ticks_remain : 0;
+}
+
 static void
 wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
 {
+  if (peer->rehandshake_started == ~0)
+    return;
+
+  u32 ticks = timer_ticks_left (vm, peer->rehandshake_started,
+                               peer->rehandshake_interval_tick);
+  if (ticks)
+    {
+      start_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, ticks);
+      return;
+    }
 
   if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES)
     {
@@ -63,17 +110,8 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
       /* We set a timer for destroying any residue that might be left
        * of a partial exchange.
        */
+      start_timer (peer, WG_TIMER_KEY_ZEROING, REJECT_AFTER_TIME * 3 * WHZ);
 
-      if (peer->timers[WG_TIMER_KEY_ZEROING] == ~0)
-       {
-         wg_main_t *wmp = &wg_main;
-
-         peer->timers[WG_TIMER_KEY_ZEROING] =
-           tw_timer_start_16t_2w_512sl (&peer->timer_wheel,
-                                        peer - wmp->peers,
-                                        WG_TIMER_KEY_ZEROING,
-                                        REJECT_AFTER_TIME * 3 * WHZ);
-       }
     }
   else
     {
@@ -85,13 +123,23 @@ wg_expired_retransmit_handshake (vlib_main_t * vm, wg_peer_t * peer)
 static void
 wg_expired_send_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 {
-  wg_send_keepalive (vm, peer);
-
-  if (peer->timer_need_another_keepalive)
+  if (peer->last_sent_packet < peer->last_received_packet)
     {
-      peer->timer_need_another_keepalive = false;
-      start_or_update_timer (peer, WG_TIMER_SEND_KEEPALIVE,
-                            KEEPALIVE_TIMEOUT * WHZ);
+      u32 ticks = timer_ticks_left (vm, peer->last_received_packet,
+                                   KEEPALIVE_TIMEOUT * WHZ);
+      if (ticks)
+       {
+         start_timer (peer, WG_TIMER_SEND_KEEPALIVE, ticks);
+         return;
+       }
+
+      wg_send_keepalive (vm, peer);
+      if (peer->timer_need_another_keepalive)
+       {
+         peer->timer_need_another_keepalive = false;
+         start_timer (peer, WG_TIMER_SEND_KEEPALIVE,
+                      KEEPALIVE_TIMEOUT * WHZ);
+       }
     }
 }
 
@@ -100,6 +148,18 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 {
   if (peer->persistent_keepalive_interval)
     {
+      f64 latest_time = peer->last_sent_packet > peer->last_received_packet
+       ? peer->last_sent_packet : peer->last_received_packet;
+
+      u32 ticks = timer_ticks_left (vm, latest_time,
+                                   peer->persistent_keepalive_interval *
+                                   WHZ);
+      if (ticks)
+       {
+         start_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE, ticks);
+         return;
+       }
+
       wg_send_keepalive (vm, peer);
     }
 }
@@ -107,64 +167,81 @@ wg_expired_send_persistent_keepalive (vlib_main_t * vm, wg_peer_t * peer)
 static void
 wg_expired_new_handshake (vlib_main_t * vm, wg_peer_t * peer)
 {
+  u32 ticks = timer_ticks_left (vm, peer->last_sent_packet,
+                               peer->new_handshake_interval_tick);
+  if (ticks)
+    {
+      start_timer (peer, WG_TIMER_NEW_HANDSHAKE, ticks);
+      return;
+    }
+
   wg_send_handshake (vm, peer, false);
 }
 
 static void
 wg_expired_zero_key_material (vlib_main_t * vm, wg_peer_t * peer)
 {
+  u32 ticks =
+    timer_ticks_left (vm, peer->session_derived, REJECT_AFTER_TIME * 3 * WHZ);
+  if (ticks)
+    {
+      start_timer (peer, WG_TIMER_KEY_ZEROING, ticks);
+      return;
+    }
+
   if (!peer->is_dead)
     {
       noise_remote_clear (vm, &peer->remote);
     }
 }
 
-
 void
 wg_timers_any_authenticated_packet_traversal (wg_peer_t * peer)
 {
   if (peer->persistent_keepalive_interval)
     {
-      start_or_update_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE,
-                            peer->persistent_keepalive_interval * WHZ);
+      start_timer_from_mt (peer - wg_peer_pool,
+                          WG_TIMER_PERSISTENT_KEEPALIVE,
+                          peer->persistent_keepalive_interval * WHZ);
     }
 }
 
 void
 wg_timers_any_authenticated_packet_sent (wg_peer_t * peer)
 {
-  stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
+  peer->last_sent_packet = vlib_time_now (vlib_get_main ());
 }
 
 void
 wg_timers_handshake_initiated (wg_peer_t * peer)
 {
-  u32 interval =
+  peer->rehandshake_started = vlib_time_now (vlib_get_main ());
+  peer->rehandshake_interval_tick =
     REKEY_TIMEOUT * WHZ + get_random_u32_max (REKEY_TIMEOUT_JITTER);
-  start_or_update_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE, interval);
+
+  start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_RETRANSMIT_HANDSHAKE,
+                      peer->rehandshake_interval_tick);
 }
 
 void
 wg_timers_session_derived (wg_peer_t * peer)
 {
-  start_or_update_timer (peer, WG_TIMER_KEY_ZEROING,
-                        REJECT_AFTER_TIME * 3 * WHZ);
+  peer->session_derived = vlib_time_now (vlib_get_main ());
+
+  start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_KEY_ZEROING,
+                      REJECT_AFTER_TIME * 3 * WHZ);
 }
 
 /* Should be called after an authenticated data packet is sent. */
 void
 wg_timers_data_sent (wg_peer_t * peer)
 {
-  u32 interval = (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ +
+  peer->new_handshake_interval_tick =
+    (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * WHZ +
     get_random_u32_max (REKEY_TIMEOUT_JITTER);
 
-  if (peer->timers[WG_TIMER_NEW_HANDSHAKE] == ~0)
-    {
-      wg_main_t *wmp = &wg_main;
-      peer->timers[WG_TIMER_NEW_HANDSHAKE] =
-       tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
-                                    WG_TIMER_NEW_HANDSHAKE, interval);
-    }
+  start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_NEW_HANDSHAKE,
+                      peer->new_handshake_interval_tick);
 }
 
 /* Should be called after an authenticated data packet is received. */
@@ -173,16 +250,11 @@ wg_timers_data_received (wg_peer_t * peer)
 {
   if (peer->timers[WG_TIMER_SEND_KEEPALIVE] == ~0)
     {
-      wg_main_t *wmp = &wg_main;
-      peer->timers[WG_TIMER_SEND_KEEPALIVE] =
-       tw_timer_start_16t_2w_512sl (&peer->timer_wheel, peer - wmp->peers,
-                                    WG_TIMER_SEND_KEEPALIVE,
-                                    KEEPALIVE_TIMEOUT * WHZ);
+      start_timer_from_mt (peer - wg_peer_pool, WG_TIMER_SEND_KEEPALIVE,
+                          KEEPALIVE_TIMEOUT * WHZ);
     }
   else
-    {
-      peer->timer_need_another_keepalive = true;
-    }
+    peer->timer_need_another_keepalive = true;
 }
 
 /* Should be called after a handshake response message is received and processed
@@ -191,15 +263,14 @@ wg_timers_data_received (wg_peer_t * peer)
 void
 wg_timers_handshake_complete (wg_peer_t * peer)
 {
-  stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
-
+  peer->rehandshake_started = ~0;
   peer->timer_handshake_attempts = 0;
 }
 
 void
 wg_timers_any_authenticated_packet_received (wg_peer_t * peer)
 {
-  stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
+  peer->last_received_packet = vlib_time_now (vlib_get_main ());
 }
 
 static vlib_node_registration_t wg_timer_mngr_node;
@@ -222,7 +293,7 @@ expired_timer_callback (u32 * expired_timers)
       pool_index = expired_timers[i] & 0x0FFFFFFF;
       timer_id = expired_timers[i] >> 28;
 
-      peer = pool_elt_at_index (wmp->peers, pool_index);
+      peer = wg_peer_get (pool_index);
       peer->timers[timer_id] = ~0;
     }
 
@@ -231,7 +302,7 @@ expired_timer_callback (u32 * expired_timers)
       pool_index = expired_timers[i] & 0x0FFFFFFF;
       timer_id = expired_timers[i] >> 28;
 
-      peer = pool_elt_at_index (wmp->peers, pool_index);
+      peer = wg_peer_get (pool_index);
       switch (timer_id)
        {
        case WG_TIMER_RETRANSMIT_HANDSHAKE:
@@ -256,18 +327,14 @@ expired_timer_callback (u32 * expired_timers)
 }
 
 void
-wg_timers_init (wg_peer_t * peer, f64 now)
+wg_timer_wheel_init ()
 {
-  for (int i = 0; i < WG_N_TIMERS; i++)
-    {
-      peer->timers[i] = ~0;
-    }
-  tw_timer_wheel_16t_2w_512sl_t *tw = &peer->timer_wheel;
+  wg_main_t *wmp = &wg_main;
+  tw_timer_wheel_16t_2w_512sl_t *tw = &wmp->timer_wheel;
   tw_timer_wheel_init_16t_2w_512sl (tw,
                                    expired_timer_callback,
                                    WG_TICK /* timer period in s */ , ~0);
-  tw->last_run_time = now;
-  peer->adj_index = INDEX_INVALID;
+  tw->last_run_time = vlib_time_now (wmp->vlib_main);
 }
 
 static uword
@@ -275,22 +342,13 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt,
                  vlib_frame_t * f)
 {
   wg_main_t *wmp = &wg_main;
-  wg_peer_t *peers;
-  wg_peer_t *peer;
-
   while (1)
     {
       vlib_process_wait_for_event_or_clock (vm, WG_TICK);
       vlib_process_get_events (vm, NULL);
 
-      peers = wmp->peers;
-      /* *INDENT-OFF* */
-      pool_foreach (peer, peers,
-      ({
-        tw_timer_expire_timers_16t_2w_512sl
-        (&peer->timer_wheel, vlib_time_now (vm));
-      }));
-      /* *INDENT-ON* */
+      tw_timer_expire_timers_16t_2w_512sl (&wmp->timer_wheel,
+                                          vlib_time_now (vm));
     }
 
   return 0;
@@ -299,11 +357,15 @@ wg_timer_mngr_fn (vlib_main_t * vm, vlib_node_runtime_t * rt,
 void
 wg_timers_stop (wg_peer_t * peer)
 {
-  stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
-  stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE);
-  stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
-  stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
-  stop_timer (peer, WG_TIMER_KEY_ZEROING);
+  ASSERT (vlib_get_thread_index () == 0);
+  if (peer->timer_wheel)
+    {
+      stop_timer (peer, WG_TIMER_RETRANSMIT_HANDSHAKE);
+      stop_timer (peer, WG_TIMER_PERSISTENT_KEEPALIVE);
+      stop_timer (peer, WG_TIMER_SEND_KEEPALIVE);
+      stop_timer (peer, WG_TIMER_NEW_HANDSHAKE);
+      stop_timer (peer, WG_TIMER_KEY_ZEROING);
+    }
 }
 
 /* *INDENT-OFF* */
index 457dce2..2cc5dd0 100755 (executable)
@@ -38,7 +38,7 @@ typedef enum _wg_timers
 
 typedef struct wg_peer wg_peer_t;
 
-void wg_timers_init (wg_peer_t * peer, f64 now);
+void wg_timer_wheel_init ();
 void wg_timers_stop (wg_peer_t * peer);
 void wg_timers_data_sent (wg_peer_t * peer);
 void wg_timers_data_received (wg_peer_t * peer);