8ba7eb488df8547fc873c8d8f2bf494014c9c431
[csit.git] / GPL / traffic_scripts / ipsec_policy.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2022 Cisco and/or its affiliates.
4 #
5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 #
7 # Licensed under the Apache License 2.0 or
8 # GNU General Public License v2.0 or later;  you may not use this file
9 # except in compliance with one of these Licenses. You
10 # may obtain a copy of the Licenses at:
11 #
12 #     http://www.apache.org/licenses/LICENSE-2.0
13 #     https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
14 #
15 # Note: If this file is linked with Scapy, which is GPLv2+, your use of it
16 # must be under GPLv2+.  If at any point in the future it is no longer linked
17 # with Scapy (or other GPLv2+ licensed software), you are free to choose
18 # Apache 2.
19 #
20 # Unless required by applicable law or agreed to in writing, software
21 # distributed under the License is distributed on an "AS IS" BASIS,
22 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 # See the License for the specific language governing permissions and
24 # limitations under the License.
25
26 """Traffic script for IPsec verification."""
27
28 import sys
29
30 from ipaddress import ip_address
31 from scapy.layers.inet import IP
32 from scapy.layers.inet6 import IPv6, ICMPv6ND_NS, ICMPv6MLReport2, ICMPv6ND_RA
33 from scapy.layers.ipsec import SecurityAssociation, ESP
34 from scapy.layers.l2 import Ether
35 from scapy.packet import Raw
36
37 from .PacketVerifier import RxQueue, TxQueue
38 from .TrafficScriptArg import TrafficScriptArg
39
40
41 def check_ipsec(pkt_recv, ip_layer, dst_tun, src_ip, dst_ip, sa_in):
42     """Check received IPsec packet.
43
44     :param pkt_recv: Received packet to verify.
45     :param ip_layer: Scapy IP layer.
46     :param dst_tun: IPsec tunnel destination address.
47     :param src_ip: Source IP/IPv6 address of original IP/IPv6 packet.
48     :param dst_ip: Destination IP/IPv6 address of original IP/IPv6 packet.
49     :param sa_in: IPsec SA for packet decryption.
50     :type pkt_recv: scapy.Ether
51     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
52     :type dst_tun: str
53     :type src_ip: str
54     :type dst_ip: str
55     :type sa_in: scapy.layers.ipsec.SecurityAssociation
56     :raises RuntimeError: If received packet is invalid.
57     """
58     if not pkt_recv.haslayer(ip_layer):
59         raise RuntimeError(
60             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
61         )
62
63     if pkt_recv[ip_layer].dst != dst_tun:
64         raise RuntimeError(
65             f"Received packet has invalid destination address: "
66             f"{pkt_recv[ip_layer].dst} should be: {dst_tun}"
67         )
68
69     if not pkt_recv.haslayer(ESP):
70         raise RuntimeError(f"Not an ESP packet received: {pkt_recv!r}")
71
72     ip_pkt = pkt_recv[ip_layer]
73     d_pkt = sa_in.decrypt(ip_pkt)
74
75     if d_pkt[ip_layer].dst != dst_ip:
76         raise RuntimeError(
77             f"Decrypted packet has invalid destination address: "
78             f"{d_pkt[ip_layer].dst} should be: {dst_ip}"
79         )
80
81     if d_pkt[ip_layer].src != src_ip:
82         raise RuntimeError(
83             f"Decrypted packet has invalid source address: "
84             f"{d_pkt[ip_layer].src} should be: {src_ip}"
85         )
86
87     if ip_layer == IP and d_pkt[ip_layer].proto != 61:
88         raise RuntimeError(
89             f"Decrypted packet has invalid IP protocol: "
90             f"{d_pkt[ip_layer].proto} should be: 61"
91         )
92
93
94 def check_ip(pkt_recv, ip_layer, src_ip, dst_ip):
95     """Check received IP/IPv6 packet.
96
97     :param pkt_recv: Received packet to verify.
98     :param ip_layer: Scapy IP layer.
99     :param src_ip: Source IP/IPv6 address.
100     :param dst_ip: Destination IP/IPv6 address.
101     :type pkt_recv: scapy.Ether
102     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
103     :type src_ip: str
104     :type dst_ip: str
105     :raises RuntimeError: If received packet is invalid.
106     """
107     if not pkt_recv.haslayer(ip_layer):
108         raise RuntimeError(
109             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
110         )
111
112     if pkt_recv[ip_layer].dst != dst_ip:
113         raise RuntimeError(
114             f"Received packet has invalid destination address: "
115             f"{pkt_recv[ip_layer.name].dst} should be: {dst_ip}"
116         )
117
118     if pkt_recv[ip_layer].src != src_ip:
119         raise RuntimeError(
120             f"Received packet has invalid destination address: "
121             f"{pkt_recv[ip_layer.name].dst} should be: {src_ip}"
122         )
123
124     if ip_layer == IP and pkt_recv[ip_layer].proto != 61:
125         raise RuntimeError(
126             f"Received packet has invalid IP protocol: "
127             f"{pkt_recv[ip_layer].proto} should be: 61"
128         )
129
130
131 def main():
132     """Send and receive IPsec packet."""
133
134     args = TrafficScriptArg(
135         [
136             u"tx_src_mac", u"tx_dst_mac", u"rx_src_mac", u"rx_dst_mac",
137             u"src_ip", u"dst_ip", u"crypto_alg", u"crypto_key", u"integ_alg",
138             u"integ_key", u"l_spi", u"r_spi"
139         ],
140         [u"src_tun", u"dst_tun"]
141     )
142
143     tx_txq = TxQueue(args.get_arg(u"tx_if"))
144     tx_rxq = RxQueue(args.get_arg(u"tx_if"))
145     rx_txq = TxQueue(args.get_arg(u"rx_if"))
146     rx_rxq = RxQueue(args.get_arg(u"rx_if"))
147
148     tx_src_mac = args.get_arg(u"tx_src_mac")
149     tx_dst_mac = args.get_arg(u"tx_dst_mac")
150     rx_src_mac = args.get_arg(u"rx_src_mac")
151     rx_dst_mac = args.get_arg(u"rx_dst_mac")
152     src_ip = args.get_arg(u"src_ip")
153     dst_ip = args.get_arg(u"dst_ip")
154     src_tun = args.get_arg(u"src_tun")
155     dst_tun = args.get_arg(u"dst_tun")
156     crypto_alg = args.get_arg(u"crypto_alg")
157     crypto_key = args.get_arg(u"crypto_key")
158     integ_alg = args.get_arg(u"integ_alg")
159     integ_key = args.get_arg(u"integ_key")
160     l_spi = int(args.get_arg(u"l_spi"))
161     r_spi = int(args.get_arg(u"r_spi"))
162
163     ip_layer = IP if ip_address(src_ip).version == 4 else IPv6
164     ip_pkt = ip_layer(src=src_ip, dst=dst_ip, proto=61) if ip_layer == IP \
165         else ip_layer(src=src_ip, dst=dst_ip)
166
167     tunnel_out = ip_layer(src=src_tun, dst=dst_tun) if src_tun and dst_tun \
168         else None
169     tunnel_in = ip_layer(src=dst_tun, dst=src_tun) if src_tun and dst_tun \
170         else None
171
172     if not (src_tun and dst_tun):
173         src_tun = src_ip
174
175     sa_in = SecurityAssociation(
176         ESP, spi=r_spi, crypt_algo=crypto_alg,
177         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
178         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_in
179     )
180
181     sa_out = SecurityAssociation(
182         ESP, spi=l_spi, crypt_algo=crypto_alg,
183         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
184         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_out
185     )
186
187     sent_packets = list()
188     tx_pkt_send = (
189         Ether(src=tx_src_mac, dst=tx_dst_mac) / sa_out.encrypt(ip_pkt))
190
191     tx_pkt_send /= Raw()
192     size_limit = 78 if ip_layer == IPv6 else 64
193     if len(tx_pkt_send) < (size_limit - 14):
194         tx_pkt_send[Raw].load += (b"\0" * (size_limit - 14 - len(tx_pkt_send)))
195     sent_packets.append(tx_pkt_send)
196     tx_txq.send(tx_pkt_send)
197
198     while True:
199         rx_pkt_recv = rx_rxq.recv(2)
200
201         if rx_pkt_recv is None:
202             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
203
204         if rx_pkt_recv.haslayer(ICMPv6ND_NS):
205             # read another packet in the queue if the current one is ICMPv6ND_NS
206             continue
207         if rx_pkt_recv.haslayer(ICMPv6MLReport2):
208             # read another packet in the queue if the current one is
209             # ICMPv6MLReport2
210             continue
211         if rx_pkt_recv.haslayer(ICMPv6ND_RA):
212             # read another packet in the queue if the current one is
213             # ICMPv6ND_RA
214             continue
215
216         # otherwise process the current packet
217         break
218
219     check_ip(
220         rx_pkt_recv, ip_layer, src_ip, dst_ip
221     )
222
223     rx_ip_pkt = ip_layer(src=dst_ip, dst=src_ip, proto=61) if ip_layer == IP \
224         else ip_layer(src=dst_ip, dst=src_ip)
225     rx_pkt_send = (
226         Ether(src=rx_dst_mac, dst=rx_src_mac) / rx_ip_pkt)
227     rx_pkt_send /= Raw()
228     if len(rx_pkt_send) < size_limit:
229         rx_pkt_send[Raw].load += (b"\0" * (size_limit - len(rx_pkt_send)))
230     rx_txq.send(rx_pkt_send)
231
232     while True:
233         tx_pkt_recv = tx_rxq.recv(2, ignore=sent_packets)
234
235         if tx_pkt_recv is None:
236             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
237
238         if tx_pkt_recv.haslayer(ICMPv6ND_NS):
239             # read another packet in the queue if the current one is ICMPv6ND_NS
240             continue
241         if tx_pkt_recv.haslayer(ICMPv6MLReport2):
242             # read another packet in the queue if the current one is
243             # ICMPv6MLReport2
244             continue
245         if tx_pkt_recv.haslayer(ICMPv6ND_RA):
246             # read another packet in the queue if the current one is
247             # ICMPv6ND_RA
248             continue
249
250         break
251
252     check_ipsec(tx_pkt_recv, ip_layer, src_tun, dst_ip, src_ip, sa_in)
253
254
255 if __name__ == u"__main__":
256     main()