make test: improve documentation and PEP8 compliance
[vpp.git] / test / test_snat.py
index e90d9c0..d23becf 100644 (file)
@@ -2,12 +2,12 @@
 
 import socket
 import unittest
-from logging import *
 
 from framework import VppTestCase, VppTestRunner
 
 from scapy.layers.inet import IP, TCP, UDP, ICMP
 from scapy.layers.l2 import Ether
+from util import ppp
 
 
 class TestSNAT(VppTestCase):
@@ -34,6 +34,9 @@ class TestSNAT(VppTestCase):
                 i.config_ip4()
                 i.resolve_arp()
 
+            cls.pg0.generate_remote_hosts(2)
+            cls.pg0.configure_ipv4_neighbors()
+
             cls.overlapping_interfaces = list(list(cls.pg_interfaces[4:7]))
 
             for i in cls.overlapping_interfaces:
@@ -85,7 +88,7 @@ class TestSNAT(VppTestCase):
         :param dst_ip: Destination IP address (Default use global SNAT address)
         """
         if dst_ip is None:
-             dst_ip=self.snat_addr
+            dst_ip = self.snat_addr
         pkts = []
         # TCP
         p = (Ether(dst=out_if.local_mac, src=out_if.remote_mac) /
@@ -127,13 +130,15 @@ class TestSNAT(VppTestCase):
                     if same_port:
                         self.assertEqual(packet[TCP].sport, self.tcp_port_in)
                     else:
-                        self.assertNotEqual(packet[TCP].sport, self.tcp_port_in)
+                        self.assertNotEqual(
+                            packet[TCP].sport, self.tcp_port_in)
                     self.tcp_port_out = packet[TCP].sport
                 elif packet.haslayer(UDP):
                     if same_port:
                         self.assertEqual(packet[UDP].sport, self.udp_port_in)
                     else:
-                        self.assertNotEqual(packet[UDP].sport, self.udp_port_in)
+                        self.assertNotEqual(
+                            packet[UDP].sport, self.udp_port_in)
                     self.udp_port_out = packet[UDP].sport
                 else:
                     if same_port:
@@ -142,8 +147,8 @@ class TestSNAT(VppTestCase):
                         self.assertNotEqual(packet[ICMP].id, self.icmp_id_in)
                     self.icmp_id_out = packet[ICMP].id
             except:
-                error("Unexpected or invalid packet (outside network):")
-                error(packet.show())
+                self.logger.error(ppp("Unexpected or invalid packet "
+                                      "(outside network):", packet))
                 raise
 
     def verify_capture_in(self, capture, in_if, packet_num=3):
@@ -165,8 +170,8 @@ class TestSNAT(VppTestCase):
                 else:
                     self.assertEqual(packet[ICMP].id, self.icmp_id_in)
             except:
-                error("Unexpected or invalid packet (inside network):")
-                error(packet.show())
+                self.logger.error(ppp("Unexpected or invalid packet "
+                                      "(inside network):", packet))
                 raise
 
     def clear_snat(self):
@@ -212,8 +217,14 @@ class TestSNAT(VppTestCase):
             addr_only = 0
         l_ip = socket.inet_pton(socket.AF_INET, local_ip)
         e_ip = socket.inet_pton(socket.AF_INET, external_ip)
-        self.vapi.snat_add_static_mapping(l_ip, e_ip, local_port, external_port,
-                                          addr_only, vrf_id, is_add)
+        self.vapi.snat_add_static_mapping(
+            l_ip,
+            e_ip,
+            local_port,
+            external_port,
+            addr_only,
+            vrf_id,
+            is_add)
 
     def snat_add_address(self, ip, is_add=1):
         """
@@ -238,7 +249,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in
@@ -246,7 +257,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
     def test_static_in(self):
@@ -267,7 +278,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_out(capture, nat_ip, True)
 
         # out2in
@@ -275,7 +286,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
     def test_static_out(self):
@@ -296,7 +307,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
         # in2out
@@ -304,7 +315,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_out(capture, nat_ip, True)
 
     def test_static_with_port_in(self):
@@ -330,7 +341,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in
@@ -338,7 +349,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
     def test_static_with_port_out(self):
@@ -364,7 +375,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
         # in2out
@@ -372,7 +383,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
     def test_static_vrf_aware(self):
@@ -398,7 +409,7 @@ class TestSNAT(VppTestCase):
         self.pg4.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture, nat_ip1, True)
 
         # inside interface VRF don't match SNAT static mapping VRF (packets
@@ -407,11 +418,12 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
-        self.verify_capture_out(capture, packet_num=0)
+        self.pg3.assert_nothing_captured()
 
     def test_multiple_inside_interfaces(self):
