vppinfra: toeplitz hash four in parallel
[vpp.git] / src / vppinfra / vector / toeplitz.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2021 Cisco Systems, Inc.
3  */
4
5 #ifndef included_vector_toeplitz_h
6 #define included_vector_toeplitz_h
7 #include <vppinfra/clib.h>
8
9 typedef struct
10 {
11   u16 key_length;
12   u16 gfni_offset;
13   u8 data[];
14 } clib_toeplitz_hash_key_t;
15
16 clib_toeplitz_hash_key_t *clib_toeplitz_hash_key_init (u8 *key, u32 keylen);
17 void clib_toeplitz_hash_key_free (clib_toeplitz_hash_key_t *k);
18
19 #ifdef CLIB_HAVE_VEC256
20 static_always_inline u32x8
21 toeplitz_hash_one_x8 (u32x8 hash, u64x4 v4, u8 data, u8 off)
22 {
23   u32x8 v8 = u32x8_shuffle2 (v4 << (off * 8), v4 << (off * 8 + 4),
24                              /*uppper 32 bits of each u64 in reverse order */
25                              15, 13, 11, 9, 7, 5, 3, 1);
26
27 #ifdef CLIB_HAVE_VEC256_MASK_BITWISE_OPS
28   return u32x8_mask_xor (hash, v8, data);
29 #else
30   static const u32x8 bits = { 1, 2, 4, 8, 16, 32, 64, 128 };
31   return hash ^ (((u32x8_splat (data) & bits) != u32x8_zero ()) & v8);
32 #endif
33 }
34 #endif
35
36 #if defined(__GFNI__) && defined(__AVX512F__)
37 static const u8x64 __clib_toeplitz_hash_gfni_permute = {
38   /* clang-format off */
39   0x00, 0x01, 0x02, 0x03, 0x40, 0x41, 0x42, 0x43,
40   0x01, 0x02, 0x03, 0x04, 0x41, 0x42, 0x43, 0x44,
41   0x02, 0x03, 0x04, 0x05, 0x42, 0x43, 0x44, 0x45,
42   0x03, 0x04, 0x05, 0x06, 0x43, 0x44, 0x45, 0x46,
43   0x04, 0x05, 0x06, 0x07, 0x44, 0x45, 0x46, 0x47,
44   0x05, 0x06, 0x07, 0x08, 0x45, 0x46, 0x47, 0x48,
45   0x06, 0x07, 0x08, 0x09, 0x46, 0x47, 0x48, 0x49,
46   0x07, 0x08, 0x09, 0x0a, 0x47, 0x48, 0x49, 0x4a
47   /* clang-format on */
48 };
49 static_always_inline u64x8
50 clib_toeplitz_hash_gfni_one (u8x64 d0, u64x8 m, int i)
51 {
52
53   d0 = i == 1 ? (u8x64) u64x8_align_right (d0, d0, 1) : d0;
54   d0 = i == 2 ? (u8x64) u64x8_align_right (d0, d0, 2) : d0;
55   d0 = i == 3 ? (u8x64) u64x8_align_right (d0, d0, 3) : d0;
56   d0 = i == 4 ? (u8x64) u64x8_align_right (d0, d0, 4) : d0;
57   d0 = i == 5 ? (u8x64) u64x8_align_right (d0, d0, 5) : d0;
58   d0 = i == 6 ? (u8x64) u64x8_align_right (d0, d0, 6) : d0;
59
60   d0 = u8x64_permute (__clib_toeplitz_hash_gfni_permute, d0);
61
62   return (u64x8) _mm512_gf2p8affine_epi64_epi8 ((__m512i) d0, (__m512i) m, 0);
63 }
64
65 static_always_inline u64x8
66 clib_toeplitz_hash_gfni_two (u8x64 d0, u8x64 d1, u64x8 m, int i)
67 {
68
69   d0 = i == 1 ? (u8x64) u64x8_align_right (d0, d0, 1) : d0;
70   d1 = i == 1 ? (u8x64) u64x8_align_right (d1, d1, 1) : d1;
71   d0 = i == 2 ? (u8x64) u64x8_align_right (d0, d0, 2) : d0;
72   d1 = i == 2 ? (u8x64) u64x8_align_right (d1, d1, 2) : d1;
73   d0 = i == 3 ? (u8x64) u64x8_align_right (d0, d0, 3) : d0;
74   d1 = i == 3 ? (u8x64) u64x8_align_right (d1, d1, 3) : d1;
75   d0 = i == 4 ? (u8x64) u64x8_align_right (d0, d0, 4) : d0;
76   d1 = i == 4 ? (u8x64) u64x8_align_right (d1, d1, 4) : d1;
77   d0 = i == 5 ? (u8x64) u64x8_align_right (d0, d0, 5) : d0;
78   d1 = i == 5 ? (u8x64) u64x8_align_right (d1, d1, 5) : d1;
79   d0 = i == 6 ? (u8x64) u64x8_align_right (d0, d0, 6) : d0;
80   d1 = i == 6 ? (u8x64) u64x8_align_right (d1, d1, 6) : d1;
81
82   d0 = u8x64_permute2 (__clib_toeplitz_hash_gfni_permute, d0, d1);
83
84   return (u64x8) _mm512_gf2p8affine_epi64_epi8 ((__m512i) d0, (__m512i) m, 0);
85 }
86 #endif
87
88 static_always_inline u32
89 clib_toeplitz_hash (clib_toeplitz_hash_key_t *k, u8 *data, int n_bytes)
90 {
91   u8 *key = k->data;
92   /* key must be 4 bytes longer than data */
93   ASSERT (k->key_length - n_bytes >= 4);
94
95 #if defined(__GFNI__) && defined(__AVX512F__)
96   u8x64 d0;
97   u64x8 h0 = {};
98   u64x8u *m = (u64x8u *) ((u8 *) k + k->gfni_offset);
99
100   /* move data ptr backwards for 3 byte so mask load "prepends" three zeros */
101   data -= 3;
102   n_bytes += 3;
103
104   if (n_bytes < 64)
105     {
106       d0 = u8x64_mask_load_zero ((u8 *) data, pow2_mask (n_bytes - 3) << 3);
107       goto last8;
108     }
109
110   d0 = u8x64_mask_load_zero ((u8 *) data, -1ULL << 3);
111 next56:
112   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_one (d0, m[0], 0),
113                    clib_toeplitz_hash_gfni_one (d0, m[1], 1));
114   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_one (d0, m[2], 2),
115                    clib_toeplitz_hash_gfni_one (d0, m[3], 3));
116   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_one (d0, m[4], 4),
117                    clib_toeplitz_hash_gfni_one (d0, m[5], 5));
118   h0 ^= clib_toeplitz_hash_gfni_one (d0, m[6], 6);
119   n_bytes -= 56;
120   data += 56;
121   m += 7;
122
123   if (n_bytes >= 64)
124     {
125       d0 = *(u8x64u *) data;
126       goto next56;
127     }
128
129   if (n_bytes == 0)
130     goto done;
131
132   d0 = u8x64_mask_load_zero ((u8 *) data, pow2_mask (n_bytes));
133 last8:
134   h0 ^= clib_toeplitz_hash_gfni_one (d0, m[0], 0);
135   n_bytes -= 8;
136
137   if (n_bytes > 0)
138     {
139       m += 1;
140       d0 = (u8x64) u64x8_align_right (u64x8_zero (), d0, 1);
141       goto last8;
142     }
143
144 done:
145   return u64x8_hxor (h0);
146 #elif defined(CLIB_HAVE_VEC256)
147   u64x4 v4, shift = { 0, 1, 2, 3 };
148   u32x8 h0 = {};
149
150   while (n_bytes >= 4)
151     {
152       v4 = u64x4_splat (clib_net_to_host_u64 (*(u64u *) key)) << shift;
153
154       h0 = toeplitz_hash_one_x8 (h0, v4, data[0], 0);
155       h0 = toeplitz_hash_one_x8 (h0, v4, data[1], 1);
156       h0 = toeplitz_hash_one_x8 (h0, v4, data[2], 2);
157       h0 = toeplitz_hash_one_x8 (h0, v4, data[3], 3);
158
159       data += 4;
160       key += 4;
161       n_bytes -= 4;
162     }
163
164   if (n_bytes)
165     {
166       u64 v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
167       v |= (u64) key[4] << 24;
168
169       if (n_bytes == 3)
170         {
171           v |= (u64) key[5] << 16;
172           v |= (u64) key[6] << 8;
173           v4 = u64x4_splat (v) << shift;
174           h0 = toeplitz_hash_one_x8 (h0, v4, data[0], 0);
175           h0 = toeplitz_hash_one_x8 (h0, v4, data[1], 1);
176           h0 = toeplitz_hash_one_x8 (h0, v4, data[2], 2);
177         }
178       else if (n_bytes == 2)
179         {
180           v |= (u64) key[5] << 16;
181           v4 = u64x4_splat (v) << shift;
182           h0 = toeplitz_hash_one_x8 (h0, v4, data[0], 0);
183           h0 = toeplitz_hash_one_x8 (h0, v4, data[1], 1);
184         }
185       else
186         {
187           v4 = u64x4_splat (v) << shift;
188           h0 = toeplitz_hash_one_x8 (h0, v4, data[0], 0);
189         }
190     }
191
192   return u32x8_hxor (h0);
193 #endif
194   u64 v, hash = 0;
195
196   while (n_bytes >= 4)
197     {
198       v = clib_net_to_host_u64 (*(u64u *) key);
199
200       for (u8 bit = 1 << 7, byte = data[0]; bit; bit >>= 1, v <<= 1)
201         hash ^= byte & bit ? v : 0;
202       for (u8 bit = 1 << 7, byte = data[1]; bit; bit >>= 1, v <<= 1)
203         hash ^= byte & bit ? v : 0;
204       for (u8 bit = 1 << 7, byte = data[2]; bit; bit >>= 1, v <<= 1)
205         hash ^= byte & bit ? v : 0;
206       for (u8 bit = 1 << 7, byte = data[3]; bit; bit >>= 1, v <<= 1)
207         hash ^= byte & bit ? v : 0;
208
209       data += 4;
210       key += 4;
211       n_bytes -= 4;
212     }
213
214   if (n_bytes)
215     {
216       v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
217       v |= (u64) key[4] << 24;
218       for (u8 bit = 1 << 7, byte = data[0]; bit; bit >>= 1, v <<= 1)
219         hash ^= byte & bit ? v : 0;
220       if (n_bytes > 1)
221         {
222           v |= (u64) key[5] << 24;
223           for (u8 bit = 1 << 7, byte = data[1]; bit; bit >>= 1, v <<= 1)
224             hash ^= byte & bit ? v : 0;
225         }
226       if (n_bytes > 2)
227         {
228           v |= (u64) key[6] << 24;
229           for (u8 bit = 1 << 7, byte = data[2]; bit; bit >>= 1, v <<= 1)
230             hash ^= byte & bit ? v : 0;
231         }
232     }
233   return hash >> 32;
234 }
235
236 static_always_inline void
237 clib_toeplitz_hash_x4 (clib_toeplitz_hash_key_t *k, u8 *data0, u8 *data1,
238                        u8 *data2, u8 *data3, u32 *hash0, u32 *hash1,
239                        u32 *hash2, u32 *hash3, int n_bytes)
240 {
241   /* key must be 4 bytes longer than data */
242   ASSERT (k->key_length - n_bytes >= 4);
243 #if defined(__GFNI__) && defined(__AVX512F__)
244   u64x8u *m = (u64x8u *) ((u8 *) k + k->gfni_offset);
245   u8x64 d0, d1, d2, d3;
246   u64x8 h0 = {}, h2 = {};
247   u64 h, mask;
248
249   /* move data ptr backwards for 3 byte so mask load "prepends" three zeros */
250   data0 -= 3;
251   data1 -= 3;
252   data2 -= 3;
253   data3 -= 3;
254   n_bytes += 3;
255
256   if (n_bytes < 64)
257     {
258       mask = pow2_mask (n_bytes - 3) << 3;
259       d0 = u8x64_mask_load_zero ((u8 *) data0, mask);
260       d1 = u8x64_mask_load_zero ((u8 *) data1, mask);
261       d2 = u8x64_mask_load_zero ((u8 *) data2, mask);
262       d3 = u8x64_mask_load_zero ((u8 *) data3, mask);
263       goto last8;
264     }
265
266   mask = -1ULL << 3;
267   d0 = u8x64_mask_load_zero ((u8 *) data0, mask);
268   d1 = u8x64_mask_load_zero ((u8 *) data1, mask);
269   d2 = u8x64_mask_load_zero ((u8 *) data2, mask);
270   d3 = u8x64_mask_load_zero ((u8 *) data3, mask);
271 next56:
272   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_two (d0, d1, m[0], 0),
273                    clib_toeplitz_hash_gfni_two (d0, d1, m[1], 1));
274   h2 = u64x8_xor3 (h2, clib_toeplitz_hash_gfni_two (d2, d3, m[0], 0),
275                    clib_toeplitz_hash_gfni_two (d2, d3, m[1], 1));
276
277   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_two (d0, d1, m[2], 2),
278                    clib_toeplitz_hash_gfni_two (d0, d1, m[3], 3));
279   h2 = u64x8_xor3 (h2, clib_toeplitz_hash_gfni_two (d2, d3, m[2], 2),
280                    clib_toeplitz_hash_gfni_two (d2, d3, m[3], 3));
281
282   h0 = u64x8_xor3 (h0, clib_toeplitz_hash_gfni_two (d0, d1, m[4], 4),
283                    clib_toeplitz_hash_gfni_two (d0, d1, m[5], 5));
284   h2 = u64x8_xor3 (h2, clib_toeplitz_hash_gfni_two (d2, d3, m[4], 4),
285                    clib_toeplitz_hash_gfni_two (d2, d3, m[5], 5));
286
287   h0 ^= clib_toeplitz_hash_gfni_two (d0, d1, m[6], 6);
288   h2 ^= clib_toeplitz_hash_gfni_two (d2, d3, m[6], 6);
289
290   n_bytes -= 56;
291   data0 += 56;
292   data1 += 56;
293   data2 += 56;
294   data3 += 56;
295   m += 7;
296
297   if (n_bytes >= 64)
298     {
299       d0 = *(u8x64u *) data0;
300       d1 = *(u8x64u *) data1;
301       d2 = *(u8x64u *) data2;
302       d3 = *(u8x64u *) data3;
303       goto next56;
304     }
305
306   if (n_bytes == 0)
307     goto done;
308
309   mask = pow2_mask (n_bytes);
310   d0 = u8x64_mask_load_zero ((u8 *) data0, mask);
311   d1 = u8x64_mask_load_zero ((u8 *) data1, mask);
312   d2 = u8x64_mask_load_zero ((u8 *) data2, mask);
313   d3 = u8x64_mask_load_zero ((u8 *) data3, mask);
314 last8:
315   h0 ^= clib_toeplitz_hash_gfni_two (d0, d1, m[0], 0);
316   h2 ^= clib_toeplitz_hash_gfni_two (d2, d3, m[0], 0);
317   n_bytes -= 8;
318
319   if (n_bytes > 0)
320     {
321       u64x8 zero = {};
322       m += 1;
323       d0 = (u8x64) u64x8_align_right (zero, d0, 1);
324       d1 = (u8x64) u64x8_align_right (zero, d1, 1);
325       d2 = (u8x64) u64x8_align_right (zero, d2, 1);
326       d3 = (u8x64) u64x8_align_right (zero, d3, 1);
327       goto last8;
328     }
329
330 done:
331   h = u64x8_hxor (h0);
332   *hash0 = h;
333   *hash1 = h >> 32;
334   h = u64x8_hxor (h2);
335   *hash2 = h;
336   *hash3 = h >> 32;
337 #elif defined(CLIB_HAVE_VEC256)
338   u8 *key = k->data;
339   u64x4 v4, shift = { 0, 1, 2, 3 };
340   u32x8 h0 = {}, h1 = {}, h2 = {}, h3 = {};
341
342   while (n_bytes >= 4)
343     {
344       v4 = u64x4_splat (clib_net_to_host_u64 (*(u64u *) key)) << shift;
345
346       h0 = toeplitz_hash_one_x8 (h0, v4, data0[0], 0);
347       h1 = toeplitz_hash_one_x8 (h1, v4, data1[0], 0);
348       h2 = toeplitz_hash_one_x8 (h2, v4, data2[0], 0);
349       h3 = toeplitz_hash_one_x8 (h3, v4, data3[0], 0);
350
351       h0 = toeplitz_hash_one_x8 (h0, v4, data0[1], 1);
352       h1 = toeplitz_hash_one_x8 (h1, v4, data1[1], 1);
353       h2 = toeplitz_hash_one_x8 (h2, v4, data2[1], 1);
354       h3 = toeplitz_hash_one_x8 (h3, v4, data3[1], 1);
355
356       h0 = toeplitz_hash_one_x8 (h0, v4, data0[2], 2);
357       h1 = toeplitz_hash_one_x8 (h1, v4, data1[2], 2);
358       h2 = toeplitz_hash_one_x8 (h2, v4, data2[2], 2);
359       h3 = toeplitz_hash_one_x8 (h3, v4, data3[2], 2);
360
361       h0 = toeplitz_hash_one_x8 (h0, v4, data0[3], 3);
362       h1 = toeplitz_hash_one_x8 (h1, v4, data1[3], 3);
363       h2 = toeplitz_hash_one_x8 (h2, v4, data2[3], 3);
364       h3 = toeplitz_hash_one_x8 (h3, v4, data3[3], 3);
365
366       data0 += 4;
367       data1 += 4;
368       data2 += 4;
369       data3 += 4;
370       key += 4;
371       n_bytes -= 4;
372     }
373
374   if (n_bytes)
375     {
376       u64 v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
377       v |= (u64) key[4] << 24;
378
379       if (n_bytes == 3)
380         {
381           v |= (u64) key[5] << 16;
382           v |= (u64) key[6] << 8;
383           v4 = u64x4_splat (v) << shift;
384           h0 = toeplitz_hash_one_x8 (h0, v4, data0[0], 0);
385           h1 = toeplitz_hash_one_x8 (h1, v4, data1[0], 0);
386           h2 = toeplitz_hash_one_x8 (h2, v4, data2[0], 0);
387           h3 = toeplitz_hash_one_x8 (h3, v4, data3[0], 0);
388
389           h0 = toeplitz_hash_one_x8 (h0, v4, data0[1], 1);
390           h1 = toeplitz_hash_one_x8 (h1, v4, data1[1], 1);
391           h2 = toeplitz_hash_one_x8 (h2, v4, data2[1], 1);
392           h3 = toeplitz_hash_one_x8 (h3, v4, data3[1], 1);
393
394           h0 = toeplitz_hash_one_x8 (h0, v4, data0[2], 2);
395           h1 = toeplitz_hash_one_x8 (h1, v4, data1[2], 2);
396           h2 = toeplitz_hash_one_x8 (h2, v4, data2[2], 2);
397           h3 = toeplitz_hash_one_x8 (h3, v4, data3[2], 2);
398         }
399       else if (n_bytes == 2)
400         {
401           v |= (u64) key[5] << 16;
402           v4 = u64x4_splat (v) << shift;
403           h0 = toeplitz_hash_one_x8 (h0, v4, data0[0], 0);
404           h1 = toeplitz_hash_one_x8 (h1, v4, data1[0], 0);
405           h2 = toeplitz_hash_one_x8 (h2, v4, data2[0], 0);
406           h3 = toeplitz_hash_one_x8 (h3, v4, data3[0], 0);
407
408           h0 = toeplitz_hash_one_x8 (h0, v4, data0[1], 1);
409           h1 = toeplitz_hash_one_x8 (h1, v4, data1[1], 1);
410           h2 = toeplitz_hash_one_x8 (h2, v4, data2[1], 1);
411           h3 = toeplitz_hash_one_x8 (h3, v4, data3[1], 1);
412         }
413       else
414         {
415           v4 = u64x4_splat (v) << shift;
416           h0 = toeplitz_hash_one_x8 (h0, v4, data0[0], 0);
417           h1 = toeplitz_hash_one_x8 (h1, v4, data1[0], 0);
418           h2 = toeplitz_hash_one_x8 (h2, v4, data2[0], 0);
419           h3 = toeplitz_hash_one_x8 (h3, v4, data3[0], 0);
420         }
421     }
422
423   *hash0 = u32x8_hxor (h0);
424   *hash1 = u32x8_hxor (h1);
425   *hash2 = u32x8_hxor (h2);
426   *hash3 = u32x8_hxor (h3);
427 #else
428   u8 *key = k->data;
429   u64 v, h0 = 0, h1 = 0, h2 = 0, h3 = 0;
430
431   while (n_bytes >= 4)
432     {
433       v = clib_net_to_host_u64 (*(u64u *) key);
434
435       for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
436         {
437           h0 ^= data0[0] & bit ? v : 0;
438           h1 ^= data1[0] & bit ? v : 0;
439           h2 ^= data2[0] & bit ? v : 0;
440           h3 ^= data3[0] & bit ? v : 0;
441         }
442       for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
443         {
444           h0 ^= data0[1] & bit ? v : 0;
445           h1 ^= data1[1] & bit ? v : 0;
446           h2 ^= data2[1] & bit ? v : 0;
447           h3 ^= data3[1] & bit ? v : 0;
448         }
449       for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
450         {
451           h0 ^= data0[2] & bit ? v : 0;
452           h1 ^= data1[2] & bit ? v : 0;
453           h2 ^= data2[2] & bit ? v : 0;
454           h3 ^= data3[2] & bit ? v : 0;
455         }
456       for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
457         {
458           h0 ^= data0[3] & bit ? v : 0;
459           h1 ^= data1[3] & bit ? v : 0;
460           h2 ^= data2[3] & bit ? v : 0;
461           h3 ^= data3[3] & bit ? v : 0;
462         }
463
464       data0 += 4;
465       data1 += 4;
466       data2 += 4;
467       data3 += 4;
468       key += 4;
469       n_bytes -= 4;
470     }
471
472   if (n_bytes)
473     {
474       v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
475       v |= (u64) key[4] << 24;
476       for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
477         {
478           h0 ^= data0[0] & bit ? v : 0;
479           h1 ^= data1[0] & bit ? v : 0;
480           h2 ^= data2[0] & bit ? v : 0;
481           h3 ^= data3[0] & bit ? v : 0;
482         }
483       if (n_bytes > 1)
484         {
485           v |= (u64) key[5] << 24;
486           for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
487             {
488               h0 ^= data0[1] & bit ? v : 0;
489               h1 ^= data1[1] & bit ? v : 0;
490               h2 ^= data2[1] & bit ? v : 0;
491               h3 ^= data3[1] & bit ? v : 0;
492             }
493         }
494       if (n_bytes > 2)
495         {
496           v |= (u64) key[6] << 24;
497           for (u8 bit = 1 << 7; bit; bit >>= 1, v <<= 1)
498             {
499               h0 ^= data0[2] & bit ? v : 0;
500               h1 ^= data1[2] & bit ? v : 0;
501               h2 ^= data2[2] & bit ? v : 0;
502               h3 ^= data3[2] & bit ? v : 0;
503             }
504         }
505     }
506   *hash0 = h0 >> 32;
507   *hash1 = h1 >> 32;
508   *hash2 = h2 >> 32;
509   *hash3 = h3 >> 32;
510 #endif
511 }
512
513 #endif