wireguard: add events for peer
[vpp.git] / test / test_wireguard.py
index 65ebd8d..e844b1d 100755 (executable)
@@ -3,6 +3,7 @@
 
 import datetime
 import base64
+import os
 
 from hashlib import blake2s
 from scapy.packet import Packet
@@ -25,6 +26,7 @@ from vpp_ipip_tun_interface import VppIpIpTunInterface
 from vpp_interface import VppInterface
 from vpp_ip_route import VppIpRoute, VppRoutePath
 from vpp_object import VppObject
+from vpp_papi import VppEnum
 from framework import VppTestCase
 from re import compile
 import unittest
@@ -92,6 +94,19 @@ class VppWgInterface(VppInterface):
                 return True
         return False
 
+    def want_events(self, peer_index=0xffffffff):
+        self.test.vapi.want_wireguard_peer_events(
+            enable_disable=1,
+            pid=os.getpid(),
+            sw_if_index=self._sw_if_index,
+            peer_index=peer_index)
+
+    def wait_events(self, expect, peers, timeout=5):
+        for i in range(len(peers)):
+            rv = self.test.vapi.wait_for_event(timeout, "wireguard_peer_event")
+            self.test.assertEqual(rv.peer_index, peers[i])
+            self.test.assertEqual(rv.flags, expect)
+
     def __str__(self):
         return self.object_id()
 
@@ -343,6 +358,18 @@ class VppWgPeer(VppObject):
                 self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst)
                 self._test.assertEqual(rx[IPv6].ttl, tx[IPv6].ttl-1)
 
+    def want_events(self):
+        self._test.vapi.want_wireguard_peer_events(
+            enable_disable=1,
+            pid=os.getpid(),
+            peer_index=self.index,
+            sw_if_index=self.itf.sw_if_index)
+
+    def wait_event(self, expect, timeout=5):
+        rv = self._test.vapi.wait_for_event(timeout, "wireguard_peer_event")
+        self._test.assertEqual(rv.flags, expect)
+        self._test.assertEqual(rv.peer_index, self.index)
+
 
 class TestWg(VppTestCase):
     """ Wireguard Test Case """
@@ -1176,6 +1203,107 @@ class TestWg(VppTestCase):
         for i in wg_ifs:
             i.remove_vpp_config()
 
+    def test_wg_event(self):
+        """ Test events """
+        port = 12600
+        ESTABLISHED_FLAG = VppEnum.\
+            vl_api_wireguard_peer_flags_t.\
+            WIREGUARD_PEER_ESTABLISHED
+        DEAD_FLAG = VppEnum.\
+            vl_api_wireguard_peer_flags_t.\
+            WIREGUARD_PEER_STATUS_DEAD
+
+        # Create interfaces
+        wg0 = VppWgInterface(self,
+                             self.pg1.local_ip4,
+                             port).add_vpp_config()
+        wg1 = VppWgInterface(self,
+                             self.pg2.local_ip4,
+                             port+1).add_vpp_config()
+        wg0.admin_up()
+        wg1.admin_up()
+
+        # Check peer counter
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # Create peers
+        NUM_PEERS = 2
+        self.pg2.generate_remote_hosts(NUM_PEERS)
+        self.pg2.configure_ipv4_neighbors()
+        self.pg1.generate_remote_hosts(NUM_PEERS)
+        self.pg1.configure_ipv4_neighbors()
+
+        peers_0 = []
+        peers_1 = []
+        routes_0 = []
+        routes_1 = []
+        for i in range(NUM_PEERS):
+            peers_0.append(VppWgPeer(self,
+                                     wg0,
+                                     self.pg1.remote_hosts[i].ip4,
+                                     port+1+i,
+                                     ["10.0.%d.4/32" % i]).add_vpp_config())
+            routes_0.append(VppIpRoute(self, "10.0.%d.4" % i, 32,
+                            [VppRoutePath(self.pg1.remote_hosts[i].ip4,
+                                          wg0.sw_if_index)]).add_vpp_config())
+
+            peers_1.append(VppWgPeer(self,
+                                     wg1,
+                                     self.pg2.remote_hosts[i].ip4,
+                                     port+100+i,
+                                     ["10.100.%d.4/32" % i]).add_vpp_config())
+            routes_1.append(VppIpRoute(self, "10.100.%d.4" % i, 32,
+                            [VppRoutePath(self.pg2.remote_hosts[i].ip4,
+                                          wg1.sw_if_index)]).add_vpp_config())
+
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_PEERS*2)
+
+        # Want events from the first perr of wg0
+        # and from all wg1 peers
+        peers_0[0].want_events()
+        wg1.want_events()
+
+        for i in range(NUM_PEERS):
+            # send a valid handsake init for which we expect a response
+            p = peers_0[i].mk_handshake(self.pg1)
+            rx = self.send_and_expect(self.pg1, [p], self.pg1)
+            peers_0[i].consume_response(rx[0])
+            if (i == 0):
+                peers_0[0].wait_event(ESTABLISHED_FLAG)
+
+            p = peers_1[i].mk_handshake(self.pg2)
+            rx = self.send_and_expect(self.pg2, [p], self.pg2)
+            peers_1[i].consume_response(rx[0])
+
+        wg1.wait_events(
+            ESTABLISHED_FLAG,
+            [peers_1[0].index, peers_1[1].index])
+
+        # remove routes
+        for r in routes_0:
+            r.remove_vpp_config()
+        for r in routes_1:
+            r.remove_vpp_config()
+
+        # remove peers
+        for i in range(NUM_PEERS):
+            self.assertTrue(peers_0[i].query_vpp_config())
+            peers_0[i].remove_vpp_config()
+            if (i == 0):
+                peers_0[i].wait_event(0)
+                peers_0[i].wait_event(DEAD_FLAG)
+        for p in peers_1:
+            self.assertTrue(p.query_vpp_config())
+            p.remove_vpp_config()
+            p.wait_event(0)
+            p.wait_event(DEAD_FLAG)
+
+        wg0.remove_vpp_config()
+        wg1.remove_vpp_config()
+
 
 class WireguardHandoffTests(TestWg):
     """ Wireguard Tests in multi worker setup """
@@ -1241,14 +1369,14 @@ class WireguardHandoffTests(TestWg):
 
         # send packets into the tunnel, from the other worker
         p = [(peer_1.mk_tunnel_header(self.pg1) /
-             Wireguard(message_type=4, reserved_zero=0) /
-             WireguardTransport(
-                 receiver_index=peer_1.sender,
-                 counter=ii+1,
-                 encrypted_encapsulated_packet=peer_1.encrypt_transport(
-                     (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
-                      UDP(sport=222, dport=223) /
-                      Raw())))) for ii in range(255)]
+              Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(
+                    receiver_index=peer_1.sender,
+                    counter=ii+1,
+                    encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                        (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+                         UDP(sport=222, dport=223) /
+                         Raw())))) for ii in range(255)]
 
         rxs = self.send_and_expect(self.pg1, p, self.pg0, worker=1)