SNAT: fix 1:1 NAT without port hairpinning TCP checksum update 77/7077/4
authorMatus Fabian <matfabia@cisco.com>
Fri, 9 Jun 2017 09:33:30 +0000 (02:33 -0700)
committerOle Trøan <otroan@employees.org>
Tue, 13 Jun 2017 08:19:10 +0000 (08:19 +0000)
Change-Id: I5077fcf3671a6116b475f87e43120efc10ecaa08
Signed-off-by: Matus Fabian <matfabia@cisco.com>
src/plugins/snat/in2out.c
test/test_snat.py

index ddde702..685cdca 100644 (file)
@@ -854,6 +854,16 @@ snat_hairpinning (snat_main_t *sm,
               udp0->checksum = 0;
             }
         }
+      else
+        {
+          if (PREDICT_TRUE(proto0 == SNAT_PROTOCOL_TCP))
+            {
+              sum0 = tcp0->checksum;
+              sum0 = ip_csum_update (sum0, old_dst_addr0, new_dst_addr0,
+                                     ip4_header_t, dst_address);
+              tcp0->checksum = ip_csum_fold(sum0);
+            }
+        }
     }
 }
 
index c6344a9..c2f9280 100644 (file)
@@ -27,6 +27,17 @@ class MethodHolder(VppTestCase):
     def tearDown(self):
         super(MethodHolder, self).tearDown()
 
+    def check_tcp_checksum(self, pkt):
+        """
+        Check TCP checksum in IP packet
+
+        :param pkt: Packet to check TCP checksum
+        """
+        new = pkt.__class__(str(pkt))
+        del new['TCP'].chksum
+        new = new.__class__(str(new))
+        self.assertEqual(new['TCP'].chksum, pkt['TCP'].chksum)
+
     def create_stream_in(self, in_if, out_if, ttl=64):
         """
         Create packet stream for inside network
@@ -1111,6 +1122,7 @@ class TestSNAT(MethodHolder):
             self.assertEqual(ip.dst, server.ip4)
             self.assertNotEqual(tcp.sport, host_in_port)
             self.assertEqual(tcp.dport, server_in_port)
+            self.check_tcp_checksum(p)
             host_out_port = tcp.sport
         except:
             self.logger.error(ppp("Unexpected or invalid packet:", p))
@@ -1132,6 +1144,7 @@ class TestSNAT(MethodHolder):
             self.assertEqual(ip.dst, host.ip4)
             self.assertEqual(tcp.sport, server_out_port)
             self.assertEqual(tcp.dport, host_in_port)
+            self.check_tcp_checksum(p)
         except:
             self.logger.error(ppp("Unexpected or invalid packet:"), p)
             raise
@@ -1182,6 +1195,7 @@ class TestSNAT(MethodHolder):
                     self.assertNotEqual(packet[TCP].sport, self.tcp_port_in)
                     self.assertEqual(packet[TCP].dport, server_tcp_port)
                     self.tcp_port_out = packet[TCP].sport
+                    self.check_tcp_checksum(packet)
                 elif packet.haslayer(UDP):
                     self.assertNotEqual(packet[UDP].sport, self.udp_port_in)
                     self.assertEqual(packet[UDP].dport, server_udp_port)
@@ -1218,6 +1232,7 @@ class TestSNAT(MethodHolder):
                 if packet.haslayer(TCP):
                     self.assertEqual(packet[TCP].dport, self.tcp_port_in)
                     self.assertEqual(packet[TCP].sport, server_tcp_port)
+                    self.check_tcp_checksum(packet)
                 elif packet.haslayer(UDP):
                     self.assertEqual(packet[UDP].dport, self.udp_port_in)
                     self.assertEqual(packet[UDP].sport, server_udp_port)
@@ -1253,6 +1268,7 @@ class TestSNAT(MethodHolder):
                     self.assertEqual(packet[TCP].sport, self.tcp_port_in)
                     self.assertEqual(packet[TCP].dport, server_tcp_port)
                     self.tcp_port_out = packet[TCP].sport
+                    self.check_tcp_checksum(packet)
                 elif packet.haslayer(UDP):
                     self.assertEqual(packet[UDP].sport, self.udp_port_in)
                     self.assertEqual(packet[UDP].dport, server_udp_port)
@@ -1289,6 +1305,7 @@ class TestSNAT(MethodHolder):
                 if packet.haslayer(TCP):
                     self.assertEqual(packet[TCP].dport, self.tcp_port_in)
                     self.assertEqual(packet[TCP].sport, server_tcp_port)
+                    self.check_tcp_checksum(packet)
                 elif packet.haslayer(UDP):
                     self.assertEqual(packet[UDP].dport, self.udp_port_in)
                     self.assertEqual(packet[UDP].sport, server_udp_port)