ikev2: better packet parsing functions
[vpp.git] / src / plugins / ikev2 / test / test_ikev2.py
index 6116ebb..0bdc417 100644 (file)
@@ -114,7 +114,7 @@ class CryptoAlgo(object):
     def pad(self, data):
         pad_len = (len(data) // self.bs + 1) * self.bs - len(data)
         data = data + b'\x00' * (pad_len - 1)
-        return data + bytes([pad_len])
+        return data + bytes([pad_len - 1])
 
 
 class AuthAlgo(object):
@@ -167,6 +167,7 @@ class IKEv2SA(object):
         else:
             self.sport = 500
             self.dport = 500
+        self.msg_id = 0
         self.dh_params = None
         self.test = test
         self.priv_key = priv_key
@@ -190,6 +191,10 @@ class IKEv2SA(object):
             self.r_nonce = None
         self.child_sas = [IKEv2ChildSA(local_ts, remote_ts)]
 
+    def new_msg_id(self):
+        self.msg_id += 1
+        return self.msg_id
+
     def dh_pub_key(self):
         return self.i_dh_data
 
@@ -502,10 +507,35 @@ class TemplateResponder(VppTestCase):
 
     def tearDown(self):
         super(TemplateResponder, self).tearDown()
+        if self.sa.is_initiator:
+            self.initiate_del_sa()
+            r = self.vapi.ikev2_sa_dump()
+            self.assertEqual(len(r), 0)
+
         self.p.remove_vpp_config()
         self.assertIsNone(self.p.query_vpp_config())
 
-    def create_ike_msg(self, src_if, msg, sport=500, dport=500, natt=False):
+    def verify_del_sa(self, packet):
+        ih = self.get_ike_header(packet)
+        self.assertEqual(ih.id, self.sa.msg_id)
+        self.assertEqual(ih.exch_type, 37)  # exchange informational
+
+    def initiate_del_sa(self):
+        header = ikev2.IKEv2(init_SPI=self.sa.ispi, resp_SPI=self.sa.rspi,
+                             flags='Initiator', exch_type='INFORMATIONAL',
+                             id=self.sa.new_msg_id())
+        del_sa = ikev2.IKEv2_payload_Delete(proto='IKEv2')
+        ike_msg = self.encrypt_ike_msg(header, del_sa, 'Delete')
+        packet = self.create_packet(self.pg0, ike_msg,
+                                    self.sa.sport, self.sa.dport,
+                                    self.sa.natt)
+        self.pg0.add_stream(packet)
+        self.pg0.enable_capture()
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        self.verify_del_sa(capture[0])
+
+    def create_packet(self, src_if, msg, sport=500, dport=500, natt=False):
         res = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
                IP(src=src_if.remote_ip4, dst=src_if.local_ip4) /
                UDP(sport=sport, dport=dport))
@@ -552,15 +582,49 @@ class TemplateResponder(VppTestCase):
                     load=src_nat)
             self.sa.init_req_packet = self.sa.init_req_packet / nat_detection
 
-        ike_msg = self.create_ike_msg(self.pg0, self.sa.init_req_packet,
-                                      self.sa.sport, self.sa.dport,
-                                      self.sa.natt)
+        ike_msg = self.create_packet(self.pg0, self.sa.init_req_packet,
+                                     self.sa.sport, self.sa.dport,
+                                     self.sa.natt)
         self.pg0.add_stream(ike_msg)
         self.pg0.enable_capture()
         self.pg_start()
         capture = self.pg0.get_capture(1)
         self.verify_sa_init(capture[0])
 
+    def encrypt_ike_msg(self, header, plain, first_payload):
+        if self.sa.ike_crypto == 'AES-GCM-16ICV':
+            data = self.sa.ike_crypto_alg.pad(raw(plain))
+            plen = len(data) + GCM_IV_SIZE + GCM_ICV_SIZE +\
+                len(ikev2.IKEv2_payload_Encrypted())
+            tlen = plen + len(ikev2.IKEv2())
+
+            # prepare aad data
+            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload=first_payload,
+                                                 length=plen)
+            header.length = tlen
+            res = header / sk_p
+            encr = self.sa.encrypt(raw(plain), raw(res))
+            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload=first_payload,
+                                                 length=plen, load=encr)
+            res = header / sk_p
+        else:
+            encr = self.sa.encrypt(raw(plain))
+            trunc_len = self.sa.ike_integ_alg.trunc_len
+            plen = len(encr) + len(ikev2.IKEv2_payload_Encrypted()) + trunc_len
+            tlen = plen + len(ikev2.IKEv2())
+
+            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload=first_payload,
+                                                 length=plen, load=encr)
+            header.length = tlen
+            res = header / sk_p
+
+            integ_data = raw(res)
+            hmac_data = self.sa.compute_hmac(self.sa.ike_integ_alg.mod(),
+                                             self.sa.my_authkey, integ_data)
+            res = res / Raw(hmac_data[:trunc_len])
+        assert(len(res) == tlen)
+        return res
+
     def send_sa_auth(self):
         tr_attr = self.sa.esp_crypto_attr()
         trans = (ikev2.IKEv2_payload_Transform(transform_type='Encryption',
@@ -595,48 +659,14 @@ class TemplateResponder(VppTestCase):
                  traffic_selector=tsr) /
                  ikev2.IKEv2_payload_Notify(type='INITIAL_CONTACT'))
 
