vppinfra: toeplitz hash
[vpp.git] / src / vppinfra / vector / toeplitz.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2021 Cisco Systems, Inc.
3  */
4
5 #ifndef included_vector_toeplitz_h
6 #define included_vector_toeplitz_h
7 #include <vppinfra/clib.h>
8
9 typedef struct
10 {
11   u16 key_length;
12   u16 gfni_offset;
13   u8 data[];
14 } clib_toeplitz_hash_key_t;
15
16 clib_toeplitz_hash_key_t *clib_toeplitz_hash_key_init (u8 *key, u32 keylen);
17 void clib_toeplitz_hash_key_free (clib_toeplitz_hash_key_t *k);
18
19 #if defined(__GFNI__) && defined(__AVX512F__)
20
21 #define u64x8_gf2p8_affine(d, m, imm)                                         \
22   (u64x8) _mm512_gf2p8affine_epi64_epi8 ((__m512i) (d), (__m512i) (m), imm)
23
24 #endif
25
26 #ifdef CLIB_HAVE_VEC256
27 static_always_inline u32x8
28 toeplitz_hash_one_x8 (u32x8 hash, u64x4 v4, u8 data, u8 off)
29 {
30   u32x8 v8 = u32x8_shuffle2 (v4 << (off * 8), v4 << (off * 8 + 4),
31                              /*uppper 32 bits of each u64 in reverse order */
32                              15, 13, 11, 9, 7, 5, 3, 1);
33
34 #ifdef CLIB_HAVE_VEC256_MASK_BITWISE_OPS
35   return u32x8_mask_xor (hash, v8, data);
36 #else
37   static const u32x8 bits = { 1, 2, 4, 8, 16, 32, 64, 128 };
38   return hash ^ (((u32x8_splat (data) & bits) != u32x8_zero ()) & v8);
39 #endif
40 }
41 #endif
42
43 static_always_inline u32
44 clib_toeplitz_hash (clib_toeplitz_hash_key_t *k, u8 *data, int n_bytes)
45 {
46   u8 *key = k->data;
47   /* key must be 4 bytes longer than data */
48   ASSERT (k->key_length - n_bytes >= 4);
49
50 #if defined(__GFNI__) && defined(__AVX512F__)
51   u8x64 a, b, dv;
52   u64x8 xor_sum_x8 = {};
53   u64x8u *m = (u64x8u *) ((u8 *) k + k->gfni_offset);
54
55   u8x64 idx = { 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x01, 0x02,
56                 0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x02, 0x03, 0x04, 0x05,
57                 0x02, 0x03, 0x04, 0x05, 0x03, 0x04, 0x05, 0x06, 0x03, 0x04,
58                 0x05, 0x06, 0x04, 0x05, 0x06, 0x07, 0x04, 0x05, 0x06, 0x07,
59                 0x05, 0x06, 0x07, 0x08, 0x05, 0x06, 0x07, 0x08, 0x06, 0x07,
60                 0x08, 0x09, 0x06, 0x07, 0x08, 0x09, 0x07, 0x08, 0x09, 0x0a,
61                 0x07, 0x08, 0x09, 0x0a };
62
63   /* move data ptr backwards for 3 byte so mask load "prepends" three zeros */
64   data -= 3;
65   n_bytes += 3;
66
67   if (n_bytes < 64)
68     {
69       dv = u8x64_mask_load_zero ((u8 *) data, pow2_mask (n_bytes - 3) << 3);
70       goto last8;
71     }
72
73   dv = u8x64_mask_load_zero ((u8 *) data, -1ULL << 3);
74 next56:
75   a = u8x64_permute (idx, dv);
76   b = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 1));
77   xor_sum_x8 = u64x8_xor3 (xor_sum_x8, u64x8_gf2p8_affine (a, m[0], 0),
78                            u64x8_gf2p8_affine (b, m[1], 0));
79
80   a = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 2));
81   b = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 3));
82   xor_sum_x8 = u64x8_xor3 (xor_sum_x8, u64x8_gf2p8_affine (a, m[2], 0),
83                            u64x8_gf2p8_affine (b, m[3], 0));
84
85   a = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 4));
86   b = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 5));
87   xor_sum_x8 = u64x8_xor3 (xor_sum_x8, u64x8_gf2p8_affine (a, m[4], 0),
88                            u64x8_gf2p8_affine (b, m[5], 0));
89
90   a = u8x64_permute (idx, (u8x64) u64x8_align_right (dv, dv, 6));
91   xor_sum_x8 ^= u64x8_gf2p8_affine (a, m[6], 0);
92   n_bytes -= 56;
93   data += 56;
94   m += 7;
95
96   if (n_bytes >= 64)
97     {
98       dv = *(u8x64u *) data;
99       goto next56;
100     }
101
102   if (n_bytes == 0)
103     goto done;
104
105   dv = u8x64_mask_load_zero ((u8 *) data, pow2_mask (n_bytes));
106 last8:
107   a = u8x64_permute (idx, dv);
108   xor_sum_x8 ^= u64x8_gf2p8_affine (a, m[0], 0);
109   n_bytes -= 8;
110
111   if (n_bytes > 0)
112     {
113       m += 1;
114       dv = (u8x64) u64x8_align_right (u64x8_zero (), dv, 1);
115       goto last8;
116     }
117
118 done:
119   /* horizontal xor */
120   xor_sum_x8 ^= u64x8_align_right (xor_sum_x8, xor_sum_x8, 4);
121   xor_sum_x8 ^= u64x8_align_right (xor_sum_x8, xor_sum_x8, 2);
122   return xor_sum_x8[0] ^ xor_sum_x8[1];
123 #elif defined(CLIB_HAVE_VEC256)
124   u64x4 v4, shift = { 0, 1, 2, 3 };
125   u32x8 hash8 = {};
126   u32x4 hash4;
127
128   while (n_bytes >= 4)
129     {
130       v4 = u64x4_splat (clib_net_to_host_u64 (*(u64u *) key)) << shift;
131
132       hash8 = toeplitz_hash_one_x8 (hash8, v4, data[0], 0);
133       hash8 = toeplitz_hash_one_x8 (hash8, v4, data[1], 1);
134       hash8 = toeplitz_hash_one_x8 (hash8, v4, data[2], 2);
135       hash8 = toeplitz_hash_one_x8 (hash8, v4, data[3], 3);
136
137       data += 4;
138       key += 4;
139       n_bytes -= 4;
140     }
141
142   if (n_bytes)
143     {
144       u64 v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
145       v |= (u64) key[4] << 24;
146
147       if (n_bytes == 3)
148         {
149           v |= (u64) key[5] << 16;
150           v |= (u64) key[6] << 8;
151           v4 = u64x4_splat (v) << shift;
152           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[0], 0);
153           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[1], 1);
154           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[2], 2);
155         }
156       else if (n_bytes == 2)
157         {
158           v |= (u64) key[5] << 16;
159           v4 = u64x4_splat (v) << shift;
160           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[0], 0);
161           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[1], 1);
162         }
163       else
164         {
165           v4 = u64x4_splat (v) << shift;
166           hash8 = toeplitz_hash_one_x8 (hash8, v4, data[0], 0);
167         }
168     }
169
170   hash4 = u32x8_extract_lo (hash8) ^ u32x8_extract_hi (hash8);
171   hash4 ^= (u32x4) u8x16_align_right (hash4, hash4, 8);
172   hash4 ^= (u32x4) u8x16_align_right (hash4, hash4, 4);
173   return hash4[0];
174
175 #endif
176   u64 v, hash = 0;
177
178   while (n_bytes >= 4)
179     {
180       v = clib_net_to_host_u64 (*(u64u *) key);
181
182       for (u8 bit = 1 << 7, byte = data[0]; bit; bit >>= 1, v <<= 1)
183         hash ^= byte & bit ? v : 0;
184       for (u8 bit = 1 << 7, byte = data[1]; bit; bit >>= 1, v <<= 1)
185         hash ^= byte & bit ? v : 0;
186       for (u8 bit = 1 << 7, byte = data[2]; bit; bit >>= 1, v <<= 1)
187         hash ^= byte & bit ? v : 0;
188       for (u8 bit = 1 << 7, byte = data[3]; bit; bit >>= 1, v <<= 1)
189         hash ^= byte & bit ? v : 0;
190
191       data += 4;
192       key += 4;
193       n_bytes -= 4;
194     }
195
196   if (n_bytes)
197     {
198       v = (u64) clib_net_to_host_u32 ((u64) (*(u32u *) key)) << 32;
199       v |= (u64) key[4] << 24;
200       for (u8 bit = 1 << 7, byte = data[0]; bit; bit >>= 1, v <<= 1)
201         hash ^= byte & bit ? v : 0;
202       if (n_bytes > 1)
203         {
204           v |= (u64) key[5] << 24;
205           for (u8 bit = 1 << 7, byte = data[1]; bit; bit >>= 1, v <<= 1)
206             hash ^= byte & bit ? v : 0;
207         }
208       if (n_bytes > 2)
209         {
210           v |= (u64) key[6] << 24;
211           for (u8 bit = 1 << 7, byte = data[2]; bit; bit >>= 1, v <<= 1)
212             hash ^= byte & bit ? v : 0;
213         }
214     }
215   return hash >> 32;
216 }
217
218 #endif