vppinfra: fix masks in AVX512 clib_count_equal_*
[vpp.git] / src / vppinfra / vector / count_equal.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2021 Cisco Systems, Inc.
3  */
4
5 #ifndef included_vector_count_equal_h
6 #define included_vector_count_equal_h
7 #include <vppinfra/clib.h>
8
9 static_always_inline uword
10 clib_count_equal_u64 (u64 *data, uword max_count)
11 {
12   uword count;
13   u64 first;
14
15   if (max_count <= 1)
16     return max_count;
17   if (data[0] != data[1])
18     return 1;
19
20   count = 0;
21   first = data[0];
22
23 #if defined(CLIB_HAVE_VEC256)
24   u64x4 splat = u64x4_splat (first);
25   while (count + 3 < max_count)
26     {
27       u64 bmp;
28       bmp = u8x32_msb_mask ((u8x32) (u64x4_load_unaligned (data) == splat));
29       if (bmp != 0xffffffff)
30         {
31           count += count_trailing_zeros (~bmp) / 8;
32           return count;
33         }
34
35       data += 4;
36       count += 4;
37     }
38 #else
39   count += 2;
40   data += 2;
41   while (count + 3 < max_count && ((data[0] ^ first) | (data[1] ^ first) |
42                                    (data[2] ^ first) | (data[3] ^ first)) == 0)
43     {
44       data += 4;
45       count += 4;
46     }
47 #endif
48   while (count < max_count && (data[0] == first))
49     {
50       data += 1;
51       count += 1;
52     }
53   return count;
54 }
55
56 static_always_inline uword
57 clib_count_equal_u32 (u32 *data, uword max_count)
58 {
59   uword count;
60   u32 first;
61
62   if (max_count <= 1)
63     return max_count;
64   if (data[0] != data[1])
65     return 1;
66
67   count = 0;
68   first = data[0];
69
70 #if defined(CLIB_HAVE_VEC512)
71   u32x16 splat = u32x16_splat (first);
72   while (count + 15 < max_count)
73     {
74       u32 bmp;
75       bmp = u32x16_is_equal_mask (u32x16_load_unaligned (data), splat);
76       if (bmp != pow2_mask (16))
77         return count + count_trailing_zeros (~bmp);
78
79       data += 16;
80       count += 16;
81     }
82   if (count == max_count)
83     return count;
84   else
85     {
86       u32 mask = pow2_mask (max_count - count);
87       u32 bmp =
88         u32x16_is_equal_mask (u32x16_mask_load_zero (data, mask), splat) &
89         mask;
90       return count + count_trailing_zeros (~bmp);
91     }
92 #elif defined(CLIB_HAVE_VEC256)
93   u32x8 splat = u32x8_splat (first);
94   while (count + 7 < max_count)
95     {
96       u32 bmp;
97 #ifdef __AVX512F__
98       bmp = u32x8_is_equal_mask (u32x8_load_unaligned (data), splat);
99       if (bmp != pow2_mask (8))
100         return count + count_trailing_zeros (~bmp);
101 #else
102       bmp = u8x32_msb_mask ((u8x32) (u32x8_load_unaligned (data) == splat));
103       if (bmp != 0xffffffff)
104         return count + count_trailing_zeros (~bmp) / 4;
105 #endif
106
107       data += 8;
108       count += 8;
109     }
110   if (count == max_count)
111     return count;
112 #if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
113   else
114     {
115       u32 mask = pow2_mask (max_count - count);
116       u32 bmp =
117         u32x8_is_equal_mask (u32x8_mask_load_zero (data, mask), splat) & mask;
118       return count + count_trailing_zeros (~bmp);
119     }
120 #endif
121 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
122   u32x4 splat = u32x4_splat (first);
123   while (count + 3 < max_count)
124     {
125       u64 bmp;
126       bmp = u8x16_msb_mask ((u8x16) (u32x4_load_unaligned (data) == splat));
127       if (bmp != pow2_mask (4 * 4))
128         {
129           count += count_trailing_zeros (~bmp) / 4;
130           return count;
131         }
132
133       data += 4;
134       count += 4;
135     }
136 #else
137   count += 2;
138   data += 2;
139   while (count + 3 < max_count && ((data[0] ^ first) | (data[1] ^ first) |
140                                    (data[2] ^ first) | (data[3] ^ first)) == 0)
141     {
142       data += 4;
143       count += 4;
144     }
145 #endif
146   while (count < max_count && (data[0] == first))
147     {
148       data += 1;
149       count += 1;
150     }
151   return count;
152 }
153
154 static_always_inline uword
155 clib_count_equal_u16 (u16 *data, uword max_count)
156 {
157   uword count;
158   u16 first;
159
160   if (max_count <= 1)
161     return max_count;
162   if (data[0] != data[1])
163     return 1;
164
165   count = 0;
166   first = data[0];
167
168 #if defined(CLIB_HAVE_VEC256)
169   u16x16 splat = u16x16_splat (first);
170   while (count + 15 < max_count)
171     {
172       u64 bmp;
173       bmp = u8x32_msb_mask ((u8x32) (u16x16_load_unaligned (data) == splat));
174       if (bmp != 0xffffffff)
175         {
176           count += count_trailing_zeros (~bmp) / 2;
177           return count;
178         }
179
180       data += 16;
181       count += 16;
182     }
183 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
184   u16x8 splat = u16x8_splat (first);
185   while (count + 7 < max_count)
186     {
187       u64 bmp;
188       bmp = u8x16_msb_mask ((u8x16) (u16x8_load_unaligned (data) == splat));
189       if (bmp != 0xffff)
190         {
191           count += count_trailing_zeros (~bmp) / 2;
192           return count;
193         }
194
195       data += 8;
196       count += 8;
197     }
198 #else
199   count += 2;
200   data += 2;
201   while (count + 3 < max_count && ((data[0] ^ first) | (data[1] ^ first) |
202                                    (data[2] ^ first) | (data[3] ^ first)) == 0)
203     {
204       data += 4;
205       count += 4;
206     }
207 #endif
208   while (count < max_count && (data[0] == first))
209     {
210       data += 1;
211       count += 1;
212     }
213   return count;
214 }
215
216 static_always_inline uword
217 clib_count_equal_u8 (u8 *data, uword max_count)
218 {
219   uword count;
220   u8 first;
221
222   if (max_count <= 1)
223     return max_count;
224   if (data[0] != data[1])
225     return 1;
226
227   count = 0;
228   first = data[0];
229
230 #if defined(CLIB_HAVE_VEC512)
231   u8x64 splat = u8x64_splat (first);
232   while (count + 63 < max_count)
233     {
234       u64 bmp;
235       bmp = u8x64_is_equal_mask (u8x64_load_unaligned (data), splat);
236       if (bmp != -1)
237         return count + count_trailing_zeros (~bmp);
238
239       data += 64;
240       count += 64;
241     }
242   if (count == max_count)
243     return count;
244 #if defined(CLIB_HAVE_VEC512_MASK_LOAD_STORE)
245   else
246     {
247       u64 mask = pow2_mask (max_count - count);
248       u64 bmp =
249         u8x64_is_equal_mask (u8x64_mask_load_zero (data, mask), splat) & mask;
250       return count + count_trailing_zeros (~bmp);
251     }
252 #endif
253 #elif defined(CLIB_HAVE_VEC256)
254   u8x32 splat = u8x32_splat (first);
255   while (count + 31 < max_count)
256     {
257       u64 bmp;
258       bmp = u8x32_msb_mask ((u8x32) (u8x32_load_unaligned (data) == splat));
259       if (bmp != 0xffffffff)
260         return count + count_trailing_zeros (~bmp);
261
262       data += 32;
263       count += 32;
264     }
265   if (count == max_count)
266     return count;
267 #if defined(CLIB_HAVE_VEC256_MASK_LOAD_STORE)
268   else
269     {
270       u32 mask = pow2_mask (max_count - count);
271       u64 bmp =
272         u8x32_msb_mask (u8x32_mask_load_zero (data, mask) == splat) & mask;
273       return count + count_trailing_zeros (~bmp);
274     }
275 #endif
276 #elif defined(CLIB_HAVE_VEC128) && defined(CLIB_HAVE_VEC128_MSB_MASK)
277   u8x16 splat = u8x16_splat (first);
278   while (count + 15 < max_count)
279     {
280       u64 bmp;
281       bmp = u8x16_msb_mask ((u8x16) (u8x16_load_unaligned (data) == splat));
282       if (bmp != 0xffff)
283         return count + count_trailing_zeros (~bmp);
284
285       data += 16;
286       count += 16;
287     }
288 #else
289   count += 2;
290   data += 2;
291   while (count + 3 < max_count && ((data[0] ^ first) | (data[1] ^ first) |
292                                    (data[2] ^ first) | (data[3] ^ first)) == 0)
293     {
294       data += 4;
295       count += 4;
296     }
297 #endif
298   while (count < max_count && (data[0] == first))
299     {
300       data += 1;
301       count += 1;
302     }
303   return count;
304 }
305
306 #endif