ipsec: add ipv6 support for ipsec tunnel interface
[vpp.git] / test / template_ipsec.py
index 7888a67..483699c 100644 (file)
@@ -1,7 +1,7 @@
 import unittest
 import socket
 
 import unittest
 import socket
 
-from scapy.layers.inet import IP, ICMP, TCP
+from scapy.layers.inet import IP, ICMP, TCP, UDP
 from scapy.layers.ipsec import SecurityAssociation
 from scapy.layers.l2 import Ether, Raw
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest
 from scapy.layers.ipsec import SecurityAssociation
 from scapy.layers.l2 import Ether, Raw
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest
@@ -41,6 +41,8 @@ class IPsecIPv4Params(object):
                                   IPSEC_API_CRYPTO_ALG_AES_CBC_128)
         self.crypt_algo = 'AES-CBC'  # scapy name
         self.crypt_key = 'JPjyOWBeVEQiMe7h'
                                   IPSEC_API_CRYPTO_ALG_AES_CBC_128)
         self.crypt_algo = 'AES-CBC'  # scapy name
         self.crypt_key = 'JPjyOWBeVEQiMe7h'
+        self.flags = 0
+        self.nat_header = None
 
 
 class IPsecIPv6Params(object):
 
 
 class IPsecIPv6Params(object):
@@ -73,6 +75,8 @@ class IPsecIPv6Params(object):
                                   IPSEC_API_CRYPTO_ALG_AES_CBC_256)
         self.crypt_algo = 'AES-CBC'  # scapy name
         self.crypt_key = 'JPjyOWBeVEQiMe7hJPjyOWBeVEQiMe7h'
                                   IPSEC_API_CRYPTO_ALG_AES_CBC_256)
         self.crypt_algo = 'AES-CBC'  # scapy name
         self.crypt_key = 'JPjyOWBeVEQiMe7hJPjyOWBeVEQiMe7h'
+        self.flags = 0
+        self.nat_header = None
 
 
 class TemplateIpsec(VppTestCase):
 
 
 class TemplateIpsec(VppTestCase):
@@ -168,29 +172,35 @@ class TemplateIpsec(VppTestCase):
             auth_algo=params.auth_algo, auth_key=params.auth_key,
             tunnel_header=ip_class_by_addr_type[params.addr_type](
                 src=self.tun_if.remote_addr[params.addr_type],
             auth_algo=params.auth_algo, auth_key=params.auth_key,
             tunnel_header=ip_class_by_addr_type[params.addr_type](
                 src=self.tun_if.remote_addr[params.addr_type],
-                dst=self.tun_if.local_addr[params.addr_type]))
+                dst=self.tun_if.local_addr[params.addr_type]),
+            nat_t_header=params.nat_header)
         vpp_tun_sa = SecurityAssociation(
             self.encryption_type, spi=params.scapy_tun_spi,
             crypt_algo=params.crypt_algo, crypt_key=params.crypt_key,
             auth_algo=params.auth_algo, auth_key=params.auth_key,
             tunnel_header=ip_class_by_addr_type[params.addr_type](
                 dst=self.tun_if.remote_addr[params.addr_type],
         vpp_tun_sa = SecurityAssociation(
             self.encryption_type, spi=params.scapy_tun_spi,
             crypt_algo=params.crypt_algo, crypt_key=params.crypt_key,
             auth_algo=params.auth_algo, auth_key=params.auth_key,
             tunnel_header=ip_class_by_addr_type[params.addr_type](
                 dst=self.tun_if.remote_addr[params.addr_type],
-                src=self.tun_if.local_addr[params.addr_type]))
+                src=self.tun_if.local_addr[params.addr_type]),
+            nat_t_header=params.nat_header)
         return vpp_tun_sa, scapy_tun_sa
 
     def configure_sa_tra(self, params):
         return vpp_tun_sa, scapy_tun_sa
 
     def configure_sa_tra(self, params):
-        params.scapy_tra_sa = SecurityAssociation(self.encryption_type,
-                                                  spi=params.vpp_tra_spi,
-                                                  crypt_algo=params.crypt_algo,
-                                                  crypt_key=params.crypt_key,
-                                                  auth_algo=params.auth_algo,
-                                                  auth_key=params.auth_key)
-        params.vpp_tra_sa = SecurityAssociation(self.encryption_type,
-                                                spi=params.scapy_tra_spi,
-                                                crypt_algo=params.crypt_algo,
-                                                crypt_key=params.crypt_key,
-                                                auth_algo=params.auth_algo,
-                                                auth_key=params.auth_key)
+        params.scapy_tra_sa = SecurityAssociation(
+            self.encryption_type,
+            spi=params.vpp_tra_spi,
+            crypt_algo=params.crypt_algo,
+            crypt_key=params.crypt_key,
+            auth_algo=params.auth_algo,
+            auth_key=params.auth_key,
+            nat_t_header=params.nat_header)
+        params.vpp_tra_sa = SecurityAssociation(
+            self.encryption_type,
+            spi=params.scapy_tra_spi,
+            crypt_algo=params.crypt_algo,
+            crypt_key=params.crypt_key,
+            auth_algo=params.auth_algo,
+            auth_key=params.auth_key,
+            nat_t_header=params.nat_header)
 
 
 class IpsecTcpTests(object):
 
 
 class IpsecTcpTests(object):
