ikev2: use both local and remote ID for profile lookup
[vpp.git] / src / plugins / ikev2 / test / test_ikev2.py
1 import os
2 from cryptography.hazmat.backends import default_backend
3 from cryptography.hazmat.primitives import hashes, hmac
4 from cryptography.hazmat.primitives.asymmetric import dh
5 from cryptography.hazmat.primitives.ciphers import (
6     Cipher,
7     algorithms,
8     modes,
9 )
10 from scapy.layers.inet import IP, UDP, Ether
11 from scapy.packet import raw, Raw
12 from scapy.utils import long_converter
13 from framework import VppTestCase, VppTestRunner
14 from vpp_ikev2 import Profile, IDType
15
16
17 KEY_PAD = b"Key Pad for IKEv2"
18
19
20 # defined in rfc3526
21 # tuple structure is (p, g, key_len)
22 DH = {
23     '2048MODPgr': (long_converter("""
24     FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
25     29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
26     EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
27     E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
28     EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
29     C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
30     83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
31     670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
32     E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
33     DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
34     15728E5A 8AACAA68 FFFFFFFF FFFFFFFF"""), 2, 256)
35 }
36
37
38 class CryptoAlgo(object):
39     def __init__(self, name, cipher, mode):
40         self.name = name
41         self.cipher = cipher
42         self.mode = mode
43         if self.cipher is not None:
44             self.bs = self.cipher.block_size // 8
45
46     def encrypt(self, data, key):
47         iv = os.urandom(self.bs)
48         encryptor = Cipher(self.cipher(key), self.mode(iv),
49                            default_backend()).encryptor()
50         return iv + encryptor.update(data) + encryptor.finalize()
51
52     def decrypt(self, data, key, icv=None):
53         iv = data[:self.bs]
54         ct = data[self.bs:]
55         decryptor = Cipher(algorithms.AES(key),
56                            modes.CBC(iv),
57                            default_backend()).decryptor()
58         return decryptor.update(ct) + decryptor.finalize()
59
60     def pad(self, data):
61         pad_len = (len(data) // self.bs + 1) * self.bs - len(data)
62         data = data + b'\x00' * (pad_len - 1)
63         return data + bytes([pad_len])
64
65
66 class AuthAlgo(object):
67     def __init__(self, name, mac, mod, key_len, trunc_len=None):
68         self.name = name
69         self.mac = mac
70         self.mod = mod
71         self.key_len = key_len
72         self.trunc_len = trunc_len or key_len
73
74
75 CRYPTO_ALGOS = {
76     'NULL': CryptoAlgo('NULL', cipher=None, mode=None),
77     'AES-CBC': CryptoAlgo('AES-CBC', cipher=algorithms.AES, mode=modes.CBC),
78 }
79
80 AUTH_ALGOS = {
81     'NULL': AuthAlgo('NULL', mac=None, mod=None, key_len=0, trunc_len=0),
82     'HMAC-SHA1-96': AuthAlgo('HMAC-SHA1-96', hmac.HMAC, hashes.SHA1, 20, 12),
83 }
84
85 PRF_ALGOS = {
86     'NULL': AuthAlgo('NULL', mac=None, mod=None, key_len=0, trunc_len=0),
87     'PRF_HMAC_SHA2_256': AuthAlgo('PRF_HMAC_SHA2_256', hmac.HMAC,
88                                   hashes.SHA256, 32),
89 }
90
91
92 class IKEv2ChildSA(object):
93     def __init__(self, local_ts, remote_ts, spi=None):
94         self.spi = spi or os.urandom(4)
95         self.local_ts = local_ts
96         self.remote_ts = remote_ts
97
98
99 class IKEv2SA(object):
100     def __init__(self, test, is_initiator=True, spi=b'\x04' * 8,
101                  i_id=None, r_id=None, id_type='fqdn', nonce=None,
102                  auth_data=None, local_ts=None, remote_ts=None,
103                  auth_method='shared-key'):
104         self.dh_params = None
105         self.test = test
106         self.is_initiator = is_initiator
107         nonce = nonce or os.urandom(32)
108         self.auth_data = auth_data
109         self.i_id = i_id
110         self.r_id = r_id
111         if isinstance(id_type, str):
112             self.id_type = IDType.value(id_type)
113         else:
114             self.id_type = id_type
115         self.auth_method = auth_method
116         if self.is_initiator:
117             self.rspi = None
118             self.ispi = spi
119             self.i_nonce = nonce
120         else:
121             self.rspi = spi
122             self.ispi = None
123             self.r_nonce = None
124         self.child_sas = [IKEv2ChildSA(local_ts, remote_ts)]
125
126     def dh_pub_key(self):
127         return self.i_dh_data
128
129     def compute_secret(self):
130         priv = self.dh_private_key
131         peer = self.r_dh_data
132         p, g, l = self.ike_group
133         return pow(int.from_bytes(peer, 'big'),
134                    int.from_bytes(priv, 'big'), p).to_bytes(l, 'big')
135
136     def generate_dh_data(self):
137         # generate DH keys
138         if self.is_initiator:
139             if self.ike_dh not in DH:
140                 raise NotImplementedError('%s not in DH group' % self.ike_dh)
141             if self.dh_params is None:
142                 dhg = DH[self.ike_dh]
143                 pn = dh.DHParameterNumbers(dhg[0], dhg[1])
144                 self.dh_params = pn.parameters(default_backend())
145             priv = self.dh_params.generate_private_key()
146             pub = priv.public_key()
147             x = priv.private_numbers().x
148             self.dh_private_key = x.to_bytes(priv.key_size // 8, 'big')
149             y = pub.public_numbers().y
150             self.i_dh_data = y.to_bytes(pub.key_size // 8, 'big')
151
152     def complete_dh_data(self):
153         self.dh_shared_secret = self.compute_secret()
154
155     def calc_child_keys(self):
156         prf = self.ike_prf_alg.mod()
157         s = self.i_nonce + self.r_nonce
158         c = self.child_sas[0]
159
160         encr_key_len = self.esp_crypto_key_len
161         integ_key_len = self.ike_integ_alg.key_len
162         l = (integ_key_len * 2 +
163              encr_key_len * 2)
164         keymat = self.calc_prfplus(prf, self.sk_d, s, l)
165
166         pos = 0
167         c.sk_ei = keymat[pos:pos+encr_key_len]
168         pos += encr_key_len
169
170         c.sk_ai = keymat[pos:pos+integ_key_len]
171         pos += integ_key_len
172
173         c.sk_er = keymat[pos:pos+encr_key_len]
174         pos += encr_key_len
175
176         c.sk_ar = keymat[pos:pos+integ_key_len]
177         pos += integ_key_len
178
179     def calc_prfplus(self, prf, key, seed, length):
180         r = b''
181         t = None
182         x = 1
183         while len(r) < length and x < 255:
184             if t is not None:
185                 s = t
186             else:
187                 s = b''
188             s = s + seed + bytes([x])
189             t = self.calc_prf(prf, key, s)
190             r = r + t
191             x = x + 1
192
193         if x == 255:
194             return None
195         return r
196
197     def calc_prf(self, prf, key, data):
198         h = self.ike_integ_alg.mac(key, prf, backend=default_backend())
199         h.update(data)
200         return h.finalize()
201
202     def calc_keys(self):
203         prf = self.ike_prf_alg.mod()
204         # SKEYSEED = prf(Ni | Nr, g^ir)
205         s = self.i_nonce + self.r_nonce
206         self.skeyseed = self.calc_prf(prf, s, self.dh_shared_secret)
207
208         # calculate S = Ni | Nr | SPIi SPIr
209         s = s + self.ispi + self.rspi
210
211         prf_key_trunc = self.ike_prf_alg.trunc_len
212         encr_key_len = self.ike_crypto_key_len
213         tr_prf_key_len = self.ike_prf_alg.key_len
214         integ_key_len = self.ike_integ_alg.key_len
215         l = (prf_key_trunc +
216              integ_key_len * 2 +
217              encr_key_len * 2 +
218              tr_prf_key_len * 2)
219         keymat = self.calc_prfplus(prf, self.skeyseed, s, l)
220
221         pos = 0
222         self.sk_d = keymat[:pos+prf_key_trunc]
223         pos += prf_key_trunc
224
225         self.sk_ai = keymat[pos:pos+integ_key_len]
226         pos += integ_key_len
227         self.sk_ar = keymat[pos:pos+integ_key_len]
228         pos += integ_key_len
229
230         self.sk_ei = keymat[pos:pos+encr_key_len]
231         pos += encr_key_len
232         self.sk_er = keymat[pos:pos+encr_key_len]
233         pos += encr_key_len
234
235         self.sk_pi = keymat[pos:pos+tr_prf_key_len]
236         pos += tr_prf_key_len
237         self.sk_pr = keymat[pos:pos+tr_prf_key_len]
238
239     def generate_authmsg(self, prf, packet):
240         if self.is_initiator:
241             id = self.i_id
242             nonce = self.r_nonce
243             key = self.sk_pi
244         data = bytes([self.id_type, 0, 0, 0]) + id
245         id_hash = self.calc_prf(prf, key, data)
246         return packet + nonce + id_hash
247
248     def auth_init(self):
249         prf = self.ike_prf_alg.mod()
250         authmsg = self.generate_authmsg(prf, raw(self.init_req_packet))
251         psk = self.calc_prf(prf, self.auth_data, KEY_PAD)
252         self.auth_data = self.calc_prf(prf, psk, authmsg)
253
254     def encrypt(self, data):
255         data = self.ike_crypto_alg.pad(data)
256         return self.ike_crypto_alg.encrypt(data, self.my_cryptokey)
257
258     @property
259     def peer_authkey(self):
260         if self.is_initiator:
261             return self.sk_ar
262         return self.sk_ai
263
264     @property
265     def my_authkey(self):
266         if self.is_initiator:
267             return self.sk_ai
268         return self.sk_ar
269
270     @property
271     def my_cryptokey(self):
272         if self.is_initiator:
273             return self.sk_ei
274         return self.sk_er
275
276     @property
277     def peer_cryptokey(self):
278         if self.is_initiator:
279             return self.sk_er
280         return self.sk_ei
281
282     def verify_hmac(self, ikemsg):
283         integ_trunc = self.ike_integ_alg.trunc_len
284         exp_hmac = ikemsg[-integ_trunc:]
285         data = ikemsg[:-integ_trunc]
286         computed_hmac = self.compute_hmac(self.ike_integ_alg.mod(),
287                                           self.peer_authkey, data)
288         self.test.assertEqual(computed_hmac[:integ_trunc], exp_hmac)
289
290     def compute_hmac(self, integ, key, data):
291         h = self.ike_integ_alg.mac(key, integ, backend=default_backend())
292         h.update(data)
293         return h.finalize()
294
295     def decrypt(self, data):
296         return self.ike_crypto_alg.decrypt(data, self.peer_cryptokey)
297
298     def hmac_and_decrypt(self, ike):
299         ep = ike[ikev2.IKEv2_payload_Encrypted]
300         self.verify_hmac(raw(ike))
301         integ_trunc = self.ike_integ_alg.trunc_len
302
303         # remove ICV and decrypt payload
304         ct = ep.load[:-integ_trunc]
305         return self.decrypt(ct)
306
307     def generate_ts(self):
308         c = self.child_sas[0]
309         ts1 = ikev2.IPv4TrafficSelector(
310                 IP_protocol_ID=0,
311                 starting_address_v4=c.local_ts['start_addr'],
312                 ending_address_v4=c.local_ts['end_addr'])
313         ts2 = ikev2.IPv4TrafficSelector(
314                 IP_protocol_ID=0,
315                 starting_address_v4=c.remote_ts['start_addr'],
316                 ending_address_v4=c.remote_ts['end_addr'])
317         return ([ts1], [ts2])
318
319     def set_ike_props(self, crypto, crypto_key_len, integ, prf, dh):
320         if crypto not in CRYPTO_ALGOS:
321             raise TypeError('unsupported encryption algo %r' % crypto)
322         self.ike_crypto = crypto
323         self.ike_crypto_alg = CRYPTO_ALGOS[crypto]
324         self.ike_crypto_key_len = crypto_key_len
325
326         if integ not in AUTH_ALGOS:
327             raise TypeError('unsupported auth algo %r' % integ)
328         self.ike_integ = integ
329         self.ike_integ_alg = AUTH_ALGOS[integ]
330
331         if prf not in PRF_ALGOS:
332             raise TypeError('unsupported prf algo %r' % prf)
333         self.ike_prf = prf
334         self.ike_prf_alg = PRF_ALGOS[prf]
335         self.ike_dh = dh
336         self.ike_group = DH[self.ike_dh]
337
338     def set_esp_props(self, crypto, crypto_key_len, integ):
339         self.esp_crypto_key_len = crypto_key_len
340         if crypto not in CRYPTO_ALGOS:
341             raise TypeError('unsupported encryption algo %r' % crypto)
342         self.esp_crypto = crypto
343         self.esp_crypto_alg = CRYPTO_ALGOS[crypto]
344
345         if integ not in AUTH_ALGOS:
346             raise TypeError('unsupported auth algo %r' % integ)
347         self.esp_integ = integ
348         self.esp_integ_alg = AUTH_ALGOS[integ]
349
350     def crypto_attr(self, key_len):
351         if self.ike_crypto in ['AES-CBC', 'AES-GCM']:
352             return (0x800e << 16 | key_len << 3, 12)
353         else:
354             raise Exception('unsupported attribute type')
355
356     def ike_crypto_attr(self):
357         return self.crypto_attr(self.ike_crypto_key_len)
358
359     def esp_crypto_attr(self):
360         return self.crypto_attr(self.esp_crypto_key_len)
361
362
363 class TestResponder(VppTestCase):
364     """ responder test """
365
366     @classmethod
367     def setUpClass(cls):
368         import scapy.contrib.ikev2 as _ikev2
369         globals()['ikev2'] = _ikev2
370         super(TestResponder, cls).setUpClass()
371         cls.create_pg_interfaces(range(2))
372         for i in cls.pg_interfaces:
373             i.admin_up()
374             i.config_ip4()
375             i.resolve_arp()
376
377     @classmethod
378     def tearDownClass(cls):
379         super(TestResponder, cls).tearDownClass()
380
381     def setUp(self):
382         super(TestResponder, self).setUp()
383         self.config_tc()
384
385     def config_tc(self):
386         self.p = Profile(self, 'pr1')
387         self.p.add_auth(method='shared-key', data=b'$3cr3tpa$$w0rd')
388         self.p.add_local_id(id_type='fqdn', data=b'vpp.home')
389         self.p.add_remote_id(id_type='fqdn', data=b'roadwarrior.example.com')
390         self.p.add_local_ts(start_addr=0x0a0a0a0, end_addr=0x0a0a0aff)
391         self.p.add_remote_ts(start_addr=0xa000000, end_addr=0xa0000ff)
392         self.p.add_vpp_config()
393
394         self.sa = IKEv2SA(self, i_id=self.p.remote_id['data'],
395                           r_id=self.p.local_id['data'],
396                           is_initiator=True, auth_data=self.p.auth['data'],
397                           id_type=self.p.local_id['id_type'],
398                           local_ts=self.p.remote_ts, remote_ts=self.p.local_ts)
399
400         self.sa.set_ike_props(crypto='AES-CBC', crypto_key_len=32,
401                               integ='HMAC-SHA1-96', prf='PRF_HMAC_SHA2_256',
402                               dh='2048MODPgr')
403         self.sa.set_esp_props(crypto='AES-CBC', crypto_key_len=32,
404                               integ='HMAC-SHA1-96')
405         self.sa.generate_dh_data()
406
407     def create_ike_msg(self, src_if, msg, sport=500, dport=500):
408         return (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
409                 IP(src=src_if.remote_ip4, dst=src_if.local_ip4) /
410                 UDP(sport=sport, dport=dport) / msg)
411
412     def send_sa_init(self):
413         tr_attr = self.sa.ike_crypto_attr()
414         trans = (ikev2.IKEv2_payload_Transform(transform_type='Encryption',
415                  transform_id=self.sa.ike_crypto, length=tr_attr[1],
416                  key_length=tr_attr[0]) /
417                  ikev2.IKEv2_payload_Transform(transform_type='Integrity',
418                  transform_id=self.sa.ike_integ) /
419                  ikev2.IKEv2_payload_Transform(transform_type='PRF',
420                  transform_id=self.sa.ike_prf_alg.name) /
421                  ikev2.IKEv2_payload_Transform(transform_type='GroupDesc',
422                  transform_id=self.sa.ike_dh))
423
424         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='IKEv2',
425                  trans_nb=4, trans=trans))
426
427         self.sa.init_req_packet = (
428                 ikev2.IKEv2(init_SPI=self.sa.ispi,
429                             flags='Initiator', exch_type='IKE_SA_INIT') /
430                 ikev2.IKEv2_payload_SA(next_payload='KE', prop=props) /
431                 ikev2.IKEv2_payload_KE(next_payload='Nonce',
432                                        group=self.sa.ike_dh,
433                                        load=self.sa.dh_pub_key()) /
434                 ikev2.IKEv2_payload_Nonce(load=self.sa.i_nonce))
435
436         ike_msg = self.create_ike_msg(self.pg0, self.sa.init_req_packet)
437         self.pg0.add_stream(ike_msg)
438         self.pg0.enable_capture()
439         self.pg_start()
440         capture = self.pg0.get_capture(1)
441         self.verify_sa_init(capture[0])
442
443     def send_sa_auth(self):
444         tr_attr = self.sa.esp_crypto_attr()
445         trans = (ikev2.IKEv2_payload_Transform(transform_type='Encryption',
446                  transform_id=self.sa.esp_crypto, length=tr_attr[1],
447                  key_length=tr_attr[0]) /
448                  ikev2.IKEv2_payload_Transform(transform_type='Integrity',
449                  transform_id=self.sa.esp_integ) /
450                  ikev2.IKEv2_payload_Transform(
451                  transform_type='Extended Sequence Number',
452                  transform_id='No ESN') /
453                  ikev2.IKEv2_payload_Transform(
454                  transform_type='Extended Sequence Number',
455                  transform_id='ESN'))
456
457         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='ESP',
458                  SPIsize=4, SPI=os.urandom(4), trans_nb=4, trans=trans))
459
460         tsi, tsr = self.sa.generate_ts()
461         plain = (ikev2.IKEv2_payload_IDi(next_payload='IDr',
462                  IDtype=self.sa.id_type, load=self.sa.i_id) /
463                  ikev2.IKEv2_payload_IDr(next_payload='AUTH',
464                  IDtype=self.sa.id_type, load=self.sa.r_id) /
465                  ikev2.IKEv2_payload_AUTH(next_payload='SA',
466                  auth_type=2, load=self.sa.auth_data) /
467                  ikev2.IKEv2_payload_SA(next_payload='TSi', prop=props) /
468                  ikev2.IKEv2_payload_TSi(next_payload='TSr',
469                  number_of_TSs=len(tsi),
470                  traffic_selector=tsi) /
471                  ikev2.IKEv2_payload_TSr(next_payload='Notify',
472                  number_of_TSs=len(tsr),
473                  traffic_selector=tsr) /
474                  ikev2.IKEv2_payload_Notify(type='INITIAL_CONTACT'))
475         encr = self.sa.encrypt(raw(plain))
476
477         trunc_len = self.sa.ike_integ_alg.trunc_len
478         plen = len(encr) + len(ikev2.IKEv2_payload_Encrypted()) + trunc_len
479         tlen = plen + len(ikev2.IKEv2())
480
481         sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
482                                              length=plen, load=encr)
483         sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi, resp_SPI=self.sa.rspi,
484                    length=tlen, flags='Initiator', exch_type='IKE_AUTH', id=1))
485         sa_auth /= sk_p
486
487         integ_data = raw(sa_auth)
488         hmac_data = self.sa.compute_hmac(self.sa.ike_integ_alg.mod(),
489                                          self.sa.my_authkey, integ_data)
490         sa_auth = sa_auth / Raw(hmac_data[:trunc_len])
491         assert(len(sa_auth) == tlen)
492
493         packet = self.create_ike_msg(self.pg0, sa_auth)
494         self.pg0.add_stream(packet)
495         self.pg0.enable_capture()
496         self.pg_start()
497         capture = self.pg0.get_capture(1)
498         self.verify_sa_auth(capture[0])
499
500     def verify_sa_init(self, packet):
501         ih = packet[ikev2.IKEv2]
502         self.assertEqual(ih.exch_type, 34)
503         self.assertTrue('Response' in ih.flags)
504         self.assertEqual(ih.init_SPI, self.sa.ispi)
505         self.assertNotEqual(ih.resp_SPI, 0)
506         self.sa.rspi = ih.resp_SPI
507         try:
508             sa = ih[ikev2.IKEv2_payload_SA]
509             self.sa.r_nonce = ih[ikev2.IKEv2_payload_Nonce].load
510             self.sa.r_dh_data = ih[ikev2.IKEv2_payload_KE].load
511         except AttributeError as e:
512             self.logger.error("unexpected reply: SA/Nonce/KE payload found!")
513             raise
514         self.sa.complete_dh_data()
515         self.sa.calc_keys()
516         self.sa.auth_init()
517
518     def verify_sa_auth(self, packet):
519         try:
520             ike = packet[ikev2.IKEv2]
521             ep = packet[ikev2.IKEv2_payload_Encrypted]
522         except KeyError as e:
523             self.logger.error("unexpected reply: no IKEv2/Encrypt payload!")
524             raise
525         plain = self.sa.hmac_and_decrypt(ike)
526         self.sa.calc_child_keys()
527
528     def verify_child_sas(self):
529         sas = self.vapi.ipsec_sa_dump()
530         self.assertEqual(len(sas), 2)
531         sa0 = sas[0].entry
532         sa1 = sas[1].entry
533         c = self.sa.child_sas[0]
534
535         # verify crypto keys
536         self.assertEqual(sa0.crypto_key.length, len(c.sk_er))
537         self.assertEqual(sa1.crypto_key.length, len(c.sk_ei))
538         self.assertEqual(sa0.crypto_key.data[:len(c.sk_er)], c.sk_er)
539         self.assertEqual(sa1.crypto_key.data[:len(c.sk_ei)], c.sk_ei)
540
541         # verify integ keys
542         self.assertEqual(sa0.integrity_key.length, len(c.sk_ar))
543         self.assertEqual(sa1.integrity_key.length, len(c.sk_ai))
544         self.assertEqual(sa0.integrity_key.data[:len(c.sk_ar)], c.sk_ar)
545         self.assertEqual(sa1.integrity_key.data[:len(c.sk_ai)], c.sk_ai)
546
547     def test_responder(self):
548         self.send_sa_init()
549         self.send_sa_auth()
550         self.verify_child_sas()
551
552
553 if __name__ == '__main__':
554     unittest.main(testRunner=VppTestRunner)