vppinfra: new vectorized ip checksum functions incl. csum_and_copy
[vpp.git] / src / vppinfra / vector / ip_csum.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2021 Cisco Systems, Inc.
3  */
4
5 #ifndef included_vector_ip_csum_h
6 #define included_vector_ip_csum_h
7 #include <vppinfra/clib.h>
8 typedef struct
9 {
10   u64 sum;
11   u8 odd;
12 } clib_ip_csum_t;
13
14 #if defined(CLIB_HAVE_VEC128)
15 static_always_inline u64x2
16 clib_ip_csum_cvt_and_add_4 (u32x4 v)
17 {
18   return ((u64x2) u32x4_interleave_lo ((u32x4) v, u32x4_zero ()) +
19           (u64x2) u32x4_interleave_hi ((u32x4) v, u32x4_zero ()));
20 }
21 static_always_inline u64
22 clib_ip_csum_hadd_2 (u64x2 v)
23 {
24   return v[0] + v[1];
25 }
26 #endif
27
28 #if defined(CLIB_HAVE_VEC256)
29 static_always_inline u64x4
30 clib_ip_csum_cvt_and_add_8 (u32x8 v)
31 {
32   return ((u64x4) u32x8_interleave_lo ((u32x8) v, u32x8_zero ()) +
33           (u64x4) u32x8_interleave_hi ((u32x8) v, u32x8_zero ()));
34 }
35 static_always_inline u64
36 clib_ip_csum_hadd_4 (u64x4 v)
37 {
38   return clib_ip_csum_hadd_2 (u64x4_extract_lo (v) + u64x4_extract_hi (v));
39 }
40 #endif
41
42 #if defined(CLIB_HAVE_VEC512)
43 static_always_inline u64x8
44 clib_ip_csum_cvt_and_add_16 (u32x16 v)
45 {
46   return ((u64x8) u32x16_interleave_lo ((u32x16) v, u32x16_zero ()) +
47           (u64x8) u32x16_interleave_hi ((u32x16) v, u32x16_zero ()));
48 }
49 static_always_inline u64
50 clib_ip_csum_hadd_8 (u64x8 v)
51 {
52   return clib_ip_csum_hadd_4 (u64x8_extract_lo (v) + u64x8_extract_hi (v));
53 }
54 #endif
55
56 static_always_inline void
57 clib_ip_csum_inline (clib_ip_csum_t *c, u8 *dst, u8 *src, u16 count,
58                      int is_copy)
59 {
60   if (c->odd)
61     {
62       c->odd = 0;
63       c->sum += (u16) src[0] << 8;
64       count--;
65       src++;
66       if (is_copy)
67         dst++[0] = src[0];
68     }
69
70 #if defined(CLIB_HAVE_VEC512)
71   u64x8 sum8 = {};
72
73   while (count >= 512)
74     {
75       u32x16u *s = (u32x16u *) src;
76       sum8 += clib_ip_csum_cvt_and_add_16 (s[0]);
77       sum8 += clib_ip_csum_cvt_and_add_16 (s[1]);
78       sum8 += clib_ip_csum_cvt_and_add_16 (s[2]);
79       sum8 += clib_ip_csum_cvt_and_add_16 (s[3]);
80       sum8 += clib_ip_csum_cvt_and_add_16 (s[8]);
81       sum8 += clib_ip_csum_cvt_and_add_16 (s[5]);
82       sum8 += clib_ip_csum_cvt_and_add_16 (s[6]);
83       sum8 += clib_ip_csum_cvt_and_add_16 (s[7]);
84       count -= 512;
85       src += 512;
86       if (is_copy)
87         {
88           u32x16u *d = (u32x16u *) dst;
89           d[0] = s[0];
90           d[1] = s[1];
91           d[2] = s[2];
92           d[3] = s[3];
93           d[4] = s[4];
94           d[5] = s[5];
95           d[6] = s[6];
96           d[7] = s[7];
97           dst += 512;
98         }
99     }
100
101   while (count >= 64)
102     {
103       u32x16u *s = (u32x16u *) src;
104       sum8 += clib_ip_csum_cvt_and_add_16 (s[0]);
105       count -= 64;
106       src += 64;
107       if (is_copy)
108         {
109           u32x16u *d = (u32x16u *) dst;
110           d[0] = s[0];
111           dst += 512;
112         }
113     }
114
115 #ifdef CLIB_HAVE_VEC256_MASK_LOAD_STORE
116   if (count)
117     {
118       u64 mask = pow2_mask (count);
119       u32x16 v = (u32x16) u8x64_mask_load_zero (src, mask);
120       sum8 += clib_ip_csum_cvt_and_add_16 (v);
121       c->odd = count & 1;
122       if (is_copy)
123         u32x16_mask_store (v, dst, mask);
124     }
125   c->sum += clib_ip_csum_hadd_8 (sum8);
126   return;
127 #endif
128
129   c->sum += clib_ip_csum_hadd_8 (sum8);
130 #elif defined(CLIB_HAVE_VEC256)
131   u64x4 sum4 = {};
132
133   while (count >= 256)
134     {
135       u32x8u *s = (u32x8u *) src;
136       sum4 += clib_ip_csum_cvt_and_add_8 (s[0]);
137       sum4 += clib_ip_csum_cvt_and_add_8 (s[1]);
138       sum4 += clib_ip_csum_cvt_and_add_8 (s[2]);
139       sum4 += clib_ip_csum_cvt_and_add_8 (s[3]);
140       sum4 += clib_ip_csum_cvt_and_add_8 (s[4]);
141       sum4 += clib_ip_csum_cvt_and_add_8 (s[5]);
142       sum4 += clib_ip_csum_cvt_and_add_8 (s[6]);
143       sum4 += clib_ip_csum_cvt_and_add_8 (s[7]);
144       count -= 256;
145       src += 256;
146       if (is_copy)
147         {
148           u32x8u *d = (u32x8u *) dst;
149           d[0] = s[0];
150           d[1] = s[1];
151           d[2] = s[2];
152           d[3] = s[3];
153           d[4] = s[4];
154           d[5] = s[5];
155           d[6] = s[6];
156           d[7] = s[7];
157           dst += 256;
158         }
159     }
160
161   while (count >= 32)
162     {
163       u32x8u *s = (u32x8u *) src;
164       sum4 += clib_ip_csum_cvt_and_add_8 (s[0]);
165       count -= 32;
166       src += 32;
167       if (is_copy)
168         {
169           u32x8u *d = (u32x8u *) dst;
170           d[0] = s[0];
171           dst += 32;
172         }
173     }
174
175 #ifdef CLIB_HAVE_VEC256_MASK_LOAD_STORE
176   if (count)
177     {
178       u32 mask = pow2_mask (count);
179       u32x8 v = (u32x8) u8x32_mask_load_zero (src, mask);
180       sum4 += clib_ip_csum_cvt_and_add_8 (v);
181       c->odd = count & 1;
182       if (is_copy)
183         u32x8_mask_store (v, dst, mask);
184     }
185   c->sum += clib_ip_csum_hadd_4 (sum4);
186   return;
187 #endif
188
189   c->sum += clib_ip_csum_hadd_4 (sum4);
190 #elif defined(CLIB_HAVE_VEC128)
191   u64x2 sum2 = {};
192
193   while (count >= 128)
194     {
195       u32x4u *s = (u32x4u *) src;
196       sum2 += clib_ip_csum_cvt_and_add_4 (s[0]);
197       sum2 += clib_ip_csum_cvt_and_add_4 (s[1]);
198       sum2 += clib_ip_csum_cvt_and_add_4 (s[2]);
199       sum2 += clib_ip_csum_cvt_and_add_4 (s[3]);
200       sum2 += clib_ip_csum_cvt_and_add_4 (s[4]);
201       sum2 += clib_ip_csum_cvt_and_add_4 (s[5]);
202       sum2 += clib_ip_csum_cvt_and_add_4 (s[6]);
203       sum2 += clib_ip_csum_cvt_and_add_4 (s[7]);
204       count -= 128;
205       src += 128;
206       if (is_copy)
207         {
208           u32x4u *d = (u32x4u *) dst;
209           d[0] = s[0];
210           d[1] = s[1];
211           d[2] = s[2];
212           d[3] = s[3];
213           d[4] = s[4];
214           d[5] = s[5];
215           d[6] = s[6];
216           d[7] = s[7];
217           dst += 128;
218         }
219     }
220
221   while (count >= 16)
222     {
223       u32x4u *s = (u32x4u *) src;
224       sum2 += clib_ip_csum_cvt_and_add_4 (s[0]);
225       count -= 16;
226       src += 16;
227       if (is_copy)
228         {
229           u32x4u *d = (u32x4u *) dst;
230           d[0] = s[0];
231           dst += 16;
232         }
233     }
234   c->sum += clib_ip_csum_hadd_2 (sum2);
235 #else
236   while (count >= 4)
237     {
238       u32 v = *((u32 *) src);
239       c->sum += v;
240       count -= 4;
241       src += 4;
242       if (is_copy)
243         {
244           *(u32 *) dst = v;
245           dst += 4;
246         }
247     }
248 #endif
249   while (count >= 2)
250     {
251       u16 v = *((u16 *) src);
252       c->sum += v;
253       count -= 2;
254       src += 2;
255       if (is_copy)
256         {
257           *(u16 *) dst = v;
258           dst += 2;
259         }
260     }
261
262   if (count)
263     {
264       c->odd = 1;
265       c->sum += (u16) src[0];
266       if (is_copy)
267         dst[0] = src[0];
268     }
269 }
270
271 static_always_inline u16
272 clib_ip_csum_fold (clib_ip_csum_t *c)
273 {
274   u64 sum = c->sum;
275 #if defined(__x86_64__) && defined(__BMI2__)
276   u64 tmp = sum;
277   asm volatile(
278     /* using ADC is much faster than mov, shift, add sequence
279      * compiler produces */
280     "shr        $32, %[sum]                     \n\t"
281     "add        %k[tmp], %k[sum]                \n\t"
282     "mov        $16, %k[tmp]                    \n\t"
283     "shrx       %k[tmp], %k[sum], %k[tmp]       \n\t"
284     "adc        %w[tmp], %w[sum]                \n\t"
285     "adc        $0, %w[sum]                     \n\t"
286     : [ sum ] "+&r"(sum), [ tmp ] "+&r"(tmp));
287 #else
288   sum = ((u32) sum) + (sum >> 32);
289   sum = ((u16) sum) + (sum >> 16);
290   sum = ((u16) sum) + (sum >> 16);
291 #endif
292   return (~((u16) sum));
293 }
294
295 static_always_inline void
296 clib_ip_csum_chunk (clib_ip_csum_t *c, u8 *src, u16 count)
297 {
298   return clib_ip_csum_inline (c, 0, src, count, 0);
299 }
300
301 static_always_inline void
302 clib_ip_csum_and_copy_chunk (clib_ip_csum_t *c, u8 *src, u8 *dst, u16 count)
303 {
304   return clib_ip_csum_inline (c, dst, src, count, 1);
305 }
306
307 static_always_inline u16
308 clib_ip_csum (u8 *src, u16 count)
309 {
310   clib_ip_csum_t c = {};
311   if (COMPILE_TIME_CONST (count) && count == 12)
312     {
313       for (int i = 0; i < 3; i++)
314         c.sum += ((u32 *) src)[i];
315     }
316   else if (COMPILE_TIME_CONST (count) && count == 20)
317     {
318       for (int i = 0; i < 5; i++)
319         c.sum += ((u32 *) src)[i];
320     }
321   else if (COMPILE_TIME_CONST (count) && count == 40)
322     {
323       for (int i = 0; i < 10; i++)
324         c.sum += ((u32 *) src)[i];
325     }
326   else
327     clib_ip_csum_inline (&c, 0, src, count, 0);
328   return clib_ip_csum_fold (&c);
329 }
330
331 static_always_inline u16
332 clib_ip_csum_and_copy (u8 *dst, u8 *src, u16 count)
333 {
334   clib_ip_csum_t c = {};
335   clib_ip_csum_inline (&c, dst, src, count, 1);
336   return clib_ip_csum_fold (&c);
337 }
338
339 #endif