ikev2: support ipv6 traffic selectors & overlay
[vpp.git] / src / plugins / ikev2 / test / test_ikev2.py
index 0bdc417..a47c59f 100644 (file)
@@ -9,15 +9,20 @@ from cryptography.hazmat.primitives.ciphers import (
     algorithms,
     modes,
 )
-from ipaddress import IPv4Address
+from ipaddress import IPv4Address, IPv6Address, ip_address
 from scapy.layers.ipsec import ESP
 from scapy.layers.inet import IP, UDP, Ether
+from scapy.layers.inet6 import IPv6
 from scapy.packet import raw, Raw
 from scapy.utils import long_converter
 from framework import VppTestCase, VppTestRunner
 from vpp_ikev2 import Profile, IDType, AuthMethod
 from vpp_papi import VppEnum
 
+try:
+    text_type = unicode
+except NameError:
+    text_type = str
 
 KEY_PAD = b"Key Pad for IKEv2"
 SALT_SIZE = 4
@@ -156,9 +161,9 @@ class IKEv2ChildSA(object):
 
 
 class IKEv2SA(object):
-    def __init__(self, test, is_initiator=True, spi=b'\x04' * 8,
-                 i_id=None, r_id=None, id_type='fqdn', nonce=None,
-                 auth_data=None, local_ts=None, remote_ts=None,
+    def __init__(self, test, is_initiator=True, i_id=None, r_id=None,
+                 spi=b'\x01\x02\x03\x04\x05\x06\x07\x08', id_type='fqdn',
+                 nonce=None, auth_data=None, local_ts=None, remote_ts=None,
                  auth_method='shared-key', priv_key=None, natt=False):
         self.natt = natt
         if natt:
@@ -182,12 +187,12 @@ class IKEv2SA(object):
             self.id_type = id_type
         self.auth_method = auth_method
         if self.is_initiator:
-            self.rspi = None
+            self.rspi = 8 * b'\x00'
             self.ispi = spi
             self.i_nonce = nonce
         else:
             self.rspi = spi
-            self.ispi = None
+            self.ispi = 8 * b'\x00'
             self.r_nonce = None
         self.child_sas = [IKEv2ChildSA(local_ts, remote_ts)]
 
@@ -416,18 +421,25 @@ class IKEv2SA(object):
             ct = ep.load[:-integ_trunc]
             return self.decrypt(ct)
 
-    def generate_ts(self):
+    def build_ts_addr(self, ts, version):
+        return {'starting_address_v' + version: ts['start_addr'],
+                'ending_address_v' + version: ts['end_addr']}
+
+    def generate_ts(self, is_ip4):
         c = self.child_sas[0]
-        ts1 = ikev2.IPv4TrafficSelector(
-                IP_protocol_ID=0,
-                start_port=0,
-                end_port=0xffff,
-                starting_address_v4=c.local_ts['start_addr'],
-                ending_address_v4=c.local_ts['end_addr'])
-        ts2 = ikev2.IPv4TrafficSelector(
-                IP_protocol_ID=0,
-                starting_address_v4=c.remote_ts['start_addr'],
-                ending_address_v4=c.remote_ts['end_addr'])
+        ts_data = {'IP_protocol_ID': 0,
+                   'start_port': 0,
+                   'end_port': 0xffff}
+        if is_ip4:
+            ts_data.update(self.build_ts_addr(c.local_ts, '4'))
+            ts1 = ikev2.IPv4TrafficSelector(**ts_data)
+            ts_data.update(self.build_ts_addr(c.remote_ts, '4'))
+            ts2 = ikev2.IPv4TrafficSelector(**ts_data)
+        else:
+            ts_data.update(self.build_ts_addr(c.local_ts, '6'))
+            ts1 = ikev2.IPv6TrafficSelector(**ts_data)
+            ts_data.update(self.build_ts_addr(c.remote_ts, '6'))
+            ts2 = ikev2.IPv6TrafficSelector(**ts_data)
         return ([ts1], [ts2])
 
     def set_ike_props(self, crypto, crypto_key_len, integ, prf, dh):
@@ -474,7 +486,7 @@ class IKEv2SA(object):
         return self.crypto_attr(self.esp_crypto_key_len)
 
     def compute_nat_sha1(self, ip, port):
-        data = self.ispi + b'\x00' * 8 + ip + (port).to_bytes(2, 'big')
+        data = self.ispi + self.rspi + ip + (port).to_bytes(2, 'big')
         digest = hashes.Hash(hashes.SHA1(), backend=default_backend())
         digest.update(data)
         return digest.finalize()
@@ -493,6 +505,8 @@ class TemplateResponder(VppTestCase):
             i.admin_up()
             i.config_ip4()
             i.resolve_arp()
+            i.config_ip6()
+            i.resolve_ndp()
 
     @classmethod
     def tearDownClass(cls):
@@ -504,6 +518,8 @@ class TemplateResponder(VppTestCase):
         self.p.add_vpp_config()
         self.assertIsNotNone(self.p.query_vpp_config())
         self.sa.generate_dh_data()
+        self.vapi.cli('ikev2 set logging level 4')
+        self.vapi.cli('event-lo clear')
 
     def tearDown(self):
         super(TemplateResponder, self).tearDown()
@@ -528,16 +544,25 @@ class TemplateResponder(VppTestCase):
         ike_msg = self.encrypt_ike_msg(header, del_sa, 'Delete')
         packet = self.create_packet(self.pg0, ike_msg,
                                     self.sa.sport, self.sa.dport,
-                                    self.sa.natt)
+                                    self.sa.natt, self.ip6)
         self.pg0.add_stream(packet)
         self.pg0.enable_capture()
         self.pg_start()
         capture = self.pg0.get_capture(1)
         self.verify_del_sa(capture[0])
 
-    def create_packet(self, src_if, msg, sport=500, dport=500, natt=False):
+    def create_packet(self, src_if, msg, sport=500, dport=500, natt=False,
+                      use_ip6=False):
+        if use_ip6:
+            src_ip = src_if.remote_ip6
+            dst_ip = src_if.local_ip6
+            ip_layer = IPv6
+        else:
+            src_ip = src_if.remote_ip4
+            dst_ip = src_if.local_ip4
+            ip_layer = IP
         res = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
-               IP(src=src_if.remote_ip4, dst=src_if.local_ip4) /
+               ip_layer(src=src_ip, dst=dst_ip) /
                UDP(sport=sport, dport=dport))
         if natt:
             # insert non ESP marker
@@ -575,16 +600,24 @@ class TemplateResponder(VppTestCase):
                                           load=self.sa.i_nonce))
 
         if behind_nat:
-            src_nat = self.sa.compute_nat_sha1(b'\x0a\x0a\x0a\x01',
-                                               self.sa.sport)
-            nat_detection = ikev2.IKEv2_payload_Notify(
-                    type='NAT_DETECTION_SOURCE_IP',
-                    load=src_nat)
-            self.sa.init_req_packet = self.sa.init_req_packet / nat_detection
+            src_address = b'\x0a\x0a\x0a\x01'
+        else:
+            src_address = bytes(self.pg0.local_ip4, 'ascii')
+
+        src_nat = self.sa.compute_nat_sha1(src_address, self.sa.sport)
+        dst_nat = self.sa.compute_nat_sha1(bytes(self.pg0.remote_ip4, 'ascii'),
+                                           self.sa.sport)
+        nat_src_detection = ikev2.IKEv2_payload_Notify(
+                type='NAT_DETECTION_SOURCE_IP', load=src_nat)
+        nat_dst_detection = ikev2.IKEv2_payload_Notify(
+                type='NAT_DETECTION_DESTINATION_IP', load=dst_nat)
+        self.sa.init_req_packet = (self.sa.init_req_packet /
+                                   nat_src_detection /
+                                   nat_dst_detection)
 
         ike_msg = self.create_packet(self.pg0, self.sa.init_req_packet,
                                      self.sa.sport, self.sa.dport,
-                                     self.sa.natt)
+                                     self.sa.natt, self.ip6)
         self.pg0.add_stream(ike_msg)
         self.pg0.enable_capture()
         self.pg_start()
@@ -642,7 +675,7 @@ class TemplateResponder(VppTestCase):
         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='ESP',
                  SPIsize=4, SPI=os.urandom(4), trans_nb=4, trans=trans))
 
-        tsi, tsr = self.sa.generate_ts()
+        tsi, tsr = self.sa.generate_ts(self.p.ts_is_ip4)
         plain = (ikev2.IKEv2_payload_IDi(next_payload='IDr',
                  IDtype=self.sa.id_type, load=self.sa.i_id) /
                  ikev2.IKEv2_payload_IDr(next_payload='AUTH',
@@ -666,7 +699,7 @@ class TemplateResponder(VppTestCase):
 
         ike_msg = self.encrypt_ike_msg(header, plain, 'IDi')
         packet = self.create_packet(self.pg0, ike_msg, self.sa.sport,
-                                    self.sa.dport, self.sa.natt)
+                                    self.sa.dport, self.sa.natt, self.ip6)
         self.pg0.add_stream(packet)
         self.pg0.enable_capture()
         self.pg_start()
@@ -681,6 +714,8 @@ class TemplateResponder(VppTestCase):
             # ipsec register for port 4500
             esp = packet[ESP]
             ih = self.verify_and_remove_non_esp_marker(esp)
+
+        self.assertEqual(ih.version, 0x20)
         return ih
 
     def verify_sa_init(self, packet):
@@ -776,10 +811,14 @@ class TemplateResponder(VppTestCase):
         r = self.vapi.ikev2_sa_dump()
         self.assertEqual(len(r), 1)
         sa = r[0].sa
-        self.assertEqual(self.sa.ispi, (sa.ispi).to_bytes(8, 'little'))
+        self.assertEqual(self.sa.ispi, (sa.ispi).to_bytes(8, 'big'))
         self.assertEqual(self.sa.rspi, (sa.rspi).to_bytes(8, 'big'))
-        self.assertEqual(sa.iaddr, IPv4Address(self.pg0.remote_ip4))
-        self.assertEqual(sa.raddr, IPv4Address(self.pg0.local_ip4))
+        if self.ip6:
+            self.assertEqual(sa.iaddr, IPv6Address(self.pg0.remote_ip6))
+            self.assertEqual(sa.raddr, IPv6Address(self.pg0.local_ip6))
+        else:
+            self.assertEqual(sa.iaddr, IPv4Address(self.pg0.remote_ip4))
+            self.assertEqual(sa.raddr, IPv4Address(self.pg0.local_ip4))
         self.verify_keymat(sa.keys, self.sa, 'sk_d')
         self.verify_keymat(sa.keys, self.sa, 'sk_ai')
         self.verify_keymat(sa.keys, self.sa, 'sk_ar')
@@ -806,7 +845,7 @@ class TemplateResponder(VppTestCase):
         self.verify_keymat(csa.keys, c, 'sk_ei')
         self.verify_keymat(csa.keys, c, 'sk_er')
 
-        tsi, tsr = self.sa.generate_ts()
+        tsi, tsr = self.sa.generate_ts(self.p.ts_is_ip4)
         tsi = tsi[0]
         tsr = tsr[0]
         r = self.vapi.ikev2_traffic_selector_dump(
@@ -838,10 +877,17 @@ class TemplateResponder(VppTestCase):
             self.assertTrue(api_ts.is_local)
         else:
             self.assertFalse(api_ts.is_local)
-        self.assertEqual(api_ts.start_addr,
-                         IPv4Address(ts.starting_address_v4))
-        self.assertEqual(api_ts.end_addr,
-                         IPv4Address(ts.ending_address_v4))
+
+        if self.p.ts_is_ip4:
+            self.assertEqual(api_ts.start_addr,
+                             IPv4Address(ts.starting_address_v4))
+            self.assertEqual(api_ts.end_addr,
+                             IPv4Address(ts.ending_address_v4))
+        else:
+            self.assertEqual(api_ts.start_addr,
+                             IPv6Address(ts.starting_address_v6))
+            self.assertEqual(api_ts.end_addr,
+                             IPv6Address(ts.ending_address_v6))
         self.assertEqual(api_ts.start_port, ts.start_port)
         self.assertEqual(api_ts.end_port, ts.end_port)
         self.assertEqual(api_ts.protocol_id, ts.IP_protocol_ID)
@@ -872,6 +918,7 @@ class Ikev2Params(object):
 
         is_natt = 'natt' in params and params['natt'] or False
         self.p = Profile(self, 'pr1')
+        self.ip6 = False if 'ip6' not in params else params['ip6']
 
         if 'auth' in params and params['auth'] == 'rsa-sig':
             auth_method = 'rsa-sig'
@@ -897,8 +944,12 @@ class Ikev2Params(object):
 
         self.p.add_local_id(id_type='fqdn', data=b'vpp.home')
         self.p.add_remote_id(id_type='fqdn', data=b'roadwarrior.example.com')
-        self.p.add_local_ts(start_addr='10.10.10.0', end_addr='10.10.10.255')
-        self.p.add_remote_ts(start_addr='10.0.0.0', end_addr='10.0.0.255')
+        loc_ts = {'start_addr': '10.10.10.0', 'end_addr': '10.10.10.255'} if\
+            'loc_ts' not in params else params['loc_ts']
+        rem_ts = {'start_addr': '10.0.0.0', 'end_addr': '10.0.0.255'} if\
+            'rem_ts' not in params else params['rem_ts']
+        self.p.add_local_ts(**loc_ts)
+        self.p.add_remote_ts(**rem_ts)
 
         self.sa = IKEv2SA(self, i_id=self.p.remote_id['data'],
                           r_id=self.p.local_id['data'],
@@ -964,14 +1015,14 @@ class TestApi(VppTestCase):
 
     def test_profile_api(self):
         """ test profile dump API """
-        loc_ts = {
+        loc_ts4 = {
                     'proto': 8,
                     'start_port': 1,
                     'end_port': 19,
                     'start_addr': '3.3.3.2',
                     'end_addr': '3.3.3.3',
                 }
-        rem_ts = {
+        rem_ts4 = {
                     'proto': 9,
                     'start_port': 10,
                     'end_port': 119,
@@ -979,14 +1030,29 @@ class TestApi(VppTestCase):
                     'end_addr': '2.3.4.6',
                 }
 
+        loc_ts6 = {
+                    'proto': 8,
+                    'start_port': 1,
+                    'end_port': 19,
+                    'start_addr': 'ab::1',
+                    'end_addr': 'ab::4',
+                }
+        rem_ts6 = {
+                    'proto': 9,
+                    'start_port': 10,
+                    'end_port': 119,
+                    'start_addr': 'cd::12',
+                    'end_addr': 'cd::13',
+                }
+
         conf = {
             'p1': {
                 'name': 'p1',
                 'loc_id': ('fqdn', b'vpp.home'),
                 'rem_id': ('fqdn', b'roadwarrior.example.com'),
-                'loc_ts': loc_ts,
-                'rem_ts': rem_ts,
-                'responder': {'sw_if_index': 0, 'ip4': '5.6.7.8'},
+                'loc_ts': loc_ts4,
+                'rem_ts': rem_ts4,
+                'responder': {'sw_if_index': 0, 'addr': '5.6.7.8'},
                 'ike_ts': {
                         'crypto_alg': 20,
                         'crypto_key_size': 32,
@@ -1008,10 +1074,10 @@ class TestApi(VppTestCase):
             'p2': {
                 'name': 'p2',
                 'loc_id': ('ip4-addr', b'192.168.2.1'),
-                'rem_id': ('ip4-addr', b'192.168.2.2'),
-                'loc_ts': loc_ts,
-                'rem_ts': rem_ts,
-                'responder': {'sw_if_index': 4, 'ip4': '5.6.7.99'},
+                'rem_id': ('ip6-addr', b'abcd::1'),
+                'loc_ts': loc_ts6,
+                'rem_ts': rem_ts6,
+                'responder': {'sw_if_index': 4, 'addr': 'def::10'},
                 'ike_ts': {
                         'crypto_alg': 12,
                         'crypto_key_size': 16,
@@ -1042,12 +1108,14 @@ class TestApi(VppTestCase):
         self.assertEqual(api_ts.protocol_id, cfg_ts['proto'])
         self.assertEqual(api_ts.start_port, cfg_ts['start_port'])
         self.assertEqual(api_ts.end_port, cfg_ts['end_port'])
-        self.assertEqual(api_ts.start_addr, IPv4Address(cfg_ts['start_addr']))
-        self.assertEqual(api_ts.end_addr, IPv4Address(cfg_ts['end_addr']))
+        self.assertEqual(api_ts.start_addr,
+                         ip_address(text_type(cfg_ts['start_addr'])))
+        self.assertEqual(api_ts.end_addr,
+                         ip_address(text_type(cfg_ts['end_addr'])))
 
     def verify_responder(self, api_r, cfg_r):
         self.assertEqual(api_r.sw_if_index, cfg_r['sw_if_index'])
-        self.assertEqual(api_r.ip4, IPv4Address(cfg_r['ip4']))
+        self.assertEqual(api_r.addr, ip_address(cfg_r['addr']))
 
     def verify_transforms(self, api_ts, cfg_ts):
         self.assertEqual(api_ts.crypto_alg, cfg_ts['crypto_alg'])
@@ -1150,9 +1218,15 @@ class Test_IKE_AES_GCM_16_256(TemplateResponder, Ikev2Params):
     """
     def config_tc(self):
         self.config_params({
+            'ip6': True,
+            'natt': True,
             'ike-crypto': ('AES-GCM-16ICV', 32),
             'ike-integ': 'NULL',
-            'ike-dh': '2048MODPgr'})
+            'ike-dh': '2048MODPgr',
+            'loc_ts': {'start_addr': 'ab:cd::0',
+                       'end_addr': 'ab:cd::10'},
+            'rem_ts': {'start_addr': '11::0',
+                       'end_addr': '11::100'}})
 
 
 class TestMalformedMessages(TemplateResponder, Ikev2Params):
@@ -1164,8 +1238,8 @@ class TestMalformedMessages(TemplateResponder, Ikev2Params):
     def config_tc(self):
         self.config_params()
 
-    def assert_counter(self, count, name):
-        node_name = '/err/ikev2/' + name
+    def assert_counter(self, count, name, version='ip4'):
+        node_name = '/err/ikev2-%s/' % version + name
         self.assertEqual(count, self.statistics.get_err_counter(node_name))
 
     def create_ike_init_msg(self, length=None, payload=None):