SNAT: in2out translation as an output feature (VPP-903)
[vpp.git] / test / test_snat.py
index df81fb5..1db1546 100644 (file)
@@ -615,6 +615,12 @@ class TestSNAT(MethodHolder):
                                                      intf.is_inside,
                                                      is_add=0)
 
+        interfaces = self.vapi.snat_interface_output_feature_dump()
+        for intf in interfaces:
+            self.vapi.snat_interface_add_del_output_feature(intf.sw_if_index,
+                                                            intf.is_inside,
+                                                            is_add=0)
+
         static_mappings = self.vapi.snat_static_mapping_dump()
         for sm in static_mappings:
             self.vapi.snat_add_static_mapping(sm.local_ip_address,
@@ -2108,6 +2114,146 @@ class TestSNAT(MethodHolder):
             self.logger.error(ppp("Unexpected or invalid packet:", packet))
             raise
 
+    def test_output_feature(self):
+        """ S-NAT interface output feature (in2out postrouting) """
+        self.snat_add_address(self.snat_addr)
+        self.vapi.snat_interface_add_del_output_feature(self.pg0.sw_if_index)
+        self.vapi.snat_interface_add_del_output_feature(self.pg1.sw_if_index,
+                                                        is_inside=0)
+
+        # in2out
+        pkts = self.create_stream_in(self.pg0, self.pg1)
+        self.pg0.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg1.get_capture(len(pkts))
+        self.verify_capture_out(capture)
+
+        # out2in
+        pkts = self.create_stream_out(self.pg1)
+        self.pg1.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(len(pkts))
+        self.verify_capture_in(capture, self.pg0)
+
+    def test_output_feature_vrf_aware(self):
+        """ S-NAT interface output feature VRF aware (in2out postrouting) """
+        nat_ip_vrf10 = "10.0.0.10"
+        nat_ip_vrf20 = "10.0.0.20"
+
+        self.vapi.ip_add_del_route(dst_address=self.pg3.remote_ip4n,
+                                   dst_address_length=32,
+                                   next_hop_address=self.pg3.remote_ip4n,
+                                   next_hop_sw_if_index=self.pg3.sw_if_index,
+                                   table_id=10)
+        self.vapi.ip_add_del_route(dst_address=self.pg3.remote_ip4n,
+                                   dst_address_length=32,
+                                   next_hop_address=self.pg3.remote_ip4n,
+                                   next_hop_sw_if_index=self.pg3.sw_if_index,
+                                   table_id=20)
+
+        self.snat_add_address(nat_ip_vrf10, vrf_id=10)
+        self.snat_add_address(nat_ip_vrf20, vrf_id=20)
+        self.vapi.snat_interface_add_del_output_feature(self.pg4.sw_if_index)
+        self.vapi.snat_interface_add_del_output_feature(self.pg6.sw_if_index)
+        self.vapi.snat_interface_add_del_output_feature(self.pg3.sw_if_index,
+                                                        is_inside=0)
+
+        # in2out VRF 10
+        pkts = self.create_stream_in(self.pg4, self.pg3)
+        self.pg4.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg3.get_capture(len(pkts))
+        self.verify_capture_out(capture, nat_ip=nat_ip_vrf10)
+
+        # out2in VRF 10
+        pkts = self.create_stream_out(self.pg3, dst_ip=nat_ip_vrf10)
+        self.pg3.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg4.get_capture(len(pkts))
+        self.verify_capture_in(capture, self.pg4)
+
+        # in2out VRF 20
+        pkts = self.create_stream_in(self.pg6, self.pg3)
+        self.pg6.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg3.get_capture(len(pkts))
+        self.verify_capture_out(capture, nat_ip=nat_ip_vrf20)
+
+        # out2in VRF 20
+        pkts = self.create_stream_out(self.pg3, dst_ip=nat_ip_vrf20)
+        self.pg3.add_stream(pkts)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg6.get_capture(len(pkts))
+        self.verify_capture_in(capture, self.pg6)
+
+    def _test_output_feature_hairpinning(self):
+        """ S-NAT interface output feature hairpinning (in2out postrouting) """
+        host = self.pg0.remote_hosts[0]
+        server = self.pg0.remote_hosts[1]
+        host_in_port = 1234
+        host_out_port = 0
+        server_in_port = 5678
+        server_out_port = 8765
+
+        self.snat_add_address(self.snat_addr)
+        self.vapi.snat_interface_add_del_output_feature(self.pg0.sw_if_index)
+        self.vapi.snat_interface_add_del_output_feature(self.pg1.sw_if_index,
+                                                        is_inside=0)
+
+        # add static mapping for server
+        self.snat_add_static_mapping(server.ip4, self.snat_addr,
+                                     server_in_port, server_out_port,
+                                     proto=IP_PROTOS.tcp)
+
+        # send packet from host to server
+        p = (Ether(src=host.mac, dst=self.pg0.local_mac) /
+             IP(src=host.ip4, dst=self.snat_addr) /
+             TCP(sport=host_in_port, dport=server_out_port))
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        p = capture[0]
+        try:
+            ip = p[IP]
+            tcp = p[TCP]
+            self.assertEqual(ip.src, self.snat_addr)
+            self.assertEqual(ip.dst, server.ip4)
+            self.assertNotEqual(tcp.sport, host_in_port)
+            self.assertEqual(tcp.dport, server_in_port)
+            self.check_tcp_checksum(p)
+            host_out_port = tcp.sport
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:", p))
+            raise
+
+        # send reply from server to host
+        p = (Ether(src=server.mac, dst=self.pg0.local_mac) /
+             IP(src=server.ip4, dst=self.snat_addr) /
+             TCP(sport=server_in_port, dport=host_out_port))
+        self.pg0.add_stream(p)
+        self.pg_enable_capture(self.pg_interfaces)
+        self.pg_start()
+        capture = self.pg0.get_capture(1)
+        p = capture[0]
+        try:
+            ip = p[IP]
+            tcp = p[TCP]
+            self.assertEqual(ip.src, self.snat_addr)
+            self.assertEqual(ip.dst, host.ip4)
+            self.assertEqual(tcp.sport, server_out_port)
+            self.assertEqual(tcp.dport, host_in_port)
+            self.check_tcp_checksum(p)
+        except:
+            self.logger.error(ppp("Unexpected or invalid packet:"), p)
+            raise
+
     def tearDown(self):
         super(TestSNAT, self).tearDown()
         if not self.vpp_dead: