wireguard: stop sending handshakes when wg intf is down
[vpp.git] / test / test_wireguard.py
index 7055b7a..7e5ef5b 100644 (file)
@@ -38,9 +38,12 @@ from Crypto.Random import get_random_bytes
 
 from vpp_ipip_tun_interface import VppIpIpTunInterface
 from vpp_interface import VppInterface
+from vpp_pg_interface import is_ipv6_misc
 from vpp_ip_route import VppIpRoute, VppRoutePath
 from vpp_object import VppObject
 from vpp_papi import VppEnum
+from framework import tag_fixme_ubuntu2204, tag_fixme_debian11
+from framework import is_distro_ubuntu2204, is_distro_debian11
 from framework import VppTestCase
 from re import compile
 import unittest
@@ -171,7 +174,7 @@ class VppWgPeer(VppObject):
         self.endpoint = endpoint
         self.port = port
 
-    def add_vpp_config(self, is_ip6=False):
+    def add_vpp_config(self):
         rv = self._test.vapi.wireguard_peer_add(
             peer={
                 "public_key": self.public_key_bytes(),
@@ -493,6 +496,14 @@ class VppWgPeer(VppObject):
         self._test.assertEqual(rv.peer_index, self.index)
 
 
+def is_handshake_init(p):
+    wg_p = Wireguard(p[Raw])
+
+    return wg_p[Wireguard].message_type == 1
+
+
+@tag_fixme_ubuntu2204
+@tag_fixme_debian11
 class TestWg(VppTestCase):
     """Wireguard Test Case"""
 
@@ -518,6 +529,10 @@ class TestWg(VppTestCase):
     @classmethod
     def setUpClass(cls):
         super(TestWg, cls).setUpClass()
+        if (is_distro_ubuntu2204 == True or is_distro_debian11 == True) and not hasattr(
+            cls, "vpp"
+        ):
+            return
         try:
             cls.create_pg_interfaces(range(3))
             for i in cls.pg_interfaces:
@@ -558,6 +573,25 @@ class TestWg(VppTestCase):
             self.ratelimited6_err
         )
 
+    def send_and_assert_no_replies_ignoring_init(
+        self, intf, pkts, remark="", timeout=None
+    ):
+        self.pg_send(intf, pkts)
+
+        def _filter_out_fn(p):
+            return is_ipv6_misc(p) or is_handshake_init(p)
+
+        try:
+            if not timeout:
+                timeout = 1
+            for i in self.pg_interfaces:
+                i.assert_nothing_captured(
+                    timeout=timeout, remark=remark, filter_out_fn=_filter_out_fn
+                )
+                timeout = 0.1
+        finally:
+            pass
+
     def test_wg_interface(self):
         """Simple interface creation"""
         port = 12312
@@ -1334,7 +1368,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp4_err + 1, self.statistics.get_err_counter(self.kp4_error)
         )
@@ -1348,7 +1382,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_peer4_out_err + 1,
             self.statistics.get_err_counter(self.peer4_out_err),
@@ -1357,7 +1391,7 @@ class TestWg(VppTestCase):
         # send a handsake from the peer with an invalid MAC
         p = peer_1.mk_handshake(self.pg1)
         p[WireguardInitiation].mac1 = b"foobar"
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_mac4_err + 1, self.statistics.get_err_counter(self.mac4_error)
         )
@@ -1366,7 +1400,7 @@ class TestWg(VppTestCase):
         p = peer_1.mk_handshake(
             self.pg1, False, X25519PrivateKey.generate().public_key()
         )
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_peer4_in_err + 1,
             self.statistics.get_err_counter(self.peer4_in_err),
@@ -1387,7 +1421,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp4_err + 2, self.statistics.get_err_counter(self.kp4_error)
         )
