4b7d758be5d0a6a7dd5ad1a0a795ebf859785315
[csit.git] / GPL / traffic_scripts / ipsec_interface.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2020 Cisco and/or its affiliates.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at:
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """Traffic script for IPsec verification."""
17
18 import sys
19
20 from ipaddress import ip_address
21 from scapy.layers.inet import IP
22 from scapy.layers.inet6 import IPv6, ICMPv6ND_NS
23 from scapy.layers.ipsec import SecurityAssociation, ESP
24 from scapy.layers.l2 import Ether
25 from scapy.packet import Raw
26
27 from .PacketVerifier import RxQueue, TxQueue
28 from .TrafficScriptArg import TrafficScriptArg
29
30
31 def check_ipsec(
32         pkt_recv, ip_layer, src_mac, dst_mac, src_tun, dst_tun, src_ip, dst_ip,
33         sa_in):
34     """Check received IPsec packet.
35
36     :param pkt_recv: Received packet to verify.
37     :param ip_layer: Scapy IP layer.
38     :param src_mac: Source MAC address.
39     :param dst_mac: Destination MAC address.
40     :param src_tun: IPsec tunnel source address.
41     :param dst_tun: IPsec tunnel destination address.
42     :param src_ip: Source IP/IPv6 address of original IP/IPv6 packet.
43     :param dst_ip: Destination IP/IPv6 address of original IP/IPv6 packet.
44     :param sa_in: IPsec SA for packet decryption.
45     :type pkt_recv: scapy.Ether
46     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
47     :type src_mac: str
48     :type dst_mac: str
49     :type src_tun: str
50     :type dst_tun: str
51     :type src_ip: str
52     :type dst_ip: str
53     :type sa_in: scapy.layers.ipsec.SecurityAssociation
54     :raises RuntimeError: If received packet is invalid.
55     """
56     if pkt_recv[Ether].src != src_mac:
57         raise RuntimeError(
58             f"Received frame has invalid source MAC address: "
59             f"{pkt_recv[Ether].src} should be: {src_mac}"
60         )
61
62     if pkt_recv[Ether].dst != dst_mac:
63         raise RuntimeError(
64             f"Received frame has invalid destination MAC address: "
65             f"{pkt_recv[Ether].dst} should be: {dst_mac}"
66         )
67
68     if not pkt_recv.haslayer(ip_layer):
69         raise RuntimeError(
70             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
71         )
72
73     if pkt_recv[ip_layer].src != src_tun:
74         raise RuntimeError(
75             f"Received packet has invalid source address: "
76             f"{pkt_recv[ip_layer].src} should be: {src_tun}"
77         )
78
79     if pkt_recv[ip_layer].dst != dst_tun:
80         raise RuntimeError(
81             f"Received packet has invalid destination address: "
82             f"{pkt_recv[ip_layer].dst} should be: {dst_tun}"
83         )
84
85     if not pkt_recv.haslayer(ESP):
86         raise RuntimeError(f"Not an ESP packet received: {pkt_recv!r}")
87
88     ip_pkt = pkt_recv[ip_layer]
89     d_pkt = sa_in.decrypt(ip_pkt)
90
91     if d_pkt[ip_layer].dst != dst_ip:
92         raise RuntimeError(
93             f"Decrypted packet has invalid destination address: "
94             f"{d_pkt[ip_layer].dst} should be: {dst_ip}"
95         )
96
97     if d_pkt[ip_layer].src != src_ip:
98         raise RuntimeError(
99             f"Decrypted packet has invalid source address: "
100             f"{d_pkt[ip_layer].src} should be: {src_ip}"
101         )
102
103     if ip_layer == IP and d_pkt[ip_layer].proto != 61:
104         raise RuntimeError(
105             f"Decrypted packet has invalid IP protocol: "
106             f"{d_pkt[ip_layer].proto} should be: 61"
107         )
108
109
110 def check_ip(pkt_recv, ip_layer, src_mac, dst_mac, src_ip, dst_ip):
111     """Check received IP/IPv6 packet.
112
113     :param pkt_recv: Received packet to verify.
114     :param ip_layer: Scapy IP layer.
115     :param src_mac: Source MAC address.
116     :param dst_mac: Destination MAC address.
117     :param src_ip: Source IP/IPv6 address.
118     :param dst_ip: Destination IP/IPv6 address.
119     :type pkt_recv: scapy.Ether
120     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
121     :type src_mac: str
122     :type dst_mac: str
123     :type src_ip: str
124     :type dst_ip: str
125     :raises RuntimeError: If received packet is invalid.
126     """
127     if pkt_recv[Ether].src != src_mac:
128         raise RuntimeError(
129             f"Received frame has invalid source MAC address: "
130             f"{pkt_recv[Ether].src} should be: {src_mac}"
131         )
132
133     if pkt_recv[Ether].dst != dst_mac:
134         raise RuntimeError(
135             f"Received frame has invalid destination MAC address: "
136             f"{pkt_recv[Ether].dst} should be: {dst_mac}"
137         )
138
139     if not pkt_recv.haslayer(ip_layer):
140         raise RuntimeError(
141             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
142         )
143
144     if pkt_recv[ip_layer].dst != dst_ip:
145         raise RuntimeError(
146             f"Received packet has invalid destination address: "
147             f"{pkt_recv[ip_layer.name].dst} should be: {dst_ip}"
148         )
149
150     if pkt_recv[ip_layer].src != src_ip:
151         raise RuntimeError(
152             f"Received packet has invalid destination address: "
153             f"{pkt_recv[ip_layer.name].dst} should be: {src_ip}"
154         )
155
156     if ip_layer == IP and pkt_recv[ip_layer].proto != 61:
157         raise RuntimeError(
158             f"Received packet has invalid IP protocol: "
159             f"{pkt_recv[ip_layer].proto} should be: 61"
160         )
161
162
163 def main():
164     """Send and receive IPsec packet."""
165
166     args = TrafficScriptArg(
167         [
168             u"tx_src_mac", u"tx_dst_mac", u"rx_src_mac", u"rx_dst_mac",
169             u"src_ip", u"dst_ip", u"src_tun", u"dst_tun", u"crypto_alg",
170             u"crypto_key", u"integ_alg", u"integ_key", u"l_spi", u"r_spi"
171         ]
172     )
173
174     tx_txq = TxQueue(args.get_arg(u"tx_if"))
175     tx_rxq = RxQueue(args.get_arg(u"tx_if"))
176     rx_txq = TxQueue(args.get_arg(u"rx_if"))
177     rx_rxq = RxQueue(args.get_arg(u"rx_if"))
178
179     tx_src_mac = args.get_arg(u"tx_src_mac")
180     tx_dst_mac = args.get_arg(u"tx_dst_mac")
181     rx_src_mac = args.get_arg(u"rx_src_mac")
182     rx_dst_mac = args.get_arg(u"rx_dst_mac")
183     src_ip = args.get_arg(u"src_ip")
184     dst_ip = args.get_arg(u"dst_ip")
185     src_tun = args.get_arg(u"src_tun")
186     dst_tun = args.get_arg(u"dst_tun")
187     crypto_alg = args.get_arg(u"crypto_alg")
188     crypto_key = args.get_arg(u"crypto_key")
189     integ_alg = args.get_arg(u"integ_alg")
190     integ_key = args.get_arg(u"integ_key")
191     l_spi = int(args.get_arg(u"l_spi"))
192     r_spi = int(args.get_arg(u"r_spi"))
193
194     ip_layer = IP if ip_address(src_ip).version == 4 else IPv6
195     ip_pkt = ip_layer(src=src_ip, dst=dst_ip, proto=61) if ip_layer == IP \
196         else ip_layer(src=src_ip, dst=dst_ip)
197
198     tunnel_in = ip_layer(src=src_tun, dst=dst_tun)
199     tunnel_out = ip_layer(src=dst_tun, dst=src_tun)
200
201     sa_in = SecurityAssociation(
202         ESP, spi=l_spi, crypt_algo=crypto_alg,
203         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
204         auth_key=integ_key.encode(encoding=u"utf-8"),
205         tunnel_header=tunnel_in
206     )
207
208     sa_out = SecurityAssociation(
209         ESP, spi=r_spi, crypt_algo=crypto_alg,
210         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
211         auth_key=integ_key.encode(encoding=u"utf-8"),
212         tunnel_header=tunnel_out
213     )
214
215     sent_packets = list()
216     tx_pkt_send = (Ether(src=tx_src_mac, dst=tx_dst_mac) / ip_pkt)
217     tx_pkt_send /= Raw()
218     size_limit = 78 if ip_layer == IPv6 else 64
219     if len(tx_pkt_send) < size_limit:
220         tx_pkt_send[Raw].load += (b"\0" * (size_limit - len(tx_pkt_send)))
221     sent_packets.append(tx_pkt_send)
222     tx_txq.send(tx_pkt_send)
223
224     while True:
225         rx_pkt_recv = rx_rxq.recv(2)
226
227         if rx_pkt_recv is None:
228             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
229
230         if rx_pkt_recv.haslayer(ICMPv6ND_NS):
231             # read another packet in the queue if the current one is ICMPv6ND_NS
232             continue
233         else:
234             # otherwise process the current packet
235             break
236
237     check_ipsec(
238         rx_pkt_recv, ip_layer, rx_src_mac, rx_dst_mac, src_tun, dst_tun, src_ip,
239         dst_ip, sa_in
240     )
241
242     ip_pkt = ip_layer(src=dst_ip, dst=src_ip, proto=61) if ip_layer == IP \
243         else ip_layer(src=dst_ip, dst=src_ip)
244     ip_pkt /= Raw()
245     if len(ip_pkt) < (size_limit - 14):
246         ip_pkt[Raw].load += (b"\0" * (size_limit - 14 - len(ip_pkt)))
247     e_pkt = sa_out.encrypt(ip_pkt)
248     rx_pkt_send = (Ether(src=rx_dst_mac, dst=rx_src_mac) /
249                    e_pkt)
250     rx_txq.send(rx_pkt_send)
251
252     while True:
253         tx_pkt_recv = tx_rxq.recv(2, ignore=sent_packets)
254
255         if tx_pkt_recv is None:
256             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
257
258         if tx_pkt_recv.haslayer(ICMPv6ND_NS):
259             # read another packet in the queue if the current one is ICMPv6ND_NS
260             continue
261         else:
262             # otherwise process the current packet
263             break
264
265     check_ip(tx_pkt_recv, ip_layer, tx_dst_mac, tx_src_mac, dst_ip, src_ip)
266
267     sys.exit(0)
268
269
270 if __name__ == u"__main__":
271     main()