-        """ SNAT multiple inside interfaces with non-overlapping address space """
+        """
+        SNAT multiple inside interfaces with non-overlapping address space
+        """
 
         self.snat_add_address(self.snat_addr)
         self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index)
@@ -425,7 +437,7 @@ class TestSNAT(VppTestCase):
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 1st interface
@@ -433,7 +445,7 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture()
+        capture = self.pg0.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg0)
 
         # in2out 2nd interface
@@ -441,7 +453,7 @@ class TestSNAT(VppTestCase):
         self.pg1.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 2nd interface
@@ -449,7 +461,7 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture()
+        capture = self.pg1.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg1)
 
         # in2out 3rd interface
@@ -457,7 +469,7 @@ class TestSNAT(VppTestCase):
         self.pg2.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 3rd interface
@@ -465,7 +477,7 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg2.get_capture()
+        capture = self.pg2.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg2)
 
     def test_inside_overlapping_interfaces(self):
@@ -483,7 +495,7 @@ class TestSNAT(VppTestCase):
         self.pg4.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 1st interface
@@ -491,7 +503,7 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg4.get_capture()
+        capture = self.pg4.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg4)
 
         # in2out 2nd interface
@@ -499,7 +511,7 @@ class TestSNAT(VppTestCase):
         self.pg5.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 2nd interface
@@ -507,7 +519,7 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg5.get_capture()
+        capture = self.pg5.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg5)
 
         # in2out 3rd interface
@@ -515,7 +527,7 @@ class TestSNAT(VppTestCase):
         self.pg6.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg3.get_capture()
+        capture = self.pg3.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         # out2in 3rd interface
@@ -523,9 +535,94 @@ class TestSNAT(VppTestCase):
         self.pg3.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg6.get_capture()
+        capture = self.pg6.get_capture(len(pkts))
         self.verify_capture_in(capture, self.pg6)
 
+    def test_hairpinning(self):
+        """ SNAT hairpinning """
+
+        host = self.pg0.remote_hosts[0]
+        server = self.pg0.remote_hosts[1]
+        host_in_port = 1234
+        host_out_port = 0
+        server_in_port = 5678
+        server_out_port = 8765
+
+        self.snat_add_address(self.snat_addr)
+        self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index)
+        self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index,
+                                                 is_inside=0)
+        # add static mapping for server
+        self.snat_add_static_mapping(server.ip4, self.snat_addr,
+                                     server_in_port, server_out_port)
+
+        # send packet from host to server
+        p = (Ether(src=host.mac, dst=self.pg0.local_mac) /
+             IP(src=host.ip4, dst=self.snat_addr) /
+             TCP(sport=host_in_port, dport=server_out_port))
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        p = capture[0]
+        try:
+            ip = p[IP]
+            tcp = p[TCP]
+            self.assertEqual(ip.src, self.snat_addr)
+            self.assertEqual(ip.dst, server.ip4)
+            self.assertNotEqual(tcp.sport, host_in_port)
+            self.assertEqual(tcp.dport, server_in_port)
+            host_out_port = tcp.sport
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:", p))
+            raise
+
+        # send reply from server to host
+        p = (Ether(src=server.mac, dst=self.pg0.local_mac) /
+             IP(src=server.ip4, dst=self.snat_addr) /
+             TCP(sport=server_in_port, dport=host_out_port))
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        p = capture[0]
+        try:
+            ip = p[IP]
+            tcp = p[TCP]
+            self.assertEqual(ip.src, self.snat_addr)
+            self.assertEqual(ip.dst, host.ip4)
+            self.assertEqual(tcp.sport, server_out_port)
+            self.assertEqual(tcp.dport, host_in_port)
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:"), p)
+            raise
+
+    def test_max_translations_per_user(self):
+        """ MAX translations per user - recycle the least recently used """
+
+        self.snat_add_address(self.snat_addr)
+        self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index)
+        self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index,
+                                                 is_inside=0)
+
+        # get maximum number of translations per user
+        snat_config = self.vapi.snat_show_config()
+
+        # send more than maximum number of translations per user packets
+        pkts_num = snat_config.max_translations_per_user + 5
+        pkts = []
+        for port in range(0, pkts_num):
+            p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+                 IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) /
+                 TCP(sport=1025 + port))
+            pkts.append(p)
+        self.pg0.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+
+        # verify number of translated packet
+        self.pg1.get_capture(pkts_num)
+
     def tearDown(self):
         super(TestSNAT, self).tearDown()
         if not self.vpp_dead: