nat: improve outside port selection & perf 64/27464/7
authorKlement Sekera <ksekera@cisco.com>
Mon, 8 Jun 2020 11:10:55 +0000 (11:10 +0000)
committerOle Trøan <otroan@employees.org>
Tue, 16 Jun 2020 09:05:15 +0000 (09:05 +0000)
Prefer using source port form packet as outside port if possible.

Type: improvement
Signed-off-by: Klement Sekera <ksekera@cisco.com>
Change-Id: I5c25f6a42386f38c9a6cc95bd7dda9f090b49817

src/plugins/nat/in2out_ed.c
src/plugins/nat/nat.c
src/plugins/nat/nat_inlines.h
src/plugins/nat/test/test_nat.py

index 49e3812..19b1288 100644 (file)
@@ -191,27 +191,18 @@ icmp_in2out_ed_slow_path (snat_main_t * sm, vlib_buffer_t * b0,
   return next0;
 }
 
-static_always_inline u16
-snat_random_port (u16 min, u16 max)
-{
-  snat_main_t *sm = &snat_main;
-  return min + random_u32 (&sm->random_seed) /
-    (random_u32_max () / (max - min + 1) + 1);
-}
-
 static int
 nat_ed_alloc_addr_and_port (snat_main_t * sm, u32 rx_fib_index,
                            u32 nat_proto, u32 thread_index,
                            ip4_address_t r_addr, u16 r_port, u8 proto,
                            u16 port_per_thread, u32 snat_thread_index,
                            snat_session_t * s,
-                           ip4_address_t * allocated_addr,
-                           u16 * allocated_port,
+                           ip4_address_t * outside_addr,
+                           u16 * outside_port,
                            clib_bihash_kv_16_8_t * out2in_ed_kv)
 {
   int i;
   snat_address_t *a, *ga = 0;
-  u32 portnum;
   snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index];
 
   const u16 port_thread_offset = (port_per_thread * snat_thread_index) + 1024;
@@ -225,29 +216,39 @@ nat_ed_alloc_addr_and_port (snat_main_t * sm, u32 rx_fib_index,
   case NAT_PROTOCOL_##N:                                                     \
     if (a->fib_index == rx_fib_index)                                        \
       {                                                                      \
-        u16 port = snat_random_port (1, port_per_thread);                    \
+        /* first try port suggested by caller */                             \
+        u16 port = clib_net_to_host_u16 (*outside_port);                   \
+        u16 port_offset = port - port_thread_offset;                         \
+        if (port <= port_thread_offset ||                                    \
+            port > port_thread_offset + port_per_thread)                     \
+          {                                                                  \
+            /* need to pick a different port, suggested port doesn't fit in  \
+             * this thread's port range */                                   \
+            port_offset = snat_random_port (1, port_per_thread);             \
+            port = port_thread_offset + port_offset;                         \
+          }                                                                  \
         u16 attempts = port_per_thread;                                      \
-        while (attempts > 0)                                                 \
+        do                                                                   \
           {                                                                  \
-            --attempts;                                                      \
-            portnum = port_thread_offset + port;                             \
-            init_ed_kv (out2in_ed_kv, a->addr,                               \
-                        clib_host_to_net_u16 (portnum), r_addr, r_port,      \
-                        s->out2in.fib_index, proto, thread_index,            \
-                        s - tsm->sessions);                                  \
+            init_ed_kv (out2in_ed_kv, a->addr, clib_host_to_net_u16 (port),  \
+                        r_addr, r_port, s->out2in.fib_index, proto,          \
+                        thread_index, s - tsm->sessions);                    \
             int rv = clib_bihash_add_del_16_8 (&sm->out2in_ed, out2in_ed_kv, \
                                                2 /* is_add */);              \
             if (0 == rv)                                                     \
               {                                                              \
-                ++a->busy_##n##_port_refcounts[portnum];                     \
+                ++a->busy_##n##_port_refcounts[port];                        \
                 a->busy_##n##_ports_per_thread[thread_index]++;              \
                 a->busy_##n##_ports++;                                       \
-                *allocated_addr = a->addr;                                   \
-                *allocated_port = clib_host_to_net_u16 (portnum);            \
+                *outside_addr = a->addr;                                   \
+                *outside_port = clib_host_to_net_u16 (port);               \
                 return 0;                                                    \
               }                                                              \
-            port = (port + 1) % port_per_thread;                             \
+            port_offset = (port_offset + 1) % port_per_thread;               \
+            port = port_thread_offset + port_offset;                         \
+            --attempts;                                                      \
           }                                                                  \
+        while (attempts > 0);                                                \
       }                                                                      \
     else if (a->fib_index == ~0)                                             \
       {                                                                      \
@@ -326,8 +327,8 @@ slow_path_ed (snat_main_t * sm,
   snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index];
   clib_bihash_kv_16_8_t out2in_ed_kv;
   nat44_is_idle_session_ctx_t ctx;
-  ip4_address_t allocated_addr;
-  u16 allocated_port;
+  ip4_address_t outside_addr;
+  u16 outside_port;
   u8 identity_nat;
 
   u32 nat_proto = ip_proto_to_nat_proto (proto);
@@ -393,20 +394,21 @@ slow_path_ed (snat_main_t * sm,
        }
 
       /* Try to create dynamic translation */
+      outside_port = l_port;   // suggest using local port to allocation function
       if (nat_ed_alloc_addr_and_port (sm, rx_fib_index, nat_proto,
                                      thread_index, r_addr, r_port, proto,
                                      sm->port_per_thread,
                                      tsm->snat_thread_index, s,
-                                     &allocated_addr,
-                                     &allocated_port, &out2in_ed_kv))
+                                     &outside_addr,
+                                     &outside_port, &out2in_ed_kv))
        {
          nat_elog_notice ("addresses exhausted");
          b->error = node->errors[NAT_IN2OUT_ED_ERROR_OUT_OF_PORTS];
          nat_ed_session_delete (sm, s, thread_index, 1);
          return NAT_NEXT_DROP;
        }
