ipsec: add support for AES CTR
[vpp.git] / test / template_ipsec.py
index 5a700e8..918c993 100644 (file)
@@ -5,7 +5,7 @@ import struct
 from scapy.layers.inet import IP, ICMP, TCP, UDP
 from scapy.layers.ipsec import SecurityAssociation, ESP
 from scapy.layers.l2 import Ether
 from scapy.layers.inet import IP, ICMP, TCP, UDP
 from scapy.layers.ipsec import SecurityAssociation, ESP
 from scapy.layers.l2 import Ether
-from scapy.packet import Raw
+from scapy.packet import raw, Raw
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest, IPv6ExtHdrHopByHop, \
     IPv6ExtHdrFragment, IPv6ExtHdrDestOpt
 
 from scapy.layers.inet6 import IPv6, ICMPv6EchoRequest, IPv6ExtHdrHopByHop, \
     IPv6ExtHdrFragment, IPv6ExtHdrDestOpt
 
@@ -15,7 +15,7 @@ from util import ppp, reassemble4, fragment_rfc791, fragment_rfc8200
 from vpp_papi import VppEnum
 
 
 from vpp_papi import VppEnum
 
 
-class IPsecIPv4Params(object):
+class IPsecIPv4Params:
 
     addr_type = socket.AF_INET
     addr_any = "0.0.0.0"
 
     addr_type = socket.AF_INET
     addr_any = "0.0.0.0"
@@ -28,14 +28,14 @@ class IPsecIPv4Params(object):
         self.remote_tun_if_host6 = '1111::1'
 
         self.scapy_tun_sa_id = 100
         self.remote_tun_if_host6 = '1111::1'
 
         self.scapy_tun_sa_id = 100
-        self.scapy_tun_spi = 1001
+        self.scapy_tun_spi = 1000
         self.vpp_tun_sa_id = 200
         self.vpp_tun_sa_id = 200
-        self.vpp_tun_spi = 1000
+        self.vpp_tun_spi = 2000
 
         self.scapy_tra_sa_id = 300
 
         self.scapy_tra_sa_id = 300
-        self.scapy_tra_spi = 2001
+        self.scapy_tra_spi = 3000
         self.vpp_tra_sa_id = 400
         self.vpp_tra_sa_id = 400
-        self.vpp_tra_spi = 2000
+        self.vpp_tra_spi = 4000
 
         self.auth_algo_vpp_id = (VppEnum.vl_api_ipsec_integ_alg_t.
                                  IPSEC_API_INTEG_ALG_SHA1_96)
 
         self.auth_algo_vpp_id = (VppEnum.vl_api_ipsec_integ_alg_t.
                                  IPSEC_API_INTEG_ALG_SHA1_96)
@@ -49,9 +49,12 @@ class IPsecIPv4Params(object):
         self.salt = 0
         self.flags = 0
         self.nat_header = None
         self.salt = 0
         self.flags = 0
         self.nat_header = None
+        self.tun_flags = (VppEnum.vl_api_tunnel_encap_decap_flags_t.
+                          TUNNEL_API_ENCAP_DECAP_FLAG_NONE)
+        self.dscp = 0
 
 
 
 
-class IPsecIPv6Params(object):
+class IPsecIPv6Params:
 
     addr_type = socket.AF_INET6
     addr_any = "0::0"
 
     addr_type = socket.AF_INET6
     addr_any = "0::0"
@@ -85,10 +88,13 @@ class IPsecIPv6Params(object):
         self.salt = 0
         self.flags = 0
         self.nat_header = None
         self.salt = 0
         self.flags = 0
         self.nat_header = None
+        self.tun_flags = (VppEnum.vl_api_tunnel_encap_decap_flags_t.
+                          TUNNEL_API_ENCAP_DECAP_FLAG_NONE)
+        self.dscp = 0
 
 
 def mk_scapy_crypt_key(p):
 
 
 def mk_scapy_crypt_key(p):
-    if p.crypt_algo == "AES-GCM":
+    if p.crypt_algo in ("AES-GCM", "AES-CTR"):
         return p.crypt_key + struct.pack("!I", p.salt)
     else:
         return p.crypt_key
         return p.crypt_key + struct.pack("!I", p.salt)
     else:
         return p.crypt_key
@@ -181,8 +187,10 @@ class TemplateIpsec(VppTestCase):
         super(TemplateIpsec, cls).tearDownClass()
 
     def setup_params(self):
         super(TemplateIpsec, cls).tearDownClass()
 
     def setup_params(self):
-        self.ipv4_params = IPsecIPv4Params()
-        self.ipv6_params = IPsecIPv6Params()
+        if not hasattr(self, 'ipv4_params'):
+            self.ipv4_params = IPsecIPv4Params()
+        if not hasattr(self, 'ipv6_params'):
+            self.ipv6_params = IPsecIPv6Params()
         self.params = {self.ipv4_params.addr_type: self.ipv4_params,
                        self.ipv6_params.addr_type: self.ipv6_params}
 
         self.params = {self.ipv4_params.addr_type: self.ipv4_params,
                        self.ipv6_params.addr_type: self.ipv6_params}
 
