nat: fix multi worker scenarios
[vpp.git] / src / plugins / nat / test / test_nat44_ed.py
index 7f61eed..2ce7f23 100644 (file)
@@ -2,7 +2,7 @@
 
 import unittest
 from io import BytesIO
-from random import randint
+from random import randint, shuffle, choice
 
 import scapy.compat
 from framework import VppTestCase, VppTestRunner
@@ -1953,7 +1953,8 @@ class TestNAT44ED(NAT44EDTestCase):
 
 class TestNAT44EDMW(TestNAT44ED):
     """ NAT44ED MW Test Case """
-    vpp_worker_count = 1
+    vpp_worker_count = 4
+    max_sessions = 5000
 
     @unittest.skip('MW fix required')
     def test_users_dump(self):
@@ -2014,6 +2015,10 @@ class TestNAT44EDMW(TestNAT44ED):
 
     def test_dynamic(self):
         """ NAT44ED dynamic translation test """
+        pkt_count = 1500
+        tcp_port_offset = 20
+        udp_port_offset = 20
+        icmp_id_offset = 20
 
         self.nat_add_address(self.nat_addr)
         self.nat_add_inside_interface(self.pg0)
@@ -2025,14 +2030,31 @@ class TestNAT44EDMW(TestNAT44ED):
         ic1 = self.statistics['/nat44-ed/in2out/slowpath/icmp']
         dc1 = self.statistics['/nat44-ed/in2out/slowpath/drops']
 
-        pkts = self.create_stream_in(self.pg0, self.pg1)
-        # TODO: specify worker=idx, also stats have to
-        #       know from which worker to take capture
-        self.pg0.add_stream(pkts)
+        i2o_pkts = [[] for x in range(0, self.vpp_worker_count)]
+
+        for i in range(pkt_count):
+            p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+                 IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) /
+                 TCP(sport=tcp_port_offset + i, dport=20))
+            i2o_pkts[p[TCP].sport % self.vpp_worker_count].append(p)
+
+            p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+                 IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) /
+                 UDP(sport=udp_port_offset + i, dport=20))
+            i2o_pkts[p[UDP].sport % self.vpp_worker_count].append(p)
+
+            p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+                 IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) /
+                 ICMP(id=icmp_id_offset + i, type='echo-request'))
+            i2o_pkts[p[ICMP].id % self.vpp_worker_count].append(p)
+
+        for i in range(0, self.vpp_worker_count):
+            if len(i2o_pkts[i]) > 0:
+                self.pg0.add_stream(i2o_pkts[i], worker=i)
+
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture, ignore_port=True)
+        capture = self.pg1.get_capture(pkt_count * 3)
 
         if_idx = self.pg0.sw_if_index
         tc2 = self.statistics['/nat44-ed/in2out/slowpath/tcp']
@@ -2040,23 +2062,82 @@ class TestNAT44EDMW(TestNAT44ED):
         ic2 = self.statistics['/nat44-ed/in2out/slowpath/icmp']
         dc2 = self.statistics['/nat44-ed/in2out/slowpath/drops']
 
-        self.assertEqual(tc2[:, if_idx].sum() - tc1[:, if_idx].sum(), 2)
-        self.assertEqual(uc2[:, if_idx].sum() - uc1[:, if_idx].sum(), 1)
-        self.assertEqual(ic2[:, if_idx].sum() - ic1[:, if_idx].sum(), 1)
+        self.assertEqual(
+            tc2[:, if_idx].sum() - tc1[:, if_idx].sum(), pkt_count)
+        self.assertEqual(
+            uc2[:, if_idx].sum() - uc1[:, if_idx].sum(), pkt_count)
+        self.assertEqual(
+            ic2[:, if_idx].sum() - ic1[:, if_idx].sum(), pkt_count)
         self.assertEqual(dc2[:, if_idx].sum() - dc1[:, if_idx].sum(), 0)
 
+        self.logger.info(self.vapi.cli("show trace"))
+
         # out2in
         tc1 = self.statistics['/nat44-ed/out2in/fastpath/tcp']
         uc1 = self.statistics['/nat44-ed/out2in/fastpath/udp']
         ic1 = self.statistics['/nat44-ed/out2in/fastpath/icmp']
         dc1 = self.statistics['/nat44-ed/out2in/fastpath/drops']
 
