License: Wrap GPL block to 80 characters
[csit.git] / GPL / traffic_scripts / nat.py
1 #!/usr/bin/env python3
2
3 # Copyright (c) 2021 Cisco and/or its affiliates.
4 #
5 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6 #
7 # Licensed under the Apache License 2.0 or
8 # GNU General Public License v2.0 or later;  you may not use this file
9 # except in compliance with one of these Licenses. You
10 # may obtain a copy of the Licenses at:
11 #
12 #     http://www.apache.org/licenses/LICENSE-2.0
13 #     https://www.gnu.org/licenses/old-licenses/gpl-2.0-standalone.html
14 #
15 # Note: If this file is linked with Scapy, which is GPLv2+, your use of it
16 # must be under GPLv2+.  If at any point in the future it is no longer linked
17 # with Scapy (or other GPLv2+ licensed software), you are free to choose
18 # Apache 2.
19 #
20 # Unless required by applicable law or agreed to in writing, software
21 # distributed under the License is distributed on an "AS IS" BASIS,
22 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 # See the License for the specific language governing permissions and
24 # limitations under the License.
25
26 """Traffic script for NAT verification."""
27
28 import sys
29
30 import ipaddress
31
32 from scapy.layers.inet import IP, TCP, UDP
33 from scapy.layers.inet6 import IPv6, ICMPv6ND_NS
34 from scapy.layers.l2 import Ether
35 from scapy.packet import Raw
36
37 from .PacketVerifier import RxQueue, TxQueue
38 from .TrafficScriptArg import TrafficScriptArg
39
40
41 def valid_ipv4(ip):
42     try:
43         ipaddress.IPv4Address(ip)
44         return True
45     except (AttributeError, ipaddress.AddressValueError):
46         return False
47
48
49 def valid_ipv6(ip):
50     try:
51         ipaddress.IPv6Address(ip)
52         return True
53     except (AttributeError, ipaddress.AddressValueError):
54         return False
55
56
57 def main():
58     """Send, receive and check IP/IPv6 packets with UDP/TCP layer passing
59     through NAT.
60     """
61     args = TrafficScriptArg(
62         [
63             u"tx_src_mac", u"rx_dst_mac", u"src_ip_in", u"src_ip_out",
64             u"dst_ip", u"tx_dst_mac", u"rx_src_mac", u"protocol",
65             u"src_port_in", u"src_port_out", u"dst_port"
66         ]
67     )
68
69     tx_src_mac = args.get_arg(u"tx_src_mac")
70     tx_dst_mac = args.get_arg(u"tx_dst_mac")
71     rx_dst_mac = args.get_arg(u"rx_dst_mac")
72     rx_src_mac = args.get_arg(u"rx_src_mac")
73     src_ip_in = args.get_arg(u"src_ip_in")
74     src_ip_out = args.get_arg(u"src_ip_out")
75     dst_ip = args.get_arg(u"dst_ip")
76     protocol = args.get_arg(u"protocol")
77     sport_in = int(args.get_arg(u"src_port_in"))
78     try:
79         sport_out = int(args.get_arg(u"src_port_out"))
80     except ValueError:
81         sport_out = None
82     dst_port = int(args.get_arg(u"dst_port"))
83
84     tx_txq = TxQueue(args.get_arg(u"tx_if"))
85     tx_rxq = RxQueue(args.get_arg(u"tx_if"))
86     rx_txq = TxQueue(args.get_arg(u"rx_if"))
87     rx_rxq = RxQueue(args.get_arg(u"rx_if"))
88
89     sent_packets = list()
90     pkt_raw = Ether(src=tx_src_mac, dst=tx_dst_mac)
91
92     if valid_ipv4(src_ip_in) and valid_ipv4(dst_ip):
93         ip_layer = IP
94     elif valid_ipv6(src_ip_in) and valid_ipv6(dst_ip):
95         ip_layer = IPv6
96     else:
97         raise ValueError(u"IP not in correct format")
98     pkt_raw /= ip_layer(src=src_ip_in, dst=dst_ip)
99
100     if protocol == u"UDP":
101         pkt_raw /= UDP(sport=sport_in, dport=dst_port)
102         proto_layer = UDP
103     elif protocol == u"TCP":
104         # flags=0x2 => SYN flag set
105         pkt_raw /= TCP(sport=sport_in, dport=dst_port, flags=0x2)
106         proto_layer = TCP
107     else:
108         raise ValueError(u"Incorrect protocol")
109
110     pkt_raw /= Raw()
111     sent_packets.append(pkt_raw)
112     tx_txq.send(pkt_raw)
113
114     while True:
115         ether = rx_rxq.recv(2)
116
117         if ether is None:
118             raise RuntimeError(u"IP packet Rx timeout")
119
120         if ether.haslayer(ICMPv6ND_NS):
121             # read another packet in the queue if the current one is ICMPv6ND_NS
122             continue
123         else:
124             # otherwise process the current packet
125             break
126
127     if rx_dst_mac != ether[Ether].dst or rx_src_mac != ether[Ether].src:
128         raise RuntimeError(f"Matching packet unsuccessful: {ether!r}")
129
130     ip_pkt = ether.payload
131     if not isinstance(ip_pkt, ip_layer):
132         raise RuntimeError(f"Not an {ip_layer!s} packet received: {ip_pkt!r}")
133     if ip_pkt.src != src_ip_out:
134         raise RuntimeError(
135             f"Matching Src IP address unsuccessful: "
136             f"{src_ip_out} != {ip_pkt.src}"
137         )
138     if ip_pkt.dst != dst_ip:
139         raise RuntimeError(
140             f"Matching Dst IP address unsuccessful: {dst_ip} != {ip_pkt.dst}"
141         )
142
143     proto_pkt = ip_pkt.payload
144     if not isinstance(proto_pkt, proto_layer):
145         raise RuntimeError(
146             f"Not a {proto_layer!s} packet received: {proto_pkt!r}"
147         )
148     if sport_out is not None:
149         if proto_pkt.sport != sport_out:
150             raise RuntimeError(
151                 f"Matching Src {proto_layer!s} port unsuccessful: "
152                 f"{sport_out} != {proto_pkt.sport}"
153             )
154     else:
155         sport_out = proto_pkt.sport
156     if proto_pkt.dport != dst_port:
157         raise RuntimeError(
158             f"Matching Dst {proto_layer!s} port unsuccessful: "
159             f"{dst_port} != {proto_pkt.dport}"
160         )
161     if proto_layer == TCP:
162         if proto_pkt.flags != 0x2:
163             raise RuntimeError(
164                 f"Not a TCP SYN packet received: {proto_pkt!r}"
165             )
166
167     pkt_raw = Ether(src=rx_dst_mac, dst=rx_src_mac)
168     pkt_raw /= ip_layer(src=dst_ip, dst=src_ip_out)
169     pkt_raw /= proto_layer(sport=dst_port, dport=sport_out)
170     if proto_layer == TCP:
171         # flags=0x12 => SYN, ACK flags set
172         pkt_raw[TCP].flags = 0x12
173     pkt_raw /= Raw()
174     rx_txq.send(pkt_raw)
175
176     while True:
177         ether = tx_rxq.recv(2, ignore=sent_packets)
178
179         if ether is None:
180             raise RuntimeError(u"IP packet Rx timeout")
181
182         if ether.haslayer(ICMPv6ND_NS):
183             # read another packet in the queue if the current one is ICMPv6ND_NS
184             continue
185         else:
186             # otherwise process the current packet
187             break
188
189     if ether[Ether].dst != tx_src_mac or ether[Ether].src != tx_dst_mac:
190         raise RuntimeError(f"Matching packet unsuccessful: {ether!r}")
191
192     ip_pkt = ether.payload
193     if not isinstance(ip_pkt, ip_layer):
194         raise RuntimeError(f"Not an {ip_layer!s} packet received: {ip_pkt!r}")
195     if ip_pkt.src != dst_ip:
196         raise RuntimeError(
197             f"Matching Src IP address unsuccessful: {dst_ip} != {ip_pkt.src}"
198         )
199     if ip_pkt.dst != src_ip_in:
200         raise RuntimeError(
201             f"Matching Dst IP address unsuccessful: {src_ip_in} != {ip_pkt.dst}"
202         )
203
204     proto_pkt = ip_pkt.payload
205     if not isinstance(proto_pkt, proto_layer):
206         raise RuntimeError(
207             f"Not a {proto_layer!s} packet received: {proto_pkt!r}"
208         )
209     if proto_pkt.sport != dst_port:
210         raise RuntimeError(
211             f"Matching Src {proto_layer!s} port unsuccessful: "
212             f"{dst_port} != {proto_pkt.sport}"
213         )
214     if proto_pkt.dport != sport_in:
215         raise RuntimeError(
216             f"Matching Dst {proto_layer!s} port unsuccessful: "
217             f"{sport_in} != {proto_pkt.dport}"
218         )
219     if proto_layer == TCP:
220         if proto_pkt.flags != 0x12:
221             raise RuntimeError(
222                 f"Not a TCP SYN-ACK packet received: {proto_pkt!r}"
223             )
224
225     sys.exit(0)
226
227
228 if __name__ == u"__main__":
229     main()