ipsec: fast path outbound policy matching implementation for ipv6
[vpp.git] / src / vnet / ipsec / ipsec_spd_policy.c
index 8cdbe32..b198c20 100644 (file)
@@ -167,8 +167,10 @@ ipsec_add_del_policy (vlib_main_t * vm,
        * Try adding the policy into fast path SPD first. Only adding to
        * traditional SPD when failed.
        **/
-      if (im->fp_spd_is_enabled &&
-         (policy->type == IPSEC_SPD_POLICY_IP4_OUTBOUND))
+      if ((im->ipv4_fp_spd_is_enabled &&
+          policy->type == IPSEC_SPD_POLICY_IP4_OUTBOUND) ||
+         (im->ipv6_fp_spd_is_enabled &&
+          policy->type == IPSEC_SPD_POLICY_IP6_OUTBOUND))
        return ipsec_fp_add_del_policy ((void *) &spd->fp_spd, policy, 1,
                                        stat_index);
 
@@ -192,12 +194,11 @@ ipsec_add_del_policy (vlib_main_t * vm,
        * Try to delete the policy from the fast path SPD first. Delete from
        * traditional SPD when fp delete fails.
        **/
-      /**
-       * TODO: add ipv6 fast path support for outbound and
-       * ipv4/v6 inbound support for fast path
-       */
-      if (im->fp_spd_is_enabled &&
-         (policy->type == IPSEC_SPD_POLICY_IP4_OUTBOUND))
+
+      if ((im->ipv4_fp_spd_is_enabled &&
+          policy->type == IPSEC_SPD_POLICY_IP4_OUTBOUND) ||
+         (im->ipv6_fp_spd_is_enabled &&
+          policy->type == IPSEC_SPD_POLICY_IP6_OUTBOUND))
        return ipsec_fp_add_del_policy ((void *) &spd->fp_spd, policy, 0,
                                        stat_index);
 
@@ -247,27 +248,26 @@ find_mask_type_index (ipsec_main_t *im, ipsec_fp_5tuple_t *mask)
 }
 
 static_always_inline void
-fill_ip6_hash_policy_kv (ipsec_main_t *im, ipsec_fp_5tuple_t *match,
-                        ipsec_fp_5tuple_t *mask, clib_bihash_kv_40_8_t *kv)
+fill_ip6_hash_policy_kv (ipsec_fp_5tuple_t *match, ipsec_fp_5tuple_t *mask,
+                        clib_bihash_kv_40_8_t *kv)
 {
   ipsec_fp_lookup_value_t *kv_val = (ipsec_fp_lookup_value_t *) &kv->value;
-  u64 *pmatch = (u64 *) &match;
-  u64 *pmask = (u64 *) &mask;
+  u64 *pmatch = (u64 *) &match->ip6_laddr;
+  u64 *pmask = (u64 *) &mask->ip6_laddr;
   u64 *pkey = (u64 *) &kv->key;
 
   *pkey++ = *pmatch++ & *pmask++;
   *pkey++ = *pmatch++ & *pmask++;
   *pkey++ = *pmatch++ & *pmask++;
   *pkey++ = *pmatch++ & *pmask++;
-  *pkey++ = *pmatch++ & *pmask++;
-  *pkey++ = *pmatch++ & *pmask++;
+  *pkey = *pmatch & *pmask;
 
   kv_val->as_u64 = 0;
 }
 
 static_always_inline void
-fill_ip4_hash_policy_kv (ipsec_main_t *im, ipsec_fp_5tuple_t *match,
-                        ipsec_fp_5tuple_t *mask, clib_bihash_kv_16_8_t *kv)
+fill_ip4_hash_policy_kv (ipsec_fp_5tuple_t *match, ipsec_fp_5tuple_t *mask,
+                        clib_bihash_kv_16_8_t *kv)
 {
   ipsec_fp_lookup_value_t *kv_val = (ipsec_fp_lookup_value_t *) &kv->value;
   u64 *pmatch = (u64 *) &match->laddr;
@@ -301,6 +301,44 @@ get_highest_set_bit_u32 (u32 x)
   return x ^= x >> 1;
 }
 
+static_always_inline u64
+mask_out_highest_set_bit_u64 (u64 x)
+{
+  x |= x >> 32;
+  x |= x >> 16;
+  x |= x >> 8;
+  x |= x >> 4;
+  x |= x >> 2;
+  x |= x >> 1;
+  return ~x;
+}
+
+static_always_inline void
+ipsec_fp_get_policy_ports_mask (ipsec_policy_t *policy,
+                               ipsec_fp_5tuple_t *mask)
+{
+  if (PREDICT_TRUE ((policy->protocol == IP_PROTOCOL_TCP) ||
+                   (policy->protocol == IP_PROTOCOL_UDP) ||
+                   (policy->protocol == IP_PROTOCOL_SCTP)))
+    {
+      mask->lport = policy->lport.start ^ policy->lport.stop;
+      mask->rport = policy->rport.start ^ policy->rport.stop;
+
+      mask->lport = get_highest_set_bit_u16 (mask->lport);
+      mask->lport = ~(mask->lport - 1) & (~mask->lport);
+
+      mask->rport = get_highest_set_bit_u16 (mask->rport);
+      mask->rport = ~(mask->rport - 1) & (~mask->rport);
+    }
+  else
+    {
+      mask->lport = 0;
+      mask->rport = 0;
+    }
+
+  mask->protocol = (policy->protocol == IPSEC_POLICY_PROTOCOL_ANY) ? 0 : ~0;
+}
+
 static_always_inline void
 ipsec_fp_ip4_get_policy_mask (ipsec_policy_t *policy, ipsec_fp_5tuple_t *mask)
 {
@@ -312,7 +350,7 @@ ipsec_fp_ip4_get_policy_mask (ipsec_policy_t *policy, ipsec_fp_5tuple_t *mask)
   u32 *prmask = (u32 *) &mask->raddr;
 
   memset (mask, 0, sizeof (mask->l3_zero_pad));
-  memset (plmask, 1, sizeof (*mask) - sizeof (mask->l3_zero_pad));
+  memset (plmask, 0xff, sizeof (*mask) - sizeof (mask->l3_zero_pad));
   /* find bits where start != stop */
   *plmask = *pladdr_start ^ *pladdr_stop;
   *prmask = *praddr_start ^ *praddr_stop;
@@ -349,49 +387,52 @@ ipsec_fp_ip4_get_policy_mask (ipsec_policy_t *policy, ipsec_fp_5tuple_t *mask)
   mask->protocol = (policy->protocol == IPSEC_POLICY_PROTOCOL_ANY) ? 0 : ~0;
 }
 
-static_always_inline int
+static_always_inline void
 ipsec_fp_ip6_get_policy_mask (ipsec_policy_t *policy, ipsec_fp_5tuple_t *mask)
 {
   u64 *pladdr_start = (u64 *) &policy->laddr.start;
   u64 *pladdr_stop = (u64 *) &policy->laddr.stop;
-  u64 *plmask = (u64 *) &mask->laddr;
+  u64 *plmask = (u64 *) &mask->ip6_laddr;
   u64 *praddr_start = (u64 *) &policy->raddr.start;
   u64 *praddr_stop = (u64 *) &policy->raddr.stop;
   u64 *prmask = (u64 *) &mask->ip6_raddr;
-  u16 *plport_start = (u16 *) &policy->lport.start;
-  u16 *plport_stop = (u16 *) &policy->lport.stop;
-  u16 *prport_start = (u16 *) &policy->rport.start;
-  u16 *prport_stop = (u16 *) &policy->rport.stop;
-
-  /* test if x is not power of 2. The test form is  !((x & (x - 1)) == 0) */
-  if (((*pladdr_stop - *pladdr_start + 1) & (*pladdr_stop - *pladdr_start)) &&
-      (((*(pladdr_stop + 1) - *(pladdr_start + 1)) + 1) &
-       (*(pladdr_stop + 1) - *(pladdr_start + 1))))
-    return -1;
 
-  if (((*praddr_stop - *praddr_start + 1) & (*praddr_stop - *praddr_start)) &&
-      (((*(praddr_stop + 1) - *(praddr_start + 1)) + 1) &
-       (*(praddr_stop + 1) - *(praddr_start + 1))))
-    return -1;
+  memset (mask, 0xff, sizeof (ipsec_fp_5tuple_t));
 
-  if (((*plport_stop - *plport_start + 1) & (*plport_stop - *plport_start)))
-    return -1;
+  *plmask = (*pladdr_start++ ^ *pladdr_stop++);
 
-  if (((*prport_stop - *prport_start + 1) & (*prport_stop - *prport_start)))
-    return -1;
+  *prmask = (*praddr_start++ ^ *praddr_stop++);
 
-  memset (mask, 1, sizeof (ipsec_fp_5tuple_t));
+  /* Find most significant bit set (that is the first position
+   * start differs from stop). Mask out everything after that bit and
+   * the bit itself. Remember that policy stores start and stop in the net
+   * order.
+   */
+  *plmask = clib_host_to_net_u64 (
+    mask_out_highest_set_bit_u64 (clib_net_to_host_u64 (*plmask)));
 
-  *plmask++ = ~(*pladdr_start++ ^ *pladdr_stop++);
-  *plmask++ = ~(*pladdr_start++ ^ *pladdr_stop++);
+  if (*plmask++ & clib_host_to_net_u64 (0x1))
+    {
+      *plmask = (*pladdr_start ^ *pladdr_stop);
+      *plmask = clib_host_to_net_u64 (
+       mask_out_highest_set_bit_u64 (clib_net_to_host_u64 (*plmask)));
+    }
+  else
+    *plmask = 0;
 
-  *prmask++ = ~(*praddr_start++ ^ *praddr_stop++);
-  *prmask++ = ~(*praddr_start++ ^ *praddr_stop++);
+  *prmask = clib_host_to_net_u64 (
+    mask_out_highest_set_bit_u64 (clib_net_to_host_u64 (*prmask)));
 
-  mask->lport = ~(policy->lport.start ^ policy->lport.stop);
-  mask->rport = ~(policy->rport.start ^ policy->rport.stop);
-  mask->protocol = 0;
-  return 0;
+  if (*prmask++ & clib_host_to_net_u64 (0x1))
+    {
+      *prmask = (*pladdr_start ^ *pladdr_stop);
+      *prmask = clib_host_to_net_u64 (
+       mask_out_highest_set_bit_u64 (clib_net_to_host_u64 (*prmask)));
+    }
+  else
+    *prmask = 0;
+
+  ipsec_fp_get_policy_ports_mask (policy, mask);
 }
 
 static_always_inline void
@@ -454,7 +495,7 @@ ipsec_fp_ip4_add_policy (ipsec_main_t *im, ipsec_spd_fp_t *fp_spd,
   policy->fp_mask_type_id = mask_index;
   ipsec_fp_get_policy_5tuple (policy, &policy_5tuple);
 
-  fill_ip4_hash_policy_kv (im, &policy_5tuple, &mask, &kv);
+  fill_ip4_hash_policy_kv (&policy_5tuple, &mask, &kv);
 
   res = clib_bihash_search_inline_2_16_8 (&fp_spd->fp_ip4_lookup_hash, &kv,
                                          &result);
@@ -523,19 +564,7 @@ ipsec_fp_ip6_add_policy (ipsec_main_t *im, ipsec_spd_fp_t *fp_spd,
 
   ipsec_fp_5tuple_t mask, policy_5tuple;
   int res;
-  /* u64 hash; */
-
-  if (PREDICT_FALSE (!fp_spd->fp_ip6_lookup_hash_initialized))
-    {
-      clib_bihash_init_40_8 (
-       &fp_spd->fp_ip6_lookup_hash, "SPD_FP ip6 rules lookup bihash",
-       im->fp_lookup_hash_buckets,
-       im->fp_lookup_hash_buckets * IPSEC_FP_IP6_HASH_MEM_PER_BUCKET);
-      fp_spd->fp_ip6_lookup_hash_initialized = 1;
-    }
-
-  if (ipsec_fp_ip6_get_policy_mask (policy, &mask) != 0)
-    return -1;
+  ipsec_fp_ip6_get_policy_mask (policy, &mask);
 
   pool_get (im->policies, vp);
   policy_index = vp - im->policies;
@@ -555,10 +584,9 @@ ipsec_fp_ip6_add_policy (ipsec_main_t *im, ipsec_spd_fp_t *fp_spd,
     mte = im->fp_mask_types + mask_index;
 
   policy->fp_mask_type_id = mask_index;
-  ipsec_fp_ip6_get_policy_mask (policy, &mask);
   ipsec_fp_get_policy_5tuple (policy, &policy_5tuple);
 
-  fill_ip6_hash_policy_kv (im, &policy_5tuple, &mask, &kv);
+  fill_ip6_hash_policy_kv (&policy_5tuple, &mask, &kv);
 
   res = clib_bihash_search_inline_2_40_8 (&fp_spd->fp_ip6_lookup_hash, &kv,
                                          &result);
@@ -626,7 +654,7 @@ ipsec_fp_ip6_del_policy (ipsec_main_t *im, ipsec_spd_fp_t *fp_spd,
 
   ipsec_fp_ip6_get_policy_mask (policy, &mask);
   ipsec_fp_get_policy_5tuple (policy, &policy_5tuple);
-  fill_ip6_hash_policy_kv (im, &policy_5tuple, &mask, &kv);
+  fill_ip6_hash_policy_kv (&policy_5tuple, &mask, &kv);
   res = clib_bihash_search_inline_2_40_8 (&fp_spd->fp_ip6_lookup_hash, &kv,
                                          &result);
   if (res != 0)
@@ -706,7 +734,7 @@ ipsec_fp_ip4_del_policy (ipsec_main_t *im, ipsec_spd_fp_t *fp_spd,
 
   ipsec_fp_ip4_get_policy_mask (policy, &mask);
   ipsec_fp_get_policy_5tuple (policy, &policy_5tuple);
-  fill_ip4_hash_policy_kv (im, &policy_5tuple, &mask, &kv);
+  fill_ip4_hash_policy_kv (&policy_5tuple, &mask, &kv);
   res = clib_bihash_search_inline_2_16_8 (&fp_spd->fp_ip4_lookup_hash, &kv,
                                          &result);
   if (res != 0)