-        pkts = self.create_stream_out(self.pg1)
-        self.pg1.add_stream(pkts)
+        recvd_tcp_ports = set()
+        recvd_udp_ports = set()
+        recvd_icmp_ids = set()
+
+        for p in capture:
+            if TCP in p:
+                recvd_tcp_ports.add(p[TCP].sport)
+            if UDP in p:
+                recvd_udp_ports.add(p[UDP].sport)
+            if ICMP in p:
+                recvd_icmp_ids.add(p[ICMP].id)
+
+        recvd_tcp_ports = list(recvd_tcp_ports)
+        recvd_udp_ports = list(recvd_udp_ports)
+        recvd_icmp_ids = list(recvd_icmp_ids)
+
+        o2i_pkts = [[] for x in range(0, self.vpp_worker_count)]
+        for i in range(pkt_count):
+            p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) /
+                 IP(src=self.pg1.remote_ip4, dst=self.nat_addr) /
+                 TCP(dport=choice(recvd_tcp_ports), sport=20))
+            o2i_pkts[p[TCP].dport % self.vpp_worker_count].append(p)
+
+            p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) /
+                 IP(src=self.pg1.remote_ip4, dst=self.nat_addr) /
+                 UDP(dport=choice(recvd_udp_ports), sport=20))
+            o2i_pkts[p[UDP].dport % self.vpp_worker_count].append(p)
+
+            p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) /
+                 IP(src=self.pg1.remote_ip4, dst=self.nat_addr) /
+                 ICMP(id=choice(recvd_icmp_ids), type='echo-reply'))
+            o2i_pkts[p[ICMP].id % self.vpp_worker_count].append(p)
+
+        for i in range(0, self.vpp_worker_count):
+            if len(o2i_pkts[i]) > 0:
+                self.pg1.add_stream(o2i_pkts[i], worker=i)
+
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        capture = self.pg0.get_capture(len(pkts))
-        self.verify_capture_in(capture, self.pg0)
+        capture = self.pg0.get_capture(pkt_count * 3)
+        for packet in capture:
+            try:
+                self.assert_packet_checksums_valid(packet)
+                self.assertEqual(packet[IP].dst, self.pg0.remote_ip4)
+                if packet.haslayer(TCP):
+                    self.assert_in_range(
+                        packet[TCP].dport, tcp_port_offset,
+                        tcp_port_offset + pkt_count, "dst TCP port")
+                elif packet.haslayer(UDP):
+                    self.assert_in_range(
+                        packet[UDP].dport, udp_port_offset,
+                        udp_port_offset + pkt_count, "dst UDP port")
+                else:
+                    self.assert_in_range(
+                        packet[ICMP].id, icmp_id_offset,
+                        icmp_id_offset + pkt_count, "ICMP id")
+            except:
+                self.logger.error(ppp("Unexpected or invalid packet "
+                                      "(inside network):", packet))
+                raise
 
         if_idx = self.pg1.sw_if_index
         tc2 = self.statistics['/nat44-ed/out2in/fastpath/tcp']
@@ -2064,13 +2145,17 @@ class TestNAT44EDMW(TestNAT44ED):
         ic2 = self.statistics['/nat44-ed/out2in/fastpath/icmp']
         dc2 = self.statistics['/nat44-ed/out2in/fastpath/drops']
 
-        self.assertEqual(tc2[:, if_idx].sum() - tc1[:, if_idx].sum(), 2)
-        self.assertEqual(uc2[:, if_idx].sum() - uc1[:, if_idx].sum(), 1)
-        self.assertEqual(ic2[:, if_idx].sum() - ic1[:, if_idx].sum(), 1)
+        self.assertEqual(
+            tc2[:, if_idx].sum() - tc1[:, if_idx].sum(), pkt_count)
+        self.assertEqual(
+            uc2[:, if_idx].sum() - uc1[:, if_idx].sum(), pkt_count)
+        self.assertEqual(
+            ic2[:, if_idx].sum() - ic1[:, if_idx].sum(), pkt_count)
         self.assertEqual(dc2[:, if_idx].sum() - dc1[:, if_idx].sum(), 0)
 
         sc = self.statistics['/nat44-ed/total-sessions']
-        self.assertEqual(sc[:, 0].sum(), 3)
+        self.assertEqual(sc[:, 0].sum(), len(recvd_tcp_ports) +
+                         len(recvd_udp_ports) + len(recvd_icmp_ids))
 
     def test_frag_in_order(self):
         """ NAT44ED translate fragments arriving in order """
@@ -2697,7 +2782,7 @@ class TestNAT44EDMW(TestNAT44ED):
                 server1_n += 1
             else:
                 server2_n += 1
-        self.assertGreater(server1_n, server2_n)
+        self.assertGreaterEqual(server1_n, server2_n)
 
         local = {
             'addr': server3.ip4,