da45565e1b8d7295cf472244e0d173ec59e00577
[csit.git] / GPL / traffic_scripts / ipsec_policy.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2020 Cisco and/or its affiliates.
4 #
5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 #
7 # Licensed under the Apache License 2.0 or
8 # GNU General Public License v2.0 or later;  you may not use this file
9 # except in compliance with one of these Licenses. You
10 # may obtain a copy of the Licenses at:
11 #
12 #     http://www.apache.org/licenses/LICENSE-2.0
13 #     https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
14 #
15 # Note: If this file is linked with Scapy, which is GPLv2+, your use of it
16 # must be under GPLv2+.  If at any point in the future it is no longer linked
17 # with Scapy (or other GPLv2+ licensed software), you are free to choose Apache 2.
18 #
19 # Unless required by applicable law or agreed to in writing, software
20 # distributed under the License is distributed on an "AS IS" BASIS,
21 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 # See the License for the specific language governing permissions and
23 # limitations under the License.
24
25 """Traffic script for IPsec verification."""
26
27 import sys
28
29 from ipaddress import ip_address
30 from scapy.layers.inet import IP
31 from scapy.layers.inet6 import IPv6, ICMPv6ND_NS
32 from scapy.layers.ipsec import SecurityAssociation, ESP
33 from scapy.layers.l2 import Ether
34 from scapy.packet import Raw
35
36 from .PacketVerifier import RxQueue, TxQueue
37 from .TrafficScriptArg import TrafficScriptArg
38
39
40 def check_ipsec(pkt_recv, ip_layer, dst_tun, src_ip, dst_ip, sa_in):
41     """Check received IPsec packet.
42
43     :param pkt_recv: Received packet to verify.
44     :param ip_layer: Scapy IP layer.
45     :param dst_tun: IPsec tunnel destination address.
46     :param src_ip: Source IP/IPv6 address of original IP/IPv6 packet.
47     :param dst_ip: Destination IP/IPv6 address of original IP/IPv6 packet.
48     :param sa_in: IPsec SA for packet decryption.
49     :type pkt_recv: scapy.Ether
50     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
51     :type dst_tun: str
52     :type src_ip: str
53     :type dst_ip: str
54     :type sa_in: scapy.layers.ipsec.SecurityAssociation
55     :raises RuntimeError: If received packet is invalid.
56     """
57     if not pkt_recv.haslayer(ip_layer):
58         raise RuntimeError(
59             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
60         )
61
62     if pkt_recv[ip_layer].dst != dst_tun:
63         raise RuntimeError(
64             f"Received packet has invalid destination address: "
65             f"{pkt_recv[ip_layer].dst} should be: {dst_tun}"
66         )
67
68     if not pkt_recv.haslayer(ESP):
69         raise RuntimeError(f"Not an ESP packet received: {pkt_recv!r}")
70
71     ip_pkt = pkt_recv[ip_layer]
72     d_pkt = sa_in.decrypt(ip_pkt)
73
74     if d_pkt[ip_layer].dst != dst_ip:
75         raise RuntimeError(
76             f"Decrypted packet has invalid destination address: "
77             f"{d_pkt[ip_layer].dst} should be: {dst_ip}"
78         )
79
80     if d_pkt[ip_layer].src != src_ip:
81         raise RuntimeError(
82             f"Decrypted packet has invalid source address: "
83             f"{d_pkt[ip_layer].src} should be: {src_ip}"
84         )
85
86     if ip_layer == IP and d_pkt[ip_layer].proto != 61:
87         raise RuntimeError(
88             f"Decrypted packet has invalid IP protocol: "
89             f"{d_pkt[ip_layer].proto} should be: 61"
90         )
91
92
93 def check_ip(pkt_recv, ip_layer, src_ip, dst_ip):
94     """Check received IP/IPv6 packet.
95
96     :param pkt_recv: Received packet to verify.
97     :param ip_layer: Scapy IP layer.
98     :param src_ip: Source IP/IPv6 address.
99     :param dst_ip: Destination IP/IPv6 address.
100     :type pkt_recv: scapy.Ether
101     :type ip_layer: scapy.layers.inet.IP or scapy.layers.inet6.IPv6
102     :type src_ip: str
103     :type dst_ip: str
104     :raises RuntimeError: If received packet is invalid.
105     """
106     if not pkt_recv.haslayer(ip_layer):
107         raise RuntimeError(
108             f"Not an {ip_layer.name} packet received: {pkt_recv!r}"
109         )
110
111     if pkt_recv[ip_layer].dst != dst_ip:
112         raise RuntimeError(
113             f"Received packet has invalid destination address: "
114             f"{pkt_recv[ip_layer.name].dst} should be: {dst_ip}"
115         )
116
117     if pkt_recv[ip_layer].src != src_ip:
118         raise RuntimeError(
119             f"Received packet has invalid destination address: "
120             f"{pkt_recv[ip_layer.name].dst} should be: {src_ip}"
121         )
122
123     if ip_layer == IP and pkt_recv[ip_layer].proto != 61:
124         raise RuntimeError(
125             f"Received packet has invalid IP protocol: "
126             f"{pkt_recv[ip_layer].proto} should be: 61"
127         )
128
129
130 # TODO: Pylint says too-many-locals and too-many-statements. Refactor!
131 def main():
132     """Send and receive IPsec packet."""
133
134     args = TrafficScriptArg(
135         [
136             u"tx_src_mac", u"tx_dst_mac", u"rx_src_mac", u"rx_dst_mac",
137             u"src_ip", u"dst_ip", u"crypto_alg", u"crypto_key", u"integ_alg",
138             u"integ_key", u"l_spi", u"r_spi"
139         ],
140         [u"src_tun", u"dst_tun"]
141     )
142
143     tx_txq = TxQueue(args.get_arg(u"tx_if"))
144     tx_rxq = RxQueue(args.get_arg(u"tx_if"))
145     rx_txq = TxQueue(args.get_arg(u"rx_if"))
146     rx_rxq = RxQueue(args.get_arg(u"rx_if"))
147
148     tx_src_mac = args.get_arg(u"tx_src_mac")
149     tx_dst_mac = args.get_arg(u"tx_dst_mac")
150     rx_src_mac = args.get_arg(u"rx_src_mac")
151     rx_dst_mac = args.get_arg(u"rx_dst_mac")
152     src_ip = args.get_arg(u"src_ip")
153     dst_ip = args.get_arg(u"dst_ip")
154     crypto_alg = args.get_arg(u"crypto_alg")
155     crypto_key = args.get_arg(u"crypto_key")
156     integ_alg = args.get_arg(u"integ_alg")
157     integ_key = args.get_arg(u"integ_key")
158     l_spi = int(args.get_arg(u"l_spi"))
159     r_spi = int(args.get_arg(u"r_spi"))
160     src_tun = args.get_arg(u"src_tun")
161     dst_tun = args.get_arg(u"dst_tun")
162
163     ip_layer = IP if ip_address(src_ip).version == 4 else IPv6
164
165     tunnel_out = ip_layer(src=src_tun, dst=dst_tun) if src_tun and dst_tun \
166         else None
167     tunnel_in = ip_layer(src=dst_tun, dst=src_tun) if src_tun and dst_tun \
168         else None
169
170     if not (src_tun and dst_tun):
171         src_tun = src_ip
172
173     sa_in = SecurityAssociation(
174         ESP, spi=r_spi, crypt_algo=crypto_alg,
175         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
176         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_in
177     )
178
179     sa_out = SecurityAssociation(
180         ESP, spi=l_spi, crypt_algo=crypto_alg,
181         crypt_key=crypto_key.encode(encoding=u"utf-8"), auth_algo=integ_alg,
182         auth_key=integ_key.encode(encoding=u"utf-8"), tunnel_header=tunnel_out
183     )
184
185     ip_pkt = ip_layer(src=src_ip, dst=dst_ip, proto=61) if ip_layer == IP \
186         else ip_layer(src=src_ip, dst=dst_ip)
187     ip_pkt = ip_layer(ip_pkt)
188
189     e_pkt = sa_out.encrypt(ip_pkt)
190     tx_pkt_send = (Ether(src=tx_src_mac, dst=tx_dst_mac) /
191                    e_pkt)
192
193     sent_packets = list()
194     tx_pkt_send /= Raw()
195     sent_packets.append(tx_pkt_send)
196     tx_txq.send(tx_pkt_send)
197
198     while True:
199         rx_pkt_recv = rx_rxq.recv(2)
200
201         if rx_pkt_recv is None:
202             raise RuntimeError(f"{ip_layer.name} packet Rx timeout")
203
204         if rx_pkt_recv.haslayer(ICMPv6ND_NS):
205             # read another packet in the queue if the current one is ICMPv6ND_NS
206             continue
207         else:
208             # otherwise process the current packet
209             break
210
211     check_ip(rx_pkt_recv, ip_layer, src_ip, dst_ip)
212
213     rx_ip_pkt = ip_layer(src=dst_ip, dst=src_ip, proto=61) if ip_layer == IP \
214         else ip_layer(src=dst_ip, dst=src_ip)
215     rx_pkt_send = (Ether(src=rx_dst_mac, dst=rx_src_mac) /
216                    rx_ip_pkt)
217
218     rx_pkt_send /= Raw()
219     rx_txq.send(rx_pkt_send)
220
221     while True:
222         tx_pkt_recv = tx_rxq.recv(2, sent_packets)
223
224         if tx_pkt_recv is None:
225             raise RuntimeError(u"ESP packet Rx timeout")
226
227         if tx_pkt_recv.haslayer(ICMPv6ND_NS):
228             # read another packet in the queue if the current one is ICMPv6ND_NS
229             continue
230         else:
231             # otherwise process the current packet
232             break
233
234     check_ipsec(tx_pkt_recv, ip_layer, src_tun, dst_ip, src_ip, sa_in)
235
236     sys.exit(0)
237
238
239 if __name__ == u"__main__":
240     main()