FIX: Add ICMPv6MLReport2 masking
[csit.git] / GPL / traffic_scripts / ipsec_policy.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2021 Cisco and/or its affiliates.
4 #
5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 #
7 # Licensed under the Apache License 2.0 or
8 # GNU General Public License v2.0 or later;  you may not use this file
9 # except in compliance with one of these Licenses. You
10 # may obtain a copy of the Licenses at:
11 #
12 #     http://www.apache.org/licenses/LICENSE-2.0
13 #     https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
14 #
15 # Note: If this file is linked with Scapy, which is GPLv2+, your use of it
16 # must be under GPLv2+.  If at any point in the future it is no longer linked
17 # with Scapy (or other GPLv2+ licensed software), you are free to choose
18 # Apache 2.
19 #
20 # Unless required by applicable law or agreed to in writing, software
21 # distributed under the License is distributed on an "AS IS" BASIS,
22 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 # See the License for the specific language governing permissions and
24 # limitations under the License.
25
26 """Traffic script for IPsec verification."""
27
28 import sys
29
30 from ipaddress import ip_address
31 from scapy.layers.inet import IP
32 from scapy.layers.inet6 import IPv6, ICMPv6ND_NS, ICMPv6MLReport2
33 from scapy.layers.ipsec import SecurityAssociation, ESP
34 from scapy.layers.l2 import Ether
35 from scapy.packet import Raw
36
37 from .PacketVerifier import RxQueue, TxQueue
38 from .TrafficScriptArg import TrafficScriptArg
39
40
41 def check_ipsec(pkt_recv, ip_layer, dst_tun, src_ip, dst_ip, sa_in):
42     """Check received IPsec packet.
43
44     :param pkt_recv: Received packet to verify.
45     :param ip_layer: Scapy IP layer.
46     :param dst_tun: IPsec tunnel destination address.
47     :param src_ip: Source IP/IPv6 address of original IP/IPv6 packet.
48     :param dst_ip: Destination IP/IPv6 address of original IP/IPv6 packet.
49     :param sa_in: IPsec SA for packet decryption.
50     :type pkt_recv: scapy.Ether
51     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
52     :type dst_tun: str
53     :type src_ip: str
54     :type dst_ip: str
55     :type sa_in: scapy.layers.ipsec.SecurityAssociation
56     :raises RuntimeError: If received packet is invalid.
57     """
58     if not pkt_recv.haslayer(ip_layer):
59         raise RuntimeError(
60             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
61         )
62
63     if pkt_recv[ip_layer].dst != dst_tun:
64         raise RuntimeError(
65             f"Received packet has invalid destination address: "
66             f"{pkt_recv[ip_layer].dst} should be: {dst_tun}"
67         )
68
69     if not pkt_recv.haslayer(ESP):
70         raise RuntimeError(f"Not an ESP packet received: {pkt_recv!r}")
71
72     ip_pkt = pkt_recv[ip_layer]
73     d_pkt = sa_in.decrypt(ip_pkt)
74
75     if d_pkt[ip_layer].dst != dst_ip:
76         raise RuntimeError(
77             f"Decrypted packet has invalid destination address: "
78             f"{d_pkt[ip_layer].dst} should be: {dst_ip}"
79         )
80
81     if d_pkt[ip_layer].src != src_ip:
82         raise RuntimeError(
83             f"Decrypted packet has invalid source address: "
84             f"{d_pkt[ip_layer].src} should be: {src_ip}"
85         )
86
87     if ip_layer == IP and d_pkt[ip_layer].proto != 61:
88         raise RuntimeError(
89             f"Decrypted packet has invalid IP protocol: "
90             f"{d_pkt[ip_layer].proto} should be: 61"
91         )
92
93
94 def check_ip(pkt_recv, ip_layer, src_ip, dst_ip):
95     """Check received IP/IPv6 packet.
96
97     :param pkt_recv: Received packet to verify.
98     :param ip_layer: Scapy IP layer.
99     :param src_ip: Source IP/IPv6 address.
100     :param dst_ip: Destination IP/IPv6 address.
101     :type pkt_recv: scapy.Ether
102     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
103     :type src_ip: str
104     :type dst_ip: str
105     :raises RuntimeError: If received packet is invalid.
106     """
107     if not pkt_recv.haslayer(ip_layer):
108         raise RuntimeError(
109             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
110         )
111
112     if pkt_recv[ip_layer].dst != dst_ip:
113         raise RuntimeError(
114             f"Received packet has invalid destination address: "
115             f"{pkt_recv[ip_layer.name].dst} should be: {dst_ip}"
116         )
117
118     if pkt_recv[ip_layer].src != src_ip:
119         raise RuntimeError(
120             f"Received packet has invalid destination address: "
121             f"{pkt_recv[ip_layer.name].dst} should be: {src_ip}"
122         )
123
124     if ip_layer == IP and pkt_recv[ip_layer].proto != 61:
125         raise RuntimeError(
126             f"Received packet has invalid IP protocol: "
127             f"{pkt_recv[ip_layer].proto} should be: 61"
128         )
129
130
131 # TODO: Pylint says too-many-locals and too-many-statements. Refactor!
132 def main():
133     """Send and receive IPsec packet."""
134
135     args = TrafficScriptArg(
136         [
137             u"tx_src_mac", u"tx_dst_mac", u"rx_src_mac", u"rx_dst_mac",
138             u"src_ip", u"dst_ip", u"crypto_alg", u"crypto_key", u"integ_alg",
139             u"integ_key", u"l_spi", u"r_spi"
140         ],
141         [u"src_tun", u"dst_tun"]
142     )
143
144     tx_txq = TxQueue(args.get_arg(u"tx_if"))
145     tx_rxq = RxQueue(args.get_arg(u"tx_if"))
146     rx_txq = TxQueue(args.get_arg(u"rx_if"))
147     rx_rxq = RxQueue(args.get_arg(u"rx_if"))
148
149     tx_src_mac = args.get_arg(u"tx_src_mac")
150     tx_dst_mac = args.get_arg(u"tx_dst_mac")
151     rx_src_mac = args.get_arg(u"rx_src_mac")
152     rx_dst_mac = args.get_arg(u"rx_dst_mac")
153     src_ip = args.get_arg(u"src_ip")
154     dst_ip = args.get_arg(u"dst_ip")
155     crypto_alg = args.get_arg(u"crypto_alg")
156     crypto_key = args.get_arg(u"crypto_key")
157     integ_alg = args.get_arg(u"integ_alg")
158     integ_key = args.get_arg(u"integ_key")
159     l_spi = int(args.get_arg(u"l_spi"))
160     r_spi = int(args.get_arg(u"r_spi"))
161     src_tun = args.get_arg(u"src_tun")
162     dst_tun = args.get_arg(u"dst_tun")
163
164     ip_layer = IP if ip_address(src_ip).version == 4 else IPv6
165
166     tunnel_out = ip_layer(src=src_tun, dst=dst_tun) if src_tun and dst_tun \
167         else None
168     tunnel_in = ip_layer(src=dst_tun, dst=src_tun) if src_tun and dst_tun \
169         else None
170
171     if not (src_tun and dst_tun):
172         src_tun = src_ip
173
174     sa_in = SecurityAssociation(
175         ESP, spi=r_spi, crypt_algo=crypto_alg,
176         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
177         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_in
178     )
179
180     sa_out = SecurityAssociation(
181         ESP, spi=l_spi, crypt_algo=crypto_alg,
182         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
183         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_out
184     )
185
186     ip_pkt = ip_layer(src=src_ip, dst=dst_ip, proto=61) if ip_layer == IP \
187         else ip_layer(src=src_ip, dst=dst_ip)
188     ip_pkt = ip_layer(ip_pkt)
189
190     e_pkt = sa_out.encrypt(ip_pkt)
191     tx_pkt_send = (Ether(src=tx_src_mac, dst=tx_dst_mac) /
192                    e_pkt)
193
194     sent_packets = list()
195     tx_pkt_send /= Raw()
196     sent_packets.append(tx_pkt_send)
197     tx_txq.send(tx_pkt_send)
198
199     while True:
200         rx_pkt_recv = rx_rxq.recv(2)
201
202         if rx_pkt_recv is None:
203             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
204
205         if rx_pkt_recv.haslayer(ICMPv6ND_NS):
206             # read another packet in the queue if the current one is ICMPv6ND_NS
207             continue
208         elif rx_pkt_recv.haslayer(ICMPv6MLReport2):
209             # read another packet in the queue if the current one is
210             # ICMPv6MLReport2
211             continue
212         else:
213             # otherwise process the current packet
214             break
215
216     check_ip(rx_pkt_recv, ip_layer, src_ip, dst_ip)
217
218     rx_ip_pkt = ip_layer(src=dst_ip, dst=src_ip, proto=61) if ip_layer == IP \
219         else ip_layer(src=dst_ip, dst=src_ip)
220     rx_pkt_send = (Ether(src=rx_dst_mac, dst=rx_src_mac) /
221                    rx_ip_pkt)
222
223     rx_pkt_send /= Raw()
224     rx_txq.send(rx_pkt_send)
225
226     while True:
227         tx_pkt_recv = tx_rxq.recv(2, sent_packets)
228
229         if tx_pkt_recv is None:
230             raise RuntimeError(u"ESP packet Rx timeout")
231
232         if tx_pkt_recv.haslayer(ICMPv6ND_NS):
233             # read another packet in the queue if the current one is ICMPv6ND_NS
234             continue
235         else:
236             # otherwise process the current packet
237             break
238
239     check_ipsec(tx_pkt_recv, ip_layer, src_tun, dst_ip, src_ip, sa_in)
240
241     sys.exit(0)
242
243
244 if __name__ == u"__main__":
245     main()