@@ -790,7 +798,7 @@ class IpsecTun4(object):
                              "incorrect SA in counts: expected %d != %d" %
                              (count, pkts))
             pkts = p.tun_sa_out.get_stats(worker)['packets']
                              "incorrect SA in counts: expected %d != %d" %
                              (count, pkts))
             pkts = p.tun_sa_out.get_stats(worker)['packets']
-            self.assertEqual(pkts, count,
+            self.assertEqual(pkts, n_frags,
                              "incorrect SA out counts: expected %d != %d" %
                              (count, pkts))
 
                              "incorrect SA out counts: expected %d != %d" %
                              (count, pkts))
 
@@ -803,6 +811,15 @@ class IpsecTun4(object):
             self.assert_equal(rx[IP].dst, self.pg1.remote_ip4)
             self.assert_packet_checksums_valid(rx)
 
             self.assert_equal(rx[IP].dst, self.pg1.remote_ip4)
             self.assert_packet_checksums_valid(rx)
 
+    def verify_esp_padding(self, sa, esp_payload, decrypt_pkt):
+        align = sa.crypt_algo.block_size
+        if align < 4:
+            align = 4
+        exp_len = (len(decrypt_pkt) + 2 + (align - 1)) & ~(align - 1)
+        exp_len += sa.crypt_algo.iv_size
+        exp_len += sa.crypt_algo.icv_size or sa.auth_algo.icv_size
+        self.assertEqual(exp_len, len(esp_payload))
+
     def verify_encrypted(self, p, sa, rxs):
         decrypt_pkts = []
         for rx in rxs:
     def verify_encrypted(self, p, sa, rxs):
         decrypt_pkts = []
         for rx in rxs:
@@ -811,9 +828,12 @@ class IpsecTun4(object):
             self.assert_packet_checksums_valid(rx)
             self.assertEqual(len(rx) - len(Ether()), rx[IP].len)
             try:
             self.assert_packet_checksums_valid(rx)
             self.assertEqual(len(rx) - len(Ether()), rx[IP].len)
             try:
-                decrypt_pkt = p.vpp_tun_sa.decrypt(rx[IP])
+                rx_ip = rx[IP]
+                decrypt_pkt = p.vpp_tun_sa.decrypt(rx_ip)
                 if not decrypt_pkt.haslayer(IP):
                     decrypt_pkt = IP(decrypt_pkt[Raw].load)
                 if not decrypt_pkt.haslayer(IP):
                     decrypt_pkt = IP(decrypt_pkt[Raw].load)
+                if rx_ip.proto == socket.IPPROTO_ESP:
+                    self.verify_esp_padding(sa, rx_ip[ESP].data, decrypt_pkt)
                 decrypt_pkts.append(decrypt_pkt)
                 self.assert_equal(decrypt_pkt.src, self.pg1.remote_ip4)
                 self.assert_equal(decrypt_pkt.dst, p.remote_tun_if_host)
                 decrypt_pkts.append(decrypt_pkt)
                 self.assert_equal(decrypt_pkt.src, self.pg1.remote_ip4)
                 self.assert_equal(decrypt_pkt.dst, p.remote_tun_if_host)
@@ -914,6 +934,7 @@ class IpsecTun4(object):
 
     def verify_tun_64(self, p, count=1):
         self.vapi.cli("clear errors")
 
     def verify_tun_64(self, p, count=1):
         self.vapi.cli("clear errors")
+        self.vapi.cli("clear ipsec sa")
         try:
             send_pkts = self.gen_encrypt_pkts6(p, p.scapy_tun_sa, self.tun_if,
                                                src=p.remote_tun_if_host6,
         try:
             send_pkts = self.gen_encrypt_pkts6(p, p.scapy_tun_sa, self.tun_if,
                                                src=p.remote_tun_if_host6,
@@ -1104,6 +1125,7 @@ class IpsecTun6(object):
     def verify_tun_46(self, p, count=1):
         """ ipsec 4o6 tunnel basic test """
         self.vapi.cli("clear errors")
     def verify_tun_46(self, p, count=1):
         """ ipsec 4o6 tunnel basic test """
         self.vapi.cli("clear errors")
+        self.vapi.cli("clear ipsec sa")
         try:
             send_pkts = self.gen_encrypt_pkts(p, p.scapy_tun_sa, self.tun_if,
                                               src=p.remote_tun_if_host4,
         try:
             send_pkts = self.gen_encrypt_pkts(p, p.scapy_tun_sa, self.tun_if,
                                               src=p.remote_tun_if_host4,