wireguard: Fix for tunnel encap 35/28735/5
authorNeale Ranns <nranns@cisco.com>
Thu, 10 Sep 2020 08:49:10 +0000 (08:49 +0000)
committerDamjan Marion <dmarion@me.com>
Sat, 12 Sep 2020 08:20:59 +0000 (08:20 +0000)
Type: fix

add UT for sneding handshale init and transport packets

Signed-off-by: Neale Ranns <nranns@cisco.com>
Change-Id: Iab1ed8864c666d5a0ae0b2364a9ca4de3c8770dc

src/plugins/wireguard/test/test_wireguard.py
src/plugins/wireguard/wireguard_cookie.c
src/plugins/wireguard/wireguard_if.c
src/plugins/wireguard/wireguard_input.c
src/plugins/wireguard/wireguard_noise.c
src/plugins/wireguard/wireguard_output_tun.c
src/plugins/wireguard/wireguard_peer.c

index cd124f3..7734939 100755 (executable)
@@ -1,15 +1,24 @@
 #!/usr/bin/env python3
 """ Wg tests """
 
+import datetime
+import base64
+
+from hashlib import blake2s
 from scapy.packet import Packet
 from scapy.packet import Raw
-from scapy.layers.l2 import Ether
+from scapy.layers.l2 import Ether, ARP
 from scapy.layers.inet import IP, UDP
 from scapy.contrib.wireguard import Wireguard, WireguardResponse, \
-    WireguardInitiation
-from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey
+    WireguardInitiation, WireguardTransport
+from cryptography.hazmat.primitives.asymmetric.x25519 import \
+    X25519PrivateKey, X25519PublicKey
 from cryptography.hazmat.primitives.serialization import Encoding, \
     PrivateFormat, PublicFormat, NoEncryption
+from cryptography.hazmat.primitives.hashes import BLAKE2s, Hash
+from cryptography.hazmat.primitives.hmac import HMAC
+from cryptography.hazmat.backends import default_backend
+from noise.connection import NoiseConnection, Keypair
 
 from vpp_ipip_tun_interface import VppIpIpTunInterface
 from vpp_interface import VppInterface