-      s->out2in.addr = allocated_addr;
-      s->out2in.port = allocated_port;
+      s->out2in.addr = outside_addr;
+      s->out2in.port = outside_port;
     }
   else
     {
index e4fed18..60ef22f 100644 (file)
@@ -2851,14 +2851,6 @@ end:
   return 0;
 }
 
-static_always_inline u16
-snat_random_port (u16 min, u16 max)
-{
-  snat_main_t *sm = &snat_main;
-  return min + random_u32 (&sm->random_seed) /
-    (random_u32_max () / (max - min + 1) + 1);
-}
-
 int
 snat_alloc_outside_address_and_port (snat_address_t * addresses,
                                     u32 fib_index,
index 40ac3d3..6741175 100644 (file)
@@ -813,6 +813,21 @@ increment_v4_address (ip4_address_t * a)
   a->as_u32 = clib_host_to_net_u32 (v);
 }
 
+static_always_inline u16
+snat_random_port (u16 min, u16 max)
+{
+  snat_main_t *sm = &snat_main;
+  u32 rwide;
+  u16 r;
+
+  rwide = random_u32 (&sm->random_seed);
+  r = rwide & 0xFFFF;
+  if (r >= min && r <= max)
+    return r;
+
+  return min + (rwide % (max - min + 1));
+}
+
 #endif /* __included_nat_inlines_h__ */
 
 /*
index 6dee818..e996373 100644 (file)
@@ -448,7 +448,7 @@ class MethodHolder(VppTestCase):
         return pkts
 
     def verify_capture_out(self, capture, nat_ip=None, same_port=False,
-                           dst_ip=None, is_ip6=False):
+                           dst_ip=None, is_ip6=False, ignore_port=False):
         """
         Verify captured packets on outside network
 
@@ -474,25 +474,32 @@ class MethodHolder(VppTestCase):
                 if dst_ip is not None:
                     self.assertEqual(packet[IP46].dst, dst_ip)
                 if packet.haslayer(TCP):
-                    if same_port:
-                        self.assertEqual(packet[TCP].sport, self.tcp_port_in)
-                    else:
-                        self.assertNotEqual(
-                            packet[TCP].sport, self.tcp_port_in)
+                    if not ignore_port:
+                        if same_port:
+                            self.assertEqual(
+                                packet[TCP].sport, self.tcp_port_in)
+                        else:
+                            self.assertNotEqual(
+                                packet[TCP].sport, self.tcp_port_in)
                     self.tcp_port_out = packet[TCP].sport
                     self.assert_packet_checksums_valid(packet)
                 elif packet.haslayer(UDP):
-                    if same_port:
-                        self.assertEqual(packet[UDP].sport, self.udp_port_in)
-                    else:
-                        self.assertNotEqual(
-                            packet[UDP].sport, self.udp_port_in)
+                    if not ignore_port:
+                        if same_port:
+                            self.assertEqual(
+                                packet[UDP].sport, self.udp_port_in)
+                        else:
+                            self.assertNotEqual(
+                                packet[UDP].sport, self.udp_port_in)
                     self.udp_port_out = packet[UDP].sport
                 else:
-                    if same_port:
-                        self.assertEqual(packet[ICMP46].id, self.icmp_id_in)
-                    else:
-                        self.assertNotEqual(packet[ICMP46].id, self.icmp_id_in)
+                    if not ignore_port:
+                        if same_port:
+                            self.assertEqual(
+                                packet[ICMP46].id, self.icmp_id_in)
+                        else:
+                            self.assertNotEqual(
+                                packet[ICMP46].id, self.icmp_id_in)
                     self.icmp_id_out = packet[ICMP46].id
                     self.assert_packet_checksums_valid(packet)
             except:
@@ -1105,7 +1112,8 @@ class MethodHolder(VppTestCase):
         else:
             raise Exception("Unsupported protocol")
 
-    def frag_in_order(self, proto=IP_PROTOS.tcp, dont_translate=False):
+    def frag_in_order(self, proto=IP_PROTOS.tcp, dont_translate=False,
+                      ignore_port=False):
         layer = self.proto2layer(proto)
 
         if proto == IP_PROTOS.tcp:
@@ -1132,14 +1140,16 @@ class MethodHolder(VppTestCase):
         if proto != IP_PROTOS.icmp:
             if not dont_translate:
                 self.assertEqual(p[layer].dport, 20)
-                self.assertNotEqual(p[layer].sport, self.port_in)
+                if not ignore_port:
+                    self.assertNotEqual(p[layer].sport, self.port_in)
             else:
                 self.assertEqual(p[layer].sport, self.port_in)
         else:
-            if not dont_translate:
-                self.assertNotEqual(p[layer].id, self.port_in)
-            else:
-                self.assertEqual(p[layer].id, self.port_in)
+            if not ignore_port:
+                if not dont_translate:
+                    self.assertNotEqual(p[layer].id, self.port_in)
+                else:
+                    self.assertEqual(p[layer].id, self.port_in)
         self.assertEqual(data, p[Raw].load)
 
         # out2in
@@ -1220,7 +1230,7 @@ class MethodHolder(VppTestCase):
                 self.assertEqual(p[layer].id, self.port_in)
             self.assertEqual(data, p[Raw].load)
 
-    def reass_hairpinning(self, proto=IP_PROTOS.tcp):
+    def reass_hairpinning(self, proto=IP_PROTOS.tcp, ignore_port=False):
         layer = self.proto2layer(proto)
 
         if proto == IP_PROTOS.tcp:
@@ -1243,13 +1253,16 @@ class MethodHolder(VppTestCase):
                                         self.nat_addr,
                                         self.server.ip4)
         if proto != IP_PROTOS.icmp:
-            self.assertNotEqual(p[layer].sport, self.host_in_port)
+            if not ignore_port:
+                self.assertNotEqual(p[layer].sport, self.host_in_port)
             self.assertEqual(p[layer].dport, self.server_in_port)
         else:
-            self.assertNotEqual(p[layer].id, self.host_in_port)
+            if not ignore_port:
+                self.assertNotEqual(p[layer].id, self.host_in_port)
         self.assertEqual(data, p[Raw].load)
 
-    def frag_out_of_order(self, proto=IP_PROTOS.tcp, dont_translate=False):
+    def frag_out_of_order(self, proto=IP_PROTOS.tcp, dont_translate=False,
+                          ignore_port=False):
         layer = self.proto2layer(proto)
 
         if proto == IP_PROTOS.tcp:
@@ -1278,14 +1291,16 @@ class MethodHolder(VppTestCase):
             if proto != IP_PROTOS.icmp:
                 if not dont_translate:
                     self.assertEqual(p[layer].dport, 20)
-                    self.assertNotEqual(p[layer].sport, self.port_in)
+                    if not ignore_port:
+                        self.assertNotEqual(p[layer].sport, self.port_in)
                 else:
                     self.assertEqual(p[layer].sport, self.port_in)
             else:
-                if not dont_translate:
-                    self.assertNotEqual(p[layer].id, self.port_in)
-                else:
-                    self.assertEqual(p[layer].id, self.port_in)
+                if not ignore_port:
+                    if not dont_translate:
+                        self.assertNotEqual(p[layer].id, self.port_in)
+                    else:
+                        self.assertEqual(p[layer].id, self.port_in)
             self.assertEqual(data, p[Raw].load)
 
             # out2in
@@ -4437,9 +4452,9 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.vapi.nat44_interface_add_del_feature(
             sw_if_index=self.pg1.sw_if_index,
             is_add=1)
-        self.frag_in_order(proto=IP_PROTOS.tcp)
-        self.frag_in_order(proto=IP_PROTOS.udp)
-        self.frag_in_order(proto=IP_PROTOS.icmp)
+        self.frag_in_order(proto=IP_PROTOS.tcp, ignore_port=True)
+        self.frag_in_order(proto=IP_PROTOS.udp, ignore_port=True)
+        self.frag_in_order(proto=IP_PROTOS.icmp, ignore_port=True)
 
     def test_frag_in_order_dont_translate(self):
         """ NAT44 don't translate fragments arriving in order """
@@ -4463,9 +4478,9 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.vapi.nat44_interface_add_del_feature(
             sw_if_index=self.pg1.sw_if_index,
             is_add=1)
-        self.frag_out_of_order(proto=IP_PROTOS.tcp)
-        self.frag_out_of_order(proto=IP_PROTOS.udp)
-        self.frag_out_of_order(proto=IP_PROTOS.icmp)
+        self.frag_out_of_order(proto=IP_PROTOS.tcp, ignore_port=True)
+        self.frag_out_of_order(proto=IP_PROTOS.udp, ignore_port=True)
+        self.frag_out_of_order(proto=IP_PROTOS.icmp, ignore_port=True)
 
     def test_frag_out_of_order_dont_translate(self):
         """ NAT44 don't translate fragments arriving out of order """
@@ -4593,9 +4608,9 @@ class TestNAT44EndpointDependent(MethodHolder):
                                       proto=IP_PROTOS.udp)
         self.nat44_add_static_mapping(self.server.ip4, self.nat_addr)
 
-        self.reass_hairpinning(proto=IP_PROTOS.tcp)
-        self.reass_hairpinning(proto=IP_PROTOS.udp)
-        self.reass_hairpinning(proto=IP_PROTOS.icmp)
+        self.reass_hairpinning(proto=IP_PROTOS.tcp, ignore_port=True)
+        self.reass_hairpinning(proto=IP_PROTOS.udp, ignore_port=True)
+        self.reass_hairpinning(proto=IP_PROTOS.icmp, ignore_port=True)
 
     def test_clear_sessions(self):
         """ NAT44 ED session clearing test """
@@ -4617,7 +4632,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         sessions = self.statistics.get_counter('/nat44/total-sessions')
         self.assertTrue(sessions[0][0] > 0)
@@ -4664,7 +4679,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         err = self.statistics.get_err_counter(
             '/err/nat44-ed-in2out-slowpath/TCP packets')
@@ -4752,7 +4767,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         err_new = self.statistics.get_err_counter(
             '/err/nat44-ed-in2out-slowpath/out of ports')
@@ -4806,7 +4821,7 @@ class TestNAT44EndpointDependent(MethodHolder):
             self.pg_enable_capture(self.pg_interfaces)
             self.pg_start()
             capture = self.pg8.get_capture(len(pkts))
-            self.verify_capture_out(capture)
+            self.verify_capture_out(capture, ignore_port=True)
 
             err = self.statistics.get_err_counter(
                 '/err/nat44-ed-in2out-slowpath/TCP packets')
@@ -5555,13 +5570,13 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
         pkts = self.create_stream_in(self.pg0, self.pg1)
         self.pg0.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         # from external network back to local network host
         pkts = self.create_stream_out(self.pg1)
@@ -5585,7 +5600,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         pkts = self.create_stream_out(self.pg1)
         self.pg1.add_stream(pkts)
@@ -6595,7 +6610,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
         capture = self.pg1.get_capture(len(pkts))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         # out2in
         pkts = self.create_stream_out(self.pg1)
@@ -6627,7 +6642,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         pkts_in2out = self.create_stream_in(self.pg0, self.pg1)
         capture = self.send_and_expect(self.pg0, pkts_in2out, self.pg1,
                                        len(pkts_in2out))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         # send out2in again, with sessions created it should work now
         pkts_out2in = self.create_stream_out(self.pg1)
@@ -6657,7 +6672,7 @@ class TestNAT44EndpointDependent(MethodHolder):
         # send in2out to generate ACL state (NAT state was created earlier)
         capture = self.send_and_expect(self.pg0, pkts_in2out, self.pg1,
                                        len(pkts_in2out))
-        self.verify_capture_out(capture)
+        self.verify_capture_out(capture, ignore_port=True)
 
         # send out2in again. ACL state exists so it should work now.
         # TCP packets with the syn flag set also need the ack flag
@@ -6762,7 +6777,6 @@ class TestNAT44EndpointDependent(MethodHolder):
             ip = p[IP]
             tcp = p[TCP]
             self.assertEqual(ip.src, self.nat_addr)
-            self.assertNotEqual(tcp.sport, 2345)
             self.assert_packet_checksums_valid(p)
             port = tcp.sport
         except: