ikev2: fix setting responder/initiator addresses
[vpp.git] / src / plugins / ikev2 / test / test_ikev2.py
index 4f64b56..f75a517 100644 (file)
@@ -1,4 +1,5 @@
 import os
+from socket import inet_pton
 from cryptography import x509
 from cryptography.hazmat.backends import default_backend
 from cryptography.hazmat.primitives import hashes, hmac
@@ -510,8 +511,10 @@ class IKEv2SA(object):
     def esp_crypto_attr(self):
         return self.crypto_attr(self.esp_crypto_key_len)
 
-    def compute_nat_sha1(self, ip, port):
-        data = self.ispi + self.rspi + ip + (port).to_bytes(2, 'big')
+    def compute_nat_sha1(self, ip, port, rspi=None):
+        if rspi is None:
+            rspi = self.rspi
+        data = self.ispi + rspi + ip + (port).to_bytes(2, 'big')
         digest = hashes.Hash(hashes.SHA1(), backend=default_backend())
         digest.update(data)
         return digest.finalize()
@@ -775,6 +778,36 @@ class TemplateInitiator(IkePeer):
     def tearDown(self):
         super(TemplateInitiator, self).tearDown()
 
+    @staticmethod
+    def find_notify_payload(packet, notify_type):
+        n = packet[ikev2.IKEv2_payload_Notify]
+        while n is not None:
+            if n.type == notify_type:
+                return n
+            n = n.payload
+        return None
+
+    def verify_nat_detection(self, packet):
+        if self.ip6:
+            iph = packet[IPv6]
+        else:
+            iph = packet[IP]
+        udp = packet[UDP]
+
+        # NAT_DETECTION_SOURCE_IP
+        s = self.find_notify_payload(packet, 16388)
+        self.assertIsNotNone(s)
+        src_sha = self.sa.compute_nat_sha1(
+                inet_pton(socket.AF_INET, iph.src), udp.sport, b'\x00' * 8)
+        self.assertEqual(s.load, src_sha)
+
+        # NAT_DETECTION_DESTINATION_IP
+        s = self.find_notify_payload(packet, 16389)
+        self.assertIsNotNone(s)
+        dst_sha = self.sa.compute_nat_sha1(
+                inet_pton(socket.AF_INET, iph.dst), udp.dport, b'\x00' * 8)
+        self.assertEqual(s.load, dst_sha)
+
     def verify_sa_init_request(self, packet):
         ih = packet[ikev2.IKEv2]
         self.assertNotEqual(ih.init_SPI, 8 * b'\x00')
@@ -798,6 +831,7 @@ class TemplateInitiator(IkePeer):
         self.assertEqual(prop.trans[2].transform_id,
                          self.p.ike_transforms['dh_group'])
 
+        self.verify_nat_detection(packet)
         self.sa.complete_dh_data()
         self.sa.calc_keys()
 
@@ -957,11 +991,6 @@ class TemplateResponder(IkePeer):
         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='IKEv2',
                  trans_nb=4, trans=trans))
 
-        if behind_nat:
-            next_payload = 'Notify'
-        else:
-            next_payload = None
-
         self.sa.init_req_packet = (
                 ikev2.IKEv2(init_SPI=self.sa.ispi,
                             flags='Initiator', exch_type='IKE_SA_INIT') /
@@ -969,19 +998,21 @@ class TemplateResponder(IkePeer):
                 ikev2.IKEv2_payload_KE(next_payload='Nonce',
                                        group=self.sa.ike_dh,
                                        load=self.sa.my_dh_pub_key) /
-                ikev2.IKEv2_payload_Nonce(next_payload=next_payload,
+                ikev2.IKEv2_payload_Nonce(next_payload='Notify',
                                           load=self.sa.i_nonce))
 
         if behind_nat:
             src_address = b'\x0a\x0a\x0a\x01'
         else:
-            src_address = bytes(self.pg0.local_ip4, 'ascii')
+            src_address = inet_pton(socket.AF_INET, self.pg0.remote_ip4)
 
         src_nat = self.sa.compute_nat_sha1(src_address, self.sa.sport)
-        dst_nat = self.sa.compute_nat_sha1(bytes(self.pg0.remote_ip4, 'ascii'),
-                                           self.sa.sport)
+        dst_nat = self.sa.compute_nat_sha1(
+                inet_pton(socket.AF_INET, self.pg0.local_ip4),
+                self.sa.dport)
         nat_src_detection = ikev2.IKEv2_payload_Notify(
-                type='NAT_DETECTION_SOURCE_IP', load=src_nat)
+                type='NAT_DETECTION_SOURCE_IP', load=src_nat,
+                next_payload='Notify')
         nat_dst_detection = ikev2.IKEv2_payload_Notify(
                 type='NAT_DETECTION_DESTINATION_IP', load=dst_nat)
         self.sa.init_req_packet = (self.sa.init_req_packet /