@@ -210,7 +220,7 @@ class IpsecTcpTests(object):
         self.assert_packet_checksums_valid(decrypted)
 
 
         self.assert_packet_checksums_valid(decrypted)
 
 
-class IpsecTraTests(object):
+class IpsecTra4Tests(object):
     def test_tra_anti_replay(self, count=1):
         """ ipsec v4 transport anti-reply test """
         p = self.params[socket.AF_INET]
     def test_tra_anti_replay(self, count=1):
         """ ipsec v4 transport anti-reply test """
         p = self.params[socket.AF_INET]
@@ -304,6 +314,15 @@ class IpsecTraTests(object):
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
+        pkts = p.tra_sa_in.get_stats()['packets']
+        self.assertEqual(pkts, count,
+                         "incorrect SA in counts: expected %d != %d" %
+                         (count, pkts))
+        pkts = p.tra_sa_out.get_stats()['packets']
+        self.assertEqual(pkts, count,
+                         "incorrect SA out counts: expected %d != %d" %
+                         (count, pkts))
+
         self.assert_packet_counter_equal(self.tra4_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tra4_decrypt_node_name, count)
 
         self.assert_packet_counter_equal(self.tra4_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tra4_decrypt_node_name, count)
 
@@ -311,6 +330,8 @@ class IpsecTraTests(object):
         """ ipsec v4 transport burst test """
         self.test_tra_basic(count=257)
 
         """ ipsec v4 transport burst test """
         self.test_tra_basic(count=257)
 
+
+class IpsecTra6Tests(object):
     def test_tra_basic6(self, count=1):
         """ ipsec v6 transport basic test """
         self.vapi.cli("clear errors")
     def test_tra_basic6(self, count=1):
         """ ipsec v6 transport basic test """
         self.vapi.cli("clear errors")
@@ -333,6 +354,14 @@ class IpsecTraTests(object):
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
+        pkts = p.tra_sa_in.get_stats()['packets']
+        self.assertEqual(pkts, count,
+                         "incorrect SA in counts: expected %d != %d" %
+                         (count, pkts))
+        pkts = p.tra_sa_out.get_stats()['packets']
+        self.assertEqual(pkts, count,
+                         "incorrect SA out counts: expected %d != %d" %
+                         (count, pkts))
         self.assert_packet_counter_equal(self.tra6_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tra6_decrypt_node_name, count)
 
         self.assert_packet_counter_equal(self.tra6_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tra6_decrypt_node_name, count)
 
@@ -341,6 +370,10 @@ class IpsecTraTests(object):
         self.test_tra_basic6(count=257)
 
 
         self.test_tra_basic6(count=257)
 
 
+class IpsecTra46Tests(IpsecTra4Tests, IpsecTra6Tests):
+    pass
+
+
 class IpsecTun4Tests(object):
     def test_tun_basic44(self, count=1):
         """ ipsec 4o4 tunnel basic test """
 class IpsecTun4Tests(object):
     def test_tun_basic44(self, count=1):
         """ ipsec 4o4 tunnel basic test """
@@ -380,6 +413,22 @@ class IpsecTun4Tests(object):
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
+        if (hasattr(p, "spd_policy_in_any")):
+            pkts = p.spd_policy_in_any.get_stats()['packets']
+            self.assertEqual(pkts, count,
+                             "incorrect SPD any policy: expected %d != %d" %
+                             (count, pkts))
+
+        if (hasattr(p, "tun_sa_in")):
+            pkts = p.tun_sa_in.get_stats()['packets']
+            self.assertEqual(pkts, count,
+                             "incorrect SA in counts: expected %d != %d" %
+                             (count, pkts))
+            pkts = p.tun_sa_out.get_stats()['packets']
+            self.assertEqual(pkts, count,
+                             "incorrect SA out counts: expected %d != %d" %
+                             (count, pkts))
+
         self.assert_packet_counter_equal(self.tun4_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tun4_decrypt_node_name, count)
 
         self.assert_packet_counter_equal(self.tun4_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tun4_decrypt_node_name, count)
 
@@ -428,6 +477,15 @@ class IpsecTun6Tests(object):
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show ipsec"))
 
+        if (hasattr(p, "tun_sa_in")):
+            pkts = p.tun_sa_in.get_stats()['packets']
+            self.assertEqual(pkts, count,
+                             "incorrect SA in counts: expected %d != %d" %
+                             (count, pkts))
+            pkts = p.tun_sa_out.get_stats()['packets']
+            self.assertEqual(pkts, count,
+                             "incorrect SA out counts: expected %d != %d" %
+                             (count, pkts))
         self.assert_packet_counter_equal(self.tun6_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tun6_decrypt_node_name, count)
 
         self.assert_packet_counter_equal(self.tun6_encrypt_node_name, count)
         self.assert_packet_counter_equal(self.tun6_decrypt_node_name, count)
 
@@ -436,7 +494,7 @@ class IpsecTun6Tests(object):
         self.test_tun_basic66(count=257)
 
 
         self.test_tun_basic66(count=257)
 
 
-class IpsecTunTests(IpsecTun4Tests, IpsecTun6Tests):
+class IpsecTun46Tests(IpsecTun4Tests, IpsecTun6Tests):
     pass
 
 
     pass