vppinfra: AVX512 in clib_count_equal_*
[vpp.git] / src / vppinfra / vector / count_equal.h
index 98770cf..a2aeecd 100644 (file)
@@ -67,28 +67,62 @@ clib_count_equal_u32 (u32 *data, uword max_count)
   count = 0;
   first = data[0];
 
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+  u32x16 splat = u32x16_splat (first);
+  while (count + 15 < max_count)
+    {
+      u32 bmp;
+      bmp = u32x16_is_equal_mask (u32x16_load_unaligned (data), splat);
+      if (bmp != pow2_mask (16))
+       return count + count_trailing_zeros (~bmp);
+
+      data += 16;
+      count += 16;
+    }
+  if (count == max_count)
+    return count;
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u32 bmp =
+       u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#elif defined(CLIB_HAVE_VEC256)
   u32x8 splat = u32x8_splat (first);
   while (count + 7 < max_count)
     {
-      u64 bmp;
+      u32 bmp;
+#ifdef __AVX512F__
+      bmp = u32x8_is_equal_mask (u32x8_load_unaligned (data), splat);
+      if (bmp != pow2_mask (8))
+       return count + count_trailing_zeros (~bmp);
+#else
       bmp = u8x32_msb_mask ((u8x32) (u32x8_load_unaligned (data) == splat));
       if (bmp != 0xffffffff)
-       {
-         count += count_trailing_zeros (~bmp) / 4;
-         return count;
-       }
+       return count + count_trailing_zeros (~bmp) / 4;
+#endif
 
       data += 8;
       count += 8;
     }
+  if (count == max_count)
+    return count;
+#if defined(CxLIB_HAVE_VEC256_MASK_LOAD_STORE)
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u32 bmp = u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
   u32x4 splat = u32x4_splat (first);
   while (count + 3 < max_count)
     {
       u64 bmp;
       bmp = u8x16_msb_mask ((u8x16) (u32x4_load_unaligned (data) == splat));
-      if (bmp != 0xffff)
+      if (bmp != pow2_mask (4 * 4))
        {
          count += count_trailing_zeros (~bmp) / 4;
          return count;
@@ -191,18 +225,50 @@ clib_count_equal_u8 (u8 *data, uword max_count)
   count = 0;
   first = data[0];
 
-#if defined(CLIB_HAVE_VEC256)
+#if defined(CLIB_HAVE_VEC512)
+  u8x64 splat = u8x64_splat (first);
+  while (count + 63 < max_count)
+    {
+      u64 bmp;
+      bmp = u8x64_is_equal_mask (u8x64_load_unaligned (data), splat);
+      if (bmp != -1)
+       return count + count_trailing_zeros (~bmp);
+
+      data += 64;
+      count += 64;
+    }
+  if (count == max_count)
+    return count;
+#if defined(CLIB_HAVE_VEC512_MASK_LOAD_STORE)
+  else
+    {
+      u64 mask = pow2_mask (max_count - count);
+      u64 bmp = u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
+#elif defined(CLIB_HAVE_VEC256)
   u8x32 splat = u8x32_splat (first);
   while (count + 31 < max_count)
     {
       u64 bmp;
       bmp = u8x32_msb_mask ((u8x32) (u8x32_load_unaligned (data) == splat));
       if (bmp != 0xffffffff)
-       return max_count;
+       return count + count_trailing_zeros (~bmp);
 
       data += 32;
       count += 32;
     }
+  if (count == max_count)
+    return count;
+#if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
+  else
+    {
+      u32 mask = pow2_mask (max_count - count);
+      u64 bmp = u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat);
+      return count + count_trailing_zeros (~bmp);
+    }
+#endif
 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
   u8x16 splat = u8x16_splat (first);
   while (count + 15 < max_count)
@@ -210,10 +276,7 @@ clib_count_equal_u8 (u8 *data, uword max_count)
       u64 bmp;
       bmp = u8x16_msb_mask ((u8x16) (u8x16_load_unaligned (data) == splat));
       if (bmp != 0xffff)
-       {
-         count += count_trailing_zeros (~bmp);
-         return count;
-       }
+       return count + count_trailing_zeros (~bmp);
 
       data += 16;
       count += 16;
@@ -235,4 +298,5 @@ clib_count_equal_u8 (u8 *data, uword max_count)
     }
   return count;
 }
+
 #endif