@@ -1473,7 +1507,7 @@ class TestWg(VppTestCase):
 
         peer_1 = VppWgPeer(
             self, wg0, self.pg1.remote_ip6, port + 1, ["1::3:0/112"]
-        ).add_vpp_config(True)
+        ).add_vpp_config()
         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
 
         r1 = VppIpRoute(
@@ -1493,7 +1527,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
 
         self.assertEqual(
             self.base_kp6_err + 1, self.statistics.get_err_counter(self.kp6_error)
@@ -1508,7 +1542,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_peer6_out_err + 1,
             self.statistics.get_err_counter(self.peer6_out_err),
@@ -1517,7 +1551,7 @@ class TestWg(VppTestCase):
         # send a handsake from the peer with an invalid MAC
         p = peer_1.mk_handshake(self.pg1, True)
         p[WireguardInitiation].mac1 = b"foobar"
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
 
         self.assertEqual(
             self.base_mac6_err + 1, self.statistics.get_err_counter(self.mac6_error)
@@ -1527,7 +1561,7 @@ class TestWg(VppTestCase):
         p = peer_1.mk_handshake(
             self.pg1, True, X25519PrivateKey.generate().public_key()
         )
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_peer6_in_err + 1,
             self.statistics.get_err_counter(self.peer6_in_err),
@@ -1548,7 +1582,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp6_err + 2, self.statistics.get_err_counter(self.kp6_error)
         )
@@ -1634,7 +1668,7 @@ class TestWg(VppTestCase):
 
         peer_1 = VppWgPeer(
             self, wg0, self.pg1.remote_ip4, port + 1, ["1::3:0/112"]
-        ).add_vpp_config(True)
+        ).add_vpp_config()
         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
 
         r1 = VppIpRoute(
@@ -1650,7 +1684,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp6_err + 1, self.statistics.get_err_counter(self.kp6_error)
         )
@@ -1658,7 +1692,7 @@ class TestWg(VppTestCase):
         # send a handsake from the peer with an invalid MAC
         p = peer_1.mk_handshake(self.pg1)
         p[WireguardInitiation].mac1 = b"foobar"
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
 
         self.assertEqual(
             self.base_mac4_err + 1, self.statistics.get_err_counter(self.mac4_error)
@@ -1668,7 +1702,7 @@ class TestWg(VppTestCase):
         p = peer_1.mk_handshake(
             self.pg1, False, X25519PrivateKey.generate().public_key()
         )
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_peer4_in_err + 1,
             self.statistics.get_err_counter(self.peer4_in_err),
@@ -1689,7 +1723,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp6_err + 2, self.statistics.get_err_counter(self.kp6_error)
         )
@@ -1790,7 +1824,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp4_err + 1, self.statistics.get_err_counter(self.kp4_error)
         )
@@ -1798,7 +1832,7 @@ class TestWg(VppTestCase):
         # send a handsake from the peer with an invalid MAC
         p = peer_1.mk_handshake(self.pg1, True)
         p[WireguardInitiation].mac1 = b"foobar"
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_mac6_err + 1, self.statistics.get_err_counter(self.mac6_error)
         )
@@ -1807,7 +1841,7 @@ class TestWg(VppTestCase):
         p = peer_1.mk_handshake(
             self.pg1, True, X25519PrivateKey.generate().public_key()
         )
-        self.send_and_assert_no_replies(self.pg1, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg1, [p])
         self.assertEqual(
             self.base_peer6_in_err + 1,
             self.statistics.get_err_counter(self.peer6_in_err),
@@ -1828,7 +1862,7 @@ class TestWg(VppTestCase):
             / UDP(sport=555, dport=556)
             / Raw()
         )
-        self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies_ignoring_init(self.pg0, [p])
         self.assertEqual(
             self.base_kp4_err + 2, self.statistics.get_err_counter(self.kp4_error)
         )
@@ -2226,7 +2260,141 @@ class TestWg(VppTestCase):
         wg0.remove_vpp_config()
         wg1.remove_vpp_config()
 
