1 diff --git a/scapy/layers/ipsec.py b/scapy/layers/ipsec.py
2 index ae057ee1..b6806f71 100644
3 --- a/scapy/layers/ipsec.py
4 +++ b/scapy/layers/ipsec.py
5 @@ -56,6 +56,7 @@ from scapy.fields import ByteEnumField, ByteField, IntField, PacketField, \
6 ShortField, StrField, XIntField, XStrField, XStrLenField
7 from scapy.packet import Packet, bind_layers, Raw
8 from scapy.layers.inet import IP, UDP
9 +from scapy.contrib.mpls import MPLS
10 import scapy.modules.six as six
11 from scapy.modules.six.moves import range
12 from scapy.layers.inet6 import IPv6, IPv6ExtHdrHopByHop, IPv6ExtHdrDestOpt, \
13 @@ -359,13 +360,17 @@ class CryptAlgo(object):
14 encryptor = cipher.encryptor()
18 - aad = struct.pack('!LLL', esp.spi, esn, esp.seq)
20 - aad = struct.pack('!LL', esp.spi, esp.seq)
21 + aad = sa.build_aead(esp)
22 + if self.name == 'AES-NULL-GMAC':
23 + aad = aad + esp.iv + data
24 + aes_null_gmac_data = data
26 encryptor.authenticate_additional_data(aad)
28 data = encryptor.update(data) + encryptor.finalize()
29 data += encryptor.tag[:self.icv_size]
30 + if self.name == 'AES-NULL-GMAC':
31 + data = aes_null_gmac_data + data
33 data = encryptor.update(data) + encryptor.finalize()
35 @@ -399,17 +404,19 @@ class CryptAlgo(object):
36 decryptor = cipher.decryptor()
39 + aad = sa.build_aead(esp)
40 + if self.name == 'AES-NULL-GMAC':
41 + aad = aad + iv + data
42 + aes_null_gmac_data = data
44 # Tag value check is done during the finalize method
46 - decryptor.authenticate_additional_data(
47 - struct.pack('!LLL', esp.spi, esn, esp.seq))
49 - decryptor.authenticate_additional_data(
50 - struct.pack('!LL', esp.spi, esp.seq))
51 + decryptor.authenticate_additional_data(aad)
53 data = decryptor.update(data) + decryptor.finalize()
54 except InvalidTag as err:
55 raise IPSecIntegrityError(err)
56 + if self.name == 'AES-NULL-GMAC':
57 + data = aes_null_gmac_data + data
59 # extract padlen and nh
60 padlen = orb(data[-2])
61 @@ -445,6 +452,7 @@ if algorithms:
62 CRYPT_ALGOS['AES-CTR'] = CryptAlgo('AES-CTR',
63 cipher=algorithms.AES,
68 format_mode_iv=_aes_ctr_format_mode_iv)
69 @@ -452,14 +460,24 @@ if algorithms:
70 CRYPT_ALGOS['AES-GCM'] = CryptAlgo('AES-GCM',
71 cipher=algorithms.AES,
77 format_mode_iv=_salt_format_mode_iv)
78 + CRYPT_ALGOS['AES-NULL-GMAC'] = CryptAlgo('AES-NULL-GMAC',
79 + cipher=algorithms.AES,
85 + format_mode_iv=_salt_format_mode_iv)
86 if hasattr(modes, 'CCM'):
87 CRYPT_ALGOS['AES-CCM'] = CryptAlgo('AES-CCM',
88 cipher=algorithms.AES,
94 @@ -544,7 +562,7 @@ class AuthAlgo(object):
96 return self.mac(key, self.digestmod(), default_backend())
98 - def sign(self, pkt, key):
99 + def sign(self, pkt, key, trailer=None):
101 Sign an IPsec (ESP or AH) packet with this algo.
103 @@ -560,16 +578,20 @@ class AuthAlgo(object):
105 if pkt.haslayer(ESP):
106 mac.update(raw(pkt[ESP]))
108 + mac.update(trailer)
109 pkt[ESP].data += mac.finalize()[:self.icv_size]
111 elif pkt.haslayer(AH):
112 clone = zero_mutable_fields(pkt.copy(), sending=True)
113 mac.update(raw(clone))
115 + mac.update(trailer)
116 pkt[AH].icv = mac.finalize()[:self.icv_size]
120 - def verify(self, pkt, key):
121 + def verify(self, pkt, key, trailer):
123 Check that the integrity check value (icv) of a packet is valid.
125 @@ -600,6 +622,8 @@ class AuthAlgo(object):
126 clone = zero_mutable_fields(pkt.copy(), sending=False)
128 mac.update(raw(clone))
130 + mac.update(trailer) # bytearray(4)) #raw(trailer))
131 computed_icv = mac.finalize()[:self.icv_size]
133 # XXX: Cannot use mac.verify because the ICV can be truncated
134 @@ -788,7 +812,7 @@ class SecurityAssociation(object):
135 This class is responsible of "encryption" and "decryption" of IPsec packets. # noqa: E501
138 - SUPPORTED_PROTOS = (IP, IPv6)
139 + SUPPORTED_PROTOS = (IP, IPv6, MPLS)
141 def __init__(self, proto, spi, seq_num=1, crypt_algo=None, crypt_key=None,
142 auth_algo=None, auth_key=None, tunnel_header=None, nat_t_header=None, esn_en=False, esn=0): # noqa: E501
143 @@ -862,6 +886,23 @@ class SecurityAssociation(object):
144 raise TypeError('nat_t_header must be %s' % UDP.name)
145 self.nat_t_header = nat_t_header
147 + def build_aead(self, esp):
149 + return (struct.pack('!LLL', esp.spi, self.seq_num >> 32, esp.seq))
151 + return (struct.pack('!LL', esp.spi, esp.seq))
153 + def build_seq_num(self, num):
154 + # only lower order bits are transmitted
155 + # higher order bits are used in the ICV
156 + lower = num & 0xffffffff
160 + return lower, struct.pack("!I", upper)
164 def check_spi(self, pkt):
165 if pkt.spi != self.spi:
166 raise TypeError('packet spi=0x%x does not match the SA spi=0x%x' %
167 @@ -875,7 +916,8 @@ class SecurityAssociation(object):
168 if len(iv) != self.crypt_algo.iv_size:
169 raise TypeError('iv length must be %s' % self.crypt_algo.iv_size) # noqa: E501
171 - esp = _ESPPlain(spi=self.spi, seq=seq_num or self.seq_num, iv=iv)
172 + low_seq_num, high_seq_num = self.build_seq_num(seq_num or self.seq_num)
173 + esp = _ESPPlain(spi=self.spi, seq=low_seq_num, iv=iv)
175 if self.tunnel_header:
176 tunnel = self.tunnel_header.copy()
177 @@ -899,7 +941,7 @@ class SecurityAssociation(object):
178 esn_en=esn_en or self.esn_en,
181 - self.auth_algo.sign(esp, self.auth_key)
182 + self.auth_algo.sign(esp, self.auth_key, high_seq_num)
184 if self.nat_t_header:
185 nat_t_header = self.nat_t_header.copy()
186 @@ -926,7 +968,8 @@ class SecurityAssociation(object):
188 def _encrypt_ah(self, pkt, seq_num=None):
190 - ah = AH(spi=self.spi, seq=seq_num or self.seq_num,
191 + low_seq_num, high_seq_num = self.build_seq_num(seq_num or self.seq_num)
192 + ah = AH(spi=self.spi, seq=low_seq_num,
193 icv=b"\x00" * self.auth_algo.icv_size)
195 if self.tunnel_header:
196 @@ -966,7 +1009,8 @@ class SecurityAssociation(object):
198 ip_header.plen = len(ip_header.payload) + len(ah) + len(payload)
200 - signed_pkt = self.auth_algo.sign(ip_header / ah / payload, self.auth_key) # noqa: E501
201 + signed_pkt = self.auth_algo.sign(ip_header / ah / payload,
202 + self.auth_key, high_seq_num) # noqa: E501
204 # sequence number must always change, unless specified by the user
206 @@ -1003,11 +1047,12 @@ class SecurityAssociation(object):
208 def _decrypt_esp(self, pkt, verify=True, esn_en=None, esn=None):
210 + low_seq_num, high_seq_num = self.build_seq_num(self.seq_num)
215 - self.auth_algo.verify(encrypted, self.auth_key)
216 + self.auth_algo.verify(encrypted, self.auth_key, high_seq_num)
218 esp = self.crypt_algo.decrypt(self, encrypted, self.crypt_key,
219 self.crypt_algo.icv_size or
220 @@ -1048,9 +1093,10 @@ class SecurityAssociation(object):
222 def _decrypt_ah(self, pkt, verify=True):
224 + low_seq_num, high_seq_num = self.build_seq_num(self.seq_num)
227 - self.auth_algo.verify(pkt, self.auth_key)
228 + self.auth_algo.verify(pkt, self.auth_key, high_seq_num)