ikev2: add support for AES-GCM cipher in IKE
[vpp.git] / src / plugins / ikev2 / test / test_ikev2.py
1 import os
2 from cryptography import x509
3 from cryptography.hazmat.backends import default_backend
4 from cryptography.hazmat.primitives import hashes, hmac
5 from cryptography.hazmat.primitives.asymmetric import dh, padding
6 from cryptography.hazmat.primitives.serialization import load_pem_private_key
7 from cryptography.hazmat.primitives.ciphers import (
8     Cipher,
9     algorithms,
10     modes,
11 )
12 from scapy.layers.ipsec import ESP
13 from scapy.layers.inet import IP, UDP, Ether
14 from scapy.packet import raw, Raw
15 from scapy.utils import long_converter
16 from framework import VppTestCase, VppTestRunner
17 from vpp_ikev2 import Profile, IDType, AuthMethod
18 from vpp_papi import VppEnum
19
20
21 KEY_PAD = b"Key Pad for IKEv2"
22 SALT_SIZE = 4
23 GCM_ICV_SIZE = 16
24 GCM_IV_SIZE = 8
25
26
27 # defined in rfc3526
28 # tuple structure is (p, g, key_len)
29 DH = {
30     '2048MODPgr': (long_converter("""
31     FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
32     29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
33     EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
34     E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
35     EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
36     C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
37     83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
38     670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
39     E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
40     DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
41     15728E5A 8AACAA68 FFFFFFFF FFFFFFFF"""), 2, 256),
42
43     '3072MODPgr': (long_converter("""
44     FFFFFFFF FFFFFFFF C90FDAA2 2168C234 C4C6628B 80DC1CD1
45     29024E08 8A67CC74 020BBEA6 3B139B22 514A0879 8E3404DD
46     EF9519B3 CD3A431B 302B0A6D F25F1437 4FE1356D 6D51C245
47     E485B576 625E7EC6 F44C42E9 A637ED6B 0BFF5CB6 F406B7ED
48     EE386BFB 5A899FA5 AE9F2411 7C4B1FE6 49286651 ECE45B3D
49     C2007CB8 A163BF05 98DA4836 1C55D39A 69163FA8 FD24CF5F
50     83655D23 DCA3AD96 1C62F356 208552BB 9ED52907 7096966D
51     670C354E 4ABC9804 F1746C08 CA18217C 32905E46 2E36CE3B
52     E39E772C 180E8603 9B2783A2 EC07A28F B5C55DF0 6F4C52C9
53     DE2BCBF6 95581718 3995497C EA956AE5 15D22618 98FA0510
54     15728E5A 8AAAC42D AD33170D 04507A33 A85521AB DF1CBA64
55     ECFB8504 58DBEF0A 8AEA7157 5D060C7D B3970F85 A6E1E4C7
56     ABF5AE8C DB0933D7 1E8C94E0 4A25619D CEE3D226 1AD2EE6B
57     F12FFA06 D98A0864 D8760273 3EC86A64 521F2B18 177B200C
58     BBE11757 7A615D6C 770988C0 BAD946E2 08E24FA0 74E5AB31
59     43DB5BFC E0FD108E 4B82D120 A93AD2CA FFFFFFFF FFFFFFFF"""), 2, 384)
60 }
61
62
63 class CryptoAlgo(object):
64     def __init__(self, name, cipher, mode):
65         self.name = name
66         self.cipher = cipher
67         self.mode = mode
68         if self.cipher is not None:
69             self.bs = self.cipher.block_size // 8
70
71             if self.name == 'AES-GCM-16ICV':
72                 self.iv_len = GCM_IV_SIZE
73             else:
74                 self.iv_len = self.bs
75
76     def encrypt(self, data, key, aad=None):
77         iv = os.urandom(self.iv_len)
78         if aad is None:
79             encryptor = Cipher(self.cipher(key), self.mode(iv),
80                                default_backend()).encryptor()
81             return iv + encryptor.update(data) + encryptor.finalize()
82         else:
83             salt = key[-SALT_SIZE:]
84             nonce = salt + iv
85             encryptor = Cipher(self.cipher(key[:-SALT_SIZE]), self.mode(nonce),
86                                default_backend()).encryptor()
87             encryptor.authenticate_additional_data(aad)
88             data = encryptor.update(data) + encryptor.finalize()
89             data += encryptor.tag[:GCM_ICV_SIZE]
90             return iv + data
91
92     def decrypt(self, data, key, aad=None, icv=None):
93         if aad is None:
94             iv = data[:self.iv_len]
95             ct = data[self.iv_len:]
96             decryptor = Cipher(algorithms.AES(key),
97                                self.mode(iv),
98                                default_backend()).decryptor()
99             return decryptor.update(ct) + decryptor.finalize()
100         else:
101             salt = key[-SALT_SIZE:]
102             nonce = salt + data[:GCM_IV_SIZE]
103             ct = data[GCM_IV_SIZE:]
104             key = key[:-SALT_SIZE]
105             decryptor = Cipher(algorithms.AES(key),
106                                self.mode(nonce, icv, len(icv)),
107                                default_backend()).decryptor()
108             decryptor.authenticate_additional_data(aad)
109             pt = decryptor.update(ct) + decryptor.finalize()
110             pad_len = pt[-1] + 1
111             return pt[:-pad_len]
112
113     def pad(self, data):
114         pad_len = (len(data) // self.bs + 1) * self.bs - len(data)
115         data = data + b'\x00' * (pad_len - 1)
116         return data + bytes([pad_len])
117
118
119 class AuthAlgo(object):
120     def __init__(self, name, mac, mod, key_len, trunc_len=None):
121         self.name = name
122         self.mac = mac
123         self.mod = mod
124         self.key_len = key_len
125         self.trunc_len = trunc_len or key_len
126
127
128 CRYPTO_ALGOS = {
129     'NULL': CryptoAlgo('NULL', cipher=None, mode=None),
130     'AES-CBC': CryptoAlgo('AES-CBC', cipher=algorithms.AES, mode=modes.CBC),
131     'AES-GCM-16ICV': CryptoAlgo('AES-GCM-16ICV', cipher=algorithms.AES,
132                                 mode=modes.GCM),
133 }
134
135 AUTH_ALGOS = {
136     'NULL': AuthAlgo('NULL', mac=None, mod=None, key_len=0, trunc_len=0),
137     'HMAC-SHA1-96': AuthAlgo('HMAC-SHA1-96', hmac.HMAC, hashes.SHA1, 20, 12),
138     'SHA2-256-128': AuthAlgo('SHA2-256-128', hmac.HMAC, hashes.SHA256, 32, 16),
139     'SHA2-384-192': AuthAlgo('SHA2-384-192', hmac.HMAC, hashes.SHA256, 48, 24),
140     'SHA2-512-256': AuthAlgo('SHA2-512-256', hmac.HMAC, hashes.SHA256, 64, 32),
141 }
142
143 PRF_ALGOS = {
144     'NULL': AuthAlgo('NULL', mac=None, mod=None, key_len=0, trunc_len=0),
145     'PRF_HMAC_SHA2_256': AuthAlgo('PRF_HMAC_SHA2_256', hmac.HMAC,
146                                   hashes.SHA256, 32),
147 }
148
149
150 class IKEv2ChildSA(object):
151     def __init__(self, local_ts, remote_ts, spi=None):
152         self.spi = spi or os.urandom(4)
153         self.local_ts = local_ts
154         self.remote_ts = remote_ts
155
156
157 class IKEv2SA(object):
158     def __init__(self, test, is_initiator=True, spi=b'\x04' * 8,
159                  i_id=None, r_id=None, id_type='fqdn', nonce=None,
160                  auth_data=None, local_ts=None, remote_ts=None,
161                  auth_method='shared-key', priv_key=None, natt=False):
162         self.natt = natt
163         if natt:
164             self.sport = 4500
165             self.dport = 4500
166         else:
167             self.sport = 500
168             self.dport = 500
169         self.dh_params = None
170         self.test = test
171         self.priv_key = priv_key
172         self.is_initiator = is_initiator
173         nonce = nonce or os.urandom(32)
174         self.auth_data = auth_data
175         self.i_id = i_id
176         self.r_id = r_id
177         if isinstance(id_type, str):
178             self.id_type = IDType.value(id_type)
179         else:
180             self.id_type = id_type
181         self.auth_method = auth_method
182         if self.is_initiator:
183             self.rspi = None
184             self.ispi = spi
185             self.i_nonce = nonce
186         else:
187             self.rspi = spi
188             self.ispi = None
189             self.r_nonce = None
190         self.child_sas = [IKEv2ChildSA(local_ts, remote_ts)]
191
192     def dh_pub_key(self):
193         return self.i_dh_data
194
195     def compute_secret(self):
196         priv = self.dh_private_key
197         peer = self.r_dh_data
198         p, g, l = self.ike_group
199         return pow(int.from_bytes(peer, 'big'),
200                    int.from_bytes(priv, 'big'), p).to_bytes(l, 'big')
201
202     def generate_dh_data(self):
203         # generate DH keys
204         if self.is_initiator:
205             if self.ike_dh not in DH:
206                 raise NotImplementedError('%s not in DH group' % self.ike_dh)
207             if self.dh_params is None:
208                 dhg = DH[self.ike_dh]
209                 pn = dh.DHParameterNumbers(dhg[0], dhg[1])
210                 self.dh_params = pn.parameters(default_backend())
211             priv = self.dh_params.generate_private_key()
212             pub = priv.public_key()
213             x = priv.private_numbers().x
214             self.dh_private_key = x.to_bytes(priv.key_size // 8, 'big')
215             y = pub.public_numbers().y
216             self.i_dh_data = y.to_bytes(pub.key_size // 8, 'big')
217
218     def complete_dh_data(self):
219         self.dh_shared_secret = self.compute_secret()
220
221     def calc_child_keys(self):
222         prf = self.ike_prf_alg.mod()
223         s = self.i_nonce + self.r_nonce
224         c = self.child_sas[0]
225
226         encr_key_len = self.esp_crypto_key_len
227         integ_key_len = self.esp_integ_alg.key_len
228         salt_len = 0 if integ_key_len else 4
229
230         l = (integ_key_len * 2 +
231              encr_key_len * 2 +
232              salt_len * 2)
233         keymat = self.calc_prfplus(prf, self.sk_d, s, l)
234
235         pos = 0
236         c.sk_ei = keymat[pos:pos+encr_key_len]
237         pos += encr_key_len
238
239         if integ_key_len:
240             c.sk_ai = keymat[pos:pos+integ_key_len]
241             pos += integ_key_len
242         else:
243             c.salt_ei = keymat[pos:pos+salt_len]
244             pos += salt_len
245
246         c.sk_er = keymat[pos:pos+encr_key_len]
247         pos += encr_key_len
248
249         if integ_key_len:
250             c.sk_ar = keymat[pos:pos+integ_key_len]
251             pos += integ_key_len
252         else:
253             c.salt_er = keymat[pos:pos+salt_len]
254             pos += salt_len
255
256     def calc_prfplus(self, prf, key, seed, length):
257         r = b''
258         t = None
259         x = 1
260         while len(r) < length and x < 255:
261             if t is not None:
262                 s = t
263             else:
264                 s = b''
265             s = s + seed + bytes([x])
266             t = self.calc_prf(prf, key, s)
267             r = r + t
268             x = x + 1
269
270         if x == 255:
271             return None
272         return r
273
274     def calc_prf(self, prf, key, data):
275         h = self.ike_prf_alg.mac(key, prf, backend=default_backend())
276         h.update(data)
277         return h.finalize()
278
279     def calc_keys(self):
280         prf = self.ike_prf_alg.mod()
281         # SKEYSEED = prf(Ni | Nr, g^ir)
282         s = self.i_nonce + self.r_nonce
283         self.skeyseed = self.calc_prf(prf, s, self.dh_shared_secret)
284
285         # calculate S = Ni | Nr | SPIi SPIr
286         s = s + self.ispi + self.rspi
287
288         prf_key_trunc = self.ike_prf_alg.trunc_len
289         encr_key_len = self.ike_crypto_key_len
290         tr_prf_key_len = self.ike_prf_alg.key_len
291         integ_key_len = self.ike_integ_alg.key_len
292         if integ_key_len == 0:
293             salt_size = 4
294         else:
295             salt_size = 0
296
297         l = (prf_key_trunc +
298              integ_key_len * 2 +
299              encr_key_len * 2 +
300              tr_prf_key_len * 2 +
301              salt_size * 2)
302         keymat = self.calc_prfplus(prf, self.skeyseed, s, l)
303
304         pos = 0
305         self.sk_d = keymat[:pos+prf_key_trunc]
306         pos += prf_key_trunc
307
308         self.sk_ai = keymat[pos:pos+integ_key_len]
309         pos += integ_key_len
310         self.sk_ar = keymat[pos:pos+integ_key_len]
311         pos += integ_key_len
312
313         self.sk_ei = keymat[pos:pos+encr_key_len + salt_size]
314         pos += encr_key_len + salt_size
315         self.sk_er = keymat[pos:pos+encr_key_len + salt_size]
316         pos += encr_key_len + salt_size
317
318         self.sk_pi = keymat[pos:pos+tr_prf_key_len]
319         pos += tr_prf_key_len
320         self.sk_pr = keymat[pos:pos+tr_prf_key_len]
321
322     def generate_authmsg(self, prf, packet):
323         if self.is_initiator:
324             id = self.i_id
325             nonce = self.r_nonce
326             key = self.sk_pi
327         data = bytes([self.id_type, 0, 0, 0]) + id
328         id_hash = self.calc_prf(prf, key, data)
329         return packet + nonce + id_hash
330
331     def auth_init(self):
332         prf = self.ike_prf_alg.mod()
333         authmsg = self.generate_authmsg(prf, raw(self.init_req_packet))
334         if self.auth_method == 'shared-key':
335             psk = self.calc_prf(prf, self.auth_data, KEY_PAD)
336             self.auth_data = self.calc_prf(prf, psk, authmsg)
337         elif self.auth_method == 'rsa-sig':
338             self.auth_data = self.priv_key.sign(authmsg, padding.PKCS1v15(),
339                                                 hashes.SHA1())
340         else:
341             raise TypeError('unknown auth method type!')
342
343     def encrypt(self, data, aad=None):
344         data = self.ike_crypto_alg.pad(data)
345         return self.ike_crypto_alg.encrypt(data, self.my_cryptokey, aad)
346
347     @property
348     def peer_authkey(self):
349         if self.is_initiator:
350             return self.sk_ar
351         return self.sk_ai
352
353     @property
354     def my_authkey(self):
355         if self.is_initiator:
356             return self.sk_ai
357         return self.sk_ar
358
359     @property
360     def my_cryptokey(self):
361         if self.is_initiator:
362             return self.sk_ei
363         return self.sk_er
364
365     @property
366     def peer_cryptokey(self):
367         if self.is_initiator:
368             return self.sk_er
369         return self.sk_ei
370
371     def concat(self, alg, key_len):
372         return alg + '-' + str(key_len * 8)
373
374     @property
375     def vpp_ike_cypto_alg(self):
376         return self.concat(self.ike_crypto, self.ike_crypto_key_len)
377
378     @property
379     def vpp_esp_cypto_alg(self):
380         return self.concat(self.esp_crypto, self.esp_crypto_key_len)
381
382     def verify_hmac(self, ikemsg):
383         integ_trunc = self.ike_integ_alg.trunc_len
384         exp_hmac = ikemsg[-integ_trunc:]
385         data = ikemsg[:-integ_trunc]
386         computed_hmac = self.compute_hmac(self.ike_integ_alg.mod(),
387                                           self.peer_authkey, data)
388         self.test.assertEqual(computed_hmac[:integ_trunc], exp_hmac)
389
390     def compute_hmac(self, integ, key, data):
391         h = self.ike_integ_alg.mac(key, integ, backend=default_backend())
392         h.update(data)
393         return h.finalize()
394
395     def decrypt(self, data, aad=None, icv=None):
396         return self.ike_crypto_alg.decrypt(data, self.peer_cryptokey, aad, icv)
397
398     def hmac_and_decrypt(self, ike):
399         ep = ike[ikev2.IKEv2_payload_Encrypted]
400         if self.ike_crypto == 'AES-GCM-16ICV':
401             aad_len = len(ikev2.IKEv2_payload_Encrypted()) + len(ikev2.IKEv2())
402             ct = ep.load[:-GCM_ICV_SIZE]
403             tag = ep.load[-GCM_ICV_SIZE:]
404             return self.decrypt(ct, raw(ike)[:aad_len], tag)
405         else:
406             self.verify_hmac(raw(ike))
407             integ_trunc = self.ike_integ_alg.trunc_len
408
409             # remove ICV and decrypt payload
410             ct = ep.load[:-integ_trunc]
411             return self.decrypt(ct)
412
413     def generate_ts(self):
414         c = self.child_sas[0]
415         ts1 = ikev2.IPv4TrafficSelector(
416                 IP_protocol_ID=0,
417                 starting_address_v4=c.local_ts['start_addr'],
418                 ending_address_v4=c.local_ts['end_addr'])
419         ts2 = ikev2.IPv4TrafficSelector(
420                 IP_protocol_ID=0,
421                 starting_address_v4=c.remote_ts['start_addr'],
422                 ending_address_v4=c.remote_ts['end_addr'])
423         return ([ts1], [ts2])
424
425     def set_ike_props(self, crypto, crypto_key_len, integ, prf, dh):
426         if crypto not in CRYPTO_ALGOS:
427             raise TypeError('unsupported encryption algo %r' % crypto)
428         self.ike_crypto = crypto
429         self.ike_crypto_alg = CRYPTO_ALGOS[crypto]
430         self.ike_crypto_key_len = crypto_key_len
431
432         if integ not in AUTH_ALGOS:
433             raise TypeError('unsupported auth algo %r' % integ)
434         self.ike_integ = None if integ == 'NULL' else integ
435         self.ike_integ_alg = AUTH_ALGOS[integ]
436
437         if prf not in PRF_ALGOS:
438             raise TypeError('unsupported prf algo %r' % prf)
439         self.ike_prf = prf
440         self.ike_prf_alg = PRF_ALGOS[prf]
441         self.ike_dh = dh
442         self.ike_group = DH[self.ike_dh]
443
444     def set_esp_props(self, crypto, crypto_key_len, integ):
445         self.esp_crypto_key_len = crypto_key_len
446         if crypto not in CRYPTO_ALGOS:
447             raise TypeError('unsupported encryption algo %r' % crypto)
448         self.esp_crypto = crypto
449         self.esp_crypto_alg = CRYPTO_ALGOS[crypto]
450
451         if integ not in AUTH_ALGOS:
452             raise TypeError('unsupported auth algo %r' % integ)
453         self.esp_integ = None if integ == 'NULL' else integ
454         self.esp_integ_alg = AUTH_ALGOS[integ]
455
456     def crypto_attr(self, key_len):
457         if self.ike_crypto in ['AES-CBC', 'AES-GCM-16ICV']:
458             return (0x800e << 16 | key_len << 3, 12)
459         else:
460             raise Exception('unsupported attribute type')
461
462     def ike_crypto_attr(self):
463         return self.crypto_attr(self.ike_crypto_key_len)
464
465     def esp_crypto_attr(self):
466         return self.crypto_attr(self.esp_crypto_key_len)
467
468     def compute_nat_sha1(self, ip, port):
469         data = self.ispi + b'\x00' * 8 + ip + (port).to_bytes(2, 'big')
470         digest = hashes.Hash(hashes.SHA1(), backend=default_backend())
471         digest.update(data)
472         return digest.finalize()
473
474
475 class TemplateResponder(VppTestCase):
476     """ responder test template """
477
478     @classmethod
479     def setUpClass(cls):
480         import scapy.contrib.ikev2 as _ikev2
481         globals()['ikev2'] = _ikev2
482         super(TemplateResponder, cls).setUpClass()
483         cls.create_pg_interfaces(range(2))
484         for i in cls.pg_interfaces:
485             i.admin_up()
486             i.config_ip4()
487             i.resolve_arp()
488
489     @classmethod
490     def tearDownClass(cls):
491         super(TemplateResponder, cls).tearDownClass()
492
493     def setUp(self):
494         super(TemplateResponder, self).setUp()
495         self.config_tc()
496         self.p.add_vpp_config()
497         self.sa.generate_dh_data()
498
499     def create_ike_msg(self, src_if, msg, sport=500, dport=500, natt=False):
500         res = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
501                IP(src=src_if.remote_ip4, dst=src_if.local_ip4) /
502                UDP(sport=sport, dport=dport))
503         if natt:
504             # insert non ESP marker
505             res = res / Raw(b'\x00' * 4)
506         return res / msg
507
508     def send_sa_init(self, behind_nat=False):
509         tr_attr = self.sa.ike_crypto_attr()
510         trans = (ikev2.IKEv2_payload_Transform(transform_type='Encryption',
511                  transform_id=self.sa.ike_crypto, length=tr_attr[1],
512                  key_length=tr_attr[0]) /
513                  ikev2.IKEv2_payload_Transform(transform_type='Integrity',
514                  transform_id=self.sa.ike_integ) /
515                  ikev2.IKEv2_payload_Transform(transform_type='PRF',
516                  transform_id=self.sa.ike_prf_alg.name) /
517                  ikev2.IKEv2_payload_Transform(transform_type='GroupDesc',
518                  transform_id=self.sa.ike_dh))
519
520         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='IKEv2',
521                  trans_nb=4, trans=trans))
522
523         if behind_nat:
524             next_payload = 'Notify'
525         else:
526             next_payload = None
527
528         self.sa.init_req_packet = (
529                 ikev2.IKEv2(init_SPI=self.sa.ispi,
530                             flags='Initiator', exch_type='IKE_SA_INIT') /
531                 ikev2.IKEv2_payload_SA(next_payload='KE', prop=props) /
532                 ikev2.IKEv2_payload_KE(next_payload='Nonce',
533                                        group=self.sa.ike_dh,
534                                        load=self.sa.dh_pub_key()) /
535                 ikev2.IKEv2_payload_Nonce(next_payload=next_payload,
536                                           load=self.sa.i_nonce))
537
538         if behind_nat:
539             src_nat = self.sa.compute_nat_sha1(b'\x0a\x0a\x0a\x01',
540                                                self.sa.sport)
541             nat_detection = ikev2.IKEv2_payload_Notify(
542                     type='NAT_DETECTION_SOURCE_IP',
543                     load=src_nat)
544             self.sa.init_req_packet = self.sa.init_req_packet / nat_detection
545
546         ike_msg = self.create_ike_msg(self.pg0, self.sa.init_req_packet,
547                                       self.sa.sport, self.sa.dport,
548                                       self.sa.natt)
549         self.pg0.add_stream(ike_msg)
550         self.pg0.enable_capture()
551         self.pg_start()
552         capture = self.pg0.get_capture(1)
553         self.verify_sa_init(capture[0])
554
555     def send_sa_auth(self):
556         tr_attr = self.sa.esp_crypto_attr()
557         trans = (ikev2.IKEv2_payload_Transform(transform_type='Encryption',
558                  transform_id=self.sa.esp_crypto, length=tr_attr[1],
559                  key_length=tr_attr[0]) /
560                  ikev2.IKEv2_payload_Transform(transform_type='Integrity',
561                  transform_id=self.sa.esp_integ) /
562                  ikev2.IKEv2_payload_Transform(
563                  transform_type='Extended Sequence Number',
564                  transform_id='No ESN') /
565                  ikev2.IKEv2_payload_Transform(
566                  transform_type='Extended Sequence Number',
567                  transform_id='ESN'))
568
569         props = (ikev2.IKEv2_payload_Proposal(proposal=1, proto='ESP',
570                  SPIsize=4, SPI=os.urandom(4), trans_nb=4, trans=trans))
571
572         tsi, tsr = self.sa.generate_ts()
573         plain = (ikev2.IKEv2_payload_IDi(next_payload='IDr',
574                  IDtype=self.sa.id_type, load=self.sa.i_id) /
575                  ikev2.IKEv2_payload_IDr(next_payload='AUTH',
576                  IDtype=self.sa.id_type, load=self.sa.r_id) /
577                  ikev2.IKEv2_payload_AUTH(next_payload='SA',
578                  auth_type=AuthMethod.value(self.sa.auth_method),
579                  load=self.sa.auth_data) /
580                  ikev2.IKEv2_payload_SA(next_payload='TSi', prop=props) /
581                  ikev2.IKEv2_payload_TSi(next_payload='TSr',
582                  number_of_TSs=len(tsi),
583                  traffic_selector=tsi) /
584                  ikev2.IKEv2_payload_TSr(next_payload='Notify',
585                  number_of_TSs=len(tsr),
586                  traffic_selector=tsr) /
587                  ikev2.IKEv2_payload_Notify(type='INITIAL_CONTACT'))
588
589         if self.sa.ike_crypto == 'AES-GCM-16ICV':
590             data = self.sa.ike_crypto_alg.pad(raw(plain))
591             plen = len(data) + GCM_IV_SIZE + GCM_ICV_SIZE +\
592                 len(ikev2.IKEv2_payload_Encrypted())
593             tlen = plen + len(ikev2.IKEv2())
594
595             # prepare aad data
596             sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
597                                                  length=plen)
598             sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
599                        resp_SPI=self.sa.rspi, id=1,
600                        length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
601             sa_auth /= sk_p
602
603             encr = self.sa.encrypt(raw(plain), raw(sa_auth))
604             sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
605                                                  length=plen, load=encr)
606             sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
607                        resp_SPI=self.sa.rspi, id=1,
608                        length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
609             sa_auth /= sk_p
610         else:
611             encr = self.sa.encrypt(raw(plain))
612             trunc_len = self.sa.ike_integ_alg.trunc_len
613             plen = len(encr) + len(ikev2.IKEv2_payload_Encrypted()) + trunc_len
614             tlen = plen + len(ikev2.IKEv2())
615
616             sk_p = ikev2.IKEv2_payload_Encrypted(next_payload='IDi',
617                                                  length=plen, load=encr)
618             sa_auth = (ikev2.IKEv2(init_SPI=self.sa.ispi,
619                        resp_SPI=self.sa.rspi, id=1,
620                        length=tlen, flags='Initiator', exch_type='IKE_AUTH'))
621             sa_auth /= sk_p
622
623             integ_data = raw(sa_auth)
624             hmac_data = self.sa.compute_hmac(self.sa.ike_integ_alg.mod(),
625                                              self.sa.my_authkey, integ_data)
626             sa_auth = sa_auth / Raw(hmac_data[:trunc_len])
627
628         assert(len(sa_auth) == tlen)
629         packet = self.create_ike_msg(self.pg0, sa_auth, self.sa.sport,
630                                      self.sa.dport, self.sa.natt)
631         self.pg0.add_stream(packet)
632         self.pg0.enable_capture()
633         self.pg_start()
634         capture = self.pg0.get_capture(1)
635         self.verify_sa_auth(capture[0])
636
637     def get_ike_header(self, packet):
638         try:
639             ih = packet[ikev2.IKEv2]
640         except IndexError as e:
641             # this is a workaround for getting IKEv2 layer as both ikev2 and
642             # ipsec register for port 4500
643             esp = packet[ESP]
644             ih = self.verify_and_remove_non_esp_marker(esp)
645         return ih
646
647     def verify_sa_init(self, packet):
648         ih = self.get_ike_header(packet)
649
650         self.assertEqual(ih.exch_type, 34)
651         self.assertTrue('Response' in ih.flags)
652         self.assertEqual(ih.init_SPI, self.sa.ispi)
653         self.assertNotEqual(ih.resp_SPI, 0)
654         self.sa.rspi = ih.resp_SPI
655         try:
656             sa = ih[ikev2.IKEv2_payload_SA]
657             self.sa.r_nonce = ih[ikev2.IKEv2_payload_Nonce].load
658             self.sa.r_dh_data = ih[ikev2.IKEv2_payload_KE].load
659         except IndexError as e:
660             self.logger.error("unexpected reply: SA/Nonce/KE payload found!")
661             self.logger.error(ih.show())
662             raise
663         self.sa.complete_dh_data()
664         self.sa.calc_keys()
665         self.sa.auth_init()
666
667     def verify_and_remove_non_esp_marker(self, packet):
668         if self.sa.natt:
669             # if we are in nat traversal mode check for non esp marker
670             # and remove it
671             data = raw(packet)
672             self.assertEqual(data[:4], b'\x00' * 4)
673             return ikev2.IKEv2(data[4:])
674         else:
675             return packet
676
677     def verify_udp(self, udp):
678         self.assertEqual(udp.sport, self.sa.sport)
679         self.assertEqual(udp.dport, self.sa.dport)
680
681     def verify_sa_auth(self, packet):
682         ike = self.get_ike_header(packet)
683         udp = packet[UDP]
684         self.verify_udp(udp)
685         plain = self.sa.hmac_and_decrypt(ike)
686         self.sa.calc_child_keys()
687
688     def verify_child_sas(self):
689         sas = self.vapi.ipsec_sa_dump()
690         self.assertEqual(len(sas), 2)
691         sa0 = sas[0].entry
692         sa1 = sas[1].entry
693         c = self.sa.child_sas[0]
694
695         vpp_crypto_alg = self.vpp_enums[self.sa.vpp_esp_cypto_alg]
696         self.assertEqual(sa0.crypto_algorithm, vpp_crypto_alg)
697         self.assertEqual(sa1.crypto_algorithm, vpp_crypto_alg)
698
699         if self.sa.esp_integ is None:
700             vpp_integ_alg = 0
701         else:
702             vpp_integ_alg = self.vpp_enums[self.sa.esp_integ]
703         self.assertEqual(sa0.integrity_algorithm, vpp_integ_alg)
704         self.assertEqual(sa1.integrity_algorithm, vpp_integ_alg)
705
706         # verify crypto keys
707         self.assertEqual(sa0.crypto_key.length, len(c.sk_er))
708         self.assertEqual(sa1.crypto_key.length, len(c.sk_ei))
709         self.assertEqual(sa0.crypto_key.data[:len(c.sk_er)], c.sk_er)
710         self.assertEqual(sa1.crypto_key.data[:len(c.sk_ei)], c.sk_ei)
711
712         # verify integ keys
713         if vpp_integ_alg:
714             self.assertEqual(sa0.integrity_key.length, len(c.sk_ar))
715             self.assertEqual(sa1.integrity_key.length, len(c.sk_ai))
716             self.assertEqual(sa0.integrity_key.data[:len(c.sk_ar)], c.sk_ar)
717             self.assertEqual(sa1.integrity_key.data[:len(c.sk_ai)], c.sk_ai)
718         else:
719             self.assertEqual(sa0.salt.to_bytes(4, 'little'), c.salt_er)
720             self.assertEqual(sa1.salt.to_bytes(4, 'little'), c.salt_ei)
721
722     def test_responder(self):
723         self.send_sa_init(self.sa.natt)
724         self.send_sa_auth()
725         self.verify_child_sas()
726
727
728 class Ikev2Params(object):
729     def config_params(self, params={}):
730         ec = VppEnum.vl_api_ipsec_crypto_alg_t
731         ei = VppEnum.vl_api_ipsec_integ_alg_t
732         self.vpp_enums = {
733                 'AES-CBC-128': ec.IPSEC_API_CRYPTO_ALG_AES_CBC_128,
734                 'AES-CBC-192': ec.IPSEC_API_CRYPTO_ALG_AES_CBC_192,
735                 'AES-CBC-256': ec.IPSEC_API_CRYPTO_ALG_AES_CBC_256,
736                 'AES-GCM-16ICV-128':  ec.IPSEC_API_CRYPTO_ALG_AES_GCM_128,
737                 'AES-GCM-16ICV-192':  ec.IPSEC_API_CRYPTO_ALG_AES_GCM_192,
738                 'AES-GCM-16ICV-256':  ec.IPSEC_API_CRYPTO_ALG_AES_GCM_256,
739
740                 'HMAC-SHA1-96': ei.IPSEC_API_INTEG_ALG_SHA1_96,
741                 'SHA2-256-128': ei.IPSEC_API_INTEG_ALG_SHA_256_128,
742                 'SHA2-384-192': ei.IPSEC_API_INTEG_ALG_SHA_384_192,
743                 'SHA2-512-256': ei.IPSEC_API_INTEG_ALG_SHA_512_256}
744
745         is_natt = 'natt' in params and params['natt'] or False
746         self.p = Profile(self, 'pr1')
747
748         if 'auth' in params and params['auth'] == 'rsa-sig':
749             auth_method = 'rsa-sig'
750             work_dir = os.getenv('BR') + '/../src/plugins/ikev2/test/certs/'
751             self.vapi.ikev2_set_local_key(
752                     key_file=work_dir + params['server-key'])
753
754             client_file = work_dir + params['client-cert']
755             server_pem = open(work_dir + params['server-cert']).read()
756             client_priv = open(work_dir + params['client-key']).read()
757             client_priv = load_pem_private_key(str.encode(client_priv), None,
758                                                default_backend())
759             self.peer_cert = x509.load_pem_x509_certificate(
760                     str.encode(server_pem),
761                     default_backend())
762             self.p.add_auth(method='rsa-sig', data=str.encode(client_file))
763             auth_data = None
764         else:
765             auth_data = b'$3cr3tpa$$w0rd'
766             self.p.add_auth(method='shared-key', data=auth_data)
767             auth_method = 'shared-key'
768             client_priv = None
769
770         self.p.add_local_id(id_type='fqdn', data=b'vpp.home')
771         self.p.add_remote_id(id_type='fqdn', data=b'roadwarrior.example.com')
772         self.p.add_local_ts(start_addr=0x0a0a0a0, end_addr=0x0a0a0aff)
773         self.p.add_remote_ts(start_addr=0xa000000, end_addr=0xa0000ff)
774
775         self.sa = IKEv2SA(self, i_id=self.p.remote_id['data'],
776                           r_id=self.p.local_id['data'],
777                           id_type=self.p.local_id['id_type'], natt=is_natt,
778                           priv_key=client_priv, auth_method=auth_method,
779                           auth_data=auth_data,
780                           local_ts=self.p.remote_ts, remote_ts=self.p.local_ts)
781
782         ike_crypto = ('AES-CBC', 32) if 'ike-crypto' not in params else\
783             params['ike-crypto']
784         ike_integ = 'HMAC-SHA1-96' if 'ike-integ' not in params else\
785             params['ike-integ']
786         ike_dh = '2048MODPgr' if 'ike-dh' not in params else params['ike-dh']
787
788         esp_crypto = ('AES-CBC', 32) if 'esp-crypto' not in params else\
789             params['esp-crypto']
790         esp_integ = 'HMAC-SHA1-96' if 'esp-integ' not in params else\
791             params['esp-integ']
792
793         self.sa.set_ike_props(
794                 crypto=ike_crypto[0], crypto_key_len=ike_crypto[1],
795                 integ=ike_integ, prf='PRF_HMAC_SHA2_256', dh=ike_dh)
796         self.sa.set_esp_props(
797                 crypto=esp_crypto[0], crypto_key_len=esp_crypto[1],
798                 integ=esp_integ)
799
800
801 class TestResponderNATT(TemplateResponder, Ikev2Params):
802     """ test ikev2 responder - nat traversal """
803     def config_tc(self):
804         self.config_params(
805                 {'natt': True})
806
807
808 class TestResponderPsk(TemplateResponder, Ikev2Params):
809     """ test ikev2 responder - pre shared key auth """
810     def config_tc(self):
811         self.config_params()
812
813
814 class TestResponderRsaSign(TemplateResponder, Ikev2Params):
815     """ test ikev2 responder - cert based auth """
816     def config_tc(self):
817         self.config_params({
818             'auth': 'rsa-sig',
819             'server-key': 'server-key.pem',
820             'client-key': 'client-key.pem',
821             'client-cert': 'client-cert.pem',
822             'server-cert': 'server-cert.pem'})
823
824
825 class Test_IKE_AES_CBC_128_SHA256_128_MODP2048_ESP_AES_CBC_192_SHA_384_192\
826         (TemplateResponder, Ikev2Params):
827     """
828     IKE:AES_CBC_128_SHA256_128,DH=modp2048 ESP:AES_CBC_192_SHA_384_192
829     """
830     def config_tc(self):
831         self.config_params({
832             'ike-crypto': ('AES-CBC', 16),
833             'ike-integ': 'SHA2-256-128',
834             'esp-crypto': ('AES-CBC', 24),
835             'esp-integ': 'SHA2-384-192',
836             'ike-dh': '2048MODPgr'})
837
838
839 class TestAES_CBC_128_SHA256_128_MODP3072_ESP_AES_GCM_16\
840         (TemplateResponder, Ikev2Params):
841     """
842     IKE:AES_CBC_128_SHA256_128,DH=modp3072 ESP:AES_GCM_16
843     """
844     def config_tc(self):
845         self.config_params({
846             'ike-crypto': ('AES-CBC', 32),
847             'ike-integ': 'SHA2-256-128',
848             'esp-crypto': ('AES-GCM-16ICV', 32),
849             'esp-integ': 'NULL',
850             'ike-dh': '3072MODPgr'})
851
852
853 class Test_IKE_AES_GCM_16_256(TemplateResponder, Ikev2Params):
854     """
855     IKE:AES_GCM_16_256
856     """
857     def config_tc(self):
858         self.config_params({
859             'ike-crypto': ('AES-GCM-16ICV', 32),
860             'ike-integ': 'NULL',
861             'ike-dh': '2048MODPgr'})
862
863
864 if __name__ == '__main__':
865     unittest.main(testRunner=VppTestRunner)