SNAT: NAT packet with unknown L4 protocol if match 1:1 NAT 87/7187/2
authorMatus Fabian <matfabia@cisco.com>
Mon, 19 Jun 2017 11:28:04 +0000 (04:28 -0700)
committerOle Trøan <otroan@employees.org>
Mon, 19 Jun 2017 13:47:19 +0000 (13:47 +0000)
Change-Id: Ic81c6098d615fdb6a874e532921efd833fed872c
Signed-off-by: Matus Fabian <matfabia@cisco.com>
src/plugins/snat/in2out.c
src/plugins/snat/out2in.c
test/test_snat.py

index 685cdca..d396c79 100644 (file)
@@ -967,6 +967,55 @@ static inline u32 icmp_in2out_slow_path (snat_main_t *sm,
   return next0;
 }
 
   return next0;
 }
 
+static void
+snat_in2out_unknown_proto (snat_main_t *sm,
+                           vlib_buffer_t * b,
+                           ip4_header_t * ip,
+                           u32 rx_fib_index)
+{
+  clib_bihash_kv_8_8_t kv, value;
+  snat_static_mapping_t *m;
+  snat_session_key_t m_key;
+  u32 old_addr, new_addr;
+  ip_csum_t sum;
+
+  m_key.addr = ip->src_address;
+  m_key.port = 0;
+  m_key.protocol = 0;
+  m_key.fib_index = rx_fib_index;
+  kv.key = m_key.as_u64;
+  if (clib_bihash_search_8_8 (&sm->static_mapping_by_local, &kv, &value))
+    return;
+
+  m = pool_elt_at_index (sm->static_mappings, value.value);
+
+  old_addr = ip->src_address.as_u32;
+  new_addr = ip->src_address.as_u32 = m->external_addr.as_u32;
+  sum = ip->checksum;
+  sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, src_address);
+  ip->checksum = ip_csum_fold (sum);
+
+  /* Hairpinning */
+  m_key.addr = ip->dst_address;
+  m_key.fib_index = sm->outside_fib_index;
+  kv.key = m_key.as_u64;
+  if (clib_bihash_search_8_8 (&sm->static_mapping_by_external, &kv, &value))
+    {
+      vnet_buffer(b)->sw_if_index[VLIB_TX] = sm->outside_fib_index;
+      return;
+    }
+
+  m = pool_elt_at_index (sm->static_mappings, value.value);
+
+  old_addr = ip->dst_address.as_u32;
+  new_addr = ip->dst_address.as_u32 = m->local_addr.as_u32;
+  sum = ip->checksum;
+  sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address);
+  ip->checksum = ip_csum_fold (sum);
+
+  vnet_buffer(b)->sw_if_index[VLIB_TX] = vnet_buffer(b)->sw_if_index[VLIB_RX];
+}
+
 static inline uword
 snat_in2out_node_fn_inline (vlib_main_t * vm,
                             vlib_node_runtime_t * node,
 static inline uword
 snat_in2out_node_fn_inline (vlib_main_t * vm,
                             vlib_node_runtime_t * node,
@@ -1065,8 +1114,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm,
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto0 == ~0))
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto0 == ~0))
-                goto trace00;
-              
+                {
+                  snat_in2out_unknown_proto (sm, b0, ip0, rx_fib_index0);
+                  goto trace00;
+                }
+
               if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
                 {
                   next0 = icmp_in2out_slow_path 
               if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
                 {
                   next0 = icmp_in2out_slow_path 
@@ -1205,8 +1257,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm,
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto1 == ~0))
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto1 == ~0))
-                goto trace01;
-              
+                {
+                  snat_in2out_unknown_proto (sm, b1, ip1, rx_fib_index1);
+                  goto trace01;
+                }
+
               if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP))
                 {
                   next1 = icmp_in2out_slow_path 
               if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP))
                 {
                   next1 = icmp_in2out_slow_path 
@@ -1380,8 +1435,11 @@ snat_in2out_node_fn_inline (vlib_main_t * vm,
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto0 == ~0))
           if (is_slow_path)
             {
               if (PREDICT_FALSE (proto0 == ~0))
-                goto trace0;
-              
+                {
+                  snat_in2out_unknown_proto (sm, b0, ip0, rx_fib_index0);
+                  goto trace0;
+                }
+
               if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
                 {
                   next0 = icmp_in2out_slow_path 
               if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
                 {
                   next0 = icmp_in2out_slow_path 
index e8ddcf1..5c12b47 100644 (file)
@@ -611,6 +611,37 @@ static inline u32 icmp_out2in_slow_path (snat_main_t *sm,
   return next0;
 }
 
   return next0;
 }
 
+static void
+snat_out2in_unknown_proto (snat_main_t *sm,
+                           vlib_buffer_t * b,
+                           ip4_header_t * ip,
+                           u32 rx_fib_index)
+{
+  clib_bihash_kv_8_8_t kv, value;
+  snat_static_mapping_t *m;
+  snat_session_key_t m_key;
+  u32 old_addr, new_addr;
+  ip_csum_t sum;
+
+  m_key.addr = ip->dst_address;
+  m_key.port = 0;
+  m_key.protocol = 0;
+  m_key.fib_index = rx_fib_index;
+  kv.key = m_key.as_u64;
+  if (clib_bihash_search_8_8 (&sm->static_mapping_by_external, &kv, &value))
+    return;
+
+  m = pool_elt_at_index (sm->static_mappings, value.value);
+
+  old_addr = ip->dst_address.as_u32;
+  new_addr = ip->dst_address.as_u32 = m->local_addr.as_u32;
+  sum = ip->checksum;
+  sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address);
+  ip->checksum = ip_csum_fold (sum);
+
+  vnet_buffer(b)->sw_if_index[VLIB_TX] = m->fib_index;
+}
+
 static uword
 snat_out2in_node_fn (vlib_main_t * vm,
                  vlib_node_runtime_t * node,
 static uword
 snat_out2in_node_fn (vlib_main_t * vm,
                  vlib_node_runtime_t * node,
@@ -703,7 +734,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
           proto0 = ip_proto_to_snat_proto (ip0->protocol);
 
           if (PREDICT_FALSE (proto0 == ~0))
           proto0 = ip_proto_to_snat_proto (ip0->protocol);
 
           if (PREDICT_FALSE (proto0 == ~0))
+            {
+              snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0);
               goto trace0;
               goto trace0;
+            }
 
           if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
             {
 
           if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
             {
@@ -838,7 +872,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
           proto1 = ip_proto_to_snat_proto (ip1->protocol);
 
           if (PREDICT_FALSE (proto1 == ~0))
           proto1 = ip_proto_to_snat_proto (ip1->protocol);
 
           if (PREDICT_FALSE (proto1 == ~0))
+            {
+              snat_out2in_unknown_proto(sm, b1, ip1, rx_fib_index1);
               goto trace1;
               goto trace1;
+            }
 
           if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP))
             {
 
           if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP))
             {
@@ -997,7 +1034,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
           proto0 = ip_proto_to_snat_proto (ip0->protocol);
 
           if (PREDICT_FALSE (proto0 == ~0))
           proto0 = ip_proto_to_snat_proto (ip0->protocol);
 
           if (PREDICT_FALSE (proto0 == ~0))
+            {
+              snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0);
               goto trace00;
               goto trace00;
+            }
 
           if (PREDICT_FALSE(ip0->ttl == 1))
             {
 
           if (PREDICT_FALSE(ip0->ttl == 1))
             {
index ee689e6..e148fba 100644 (file)
@@ -9,7 +9,7 @@ from scapy.layers.inet import IP, TCP, UDP, ICMP
 from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest, ICMPv6EchoReply
 from scapy.layers.inet6 import ICMPv6DestUnreach, IPerror6
 from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest, ICMPv6EchoReply
 from scapy.layers.inet6 import ICMPv6DestUnreach, IPerror6
-from scapy.layers.l2 import Ether, ARP
+from scapy.layers.l2 import Ether, ARP, GRE
 from scapy.data import IP_PROTOS
 from scapy.packet import bind_layers
 from util import ppp
 from scapy.data import IP_PROTOS
 from scapy.packet import bind_layers
 from util import ppp
@@ -1835,6 +1835,54 @@ class TestSNAT(MethodHolder):
         capture = self.pg8.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
         capture = self.pg8.get_capture(len(pkts))
         self.verify_capture_out(capture)
 
+    def test_static_unknown_proto(self):
+        """ 1:1 NAT translate packet with unknown protocol """
+        nat_ip = "10.0.0.10"
+        self.snat_add_static_mapping(self.pg0.remote_ip4, nat_ip)
+        self.vapi.snat_interface_add_del_feature(self.pg0.sw_if_index)
+        self.vapi.snat_interface_add_del_feature(self.pg1.sw_if_index,
+                                                 is_inside=0)
+
+        # in2out
+        p = (Ether(dst=self.pg0.local_mac, src=self.pg0.remote_mac) /
+             IP(src=self.pg0.remote_ip4, dst=self.pg1.remote_ip4) /
+             GRE() /
+             IP(src=self.pg2.remote_ip4, dst=self.pg2.remote_ip4) /
+             TCP(sport=1234, dport=1234))
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        p = self.pg1.get_capture(1)
+        packet = p[0]
+        try:
+            self.assertEqual(packet[IP].src, nat_ip)
+            self.assertEqual(packet[IP].dst, self.pg1.remote_ip4)
+            self.assertTrue(packet.haslayer(GRE))
+            self.check_ip_checksum(packet)
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:", packet))
+            raise
+
+        # out2in
+        p = (Ether(dst=self.pg1.local_mac, src=self.pg1.remote_mac) /
+             IP(src=self.pg1.remote_ip4, dst=nat_ip) /
+             GRE() /
+             IP(src=self.pg2.remote_ip4, dst=self.pg2.remote_ip4) /
+             TCP(sport=1234, dport=1234))
+        self.pg1.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        p = self.pg0.get_capture(1)
+        packet = p[0]
+        try:
+            self.assertEqual(packet[IP].src, self.pg1.remote_ip4)
+            self.assertEqual(packet[IP].dst, self.pg0.remote_ip4)
+            self.assertTrue(packet.haslayer(GRE))
+            self.check_ip_checksum(packet)
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:", packet))
+            raise
+
     def tearDown(self):
         super(TestSNAT, self).tearDown()
         if not self.vpp_dead:
     def tearDown(self):
         super(TestSNAT, self).tearDown()
         if not self.vpp_dead: