From 7da9b5be41395cc6355f9cf278106aae7fd9f991 Mon Sep 17 00:00:00 2001 From: Mohsin Kazmi Date: Fri, 27 Aug 2021 18:57:16 +0200 Subject: [PATCH] vppinfra: add compress functions for u64, u16 and u8 Type: improvement Change-Id: I2640148b8959f9a8303520ba2815fe02f1e47928 Signed-off-by: Mohsin Kazmi --- src/vppinfra/vector/compress.h | 181 +++++++++++++++++++++++++++++++++++- src/vppinfra/vector/test/compress.c | 167 +++++++++++++++++++++++++++++++++ src/vppinfra/vector_avx512.h | 8 ++ 3 files changed, 352 insertions(+), 4 deletions(-) diff --git a/src/vppinfra/vector/compress.h b/src/vppinfra/vector/compress.h index 1d5d84e77ea..adb6503f711 100644 --- a/src/vppinfra/vector/compress.h +++ b/src/vppinfra/vector/compress.h @@ -7,6 +7,71 @@ #include #include +static_always_inline u64 * +clib_compress_u64_x64 (u64 *dst, u64 *src, u64 mask) +{ +#if defined(CLIB_HAVE_VEC512_COMPRESS) + u64x8u *sv = (u64x8u *) src; + for (int i = 0; i < 8; i++) + { + u64x8_compress_store (sv[i], mask, dst); + dst += _popcnt32 ((u8) mask); + mask >>= 8; + } +#elif defined(CLIB_HAVE_VEC256_COMPRESS) + u64x4u *sv = (u64x4u *) src; + for (int i = 0; i < 16; i++) + { + u64x4_compress_store (sv[i], mask, dst); + dst += _popcnt32 (((u8) mask) & 0x0f); + mask >>= 4; + } +#else + while (mask) + { + u16 bit = count_trailing_zeros (mask); + mask = clear_lowest_set_bit (mask); + dst++[0] = src[bit]; + } +#endif + return dst; +} + +/** \brief Compress array of 64-bit elemments into destination array based on + * mask + + @param dst destination array of u64 elements + @param src source array of u64 elements + @param mask array of u64 values representing compress mask + @param n_elts number of elements in the source array + @return number of elements stored in destionation array +*/ + +static_always_inline u32 +clib_compress_u64 (u64 *dst, u64 *src, u64 *mask, u32 n_elts) +{ + u64 *dst0 = dst; + while (n_elts >= 64) + { + if (mask[0] == ~0ULL) + { + clib_memcpy_fast (dst, src, 64 * sizeof (u64)); + dst += 64; + } + else + dst = clib_compress_u64_x64 (dst, src, mask[0]); + + mask++; + src += 64; + n_elts -= 64; + } + + if (PREDICT_TRUE (n_elts == 0)) + return dst - dst0; + + return clib_compress_u64_x64 (dst, src, mask[0] & pow2_mask (n_elts)) - dst0; +} + static_always_inline u32 * clib_compress_u32_x64 (u32 *dst, u32 *src, u64 mask) { @@ -14,9 +79,8 @@ clib_compress_u32_x64 (u32 *dst, u32 *src, u64 mask) u32x16u *sv = (u32x16u *) src; for (int i = 0; i < 4; i++) { - int cnt = _popcnt32 ((u16) mask); u32x16_compress_store (sv[i], mask, dst); - dst += cnt; + dst += _popcnt32 ((u16) mask); mask >>= 16; } @@ -24,9 +88,8 @@ clib_compress_u32_x64 (u32 *dst, u32 *src, u64 mask) u32x8u *sv = (u32x8u *) src; for (int i = 0; i < 8; i++) { - int cnt = _popcnt32 ((u8) mask); u32x8_compress_store (sv[i], mask, dst); - dst += cnt; + dst += _popcnt32 ((u8) mask); mask >>= 8; } #else @@ -75,4 +138,114 @@ clib_compress_u32 (u32 *dst, u32 *src, u64 *mask, u32 n_elts) return clib_compress_u32_x64 (dst, src, mask[0] & pow2_mask (n_elts)) - dst0; } +static_always_inline u16 * +clib_compress_u16_x64 (u16 *dst, u16 *src, u64 mask) +{ +#if defined(CLIB_HAVE_VEC512_COMPRESS_U8_U16) + u16x32u *sv = (u16x32u *) src; + for (int i = 0; i < 2; i++) + { + u16x32_compress_store (sv[i], mask, dst); + dst += _popcnt32 ((u32) mask); + mask >>= 32; + } +#else + while (mask) + { + u16 bit = count_trailing_zeros (mask); + mask = clear_lowest_set_bit (mask); + dst++[0] = src[bit]; + } +#endif + return dst; +} + +/** \brief Compress array of 16-bit elemments into destination array based on + * mask + + @param dst destination array of u16 elements + @param src source array of u16 elements + @param mask array of u64 values representing compress mask + @param n_elts number of elements in the source array + @return number of elements stored in destionation array +*/ + +static_always_inline u32 +clib_compress_u16 (u16 *dst, u16 *src, u64 *mask, u32 n_elts) +{ + u16 *dst0 = dst; + while (n_elts >= 64) + { + if (mask[0] == ~0ULL) + { + clib_memcpy_fast (dst, src, 64 * sizeof (u16)); + dst += 64; + } + else + dst = clib_compress_u16_x64 (dst, src, mask[0]); + + mask++; + src += 64; + n_elts -= 64; + } + + if (PREDICT_TRUE (n_elts == 0)) + return dst - dst0; + + return clib_compress_u16_x64 (dst, src, mask[0] & pow2_mask (n_elts)) - dst0; +} + +static_always_inline u8 * +clib_compress_u8_x64 (u8 *dst, u8 *src, u64 mask) +{ +#if defined(CLIB_HAVE_VEC512_COMPRESS_U8_U16) + u8x64u *sv = (u8x64u *) src; + u8x64_compress_store (sv[0], mask, dst); + dst += _popcnt64 (mask); +#else + while (mask) + { + u16 bit = count_trailing_zeros (mask); + mask = clear_lowest_set_bit (mask); + dst++[0] = src[bit]; + } +#endif + return dst; +} + +/** \brief Compress array of 8-bit elemments into destination array based on + * mask + + @param dst destination array of u8 elements + @param src source array of u8 elements + @param mask array of u64 values representing compress mask + @param n_elts number of elements in the source array + @return number of elements stored in destionation array +*/ + +static_always_inline u32 +clib_compress_u8 (u8 *dst, u8 *src, u64 *mask, u32 n_elts) +{ + u8 *dst0 = dst; + while (n_elts >= 64) + { + if (mask[0] == ~0ULL) + { + clib_memcpy_fast (dst, src, 64); + dst += 64; + } + else + dst = clib_compress_u8_x64 (dst, src, mask[0]); + + mask++; + src += 64; + n_elts -= 64; + } + + if (PREDICT_TRUE (n_elts == 0)) + return dst - dst0; + + return clib_compress_u8_x64 (dst, src, mask[0] & pow2_mask (n_elts)) - dst0; +} + #endif diff --git a/src/vppinfra/vector/test/compress.c b/src/vppinfra/vector/test/compress.c index 7e3eba9892d..9bc53ff1e41 100644 --- a/src/vppinfra/vector/test/compress.c +++ b/src/vppinfra/vector/test/compress.c @@ -6,12 +6,30 @@ #include #include +__clib_test_fn u32 +clib_compress_u64_wrapper (u64 *dst, u64 *src, u64 *mask, u32 n_elts) +{ + return clib_compress_u64 (dst, src, mask, n_elts); +} + __clib_test_fn u32 clib_compress_u32_wrapper (u32 *dst, u32 *src, u64 *mask, u32 n_elts) { return clib_compress_u32 (dst, src, mask, n_elts); } +__clib_test_fn u32 +clib_compress_u16_wrapper (u16 *dst, u16 *src, u64 *mask, u32 n_elts) +{ + return clib_compress_u16 (dst, src, mask, n_elts); +} + +__clib_test_fn u32 +clib_compress_u8_wrapper (u8 *dst, u8 *src, u64 *mask, u32 n_elts) +{ + return clib_compress_u8 (dst, src, mask, n_elts); +} + typedef struct { u64 mask[10]; @@ -30,6 +48,52 @@ static compress_test_t tests[] = { { .mask = { ~0ULL, 1, 1, ~0ULL }, .n_elts = 256 }, }; +static clib_error_t * +test_clib_compress_u64 (clib_error_t *err) +{ + u64 src[513]; + u64 dst[513]; + u32 i, j; + + for (i = 0; i < ARRAY_LEN (src); i++) + src[i] = i; + + for (i = 0; i < ARRAY_LEN (tests); i++) + { + compress_test_t *t = tests + i; + u64 *dp = dst; + u32 r; + + for (j = 0; j < ARRAY_LEN (dst); j++) + dst[j] = 0xa5a5a5a5a5a5a5a5; + + r = clib_compress_u64_wrapper (dst, src, t->mask, t->n_elts); + + for (j = 0; j < t->n_elts; j++) + { + if ((t->mask[j >> 6] & (1ULL << (j & 0x3f))) == 0) + continue; + if (dp[0] != src[j]) + return clib_error_return (err, + "wrong data in testcase %u at " + "(dst[%u] = 0x%lx, src[%u] = 0x%lx)", + i, dp - dst, dp[0], j, src[j]); + dp++; + } + + if (dst[dp - dst + 1] != 0xa5a5a5a5a5a5a5a5) + return clib_error_return (err, "buffer overrun in testcase %u", i); + + if (dp - dst != r) + return clib_error_return (err, "wrong number of elts in testcase %u", + i); + } + + return err; + + return err; +} + static clib_error_t * test_clib_compress_u32 (clib_error_t *err) { @@ -75,7 +139,110 @@ test_clib_compress_u32 (clib_error_t *err) return err; } +static clib_error_t * +test_clib_compress_u16 (clib_error_t *err) +{ + u16 src[513]; + u16 dst[513]; + u32 i, j; + + for (i = 0; i < ARRAY_LEN (src); i++) + src[i] = i; + + for (i = 0; i < ARRAY_LEN (tests); i++) + { + compress_test_t *t = tests + i; + u16 *dp = dst; + u32 r; + + for (j = 0; j < ARRAY_LEN (dst); j++) + dst[j] = 0xa5a5; + + r = clib_compress_u16_wrapper (dst, src, t->mask, t->n_elts); + + for (j = 0; j < t->n_elts; j++) + { + if ((t->mask[j >> 6] & (1ULL << (j & 0x3f))) == 0) + continue; + if (dp[0] != src[j]) + return clib_error_return (err, + "wrong data in testcase %u at " + "(dst[%u] = 0x%x, src[%u] = 0x%x)", + i, dp - dst, dp[0], j, src[j]); + dp++; + } + + if (dst[dp - dst + 1] != 0xa5a5) + return clib_error_return (err, "buffer overrun in testcase %u", i); + + if (dp - dst != r) + return clib_error_return (err, "wrong number of elts in testcase %u", + i); + } + + return err; +} + +static clib_error_t * +test_clib_compress_u8 (clib_error_t *err) +{ + u8 src[513]; + u8 dst[513]; + u32 i, j; + + for (i = 0; i < ARRAY_LEN (src); i++) + src[i] = i; + + for (i = 0; i < ARRAY_LEN (tests); i++) + { + compress_test_t *t = tests + i; + u8 *dp = dst; + u32 r; + + for (j = 0; j < ARRAY_LEN (dst); j++) + dst[j] = 0xa5; + + r = clib_compress_u8_wrapper (dst, src, t->mask, t->n_elts); + + for (j = 0; j < t->n_elts; j++) + { + if ((t->mask[j >> 6] & (1ULL << (j & 0x3f))) == 0) + continue; + if (dp[0] != src[j]) + return clib_error_return (err, + "wrong data in testcase %u at " + "(dst[%u] = 0x%x, src[%u] = 0x%x)", + i, dp - dst, dp[0], j, src[j]); + dp++; + } + + if (dst[dp - dst + 1] != 0xa5) + return clib_error_return (err, "buffer overrun in testcase %u", i); + + if (dp - dst != r) + return clib_error_return (err, "wrong number of elts in testcase %u", + i); + } + + return err; +} + +REGISTER_TEST (clib_compress_u64) = { + .name = "clib_compress_u64", + .fn = test_clib_compress_u64, +}; + REGISTER_TEST (clib_compress_u32) = { .name = "clib_compress_u32", .fn = test_clib_compress_u32, }; + +REGISTER_TEST (clib_compress_u16) = { + .name = "clib_compress_u16", + .fn = test_clib_compress_u16, +}; + +REGISTER_TEST (clib_compress_u8) = { + .name = "clib_compress_u8", + .fn = test_clib_compress_u8, +}; diff --git a/src/vppinfra/vector_avx512.h b/src/vppinfra/vector_avx512.h index 3a01c1ed824..5da490162d0 100644 --- a/src/vppinfra/vector_avx512.h +++ b/src/vppinfra/vector_avx512.h @@ -338,9 +338,17 @@ _ (u8x16, u16, _mm, __m128i, epi8) #ifdef CLIB_HAVE_VEC256 #define CLIB_HAVE_VEC256_COMPRESS +#ifdef __AVX512VBMI2__ +#define CLIB_HAVE_VEC256_COMPRESS_U8_U16 +#endif + #endif #ifdef CLIB_HAVE_VEC512 #define CLIB_HAVE_VEC512_COMPRESS +#ifdef __AVX512VBMI2__ +#define CLIB_HAVE_VEC512_COMPRESS_U8_U16 +#endif + #endif #ifndef __AVX512VBMI2__ -- 2.16.6