vppinfra: enable AES tests on ARM
[vpp.git] / src / vppinfra / crypto / aes_cbc.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_cbc_h__
6 #define __crypto_aes_cbc_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/crypto/aes.h>
11
12 typedef struct
13 {
14   const u8x16 encrypt_key[15];
15   const u8x16 decrypt_key[15];
16 } aes_cbc_key_data_t;
17
18 static_always_inline void
19 clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20                       const u8 *iv, aes_key_size_t ks, u8 *dst)
21 {
22   int rounds = AES_KEY_ROUNDS (ks);
23   u8x16 r, *k = (u8x16 *) kd->encrypt_key;
24
25   r = *(u8x16u *) iv;
26
27   for (int i = 0; i < len; i += 16)
28     {
29       int j;
30       r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
31       for (j = 1; j < rounds; j++)
32         r = aes_enc_round_x1 (r, k[j]);
33       r = aes_enc_last_round_x1 (r, k[rounds]);
34       *(u8x16u *) (dst + i) = r;
35     }
36 }
37
38 static_always_inline void
39 clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
40                          uword len, const u8 *iv, u8 *ciphertext)
41 {
42   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
43 }
44
45 static_always_inline void
46 clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47                          uword len, const u8 *iv, u8 *ciphertext)
48 {
49   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
50 }
51
52 static_always_inline void
53 clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54                          uword len, const u8 *iv, u8 *ciphertext)
55 {
56   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
57 }
58
59 static_always_inline void __clib_unused
60 aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
61              int rounds)
62 {
63   u8x16 r[4], c[4], f;
64
65   f = iv[0];
66   while (count >= 64)
67     {
68       c[0] = r[0] = src[0];
69       c[1] = r[1] = src[1];
70       c[2] = r[2] = src[2];
71       c[3] = r[3] = src[3];
72
73 #if __x86_64__
74       r[0] ^= k[0];
75       r[1] ^= k[0];
76       r[2] ^= k[0];
77       r[3] ^= k[0];
78
79       for (int i = 1; i < rounds; i++)
80         {
81           r[0] = aes_dec_round_x1 (r[0], k[i]);
82           r[1] = aes_dec_round_x1 (r[1], k[i]);
83           r[2] = aes_dec_round_x1 (r[2], k[i]);
84           r[3] = aes_dec_round_x1 (r[3], k[i]);
85         }
86
87       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
88       r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
89       r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
90       r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
91 #else
92       for (int i = 0; i < rounds - 1; i++)
93         {
94           r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
95           r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
96           r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
97           r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
98         }
99       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
100       r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
101       r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
102       r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
103 #endif
104       dst[0] = r[0] ^ f;
105       dst[1] = r[1] ^ c[0];
106       dst[2] = r[2] ^ c[1];
107       dst[3] = r[3] ^ c[2];
108       f = c[3];
109
110       count -= 64;
111       src += 4;
112       dst += 4;
113     }
114
115   while (count > 0)
116     {
117       c[0] = r[0] = src[0];
118 #if __x86_64__
119       r[0] ^= k[0];
120       for (int i = 1; i < rounds; i++)
121         r[0] = aes_dec_round_x1 (r[0], k[i]);
122       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
123 #else
124       c[0] = r[0] = src[0];
125       for (int i = 0; i < rounds - 1; i++)
126         r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
127       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
128 #endif
129       dst[0] = r[0] ^ f;
130       f = c[0];
131
132       count -= 16;
133       src += 1;
134       dst += 1;
135     }
136 }
137
138 #if __x86_64__
139 #if defined(__VAES__) && defined(__AVX512F__)
140
141 static_always_inline u8x64
142 aes_block_load_x4 (u8 *src[], int i)
143 {
144   u8x64 r = {};
145   r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
146   r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
147   r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
148   r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
149   return r;
150 }
151
152 static_always_inline void
153 aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
154 {
155   aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
156   aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
157   aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
158   aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
159 }
160
161 static_always_inline u8x64
162 aes4_cbc_dec_permute (u8x64 a, u8x64 b)
163 {
164   return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
165 }
166
167 static_always_inline void
168 aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
169               aes_key_size_t rounds)
170 {
171   u8x64 f, k4, r[4], c[4] = {};
172   __mmask8 m;
173   int i, n_blocks = count >> 4;
174
175   f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
176
177   while (n_blocks >= 16)
178     {
179       k4 = u8x64_splat_u8x16 (k[0]);
180       c[0] = src[0];
181       c[1] = src[1];
182       c[2] = src[2];
183       c[3] = src[3];
184
185       r[0] = c[0] ^ k4;
186       r[1] = c[1] ^ k4;
187       r[2] = c[2] ^ k4;
188       r[3] = c[3] ^ k4;
189
190       for (i = 1; i < rounds; i++)
191         {
192           k4 = u8x64_splat_u8x16 (k[i]);
193           r[0] = aes_dec_round_x4 (r[0], k4);
194           r[1] = aes_dec_round_x4 (r[1], k4);
195           r[2] = aes_dec_round_x4 (r[2], k4);
196           r[3] = aes_dec_round_x4 (r[3], k4);
197         }
198
199       k4 = u8x64_splat_u8x16 (k[i]);
200       r[0] = aes_dec_last_round_x4 (r[0], k4);
201       r[1] = aes_dec_last_round_x4 (r[1], k4);
202       r[2] = aes_dec_last_round_x4 (r[2], k4);
203       r[3] = aes_dec_last_round_x4 (r[3], k4);
204
205       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
206       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
207       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
208       dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
209       f = c[3];
210
211       n_blocks -= 16;
212       src += 4;
213       dst += 4;
214     }
215
216   if (n_blocks >= 12)
217     {
218       k4 = u8x64_splat_u8x16 (k[0]);
219       c[0] = src[0];
220       c[1] = src[1];
221       c[2] = src[2];
222
223       r[0] = c[0] ^ k4;
224       r[1] = c[1] ^ k4;
225       r[2] = c[2] ^ k4;
226
227       for (i = 1; i < rounds; i++)
228         {
229           k4 = u8x64_splat_u8x16 (k[i]);
230           r[0] = aes_dec_round_x4 (r[0], k4);
231           r[1] = aes_dec_round_x4 (r[1], k4);
232           r[2] = aes_dec_round_x4 (r[2], k4);
233         }
234
235       k4 = u8x64_splat_u8x16 (k[i]);
236       r[0] = aes_dec_last_round_x4 (r[0], k4);
237       r[1] = aes_dec_last_round_x4 (r[1], k4);
238       r[2] = aes_dec_last_round_x4 (r[2], k4);
239
240       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
241       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
242       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
243       f = c[2];
244
245       n_blocks -= 12;
246       src += 3;
247       dst += 3;
248     }
249   else if (n_blocks >= 8)
250     {
251       k4 = u8x64_splat_u8x16 (k[0]);
252       c[0] = src[0];
253       c[1] = src[1];
254
255       r[0] = c[0] ^ k4;
256       r[1] = c[1] ^ k4;
257
258       for (i = 1; i < rounds; i++)
259         {
260           k4 = u8x64_splat_u8x16 (k[i]);
261           r[0] = aes_dec_round_x4 (r[0], k4);
262           r[1] = aes_dec_round_x4 (r[1], k4);
263         }
264
265       k4 = u8x64_splat_u8x16 (k[i]);
266       r[0] = aes_dec_last_round_x4 (r[0], k4);
267       r[1] = aes_dec_last_round_x4 (r[1], k4);
268
269       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
270       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
271       f = c[1];
272
273       n_blocks -= 8;
274       src += 2;
275       dst += 2;
276     }
277   else if (n_blocks >= 4)
278     {
279       c[0] = src[0];
280
281       r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
282
283       for (i = 1; i < rounds; i++)
284         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
285
286       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
287
288       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
289       f = c[0];
290
291       n_blocks -= 4;
292       src += 1;
293       dst += 1;
294     }
295
296   if (n_blocks > 0)
297     {
298       k4 = u8x64_splat_u8x16 (k[0]);
299       m = (1 << (n_blocks * 2)) - 1;
300       c[0] =
301         (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
302       f = aes4_cbc_dec_permute (f, c[0]);
303       r[0] = c[0] ^ k4;
304       for (i = 1; i < rounds; i++)
305         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
306       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
307       _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
308     }
309 }
310 #elif defined(__VAES__)
311
312 static_always_inline u8x32
313 aes_block_load_x2 (u8 *src[], int i)
314 {
315   u8x32 r = {};
316   r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
317   r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
318   return r;
319 }
320
321 static_always_inline void
322 aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
323 {
324   aes_block_store (dst[0] + i, u8x32_extract_lo (r));
325   aes_block_store (dst[1] + i, u8x32_extract_hi (r));
326 }
327
328 static_always_inline u8x32
329 aes2_cbc_dec_permute (u8x32 a, u8x32 b)
330 {
331   return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
332 }
333
334 static_always_inline void
335 aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
336               aes_key_size_t rounds)
337 {
338   u8x32 k2, f = {}, r[4], c[4] = {};
339   int i, n_blocks = count >> 4;
340
341   f = u8x32_insert_hi (f, *iv);
342
343   while (n_blocks >= 8)
344     {
345       k2 = u8x32_splat_u8x16 (k[0]);
346       c[0] = src[0];
347       c[1] = src[1];
348       c[2] = src[2];
349       c[3] = src[3];
350
351       r[0] = c[0] ^ k2;
352       r[1] = c[1] ^ k2;
353       r[2] = c[2] ^ k2;
354       r[3] = c[3] ^ k2;
355
356       for (i = 1; i < rounds; i++)
357         {
358           k2 = u8x32_splat_u8x16 (k[i]);
359           r[0] = aes_dec_round_x2 (r[0], k2);
360           r[1] = aes_dec_round_x2 (r[1], k2);
361           r[2] = aes_dec_round_x2 (r[2], k2);
362           r[3] = aes_dec_round_x2 (r[3], k2);
363         }
364
365       k2 = u8x32_splat_u8x16 (k[i]);
366       r[0] = aes_dec_last_round_x2 (r[0], k2);
367       r[1] = aes_dec_last_round_x2 (r[1], k2);
368       r[2] = aes_dec_last_round_x2 (r[2], k2);
369       r[3] = aes_dec_last_round_x2 (r[3], k2);
370
371       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
372       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
373       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
374       dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
375       f = c[3];
376
377       n_blocks -= 8;
378       src += 4;
379       dst += 4;
380     }
381
382   if (n_blocks >= 6)
383     {
384       k2 = u8x32_splat_u8x16 (k[0]);
385       c[0] = src[0];
386       c[1] = src[1];
387       c[2] = src[2];
388
389       r[0] = c[0] ^ k2;
390       r[1] = c[1] ^ k2;
391       r[2] = c[2] ^ k2;
392
393       for (i = 1; i < rounds; i++)
394         {
395           k2 = u8x32_splat_u8x16 (k[i]);
396           r[0] = aes_dec_round_x2 (r[0], k2);
397           r[1] = aes_dec_round_x2 (r[1], k2);
398           r[2] = aes_dec_round_x2 (r[2], k2);
399         }
400
401       k2 = u8x32_splat_u8x16 (k[i]);
402       r[0] = aes_dec_last_round_x2 (r[0], k2);
403       r[1] = aes_dec_last_round_x2 (r[1], k2);
404       r[2] = aes_dec_last_round_x2 (r[2], k2);
405
406       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
407       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
408       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
409       f = c[2];
410
411       n_blocks -= 6;
412       src += 3;
413       dst += 3;
414     }
415   else if (n_blocks >= 4)
416     {
417       k2 = u8x32_splat_u8x16 (k[0]);
418       c[0] = src[0];
419       c[1] = src[1];
420
421       r[0] = c[0] ^ k2;
422       r[1] = c[1] ^ k2;
423
424       for (i = 1; i < rounds; i++)
425         {
426           k2 = u8x32_splat_u8x16 (k[i]);
427           r[0] = aes_dec_round_x2 (r[0], k2);
428           r[1] = aes_dec_round_x2 (r[1], k2);
429         }
430
431       k2 = u8x32_splat_u8x16 (k[i]);
432       r[0] = aes_dec_last_round_x2 (r[0], k2);
433       r[1] = aes_dec_last_round_x2 (r[1], k2);
434
435       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
436       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
437       f = c[1];
438
439       n_blocks -= 4;
440       src += 2;
441       dst += 2;
442     }
443   else if (n_blocks >= 2)
444     {
445       k2 = u8x32_splat_u8x16 (k[0]);
446       c[0] = src[0];
447       r[0] = c[0] ^ k2;
448
449       for (i = 1; i < rounds; i++)
450         r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
451
452       r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
453       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
454       f = c[0];
455
456       n_blocks -= 2;
457       src += 1;
458       dst += 1;
459     }
460
461   if (n_blocks > 0)
462     {
463       u8x16 rl = *(u8x16u *) src ^ k[0];
464       for (i = 1; i < rounds; i++)
465         rl = aes_dec_round_x1 (rl, k[i]);
466       rl = aes_dec_last_round_x1 (rl, k[i]);
467       *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
468     }
469 }
470 #endif
471 #endif
472
473 static_always_inline void
474 clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
475                          aes_key_size_t ks)
476 {
477   u8x16 e[15], d[15];
478   aes_key_expand (e, key, ks);
479   aes_key_enc_to_dec (e, d, ks);
480   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
481     {
482       ((u8x16 *) kd->decrypt_key)[i] = d[i];
483       ((u8x16 *) kd->encrypt_key)[i] = e[i];
484     }
485 }
486
487 static_always_inline void
488 clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
489 {
490   clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
491 }
492 static_always_inline void
493 clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
494 {
495   clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
496 }
497 static_always_inline void
498 clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
499 {
500   clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
501 }
502
503 static_always_inline void
504 clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
505                       uword len, const u8 *iv, aes_key_size_t ks,
506                       u8 *plaintext)
507 {
508   int rounds = AES_KEY_ROUNDS (ks);
509 #if defined(__VAES__) && defined(__AVX512F__)
510   aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
511                 (u8x16u *) iv, (int) len, rounds);
512 #elif defined(__VAES__)
513   aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
514                 (u8x16u *) iv, (int) len, rounds);
515 #else
516   aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
517                (u8x16u *) iv, (int) len, rounds);
518 #endif
519 }
520
521 static_always_inline void
522 clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
523                          uword len, const u8 *iv, u8 *plaintext)
524 {
525   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
526 }
527
528 static_always_inline void
529 clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530                          uword len, const u8 *iv, u8 *plaintext)
531 {
532   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
533 }
534
535 static_always_inline void
536 clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537                          uword len, const u8 *iv, u8 *plaintext)
538 {
539   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
540 }
541
542 #endif /* __crypto_aes_cbc_h__ */