vppinfra: more avx512 inlines (compress, expand, from, is_equal_mask)
[vpp.git] / src / vppinfra / vector_avx512.h
1 /*
2  * Copyright (c) 2015 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #ifndef included_vector_avx512_h
17 #define included_vector_avx512_h
18
19 #include <vppinfra/clib.h>
20 #include <x86intrin.h>
21
22 /* *INDENT-OFF* */
23 #define foreach_avx512_vec512i \
24   _(i,8,64,epi8) _(i,16,32,epi16) _(i,32,16,epi32)  _(i,64,8,epi64)
25 #define foreach_avx512_vec512u \
26   _(u,8,64,epi8) _(u,16,32,epi16) _(u,32,16,epi32)  _(u,64,8,epi64)
27 #define foreach_avx512_vec512f \
28   _(f,32,8,ps) _(f,64,4,pd)
29
30 /* splat, load_unaligned, store_unaligned, is_all_zero, is_equal,
31    is_all_equal, is_zero_mask */
32 #define _(t, s, c, i) \
33 static_always_inline t##s##x##c                                         \
34 t##s##x##c##_splat (t##s x)                                             \
35 { return (t##s##x##c) _mm512_set1_##i (x); }                            \
36 \
37 static_always_inline t##s##x##c                                         \
38 t##s##x##c##_load_aligned (void *p)                                     \
39 { return (t##s##x##c) _mm512_load_si512 (p); }                          \
40 \
41 static_always_inline void                                               \
42 t##s##x##c##_store_aligned (t##s##x##c v, void *p)                      \
43 { _mm512_store_si512 ((__m512i *) p, (__m512i) v); }                    \
44 \
45 static_always_inline t##s##x##c                                         \
46 t##s##x##c##_load_unaligned (void *p)                                   \
47 { return (t##s##x##c) _mm512_loadu_si512 (p); }                         \
48 \
49 static_always_inline void                                               \
50 t##s##x##c##_store_unaligned (t##s##x##c v, void *p)                    \
51 { _mm512_storeu_si512 ((__m512i *) p, (__m512i) v); }                   \
52 \
53 static_always_inline int                                                \
54 t##s##x##c##_is_all_zero (t##s##x##c v)                                 \
55 { return (_mm512_test_epi64_mask ((__m512i) v, (__m512i) v) == 0); }    \
56 \
57 static_always_inline int                                                \
58 t##s##x##c##_is_equal (t##s##x##c a, t##s##x##c b)                      \
59 { return t##s##x##c##_is_all_zero (a ^ b); }                            \
60 \
61 static_always_inline int                                                \
62 t##s##x##c##_is_all_equal (t##s##x##c v, t##s x)                        \
63 { return t##s##x##c##_is_equal (v, t##s##x##c##_splat (x)); }           \
64 \
65 static_always_inline u##c                                               \
66 t##s##x##c##_is_zero_mask (t##s##x##c v)                                \
67 { return _mm512_test_##i##_mask ((__m512i) v, (__m512i) v); }           \
68 \
69 static_always_inline t##s##x##c                                         \
70 t##s##x##c##_interleave_lo (t##s##x##c a, t##s##x##c b)                 \
71 { return (t##s##x##c) _mm512_unpacklo_##i ((__m512i) a, (__m512i) b); } \
72 \
73 static_always_inline t##s##x##c                                         \
74 t##s##x##c##_interleave_hi (t##s##x##c a, t##s##x##c b)                 \
75 { return (t##s##x##c) _mm512_unpackhi_##i ((__m512i) a, (__m512i) b); } \
76
77
78 foreach_avx512_vec512i foreach_avx512_vec512u
79 #undef _
80 /* *INDENT-ON* */
81
82 static_always_inline u32
83 u16x32_msb_mask (u16x32 v)
84 {
85   return (u32) _mm512_movepi16_mask ((__m512i) v);
86 }
87
88 static_always_inline u32x16
89 u32x16_byte_swap (u32x16 v)
90 {
91   u8x64 swap = {
92     3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12,
93     3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12,
94     3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12,
95     3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12
96   };
97   return (u32x16) _mm512_shuffle_epi8 ((__m512i) v, (__m512i) swap);
98 }
99
100 static_always_inline u16x32
101 u16x32_byte_swap (u16x32 v)
102 {
103   u8x64 swap = {
104     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
105     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
106     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
107     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
108   };
109   return (u16x32) _mm512_shuffle_epi8 ((__m512i) v, (__m512i) swap);
110 }
111
112 #define _(f, t)                                                               \
113   static_always_inline t f##_extract_lo (f v)                                 \
114   {                                                                           \
115     return (t) _mm512_extracti64x4_epi64 ((__m512i) v, 0);                    \
116   }                                                                           \
117   static_always_inline t f##_extract_hi (f v)                                 \
118   {                                                                           \
119     return (t) _mm512_extracti64x4_epi64 ((__m512i) v, 1);                    \
120   }
121
122 _ (u64x8, u64x4)
123 _ (u32x16, u32x8)
124 _ (u16x32, u16x16)
125 _ (u8x64, u8x32)
126 #undef _
127
128 static_always_inline u32
129 u32x16_min_scalar (u32x16 v)
130 {
131   return u32x8_min_scalar (u32x8_min (u32x16_extract_lo (v),
132                                       u32x16_extract_hi (v)));
133 }
134
135 static_always_inline u32x16
136 u32x16_insert_lo (u32x16 r, u32x8 v)
137 {
138   return (u32x16) _mm512_inserti64x4 ((__m512i) r, (__m256i) v, 0);
139 }
140
141 static_always_inline u32x16
142 u32x16_insert_hi (u32x16 r, u32x8 v)
143 {
144   return (u32x16) _mm512_inserti64x4 ((__m512i) r, (__m256i) v, 1);
145 }
146
147 static_always_inline u64x8
148 u64x8_permute (u64x8 a, u64x8 b, u64x8 mask)
149 {
150   return (u64x8) _mm512_permutex2var_epi64 ((__m512i) a, (__m512i) mask,
151                                             (__m512i) b);
152 }
153
154
155 #define u32x16_ternary_logic(a, b, c, d) \
156   (u32x16) _mm512_ternarylogic_epi32 ((__m512i) a, (__m512i) b, (__m512i) c, d)
157
158 #define u8x64_insert_u8x16(a, b, n) \
159   (u8x64) _mm512_inserti64x2 ((__m512i) (a), (__m128i) (b), n)
160
161 #define u8x64_extract_u8x16(a, n) \
162   (u8x16) _mm512_extracti64x2_epi64 ((__m512i) (a), n)
163
164 #define u8x64_word_shift_left(a,n)  (u8x64) _mm512_bslli_epi128((__m512i) a, n)
165 #define u8x64_word_shift_right(a,n) (u8x64) _mm512_bsrli_epi128((__m512i) a, n)
166
167 static_always_inline u8x64
168 u8x64_xor3 (u8x64 a, u8x64 b, u8x64 c)
169 {
170   return (u8x64) _mm512_ternarylogic_epi32 ((__m512i) a, (__m512i) b,
171                                             (__m512i) c, 0x96);
172 }
173
174 static_always_inline u8x64
175 u8x64_reflect_u8x16 (u8x64 x)
176 {
177   static const u8x64 mask = {
178     15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
179     15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
180     15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
181     15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
182   };
183   return (u8x64) _mm512_shuffle_epi8 ((__m512i) x, (__m512i) mask);
184 }
185
186 static_always_inline u8x64
187 u8x64_shuffle (u8x64 v, u8x64 m)
188 {
189   return (u8x64) _mm512_shuffle_epi8 ((__m512i) v, (__m512i) m);
190 }
191
192 #define u8x64_align_right(a, b, imm) \
193   (u8x64) _mm512_alignr_epi8 ((__m512i) a, (__m512i) b, imm)
194
195 static_always_inline u32
196 u32x16_sum_elts (u32x16 sum16)
197 {
198   u32x8 sum8;
199   sum16 += (u32x16) u8x64_align_right (sum16, sum16, 8);
200   sum16 += (u32x16) u8x64_align_right (sum16, sum16, 4);
201   sum8 = u32x16_extract_hi (sum16) + u32x16_extract_lo (sum16);
202   return sum8[0] + sum8[4];
203 }
204
205 static_always_inline u8x64
206 u8x64_mask_load (u8x64 a, void *p, u64 mask)
207 {
208   return (u8x64) _mm512_mask_loadu_epi8 ((__m512i) a, mask, p);
209 }
210
211 static_always_inline void
212 u8x64_mask_store (u8x64 a, void *p, u64 mask)
213 {
214   _mm512_mask_storeu_epi8 (p, mask, (__m512i) a);
215 }
216
217 static_always_inline u8x64
218 u8x64_splat_u8x16 (u8x16 a)
219 {
220   return (u8x64) _mm512_broadcast_i64x2 ((__m128i) a);
221 }
222
223 static_always_inline u32x16
224 u32x16_splat_u32x4 (u32x4 a)
225 {
226   return (u32x16) _mm512_broadcast_i64x2 ((__m128i) a);
227 }
228
229 static_always_inline u32x16
230 u32x16_mask_blend (u32x16 a, u32x16 b, u16 mask)
231 {
232   return (u32x16) _mm512_mask_blend_epi32 (mask, (__m512i) a, (__m512i) b);
233 }
234
235 static_always_inline u8x64
236 u8x64_mask_blend (u8x64 a, u8x64 b, u64 mask)
237 {
238   return (u8x64) _mm512_mask_blend_epi8 (mask, (__m512i) a, (__m512i) b);
239 }
240
241 #define _(t, m, e, p, it)                                                     \
242   static_always_inline m t##_is_equal_mask (t a, t b)                         \
243   {                                                                           \
244     return p##_cmpeq_##e##_mask ((it) a, (it) b);                             \
245   }
246 _ (u8x16, u16, epu8, _mm, __m128i)
247 _ (u16x8, u8, epu16, _mm, __m128i)
248 _ (u32x4, u8, epu32, _mm, __m128i)
249 _ (u64x2, u8, epu64, _mm, __m128i)
250
251 _ (u8x32, u32, epu8, _mm256, __m256i)
252 _ (u16x16, u16, epu16, _mm256, __m256i)
253 _ (u32x8, u8, epu32, _mm256, __m256i)
254 _ (u64x4, u8, epu64, _mm256, __m256i)
255
256 _ (u8x64, u64, epu8, _mm512, __m512i)
257 _ (u16x32, u32, epu16, _mm512, __m512i)
258 _ (u32x16, u16, epu32, _mm512, __m512i)
259 _ (u64x8, u8, epu64, _mm512, __m512i)
260 #undef _
261
262 #define _(f, t, fn, it)                                                       \
263   static_always_inline t t##_from_##f (f x) { return (t) fn ((it) x); }
264 _ (u16x16, u32x16, _mm512_cvtepi16_epi32, __m256i)
265 _ (u32x16, u16x16, _mm512_cvtusepi32_epi16, __m512i)
266 _ (u32x8, u16x8, _mm256_cvtusepi32_epi16, __m256i)
267 #undef _
268
269 #define _(vt, mt, bits, epi)                                                  \
270   static_always_inline vt vt##_compress (vt a, mt mask)                       \
271   {                                                                           \
272     return (vt) _mm##bits##_maskz_compress_##epi (mask, (__m##bits##i) a);    \
273   }                                                                           \
274   static_always_inline vt vt##_expand (vt a, mt mask)                         \
275   {                                                                           \
276     return (vt) _mm##bits##_maskz_expand_##epi (mask, (__m##bits##i) a);      \
277   }
278
279 _ (u64x8, u8, 512, epi64)
280 _ (u32x16, u16, 512, epi32)
281 _ (u64x4, u8, 256, epi64)
282 _ (u32x8, u8, 256, epi32)
283 #ifdef __AVX512VBMI2__
284 _ (u16x32, u32, 512, epi16)
285 _ (u8x64, u64, 512, epi8)
286 _ (u16x16, u16, 256, epi16)
287 _ (u8x32, u32, 256, epi8)
288 #endif
289 #undef _
290
291 #define CLIB_HAVE_VEC256_COMPRESS
292 #define CLIB_HAVE_VEC512_COMPRESS
293
294 #ifndef __AVX512VBMI2__
295 static_always_inline u16x16
296 u16x16_compress (u16x16 v, u16 mask)
297 {
298   return u16x16_from_u32x16 (u32x16_compress (u32x16_from_u16x16 (v), mask));
299 }
300
301 static_always_inline u16x8
302 u16x8_compress (u16x8 v, u8 mask)
303 {
304   return u16x8_from_u32x8 (u32x8_compress (u32x8_from_u16x8 (v), mask));
305 }
306 #endif
307
308 static_always_inline void
309 u32x16_transpose (u32x16 m[16])
310 {
311   __m512i r[16], a, b, c, d, x, y;
312
313   /* *INDENT-OFF* */
314   __m512i pm1 = (__m512i) (u64x8) { 0, 1, 8, 9, 4, 5, 12, 13};
315   __m512i pm2 = (__m512i) (u64x8) { 2, 3, 10, 11, 6, 7, 14, 15};
316   __m512i pm3 = (__m512i) (u64x8) { 0, 1, 2, 3, 8, 9, 10, 11};
317   __m512i pm4 = (__m512i) (u64x8) { 4, 5, 6, 7, 12, 13, 14, 15};
318   /* *INDENT-ON* */
319
320   r[0] = _mm512_unpacklo_epi32 ((__m512i) m[0], (__m512i) m[1]);
321   r[1] = _mm512_unpacklo_epi32 ((__m512i) m[2], (__m512i) m[3]);
322   r[2] = _mm512_unpacklo_epi32 ((__m512i) m[4], (__m512i) m[5]);
323   r[3] = _mm512_unpacklo_epi32 ((__m512i) m[6], (__m512i) m[7]);
324   r[4] = _mm512_unpacklo_epi32 ((__m512i) m[8], (__m512i) m[9]);
325   r[5] = _mm512_unpacklo_epi32 ((__m512i) m[10], (__m512i) m[11]);
326   r[6] = _mm512_unpacklo_epi32 ((__m512i) m[12], (__m512i) m[13]);
327   r[7] = _mm512_unpacklo_epi32 ((__m512i) m[14], (__m512i) m[15]);
328
329   r[8] = _mm512_unpackhi_epi32 ((__m512i) m[0], (__m512i) m[1]);
330   r[9] = _mm512_unpackhi_epi32 ((__m512i) m[2], (__m512i) m[3]);
331   r[10] = _mm512_unpackhi_epi32 ((__m512i) m[4], (__m512i) m[5]);
332   r[11] = _mm512_unpackhi_epi32 ((__m512i) m[6], (__m512i) m[7]);
333   r[12] = _mm512_unpackhi_epi32 ((__m512i) m[8], (__m512i) m[9]);
334   r[13] = _mm512_unpackhi_epi32 ((__m512i) m[10], (__m512i) m[11]);
335   r[14] = _mm512_unpackhi_epi32 ((__m512i) m[12], (__m512i) m[13]);
336   r[15] = _mm512_unpackhi_epi32 ((__m512i) m[14], (__m512i) m[15]);
337
338   a = _mm512_unpacklo_epi64 (r[0], r[1]);
339   b = _mm512_unpacklo_epi64 (r[2], r[3]);
340   c = _mm512_unpacklo_epi64 (r[4], r[5]);
341   d = _mm512_unpacklo_epi64 (r[6], r[7]);
342   x = _mm512_permutex2var_epi64 (a, pm1, b);
343   y = _mm512_permutex2var_epi64 (c, pm1, d);
344   m[0] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
345   m[8] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
346   x = _mm512_permutex2var_epi64 (a, pm2, b);
347   y = _mm512_permutex2var_epi64 (c, pm2, d);
348   m[4] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
349   m[12] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
350
351   a = _mm512_unpacklo_epi64 (r[8], r[9]);
352   b = _mm512_unpacklo_epi64 (r[10], r[11]);
353   c = _mm512_unpacklo_epi64 (r[12], r[13]);
354   d = _mm512_unpacklo_epi64 (r[14], r[15]);
355   x = _mm512_permutex2var_epi64 (a, pm1, b);
356   y = _mm512_permutex2var_epi64 (c, pm1, d);
357   m[2] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
358   m[10] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
359   x = _mm512_permutex2var_epi64 (a, pm2, b);
360   y = _mm512_permutex2var_epi64 (c, pm2, d);
361   m[6] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
362   m[14] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
363
364   a = _mm512_unpackhi_epi64 (r[0], r[1]);
365   b = _mm512_unpackhi_epi64 (r[2], r[3]);
366   c = _mm512_unpackhi_epi64 (r[4], r[5]);
367   d = _mm512_unpackhi_epi64 (r[6], r[7]);
368   x = _mm512_permutex2var_epi64 (a, pm1, b);
369   y = _mm512_permutex2var_epi64 (c, pm1, d);
370   m[1] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
371   m[9] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
372   x = _mm512_permutex2var_epi64 (a, pm2, b);
373   y = _mm512_permutex2var_epi64 (c, pm2, d);
374   m[5] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
375   m[13] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
376
377   a = _mm512_unpackhi_epi64 (r[8], r[9]);
378   b = _mm512_unpackhi_epi64 (r[10], r[11]);
379   c = _mm512_unpackhi_epi64 (r[12], r[13]);
380   d = _mm512_unpackhi_epi64 (r[14], r[15]);
381   x = _mm512_permutex2var_epi64 (a, pm1, b);
382   y = _mm512_permutex2var_epi64 (c, pm1, d);
383   m[3] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
384   m[11] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
385   x = _mm512_permutex2var_epi64 (a, pm2, b);
386   y = _mm512_permutex2var_epi64 (c, pm2, d);
387   m[7] = (u32x16) _mm512_permutex2var_epi64 (x, pm3, y);
388   m[15] = (u32x16) _mm512_permutex2var_epi64 (x, pm4, y);
389 }
390
391
392
393 static_always_inline void
394 u64x8_transpose (u64x8 m[8])
395 {
396   __m512i r[8], x, y;
397
398   /* *INDENT-OFF* */
399   __m512i pm1 = (__m512i) (u64x8) { 0, 1, 8, 9, 4, 5, 12, 13};
400   __m512i pm2 = (__m512i) (u64x8) { 2, 3, 10, 11, 6, 7, 14, 15};
401   __m512i pm3 = (__m512i) (u64x8) { 0, 1, 2, 3, 8, 9, 10, 11};
402   __m512i pm4 = (__m512i) (u64x8) { 4, 5, 6, 7, 12, 13, 14, 15};
403   /* *INDENT-ON* */
404
405   r[0] = _mm512_unpacklo_epi64 ((__m512i) m[0], (__m512i) m[1]);
406   r[1] = _mm512_unpacklo_epi64 ((__m512i) m[2], (__m512i) m[3]);
407   r[2] = _mm512_unpacklo_epi64 ((__m512i) m[4], (__m512i) m[5]);
408   r[3] = _mm512_unpacklo_epi64 ((__m512i) m[6], (__m512i) m[7]);
409   r[4] = _mm512_unpackhi_epi64 ((__m512i) m[0], (__m512i) m[1]);
410   r[5] = _mm512_unpackhi_epi64 ((__m512i) m[2], (__m512i) m[3]);
411   r[6] = _mm512_unpackhi_epi64 ((__m512i) m[4], (__m512i) m[5]);
412   r[7] = _mm512_unpackhi_epi64 ((__m512i) m[6], (__m512i) m[7]);
413
414   x = _mm512_permutex2var_epi64 (r[0], pm1, r[1]);
415   y = _mm512_permutex2var_epi64 (r[2], pm1, r[3]);
416   m[0] = (u64x8) _mm512_permutex2var_epi64 (x, pm3, y);
417   m[4] = (u64x8) _mm512_permutex2var_epi64 (x, pm4, y);
418   x = _mm512_permutex2var_epi64 (r[0], pm2, r[1]);
419   y = _mm512_permutex2var_epi64 (r[2], pm2, r[3]);
420   m[2] = (u64x8) _mm512_permutex2var_epi64 (x, pm3, y);
421   m[6] = (u64x8) _mm512_permutex2var_epi64 (x, pm4, y);
422
423   x = _mm512_permutex2var_epi64 (r[4], pm1, r[5]);
424   y = _mm512_permutex2var_epi64 (r[6], pm1, r[7]);
425   m[1] = (u64x8) _mm512_permutex2var_epi64 (x, pm3, y);
426   m[5] = (u64x8) _mm512_permutex2var_epi64 (x, pm4, y);
427   x = _mm512_permutex2var_epi64 (r[4], pm2, r[5]);
428   y = _mm512_permutex2var_epi64 (r[6], pm2, r[7]);
429   m[3] = (u64x8) _mm512_permutex2var_epi64 (x, pm3, y);
430   m[7] = (u64x8) _mm512_permutex2var_epi64 (x, pm4, y);
431 }
432
433 #endif /* included_vector_avx512_h */
434 /*
435  * fd.io coding-style-patch-verification: ON
436  *
437  * Local Variables:
438  * eval: (c-set-style "gnu")
439  * End:
440  */