wireguard: add support for chained buffers
[vpp.git] / test / test_wireguard.py
index b9713f6..e63508a 100644 (file)
@@ -11,6 +11,7 @@ from scapy.packet import Raw
 from scapy.layers.l2 import Ether, ARP
 from scapy.layers.inet import IP, UDP
 from scapy.layers.inet6 import IPv6
 from scapy.layers.l2 import Ether, ARP
 from scapy.layers.inet import IP, UDP
 from scapy.layers.inet6 import IPv6
+from scapy.layers.vxlan import VXLAN
 from scapy.contrib.wireguard import (
     Wireguard,
     WireguardResponse,
 from scapy.contrib.wireguard import (
     Wireguard,
     WireguardResponse,
@@ -40,6 +41,8 @@ from vpp_ipip_tun_interface import VppIpIpTunInterface
 from vpp_interface import VppInterface
 from vpp_pg_interface import is_ipv6_misc
 from vpp_ip_route import VppIpRoute, VppRoutePath
 from vpp_interface import VppInterface
 from vpp_pg_interface import is_ipv6_misc
 from vpp_ip_route import VppIpRoute, VppRoutePath
+from vpp_l2 import VppBridgeDomain, VppBridgeDomainPort
+from vpp_vxlan_tunnel import VppVxlanTunnel
 from vpp_object import VppObject
 from vpp_papi import VppEnum
 from framework import is_distro_ubuntu2204, is_distro_debian11, tag_fixme_vpp_debug
 from vpp_object import VppObject
 from vpp_papi import VppEnum
 from framework import is_distro_ubuntu2204, is_distro_debian11, tag_fixme_vpp_debug
@@ -470,6 +473,7 @@ class VppWgPeer(VppObject):
         return self.noise.encrypt(bytes(p))
 
     def validate_encapped(self, rxs, tx, is_tunnel_ip6=False, is_transport_ip6=False):
         return self.noise.encrypt(bytes(p))
 
     def validate_encapped(self, rxs, tx, is_tunnel_ip6=False, is_transport_ip6=False):
+        ret_rxs = []
         for rx in rxs:
             rx = self.decrypt_transport(rx, is_tunnel_ip6)
             if is_transport_ip6 is False:
         for rx in rxs:
             rx = self.decrypt_transport(rx, is_tunnel_ip6)
             if is_transport_ip6 is False:
@@ -482,6 +486,8 @@ class VppWgPeer(VppObject):
                 # check the original packet is present
                 self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst)
                 self._test.assertEqual(rx[IPv6].hlim, tx[IPv6].hlim - 1)
                 # check the original packet is present
                 self._test.assertEqual(rx[IPv6].dst, tx[IPv6].dst)
                 self._test.assertEqual(rx[IPv6].hlim, tx[IPv6].hlim - 1)
+            ret_rxs.append(rx)
+        return ret_rxs
 
     def want_events(self):
         self._test.vapi.want_wireguard_peer_events(
 
     def want_events(self):
         self._test.vapi.want_wireguard_peer_events(
@@ -2510,6 +2516,227 @@ class TestWg(VppTestCase):
         peer_1.remove_vpp_config()
         wg0.remove_vpp_config()
 
         peer_1.remove_vpp_config()
         wg0.remove_vpp_config()
 
+    def _test_wg_large_packet_tmpl(self, is_async, is_ip6):
+        self.vapi.wg_set_async_mode(is_async)
+        port = 12323
+
+        # create wg interface
+        if is_ip6:
+            wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip6()
+        else:
+            wg0 = VppWgInterface(self, self.pg1.local_ip4, port).add_vpp_config()
+            wg0.admin_up()
+            wg0.config_ip4()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create a peer
+        if is_ip6:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip6, port + 1, ["1::3:0/112"]
+            ).add_vpp_config()
+        else:
+            peer_1 = VppWgPeer(
+                self, wg0, self.pg1.remote_ip4, port + 1, ["10.11.3.0/24"]
+            ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        # create a route to rewrite traffic into the wg interface
+        if is_ip6:
+            r1 = VppIpRoute(
+                self, "1::3:0", 112, [VppRoutePath("1::3:1", wg0.sw_if_index)]
+            ).add_vpp_config()
+        else:
+            r1 = VppIpRoute(
+                self, "10.11.3.0", 24, [VppRoutePath("10.11.3.1", wg0.sw_if_index)]
+            ).add_vpp_config()
+
+        # wait for the peer to send a handshake initiation
+        rxs = self.pg1.get_capture(1, timeout=2)
+
+        # prepare and send a handshake response
+        # expect a keepalive message
+        resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=is_ip6)
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        # verify the keepalive message
+        b = peer_1.decrypt_transport(rxs[0], is_ip6=is_ip6)
+        self.assertEqual(0, len(b))
+
+        # prepare and send data packets
+        # expect to receive them decrypted
+        if is_ip6:
+            ip_header = IPv6(src="1::3:1", dst=self.pg0.remote_ip6, hlim=20)
+        else:
+            ip_header = IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20)
+        packet_len_opts = (
+            2500,  # two buffers
+            1500,  # one buffer
+            4500,  # three buffers
+            1910 if is_ip6 else 1950,  # auth tag is not contiguous
+        )
+        txs = []
+        for l in packet_len_opts:
+            txs.append(
+                peer_1.mk_tunnel_header(self.pg1, is_ip6=is_ip6)
+                / Wireguard(message_type=4, reserved_zero=0)
+                / WireguardTransport(
+                    receiver_index=peer_1.sender,
+                    counter=len(txs),
+                    encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                        ip_header / UDP(sport=222, dport=223) / Raw(b"\xfe" * l)
+                    ),
+                )
+            )
+        rxs = self.send_and_expect(self.pg1, txs, self.pg0)
+
+        # verify decrypted packets
+        for i, l in enumerate(packet_len_opts):
+            if is_ip6:
+                self.assertEqual(rxs[i][IPv6].dst, self.pg0.remote_ip6)
+                self.assertEqual(rxs[i][IPv6].hlim, ip_header.hlim - 1)
+            else:
+                self.assertEqual(rxs[i][IP].dst, self.pg0.remote_ip4)
+                self.assertEqual(rxs[i][IP].ttl, ip_header.ttl - 1)
+            self.assertEqual(len(rxs[i][Raw]), l)
+            self.assertEqual(bytes(rxs[i][Raw]), b"\xfe" * l)
+
+        # prepare and send packets that will be rewritten into the wg interface
+        # expect data packets sent
+        if is_ip6:
+            ip_header = IPv6(src=self.pg0.remote_ip6, dst="1::3:2")
+        else:
+            ip_header = IP(src=self.pg0.remote_ip4, dst="10.11.3.2")
+        packet_len_opts = (
+            2500,  # two buffers
+            1500,  # one buffer
+            4500,  # three buffers
+            1980 if is_ip6 else 2000,  # no free space to write auth tag
+        )
+        txs = []
+        for l in packet_len_opts:
+            txs.append(
+                Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac)
+                / ip_header
+                / UDP(sport=555, dport=556)
+                / Raw(b"\xfe" * l)
+            )
+        rxs = self.send_and_expect(self.pg0, txs, self.pg1)
+
+        # verify the data packets
+        rxs_decrypted = peer_1.validate_encapped(
+            rxs, ip_header, is_tunnel_ip6=is_ip6, is_transport_ip6=is_ip6
+        )
+
+        for i, l in enumerate(packet_len_opts):
+            self.assertEqual(len(rxs_decrypted[i][Raw]), l)
+            self.assertEqual(bytes(rxs_decrypted[i][Raw]), b"\xfe" * l)
+
+        # remove configs
+        r1.remove_vpp_config()
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_large_packet_v4_sync(self):
+        """Large packet (v4, sync)"""
+        self._test_wg_large_packet_tmpl(is_async=False, is_ip6=False)
+
+    def test_wg_large_packet_v6_sync(self):
+        """Large packet (v6, sync)"""
+        self._test_wg_large_packet_tmpl(is_async=False, is_ip6=True)
+
+    def test_wg_large_packet_v4_async(self):
+        """Large packet (v4, async)"""
+        self._test_wg_large_packet_tmpl(is_async=True, is_ip6=False)
+
+    def test_wg_large_packet_v6_async(self):
+        """Large packet (v6, async)"""
+        self._test_wg_large_packet_tmpl(is_async=True, is_ip6=True)
+
+    def test_wg_lack_of_buf_headroom(self):
+        """Lack of buffer's headroom (v6 vxlan over v6 wg)"""
+        port = 12323
+
+        # create wg interface
+        wg0 = VppWgInterface(self, self.pg1.local_ip6, port).add_vpp_config()
+        wg0.admin_up()
+        wg0.config_ip6()
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # create a peer
+        peer_1 = VppWgPeer(
+            self, wg0, self.pg1.remote_ip6, port + 1, ["::/0"]
+        ).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
+
+        # create a route to enable communication between wg interface addresses
+        r1 = VppIpRoute(
+            self, wg0.remote_ip6, 128, [VppRoutePath("0.0.0.0", wg0.sw_if_index)]
+        ).add_vpp_config()
+
+        # wait for the peer to send a handshake initiation
+        rxs = self.pg1.get_capture(1, timeout=2)
+
+        # prepare and send a handshake response
+        # expect a keepalive message
+        resp = peer_1.consume_init(rxs[0], self.pg1, is_ip6=True)
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        # verify the keepalive message
+        b = peer_1.decrypt_transport(rxs[0], is_ip6=True)
+        self.assertEqual(0, len(b))
+
+        # create vxlan interface over the wg interface
+        vxlan0 = VppVxlanTunnel(self, src=wg0.local_ip6, dst=wg0.remote_ip6, vni=1111)
+        vxlan0.add_vpp_config()
+
+        # create bridge domain
+        bd1 = VppBridgeDomain(self, bd_id=1)
+        bd1.add_vpp_config()
+
+        # add the vxlan interface and pg0 to the bridge domain
+        bd1_ports = (
+            VppBridgeDomainPort(self, bd1, vxlan0).add_vpp_config(),
+            VppBridgeDomainPort(self, bd1, self.pg0).add_vpp_config(),
+        )
+
+        # prepare and send packets that will be rewritten into the vxlan interface
+        # expect they to be rewritten into the wg interface then and data packets sent
+        tx = (
+            Ether(dst="00:00:00:00:00:01", src="00:00:00:00:00:02")
+            / IPv6(src="::1", dst="::2", hlim=20)
+            / UDP(sport=1111, dport=1112)
+            / Raw(b"\xfe" * 1900)
+        )
+        rxs = self.send_and_expect(self.pg0, [tx] * 5, self.pg1)
+
+        # verify the data packet
+        for rx in rxs:
+            rx_decrypted = IPv6(peer_1.decrypt_transport(rx, is_ip6=True))
+
+            self.assertEqual(rx_decrypted[VXLAN].vni, vxlan0.vni)
+            inner = rx_decrypted[VXLAN].payload
+
+            # check the original packet is present
+            self.assertEqual(inner[IPv6].dst, tx[IPv6].dst)
+            self.assertEqual(inner[IPv6].hlim, tx[IPv6].hlim)
+            self.assertEqual(len(inner[Raw]), len(tx[Raw]))
+            self.assertEqual(bytes(inner[Raw]), bytes(tx[Raw]))
+
+        # remove configs
+        for bdp in bd1_ports:
+            bdp.remove_vpp_config()
+        bd1.remove_vpp_config()
+        vxlan0.remove_vpp_config()
+        r1.remove_vpp_config()
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
 
 @tag_fixme_vpp_debug
 class WireguardHandoffTests(TestWg):
 
 @tag_fixme_vpp_debug
 class WireguardHandoffTests(TestWg):