-        if self.sa.ike_crypto == 'AES-GCM-16ICV':
-            data = self.sa.ike_crypto_alg.pad(raw(plain))
-            plen = len(data) + GCM_IV_SIZE + GCM_ICV_SIZE +\
-                len(ikev2.IKEv2_payload_Encrypted())
-            tlen = plen + len(ikev2.IKEv2())
-
-            # prepare aad data
-            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
-                                                 length=plen)
-            sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
-                       resp_SPI=self.sa.rspi, id=1,
-                       length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
-            sa_auth /= sk_p
+        header = ikev2.IKEv2(
+                init_SPI=self.sa.ispi,
+                resp_SPI=self.sa.rspi, id=self.sa.new_msg_id(),
+                flags='Initiator', exch_type='IKE_AUTH')
 
-            encr = self.sa.encrypt(raw(plain), raw(sa_auth))
-            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
-                                                 length=plen, load=encr)
-            sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
-                       resp_SPI=self.sa.rspi, id=1,
-                       length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
-            sa_auth /= sk_p
-        else:
-            encr = self.sa.encrypt(raw(plain))
-            trunc_len = self.sa.ike_integ_alg.trunc_len
-            plen = len(encr) + len(ikev2.IKEv2_payload_Encrypted()) + trunc_len
-            tlen = plen + len(ikev2.IKEv2())
-
-            sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
-                                                 length=plen, load=encr)
-            sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
-                       resp_SPI=self.sa.rspi, id=1,
-                       length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
-            sa_auth /= sk_p
-
-            integ_data = raw(sa_auth)
-            hmac_data = self.sa.compute_hmac(self.sa.ike_integ_alg.mod(),
-                                             self.sa.my_authkey, integ_data)
-            sa_auth = sa_auth / Raw(hmac_data[:trunc_len])
-
-        assert(len(sa_auth) == tlen)
-        packet = self.create_ike_msg(self.pg0, sa_auth, self.sa.sport,
-                                     self.sa.dport, self.sa.natt)
+        ike_msg = self.encrypt_ike_msg(header, plain, 'IDi')
+        packet = self.create_packet(self.pg0, ike_msg, self.sa.sport,
+                                    self.sa.dport, self.sa.natt)
         self.pg0.add_stream(packet)
         self.pg0.enable_capture()
         self.pg_start()
@@ -656,6 +686,7 @@ class TemplateResponder(VppTestCase):
     def verify_sa_init(self, packet):
         ih = self.get_ike_header(packet)
 
+        self.assertEqual(ih.id, self.sa.msg_id)
         self.assertEqual(ih.exch_type, 34)
         self.assertTrue('Response' in ih.flags)
         self.assertEqual(ih.init_SPI, self.sa.ispi)
@@ -691,6 +722,7 @@ class TemplateResponder(VppTestCase):
         ike = self.get_ike_header(packet)
         udp = packet[UDP]
         self.verify_udp(udp)
+        self.assertEqual(ike.id, self.sa.msg_id)
         plain = self.sa.hmac_and_decrypt(ike)
         self.sa.calc_child_keys()
 
@@ -1123,5 +1155,43 @@ class Test_IKE_AES_GCM_16_256(TemplateResponder, Ikev2Params):
             'ike-dh': '2048MODPgr'})
 
 
+class TestMalformedMessages(TemplateResponder, Ikev2Params):
+    """ malformed packet test """
+
+    def tearDown(self):
+        pass
+
+    def config_tc(self):
+        self.config_params()
+
+    def assert_counter(self, count, name):
+        node_name = '/err/ikev2/' + name
+        self.assertEqual(count, self.statistics.get_err_counter(node_name))
+
+    def create_ike_init_msg(self, length=None, payload=None):
+        msg = ikev2.IKEv2(length=length, init_SPI='\x11' * 8,
+                          flags='Initiator', exch_type='IKE_SA_INIT')
+        if payload is not None:
+            msg /= payload
+        return self.create_packet(self.pg0, msg, self.sa.sport,
+                                  self.sa.dport)
+
+    def verify_bad_packet_length(self):
+        ike_msg = self.create_ike_init_msg(length=0xdead)
+        self.send_and_assert_no_replies(self.pg0, ike_msg * self.pkt_count)
+        self.assert_counter(self.pkt_count, 'Bad packet length')
+
+    def verify_bad_sa_payload_length(self):
+        p = ikev2.IKEv2_payload_SA(length=0xdead)
+        ike_msg = self.create_ike_init_msg(payload=p)
+        self.send_and_assert_no_replies(self.pg0, ike_msg * self.pkt_count)
+        self.assert_counter(self.pkt_count, 'Malformed packet')
+
+    def test_responder(self):
+        self.pkt_count = 254
+        self.verify_bad_packet_length()
+        self.verify_bad_sa_payload_length()
+
+
 if __name__ == '__main__':
     unittest.main(testRunner=VppTestRunner)