@@ -25,41 +34,48 @@ Wg test.
 """
 
 
+def private_key_bytes(k):
+    return k.private_bytes(Encoding.Raw,
+                           PrivateFormat.Raw,
+                           NoEncryption())
+
+
+def public_key_bytes(k):
+    return k.public_bytes(Encoding.Raw,
+                          PublicFormat.Raw)
+
+
 class VppWgInterface(VppInterface):
     """
     VPP WireGuard interface
     """
 
-    def __init__(self, test, src, port, key=None):
+    def __init__(self, test, src, port):
         super(VppWgInterface, self).__init__(test)
 
-        self.key = key
-        if not self.key:
-            self.generate = True
-        else:
-            self.generate = False
         self.port = port
         self.src = src
+        self.private_key = X25519PrivateKey.generate()
+        self.public_key = self.private_key.public_key()
+
+    def public_key_bytes(self):
+        return public_key_bytes(self.public_key)
+
+    def private_key_bytes(self):
+        return private_key_bytes(self.private_key)
 
     def add_vpp_config(self):
         r = self.test.vapi.wireguard_interface_create(interface={
             'user_instance': 0xffffffff,
             'port': self.port,
             'src_ip': self.src,
-            'private_key': self.key_bytes()
+            'private_key': private_key_bytes(self.private_key),
+            'generate_key': False
         })
         self.set_sw_if_index(r.sw_if_index)
         self.test.registry.register(self, self.test.logger)
         return self
 
-    def key_bytes(self):
-        if self.key:
-            return self.key.private_bytes(Encoding.Raw,
-                                          PrivateFormat.Raw,
-                                          NoEncryption())
-        else:
-            return bytearray(32)
-
     def remove_vpp_config(self):
         self.test.vapi.wireguard_interface_delete(
             sw_if_index=self._sw_if_index)
@@ -70,7 +86,7 @@ class VppWgInterface(VppInterface):
             if t.interface.sw_if_index == self._sw_if_index and \
                str(t.interface.src_ip) == self.src and \
                t.interface.port == self.port and \
-               t.interface.private_key == self.key_bytes():
+               t.interface.private_key == private_key_bytes(self.private_key):
                 return True
         return False
 
@@ -91,6 +107,10 @@ def find_route(test, prefix, table_id=0):
     return False
 
 
+NOISE_HANDSHAKE_NAME = b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
+NOISE_IDENTIFIER_NAME = b"WireGuard v1 zx2c4 Jason@zx2c4.com"
+
+
 class VppWgPeer(VppObject):
 
     def __init__(self,
@@ -106,9 +126,12 @@ class VppWgPeer(VppObject):
         self.port = port
         self.allowed_ips = allowed_ips
         self.persistent_keepalive = persistent_keepalive
+
+        # remote peer's public
         self.private_key = X25519PrivateKey.generate()
         self.public_key = self.private_key.public_key()
-        self.hash = bytearray(16)
+
+        self.noise = NoiseConnection.from_name(NOISE_HANDSHAKE_NAME)
 
     def validate_routing(self):
         for a in self.allowed_ips:
@@ -129,6 +152,7 @@ class VppWgPeer(VppObject):
                 'sw_if_index': self.itf.sw_if_index,
                 'persistent_keepalive': self.persistent_keepalive})
         self.index = rv.peer_index
+        self.receiver_index = self.index + 1
         self._test.registry.register(self, self._test.logger)
         self.validate_routing()
         return self
@@ -141,13 +165,7 @@ class VppWgPeer(VppObject):
         return ("wireguard-peer-%s" % self.index)
 
     def public_key_bytes(self):
-        return self.public_key.public_bytes(Encoding.Raw,
-                                            PublicFormat.Raw)
-
-    def private_key_bytes(self):
-        return self.private_key.private_bytes(Encoding.Raw,
-                                              PrivateFormat.Raw,
-                                              NoEncryption())
+        return public_key_bytes(self.public_key)
 
     def query_vpp_config(self):
         peers = self._test.vapi.wireguard_peers_dump()
@@ -167,6 +185,148 @@ class VppWgPeer(VppObject):
                 return True
         return False
 
+    def set_responder(self):
+        self.noise.set_as_responder()
+
+    def mk_tunnel_header(self, tx_itf):
+        return (Ether(dst=tx_itf.local_mac, src=tx_itf.remote_mac) /
+                IP(src=self.endpoint, dst=self.itf.src) /
+                UDP(sport=self.port, dport=self.itf.port))
+
+    def noise_init(self, public_key=None):
+        self.noise.set_prologue(NOISE_IDENTIFIER_NAME)
+        self.noise.set_psks(psk=bytes(bytearray(32)))
+
+        if not public_key:
+            public_key = self.itf.public_key
+
+        # local/this private
+        self.noise.set_keypair_from_private_bytes(
+            Keypair.STATIC,
+            private_key_bytes(self.private_key))
+        # remote's public
+        self.noise.set_keypair_from_public_bytes(
+            Keypair.REMOTE_STATIC,
+            public_key_bytes(public_key))
+
+        self.noise.start_handshake()
+
+    def mk_handshake(self, tx_itf, public_key=None):
+        self.noise.set_as_initiator()
+        self.noise_init(public_key)
+
+        p = (Wireguard() / WireguardInitiation())
+
+        p[Wireguard].message_type = 1
+        p[Wireguard].reserved_zero = 0
+        p[WireguardInitiation].sender_index = self.receiver_index
+
+        # some random data for the message
+        #  lifted from the noise protocol's wireguard example
+        now = datetime.datetime.now()
+        tai = struct.pack('!qi', 4611686018427387914 + int(now.timestamp()),
+                          int(now.microsecond * 1e3))
+        b = self.noise.write_message(payload=tai)
+
+        # load noise into init message
+        p[WireguardInitiation].unencrypted_ephemeral = b[0:32]
+        p[WireguardInitiation].encrypted_static = b[32:80]
+        p[WireguardInitiation].encrypted_timestamp = b[80:108]
+
+        # generate the mac1 hash
+        mac_key = blake2s(b'mac1----' +
+                          self.itf.public_key_bytes()).digest()
+        p[WireguardInitiation].mac1 = blake2s(bytes(p)[0:116],
+                                              digest_size=16,
+                                              key=mac_key).digest()
+        p[WireguardInitiation].mac2 = bytearray(16)
+
+        p = (self.mk_tunnel_header(tx_itf) / p)
+
+        return p
+
+    def verify_header(self, p):
+        self._test.assertEqual(p[IP].src, self.itf.src)
+        self._test.assertEqual(p[IP].dst, self.endpoint)
+        self._test.assertEqual(p[UDP].sport, self.itf.port)
+        self._test.assertEqual(p[UDP].dport, self.port)
+        self._test.assert_packet_checksums_valid(p)
+
+    def consume_init(self, p, tx_itf):
+        self.noise.set_as_responder()
+        self.noise_init(self.itf.public_key)
+        self.verify_header(p)
+
+        init = Wireguard(p[Raw])
+
+        self._test.assertEqual(init[Wireguard].message_type, 1)
+        self._test.assertEqual(init[Wireguard].reserved_zero, 0)
+
+        self.sender = init[WireguardInitiation].sender_index
+
+        # validate the hash
+        mac_key = blake2s(b'mac1----' +
+                          public_key_bytes(self.public_key)).digest()
+        mac1 = blake2s(bytes(init)[0:-32],
+                       digest_size=16,
+                       key=mac_key).digest()
+        self._test.assertEqual(init[WireguardInitiation].mac1, mac1)
+
+        # this passes only unencrypted_ephemeral, encrypted_static,
+        # encrypted_timestamp fields of the init
+        payload = self.noise.read_message(bytes(init)[8:-32])
+
+        # build the response
+        b = self.noise.write_message()
+        mac_key = blake2s(b'mac1----' +
+                          public_key_bytes(self.itf.public_key)).digest()
+        resp = (Wireguard(message_type=2, reserved_zero=0) /
+                WireguardResponse(sender_index=self.receiver_index,
+                                  receiver_index=self.sender,
+                                  unencrypted_ephemeral=b[0:32],
+                                  encrypted_nothing=b[32:]))
+        mac1 = blake2s(bytes(resp)[:-32],
+                       digest_size=16,
+                       key=mac_key).digest()
+        resp[WireguardResponse].mac1 = mac1
+
+        resp = (self.mk_tunnel_header(tx_itf) / resp)
+        self._test.assertTrue(self.noise.handshake_finished)
+
+        return resp
+
+    def consume_response(self, p):
+        self.verify_header(p)
+
+        resp = Wireguard(p[Raw])
+
+        self._test.assertEqual(resp[Wireguard].message_type, 2)
+        self._test.assertEqual(resp[Wireguard].reserved_zero, 0)
+        self._test.assertEqual(resp[WireguardResponse].receiver_index,
+                               self.receiver_index)
+
+        self.sender = resp[Wireguard].sender_index
+
+        payload = self.noise.read_message(bytes(resp)[12:60])
+        self._test.assertEqual(payload, b'')
+        self._test.assertTrue(self.noise.handshake_finished)
+
+    def decrypt_transport(self, p):
+        self.verify_header(p)
+
+        p = Wireguard(p[Raw])
+        self._test.assertEqual(p[Wireguard].message_type, 4)
+        self._test.assertEqual(p[Wireguard].reserved_zero, 0)
+        self._test.assertEqual(p[WireguardTransport].receiver_index,
+                               self.receiver_index)
+
+        d = self.noise.decrypt(
+            p[WireguardTransport].encrypted_encapsulated_packet)
+        return d
+
+    def encrypt_transport(self, p):
+        return self.noise.encrypt(bytes(p))
+
 
 class TestWg(VppTestCase):
     """ Wireguard Test Case """
@@ -192,6 +352,7 @@ class TestWg(VppTestCase):
         super(TestWg, cls).tearDownClass()
 
     def test_wg_interface(self):
+        """ Simple interface creation """
         port = 12312
 
         # Create interface
@@ -204,7 +365,51 @@ class TestWg(VppTestCase):
         # delete interface
         wg0.remove_vpp_config()
 
-    def test_wg_peer(self):
+    def test_handshake_hash(self):
+        """ test hashing an init message """
+        # a init packet generated by linux given the key below
+        h = "0100000098b9032b" \
+            "55cc4b39e73c3d24" \
+            "a2a1ab884b524a81" \
+            "1808bb86640fb70d" \
+            "e93154fec1879125" \
+            "ab012624a27f0b75" \
+            "c0a2582f438ddb5f" \
+            "8e768af40b4ab444" \
+            "02f9ff473e1b797e" \
+            "80d39d93c5480c82" \
+            "a3d4510f70396976" \
+            "586fb67300a5167b" \
+            "ae6ca3ff3dfd00eb" \
+            "59be198810f5aa03" \
+            "6abc243d2155ee4f" \
+            "2336483900aef801" \
+            "08752cd700000000" \
+            "0000000000000000" \
+            "00000000"
+
+        b = bytearray.fromhex(h)
+        tgt = Wireguard(b)
+
+        pubb = base64.b64decode("aRuHFTTxICIQNefp05oKWlJv3zgKxb8+WW7JJMh0jyM=")
+        pub = X25519PublicKey.from_public_bytes(pubb)
+
+        self.assertEqual(pubb, public_key_bytes(pub))
+
+        # strip the macs and build a new packet
+        init = b[0:-32]
+        mac_key = blake2s(b'mac1----' + public_key_bytes(pub)).digest()
+        init += blake2s(init,
+                        digest_size=16,
+                        key=mac_key).digest()
+        init += b'\x00' * 16
+
+        act = Wireguard(init)
+
+        self.assertEqual(tgt, act)
+
+    def test_wg_peer_resp(self):
+        """ Send handshake response """
         wg_output_node_name = '/err/wg-output-tun/'
         wg_input_node_name = '/err/wg-input/'
 
@@ -213,16 +418,9 @@ class TestWg(VppTestCase):
         # Create interfaces
         wg0 = VppWgInterface(self,
                              self.pg1.local_ip4,
-                             port,
-                             key=X25519PrivateKey.generate()).add_vpp_config()
-        wg1 = VppWgInterface(self,
-                             self.pg2.local_ip4,
-                             port+1).add_vpp_config()
+                             port).add_vpp_config()
         wg0.admin_up()
-        wg1.admin_up()
-
-        # Check peer counter
-        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
+        wg0.config_ip4()
 
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
@@ -236,43 +434,210 @@ class TestWg(VppTestCase):
         self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
 
         # wait for the peer to send a handshake
-        capture = self.pg1.get_capture(1, timeout=2)
-        handshake = capture[0]
+        rx = self.pg1.get_capture(1, timeout=2)
+
+        # consume the handshake in the noise protocol and
+        # generate the response
+        resp = peer_1.consume_init(rx[0], self.pg1)
+
+        # send the response, get keepalive
+        rxs = self.send_and_expect(self.pg1, [resp], self.pg1)
+
+        for rx in rxs:
+            b = peer_1.decrypt_transport(rx)
+            self.assertEqual(0, len(b))
+
+        # send a packets that are routed into the tunnel
+        p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+             IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+             UDP(sport=555, dport=556) /
+             Raw(b'\x00' * 80))
+
+        rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
+
+        for rx in rxs:
+            rx = IP(peer_1.decrypt_transport(rx))
+            # chech the oringial packet is present
+            self.assertEqual(rx[IP].dst, p[IP].dst)
+            self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
+
+        # send packets into the tunnel, expect to receive them on
+        # the other side
+        p = [(peer_1.mk_tunnel_header(self.pg1) /
+              Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(
+                  receiver_index=peer_1.sender,
+                  counter=ii,
+                  encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                      (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+                       UDP(sport=222, dport=223) /
+                       Raw())))) for ii in range(255)]
+
+        rxs = self.send_and_expect(self.pg1, p, self.pg0)
+
+        for rx in rxs:
+            self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rx[IP].ttl, 19)
+
+    def test_wg_peer_init(self):
+        """ Send handshake init """
+        wg_output_node_name = '/err/wg-output-tun/'
+        wg_input_node_name = '/err/wg-input/'
 
-        self.assertEqual(handshake[IP].src, wg0.src)
-        self.assertEqual(handshake[IP].dst, peer_1.endpoint)
-        self.assertEqual(handshake[UDP].sport, wg0.port)
-        self.assertEqual(handshake[UDP].dport, peer_1.port)
-        handshake = Wireguard(handshake[Raw])
-        self.assertEqual(handshake.message_type, 1)  # "initiate")
-        init = handshake[WireguardInitiation]
+        port = 12323
+
+        # Create interfaces
+        wg0 = VppWgInterface(self,
+                             self.pg1.local_ip4,
+                             port).add_vpp_config()
+        wg0.admin_up()
+        wg0.config_ip4()
+
+        peer_1 = VppWgPeer(self,
+                           wg0,
+                           self.pg1.remote_ip4,
+                           port+1,
+                           ["10.11.2.0/24",
+                            "10.11.3.0/24"]).add_vpp_config()
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
 
         # route a packet into the wg interface
         #  use the allowed-ip prefix
+        #  this is dropped because the peer is not initiated
+        p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+             IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+             UDP(sport=555, dport=556) /
+             Raw())
+        self.send_and_assert_no_replies(self.pg0, [p])
+
+        kp_error = wg_output_node_name + "Keypair error"
+        self.assertEqual(1, self.statistics.get_err_counter(kp_error))
+
+        # send a handsake from the peer with an invalid MAC
+        p = peer_1.mk_handshake(self.pg1)
+        p[WireguardInitiation].mac1 = b'foobar'
+        self.send_and_assert_no_replies(self.pg1, [p])
+        self.assertEqual(1, self.statistics.get_err_counter(
+            wg_input_node_name + "Invalid MAC handshake"))
+
+        # send a handsake from the peer but signed by the wrong key.
+        p = peer_1.mk_handshake(self.pg1,
+                                X25519PrivateKey.generate().public_key())
+        self.send_and_assert_no_replies(self.pg1, [p])
+        self.assertEqual(1, self.statistics.get_err_counter(
+            wg_input_node_name + "Peer error"))
+
+        # send a valid handsake init for which we expect a response
+        p = peer_1.mk_handshake(self.pg1)
+
+        rx = self.send_and_expect(self.pg1, [p], self.pg1)
+
+        peer_1.consume_response(rx[0])
+
+        # route a packet into the wg interface
+        #  this is dropped because the peer is still not initiated
         p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
              IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
              UDP(sport=555, dport=556) /
              Raw())
-        # rx = self.send_and_expect(self.pg0, [p], self.pg1)
-        rx = self.send_and_assert_no_replies(self.pg0, [p])
+        self.send_and_assert_no_replies(self.pg0, [p])
+        self.assertEqual(2, self.statistics.get_err_counter(kp_error))
+
+        # send a data packet from the peer through the tunnel
+        # this completes the handshake
+        p = (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+             UDP(sport=222, dport=223) /
+             Raw())
+        d = peer_1.encrypt_transport(p)
+        p = (peer_1.mk_tunnel_header(self.pg1) /
+             (Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(receiver_index=peer_1.sender,
+                                 counter=0,
+                                 encrypted_encapsulated_packet=d)))
+        rxs = self.send_and_expect(self.pg1, [p], self.pg0)
+
+        for rx in rxs:
+            self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rx[IP].ttl, 19)
+
+        # send a packets that are routed into the tunnel
+        p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+             IP(src=self.pg0.remote_ip4, dst="10.11.3.2") /
+             UDP(sport=555, dport=556) /
+             Raw(b'\x00' * 80))
+
+        rxs = self.send_and_expect(self.pg0, p * 255, self.pg1)
 
-        self.logger.info(self.vapi.cli("sh error"))
-        init_sent = wg_output_node_name + "Keypair error"
-        self.assertEqual(1, self.statistics.get_err_counter(init_sent))
+        for rx in rxs:
+            rx = IP(peer_1.decrypt_transport(rx))
+
+            # chech the oringial packet is present
+            self.assertEqual(rx[IP].dst, p[IP].dst)
+            self.assertEqual(rx[IP].ttl, p[IP].ttl-1)
+
+        # send packets into the tunnel, expect to receive them on
+        # the other side
+        p = [(peer_1.mk_tunnel_header(self.pg1) /
+              Wireguard(message_type=4, reserved_zero=0) /
+              WireguardTransport(
+                  receiver_index=peer_1.sender,
+                  counter=ii+1,
+                  encrypted_encapsulated_packet=peer_1.encrypt_transport(
+                      (IP(src="10.11.3.1", dst=self.pg0.remote_ip4, ttl=20) /
+                       UDP(sport=222, dport=223) /
+                       Raw())))) for ii in range(255)]
+
+        rxs = self.send_and_expect(self.pg1, p, self.pg0)
+
+        for rx in rxs:
+            self.assertEqual(rx[IP].dst, self.pg0.remote_ip4)
+            self.assertEqual(rx[IP].ttl, 19)
+
+        peer_1.remove_vpp_config()
+        wg0.remove_vpp_config()
+
+    def test_wg_multi_peer(self):
+        """ multiple peer setup """
+        port = 12323
+
+        # Create interfaces
+        wg0 = VppWgInterface(self,
+                             self.pg1.local_ip4,
+                             port).add_vpp_config()
+        wg1 = VppWgInterface(self,
+                             self.pg2.local_ip4,
+                             port+1).add_vpp_config()
+        wg0.admin_up()
+        wg1.admin_up()
+
+        # Check peer counter
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
+
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
 
         # Create many peers on sencond interface
         NUM_PEERS = 16
         self.pg2.generate_remote_hosts(NUM_PEERS)
         self.pg2.configure_ipv4_neighbors()
+        self.pg1.generate_remote_hosts(NUM_PEERS)
+        self.pg1.configure_ipv4_neighbors()
 
-        peers = []
+        peers_1 = []
+        peers_2 = []
         for i in range(NUM_PEERS):
-            peers.append(VppWgPeer(self,
-                                   wg1,
-                                   self.pg2.remote_hosts[i].ip4,
-                                   port+1+i,
-                                   ["10.10.%d.4/32" % i]).add_vpp_config())
-            self.assertEqual(len(self.vapi.wireguard_peers_dump()), i+2)
+            peers_1.append(VppWgPeer(self,
+                                     wg0,
+                                     self.pg1.remote_hosts[i].ip4,
+                                     port+1+i,
+                                     ["10.0.%d.4/32" % i]).add_vpp_config())
+            peers_2.append(VppWgPeer(self,
+                                     wg1,
+                                     self.pg2.remote_hosts[i].ip4,
+                                     port+100+i,
+                                     ["10.100.%d.4/32" % i]).add_vpp_config())
+
+        self.assertEqual(len(self.vapi.wireguard_peers_dump()), NUM_PEERS*2)
 
         self.logger.info(self.vapi.cli("show wireguard peer"))
         self.logger.info(self.vapi.cli("show wireguard interface"))
@@ -281,12 +646,12 @@ class TestWg(VppTestCase):
         self.logger.info(self.vapi.cli("sh ip fib 10.11.3.0"))
 
         # remove peers
-        for p in peers:
+        for p in peers_1:
+            self.assertTrue(p.query_vpp_config())
+            p.remove_vpp_config()
+        for p in peers_2:
             self.assertTrue(p.query_vpp_config())
             p.remove_vpp_config()
-        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 1)
-        peer_1.remove_vpp_config()
-        self.assertEqual(len(self.vapi.wireguard_peers_dump()), 0)
 
         wg0.remove_vpp_config()
-        wg1.remove_vpp_config()
+        wg1.remove_vpp_config()
index aa476f7..f54ce71 100755 (executable)
@@ -86,7 +86,7 @@ cookie_checker_validate_macs (vlib_main_t * vm, cookie_checker_t * cc,
   len = len - sizeof (message_macs_t);
   cookie_macs_mac1 (&our_cm, buf, len, cc->cc_mac1_key);
 
-  /* If mac1 is invald, we want to drop the packet */
+  /* If mac1 is invalid, we want to drop the packet */
   if (clib_memcmp (our_cm.mac1, cm->mac1, COOKIE_MAC_SIZE) != 0)
     return INVALID_MAC;
 
index ff8ed35..ed90146 100644 (file)
@@ -42,11 +42,21 @@ format_wg_if (u8 * s, va_list * args)
   key_to_base64 (wgi->local.l_private, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " private-key:%s", key);
+  s =
+    format (s, " %U", format_hex_bytes, wgi->local.l_private,
+           NOISE_PUBLIC_KEY_LEN);
 
   key_to_base64 (wgi->local.l_public, NOISE_PUBLIC_KEY_LEN, key);
 
   s = format (s, " public-key:%s", key);
 
+  s =
+    format (s, " %U", format_hex_bytes, wgi->local.l_public,
+           NOISE_PUBLIC_KEY_LEN);
+
+  s = format (s, " mac-key: %U", format_hex_bytes,
+             &wgi->cookie_checker.cc_mac1_key, NOISE_PUBLIC_KEY_LEN);
+
   return (s);
 }
 
@@ -235,9 +245,6 @@ wg_if_create (u32 user_instance,
   if (~0 == wg_if->user_instance)
     wg_if->user_instance = t_idx;
 
-  udp_dst_port_info_t *pi = udp_get_dst_port_info (&udp_main, port, UDP_IP4);
-  if (pi)
-    return (VNET_API_ERROR_VALUE_EXIST);
   udp_register_dst_port (vlib_get_main (), port, wg_input_node.index, 1);
 
   vec_validate_init_empty (wg_if_index_by_port, port, INDEX_INVALID);
@@ -280,16 +287,17 @@ wg_if_delete (u32 sw_if_index)
 
   vnet_hw_interface_t *hw = vnet_get_sup_hw_interface (vnm, sw_if_index);
   if (hw == 0 || hw->dev_class_index != wg_if_device_class.index)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+    return VNET_API_ERROR_INVALID_VALUE;
 
   wg_if_t *wg_if;
   wg_if = wg_if_get (wg_if_find_by_sw_if_index (sw_if_index));
   if (NULL == wg_if)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+    return VNET_API_ERROR_INVALID_SW_IF_INDEX_2;
 
-  if (wg_if_instance_free (hw->dev_instance) < 0)
-    return VNET_API_ERROR_INVALID_SW_IF_INDEX;
+  if (wg_if_instance_free (wg_if->user_instance) < 0)
+    return VNET_API_ERROR_INVALID_VALUE_2;
 
+  udp_unregister_dst_port (vlib_get_main (), wg_if->port, 1);
   wg_if_index_by_port[wg_if->port] = INDEX_INVALID;
   vnet_delete_hw_interface (vnm, hw->hw_if_index);
   pool_put (wg_if_pool, wg_if);
index 832ad03..cdd65f8 100755 (executable)
@@ -313,12 +313,12 @@ VLIB_NODE_FN (wg_input_node) (vlib_main_t * vm,
            if (entry)
              {
                peer = pool_elt_at_index (wmp->peers, *entry);
-               if (!peer)
-                 {
-                   next[0] = WG_INPUT_NEXT_ERROR;
-                   b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
-                   goto out;
-                 }
+             }
+           else
+             {
+               next[0] = WG_INPUT_NEXT_ERROR;
+               b[0]->error = node->errors[WG_INPUT_ERROR_PEER];
+               goto out;
              }
 
            u16 encr_len = b[0]->current_length - sizeof (message_data_t);
index dc7d506..b47bb57 100755 (executable)
@@ -536,7 +536,7 @@ noise_remote_ready (noise_remote_t * r)
   return ret;
 }
 
-static void
+static bool
 chacha20poly1305_calc (vlib_main_t * vm,
                       u8 * src,
                       u32 src_len,
@@ -580,6 +580,8 @@ chacha20poly1305_calc (vlib_main_t * vm,
     {
       clib_memcpy (dst + src_len, op->tag, NOISE_AUTHTAG_LEN);
     }
+
+  return (op->status == VNET_CRYPTO_OP_STATUS_COMPLETED);
 }
 
 enum noise_state_crypt
@@ -668,9 +670,10 @@ noise_remote_decrypt (vlib_main_t * vm, noise_remote_t * r, uint32_t r_idx,
   /* Decrypt, then validate the counter. We don't want to validate the
    * counter before decrypting as we do not know the message is authentic
    * prior to decryption. */
-  chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
-                        VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
-                        kp->kp_recv_index);
+  if (!chacha20poly1305_calc (vm, src, srclen, dst, NULL, 0, nonce,
+                             VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC,
+                             kp->kp_recv_index))
+    goto error;
 
   if (!noise_counter_recv (&kp->kp_ctr, nonce))
     goto error;
@@ -936,8 +939,9 @@ noise_msg_decrypt (vlib_main_t * vm, uint8_t * dst, uint8_t * src,
                   uint8_t hash[NOISE_HASH_LEN])
 {
   /* Nonce always zero for Noise_IK */
-  chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
-                        VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx);
+  if (!chacha20poly1305_calc (vm, src, src_len, dst, hash, NOISE_HASH_LEN, 0,
+                             VNET_CRYPTO_OP_CHACHA20_POLY1305_DEC, key_idx))
+    return false;
   noise_mix_hash (hash, src, src_len);
   return true;
 }
index daec7a4..cdfd9d7 100755 (executable)
@@ -115,7 +115,8 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
   while (n_left_from > 0)
     {
       ip4_udp_header_t *hdr = vlib_buffer_get_current (b[0]);
-      u8 *plain_data = vlib_buffer_get_current (b[0]) + sizeof (ip4_header_t);
+      u8 *plain_data = (vlib_buffer_get_current (b[0]) +
+                       sizeof (ip4_udp_header_t));
       u16 plain_data_len =
        clib_net_to_host_u16 (((ip4_header_t *) plain_data)->length);
 
@@ -144,8 +145,8 @@ VLIB_NODE_FN (wg_output_tun_node) (vlib_main_t * vm,
        * Ensure there is enough space to write the encrypted data
        * into the packet
        */
-      if (PREDICT_FALSE (encrypted_packet_len > WG_OUTPUT_SCRATCH_SIZE) ||
-         PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) <
+      if (PREDICT_FALSE (encrypted_packet_len >= WG_OUTPUT_SCRATCH_SIZE) ||
+         PREDICT_FALSE ((b[0]->current_data + encrypted_packet_len) >=
                         vlib_buffer_get_default_data_size (vm)))
        {
          b[0]->error = node->errors[WG_OUTPUT_ERROR_TOO_BIG];
index 0dcc4e2..04b07d9 100755 (executable)
@@ -380,15 +380,16 @@ format_wg_peer (u8 * s, va_list * va)
   peer = wg_peer_get (peeri);
   key_to_base64 (peer->remote.r_public, NOISE_PUBLIC_KEY_LEN, key);
 
-  s = format (s, "[%d] key:%=45s endpoint:[%U->%U] %U keep-alive:%d adj:%d",
+  s = format (s, "[%d] endpoint:[%U->%U] %U keep-alive:%d adj:%d",
              peeri,
-             key,
              format_wg_peer_endpoint, &peer->src,
              format_wg_peer_endpoint, &peer->dst,
              format_vnet_sw_if_index_name, vnet_get_main (),
              peer->wg_sw_if_index,
              peer->persistent_keepalive_interval, peer->adj_index);
-
+  s = format (s, "\n  key:%=s %U",
+             key, format_hex_bytes, peer->remote.r_public,
+             NOISE_PUBLIC_KEY_LEN);
   s = format (s, "\n  allowed-ips:");
   vec_foreach (allowed_ip, peer->allowed_ips)
   {