+    def test_wg_sending_handshake_when_admin_down(self):
+        """Sending handshake when admin down"""
+        port = 12323
+
+        # create wg interface
+        wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+        wg0.config_ip4()
+
+        # create a peer
+        peer_1 = VppWgPeer(
+            self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.3.0/24"]
+        ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
 
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # wait for the peer to send a handshake initiation
+        # expect no handshakes
+        for i in range(2):
+            self.pg1.assert_nothing_captured(remark="handshake packet(s) sent")
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # administratively enable the wg interface
+        # expect the peer to send a handshake initiation
+        wg0.admin_up()
+        rxs = self.pg1.get_capture(1, timeout=2)
+        peer_1.consume_init(rxs[0], self.pg1)
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # administratively disable the wg interface
+        # expect no handshakes
+        wg0.admin_down()
+        for i in range(6):
+            self.pg1.assert_nothing_captured(remark="handshake packet(s) sent")
+
+        # remove configs
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_sending_data_when_admin_down(self):
+        """Sending data when admin down"""
+        port = 12323
+
+        # create wg interface
+        wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+        wg0.admin_up()
+        wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create a peer
+        peer_1 = VppWgPeer(
+            self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.3.0/24"]
+        ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        # create a route to rewrite traffic into the wg interface
+        r1 = VppIpRoute(
+            self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)]
+        ).add_vpp_config()
+
+        # wait for the peer to send a handshake initiation
+        rxs = self.pg1.get_capture(1, timeout=2)
+
+        # prepare and send a handshake response
+        # expect a keepalive message
+        resp = peer_1.consume_init(rxs[0], self.pg1)
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        # verify the keepalive message
+        b = peer_1.decrypt_transport(rxs[0])
+        self.assertEqual(0, len(b))
+
+        # prepare and send a packet that will be rewritten into the wg interface
+        # expect a data packet sent
+        p = (
+            Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
+            / IP(src=self.pg0.remote_ip4, dst="10.11.3.2")
+            / UDP(sport=555, dport=556)
+            / Raw()
+        )
+        rxs = self.send_and_expect(self.pg0, [p], self.pg1)
+
+        # verify the data packet
+        peer_1.validate_encapped(rxs, p)
+
+        # administratively disable the wg interface
+        wg0.admin_down()
+
+        # send a packet that will be rewritten into the wg interface
+        # expect no data packets sent
+        self.send_and_assert_no_replies(self.pg0, [p])
+
+        # administratively enable the wg interface
+        # expect the peer to send a handshake initiation
+        wg0.admin_up()
+        peer_1.noise_reset()
+        rxs = self.pg1.get_capture(1, timeout=2)
+        resp = peer_1.consume_init(rxs[0], self.pg1)
+
+        # send a packet that will be rewritten into the wg interface
+        # expect no data packets sent because the peer is not initiated
+        self.send_and_assert_no_replies(self.pg0, [p])
+        self.assertEqual(
+            self.base_kp4_err + 1, self.statistics.get_err_counter(self.kp4_error)
+        )
+
+        # send a handshake response and expect a keepalive message
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        # verify the keepalive message
+        b = peer_1.decrypt_transport(rxs[0])
+        self.assertEqual(0, len(b))
+
+        # send a packet that will be rewritten into the wg interface
+        # expect a data packet sent
+        rxs = self.send_and_expect(self.pg0, [p], self.pg1)
+
+        # verify the data packet
+        peer_1.validate_encapped(rxs, p)
+
+        # remove configs
+        r1.remove_vpp_config()
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+
+@tag_fixme_ubuntu2204
+@tag_fixme_debian11
 class WireguardHandoffTests(TestWg):
     """Wireguard Tests in multi worker setup"""
 
@@ -2316,7 +2484,7 @@ class WireguardHandoffTests(TestWg):
             self.assertEqual(rx[IP].ttl, 19)
 
         # send a packets that are routed into the tunnel
-        # from owrker 0
+        # from worker 0
         rxs = self.send_and_expect(self.pg0, pe * 255, self.pg1, worker=0)
 
         peer_1.validate_encapped(rxs, pe)
@@ -2395,7 +2563,7 @@ class TestWgFIB(VppTestCase):
         self.assertEqual(0, len(b))
 
         # prepare and send a packet that will be rewritten into the wg interface
-        # expect a data packet sent to the new endpoint
+        # expect a data packet sent
         p = (
             Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
             / IP(src=self.pg0.remote_ip4, dst="10.11.3.2")