vppinfra: fix masks in AVX512 clib_count_equal_*
[vpp.git] / src / vppinfra / vector / count_equal.h
index a2aeecd..ca2fbb7 100644 (file)
@@ -85,7 +85,8 @@ clib_count_equal_u32 (u32 *data, uword max_count)
     {
       u32 mask = pow2_mask (max_count - count);
       u32 bmp =
-       u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat);
+       u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat) &
+       mask;
       return count + count_trailing_zeros (~bmp);
     }
 #elif defined(CLIB_HAVE_VEC256)
@@ -108,11 +109,12 @@ clib_count_equal_u32 (u32 *data, uword max_count)
     }
   if (count == max_count)
     return count;
-#if defined(CxLIB_HAVE_VEC256_MASK_LOAD_STORE)
+#if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
   else
     {
       u32 mask = pow2_mask (max_count - count);
-      u32 bmp = u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat);
+      u32 bmp =
+       u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat) & mask;
       return count + count_trailing_zeros (~bmp);
     }
 #endif
@@ -243,7 +245,8 @@ clib_count_equal_u8 (u8 *data, uword max_count)
   else
     {
       u64 mask = pow2_mask (max_count - count);
-      u64 bmp = u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat);
+      u64 bmp =
+       u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat) & mask;
       return count + count_trailing_zeros (~bmp);
     }
 #endif
@@ -265,7 +268,8 @@ clib_count_equal_u8 (u8 *data, uword max_count)
   else
     {
       u32 mask = pow2_mask (max_count - count);
-      u64 bmp = u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat);
+      u64 bmp =
+       u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat) & mask;
       return count + count_trailing_zeros (~bmp);
     }
 #endif