Ignore unexpected ICMPv6 Neighbor Discovery - Neighbor Solicitation packets
[csit.git] / resources / traffic_scripts / ipsec.py
index 13d44b8..1561738 100755 (executable)
 import sys
 import logging
 
+# pylint: disable=no-name-in-module
+# pylint: disable=import-error
 logging.getLogger("scapy.runtime").setLevel(logging.ERROR)
-from scapy.all import Ether, IP, ICMP, IPv6, ICMPv6EchoRequest, ICMPv6EchoReply
+
+from scapy.all import Ether
+from scapy.layers.inet import ICMP, IP
+from scapy.layers.inet6 import IPv6, ICMPv6ND_NS
+from scapy.layers.inet6 import ICMPv6EchoRequest, ICMPv6EchoReply
 from scapy.layers.ipsec import SecurityAssociation, ESP
 from ipaddress import ip_address
 
@@ -39,7 +45,7 @@ def check_ipv4(pkt_recv, dst_tun, src_ip, dst_ip, sa_in):
     :type dst_tun: str
     :type src_ip: str
     :type dst_ip: str
-    :type sa_sa: scapy.layers.ipsec.SecurityAssociation
+    :type sa_in: scapy.layers.ipsec.SecurityAssociation
     :raises RuntimeError: If received packet is invalid.
     """
     if not pkt_recv.haslayer(IP):
@@ -55,15 +61,15 @@ def check_ipv4(pkt_recv, dst_tun, src_ip, dst_ip, sa_in):
         raise RuntimeError(
             'Not an ESP packet received: {0}'.format(pkt_recv.__repr__()))
 
-    ip_pkt = pkt_recv['IP']
+    ip_pkt = pkt_recv[IP]
     d_pkt = sa_in.decrypt(ip_pkt)
 
-    if d_pkt['IP'].dst != dst_ip:
+    if d_pkt[IP].dst != dst_ip:
         raise RuntimeError(
             'Decrypted packet has invalid destination address: {0} '
             'should be: {1}'.format(d_pkt['IP'].dst, dst_ip))
 
-    if d_pkt['IP'].src != src_ip:
+    if d_pkt[IP].src != src_ip:
         raise RuntimeError(
             'Decrypted packet has invalid source address: {0} should be: {1}'
             .format(d_pkt['IP'].src, src_ip))
@@ -93,7 +99,7 @@ def check_ipv6(pkt_recv, dst_tun, src_ip, dst_ip, sa_in):
         raise RuntimeError(
             'Not an IPv6 packet received: {0}'.format(pkt_recv.__repr__()))
 
-    if pkt_recv['IPv6'].dst != dst_tun:
+    if pkt_recv[IPv6].dst != dst_tun:
         raise RuntimeError(
             'Received packet has invalid destination address: {0} '
             'should be: {1}'.format(pkt_recv['IPv6'].dst, dst_tun))
@@ -102,15 +108,15 @@ def check_ipv6(pkt_recv, dst_tun, src_ip, dst_ip, sa_in):
         raise RuntimeError(
             'Not an ESP packet received: {0}'.format(pkt_recv.__repr__()))
 
-    ip_pkt = pkt_recv['IPv6']
+    ip_pkt = pkt_recv[IPv6]
     d_pkt = sa_in.decrypt(ip_pkt)
 
-    if d_pkt['IPv6'].dst != dst_ip:
+    if d_pkt[IPv6].dst != dst_ip:
         raise RuntimeError(
             'Decrypted packet has invalid destination address {0}: '
             'should be: {1}'.format(d_pkt['IPv6'].dst, dst_ip))
 
-    if d_pkt['IPv6'].src != src_ip:
+    if d_pkt[IPv6].src != src_ip:
         raise RuntimeError(
             'Decrypted packet has invalid source address: {0} should be: {1}'
             .format(d_pkt['IPv6'].src, src_ip))
@@ -175,25 +181,33 @@ def main():
     sent_packets = []
 
     if is_ipv4:
-        ip_pkt = IP(src=src_ip, dst=dst_ip) / \
-                 ICMP()
+        ip_pkt = (IP(src=src_ip, dst=dst_ip) /
+                  ICMP())
         ip_pkt = IP(str(ip_pkt))
     else:
-        ip_pkt = IPv6(src=src_ip, dst=dst_ip) / \
-                 ICMPv6EchoRequest()
+        ip_pkt = (IPv6(src=src_ip, dst=dst_ip) /
+                  ICMPv6EchoRequest())
         ip_pkt = IPv6(str(ip_pkt))
 
     e_pkt = sa_out.encrypt(ip_pkt)
-    pkt_send = Ether(src=src_mac, dst=dst_mac) / \
-               e_pkt
+    pkt_send = (Ether(src=src_mac, dst=dst_mac) /
+                e_pkt)
 
     sent_packets.append(pkt_send)
     txq.send(pkt_send)
 
-    pkt_recv = rxq.recv(2, sent_packets)
+    while True:
+        pkt_recv = rxq.recv(2, sent_packets)
 
-    if pkt_recv is None:
-        raise RuntimeError('ESP packet Rx timeout')
+        if pkt_recv is None:
+            raise RuntimeError('ESP packet Rx timeout')
+
+        if pkt_recv.haslayer(ICMPv6ND_NS):
+            # read another packet in the queue if the current one is ICMPv6ND_NS
+            continue
+        else:
+            # otherwise process the current packet
+            break
 
     if is_ipv4:
         check_ipv4(pkt_recv, src_tun, dst_ip, src_ip, sa_in)
@@ -202,5 +216,6 @@ def main():
 
     sys.exit(0)
 
+
 if __name__ == "__main__":
     main()