vppinfra: native AES-CTR implementation
[vpp.git] / src / vppinfra / crypto / aes_cbc.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_cbc_h__
6 #define __crypto_aes_cbc_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/crypto/aes.h>
11
12 typedef struct
13 {
14   const u8x16 encrypt_key[15];
15   const u8x16 decrypt_key[15];
16 } aes_cbc_key_data_t;
17
18 static_always_inline void
19 clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20                       const u8 *iv, aes_key_size_t ks, u8 *dst)
21 {
22   int rounds = AES_KEY_ROUNDS (ks);
23   u8x16 r, *k = (u8x16 *) kd->encrypt_key;
24
25   r = *(u8x16u *) iv;
26
27   for (int i = 0; i < len; i += 16)
28     {
29       int j;
30 #if __x86_64__
31       r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
32       for (j = 1; j < rounds; j++)
33         r = aes_enc_round_x1 (r, k[j]);
34       r = aes_enc_last_round_x1 (r, k[rounds]);
35 #else
36       r ^= *(u8x16u *) (src + i);
37       for (j = 1; j < rounds - 1; j++)
38         r = vaesmcq_u8 (vaeseq_u8 (r, k[j]));
39       r = vaeseq_u8 (r, k[j]) ^ k[rounds];
40 #endif
41       *(u8x16u *) (dst + i) = r;
42     }
43 }
44
45 static_always_inline void
46 clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47                          uword len, const u8 *iv, u8 *ciphertext)
48 {
49   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
50 }
51
52 static_always_inline void
53 clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54                          uword len, const u8 *iv, u8 *ciphertext)
55 {
56   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
57 }
58
59 static_always_inline void
60 clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
61                          uword len, const u8 *iv, u8 *ciphertext)
62 {
63   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
64 }
65
66 static_always_inline void __clib_unused
67 aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
68              int rounds)
69 {
70   u8x16 r[4], c[4], f;
71
72   f = iv[0];
73   while (count >= 64)
74     {
75       c[0] = r[0] = src[0];
76       c[1] = r[1] = src[1];
77       c[2] = r[2] = src[2];
78       c[3] = r[3] = src[3];
79
80 #if __x86_64__
81       r[0] ^= k[0];
82       r[1] ^= k[0];
83       r[2] ^= k[0];
84       r[3] ^= k[0];
85
86       for (int i = 1; i < rounds; i++)
87         {
88           r[0] = aes_dec_round_x1 (r[0], k[i]);
89           r[1] = aes_dec_round_x1 (r[1], k[i]);
90           r[2] = aes_dec_round_x1 (r[2], k[i]);
91           r[3] = aes_dec_round_x1 (r[3], k[i]);
92         }
93
94       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
95       r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
96       r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
97       r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
98 #else
99       for (int i = 0; i < rounds - 1; i++)
100         {
101           r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
102           r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
103           r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
104           r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
105         }
106       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
107       r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
108       r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
109       r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
110 #endif
111       dst[0] = r[0] ^ f;
112       dst[1] = r[1] ^ c[0];
113       dst[2] = r[2] ^ c[1];
114       dst[3] = r[3] ^ c[2];
115       f = c[3];
116
117       count -= 64;
118       src += 4;
119       dst += 4;
120     }
121
122   while (count > 0)
123     {
124       c[0] = r[0] = src[0];
125 #if __x86_64__
126       r[0] ^= k[0];
127       for (int i = 1; i < rounds; i++)
128         r[0] = aes_dec_round_x1 (r[0], k[i]);
129       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
130 #else
131       c[0] = r[0] = src[0];
132       for (int i = 0; i < rounds - 1; i++)
133         r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
134       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
135 #endif
136       dst[0] = r[0] ^ f;
137       f = c[0];
138
139       count -= 16;
140       src += 1;
141       dst += 1;
142     }
143 }
144
145 #if __x86_64__
146 #if defined(__VAES__) && defined(__AVX512F__)
147
148 static_always_inline u8x64
149 aes_block_load_x4 (u8 *src[], int i)
150 {
151   u8x64 r = {};
152   r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
153   r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
154   r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
155   r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
156   return r;
157 }
158
159 static_always_inline void
160 aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
161 {
162   aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
163   aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
164   aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
165   aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
166 }
167
168 static_always_inline u8x64
169 aes4_cbc_dec_permute (u8x64 a, u8x64 b)
170 {
171   return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
172 }
173
174 static_always_inline void
175 aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
176               aes_key_size_t rounds)
177 {
178   u8x64 f, k4, r[4], c[4] = {};
179   __mmask8 m;
180   int i, n_blocks = count >> 4;
181
182   f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
183
184   while (n_blocks >= 16)
185     {
186       k4 = u8x64_splat_u8x16 (k[0]);
187       c[0] = src[0];
188       c[1] = src[1];
189       c[2] = src[2];
190       c[3] = src[3];
191
192       r[0] = c[0] ^ k4;
193       r[1] = c[1] ^ k4;
194       r[2] = c[2] ^ k4;
195       r[3] = c[3] ^ k4;
196
197       for (i = 1; i < rounds; i++)
198         {
199           k4 = u8x64_splat_u8x16 (k[i]);
200           r[0] = aes_dec_round_x4 (r[0], k4);
201           r[1] = aes_dec_round_x4 (r[1], k4);
202           r[2] = aes_dec_round_x4 (r[2], k4);
203           r[3] = aes_dec_round_x4 (r[3], k4);
204         }
205
206       k4 = u8x64_splat_u8x16 (k[i]);
207       r[0] = aes_dec_last_round_x4 (r[0], k4);
208       r[1] = aes_dec_last_round_x4 (r[1], k4);
209       r[2] = aes_dec_last_round_x4 (r[2], k4);
210       r[3] = aes_dec_last_round_x4 (r[3], k4);
211
212       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
213       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
214       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
215       dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
216       f = c[3];
217
218       n_blocks -= 16;
219       src += 4;
220       dst += 4;
221     }
222
223   if (n_blocks >= 12)
224     {
225       k4 = u8x64_splat_u8x16 (k[0]);
226       c[0] = src[0];
227       c[1] = src[1];
228       c[2] = src[2];
229
230       r[0] = c[0] ^ k4;
231       r[1] = c[1] ^ k4;
232       r[2] = c[2] ^ k4;
233
234       for (i = 1; i < rounds; i++)
235         {
236           k4 = u8x64_splat_u8x16 (k[i]);
237           r[0] = aes_dec_round_x4 (r[0], k4);
238           r[1] = aes_dec_round_x4 (r[1], k4);
239           r[2] = aes_dec_round_x4 (r[2], k4);
240         }
241
242       k4 = u8x64_splat_u8x16 (k[i]);
243       r[0] = aes_dec_last_round_x4 (r[0], k4);
244       r[1] = aes_dec_last_round_x4 (r[1], k4);
245       r[2] = aes_dec_last_round_x4 (r[2], k4);
246
247       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
248       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
249       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
250       f = c[2];
251
252       n_blocks -= 12;
253       src += 3;
254       dst += 3;
255     }
256   else if (n_blocks >= 8)
257     {
258       k4 = u8x64_splat_u8x16 (k[0]);
259       c[0] = src[0];
260       c[1] = src[1];
261
262       r[0] = c[0] ^ k4;
263       r[1] = c[1] ^ k4;
264
265       for (i = 1; i < rounds; i++)
266         {
267           k4 = u8x64_splat_u8x16 (k[i]);
268           r[0] = aes_dec_round_x4 (r[0], k4);
269           r[1] = aes_dec_round_x4 (r[1], k4);
270         }
271
272       k4 = u8x64_splat_u8x16 (k[i]);
273       r[0] = aes_dec_last_round_x4 (r[0], k4);
274       r[1] = aes_dec_last_round_x4 (r[1], k4);
275
276       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
277       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
278       f = c[1];
279
280       n_blocks -= 8;
281       src += 2;
282       dst += 2;
283     }
284   else if (n_blocks >= 4)
285     {
286       c[0] = src[0];
287
288       r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
289
290       for (i = 1; i < rounds; i++)
291         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
292
293       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
294
295       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
296       f = c[0];
297
298       n_blocks -= 4;
299       src += 1;
300       dst += 1;
301     }
302
303   if (n_blocks > 0)
304     {
305       k4 = u8x64_splat_u8x16 (k[0]);
306       m = (1 << (n_blocks * 2)) - 1;
307       c[0] =
308         (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
309       f = aes4_cbc_dec_permute (f, c[0]);
310       r[0] = c[0] ^ k4;
311       for (i = 1; i < rounds; i++)
312         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
313       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
314       _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
315     }
316 }
317 #elif defined(__VAES__)
318
319 static_always_inline u8x32
320 aes_block_load_x2 (u8 *src[], int i)
321 {
322   u8x32 r = {};
323   r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
324   r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
325   return r;
326 }
327
328 static_always_inline void
329 aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
330 {
331   aes_block_store (dst[0] + i, u8x32_extract_lo (r));
332   aes_block_store (dst[1] + i, u8x32_extract_hi (r));
333 }
334
335 static_always_inline u8x32
336 aes2_cbc_dec_permute (u8x32 a, u8x32 b)
337 {
338   return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
339 }
340
341 static_always_inline void
342 aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
343               aes_key_size_t rounds)
344 {
345   u8x32 k2, f = {}, r[4], c[4] = {};
346   int i, n_blocks = count >> 4;
347
348   f = u8x32_insert_hi (f, *iv);
349
350   while (n_blocks >= 8)
351     {
352       k2 = u8x32_splat_u8x16 (k[0]);
353       c[0] = src[0];
354       c[1] = src[1];
355       c[2] = src[2];
356       c[3] = src[3];
357
358       r[0] = c[0] ^ k2;
359       r[1] = c[1] ^ k2;
360       r[2] = c[2] ^ k2;
361       r[3] = c[3] ^ k2;
362
363       for (i = 1; i < rounds; i++)
364         {
365           k2 = u8x32_splat_u8x16 (k[i]);
366           r[0] = aes_dec_round_x2 (r[0], k2);
367           r[1] = aes_dec_round_x2 (r[1], k2);
368           r[2] = aes_dec_round_x2 (r[2], k2);
369           r[3] = aes_dec_round_x2 (r[3], k2);
370         }
371
372       k2 = u8x32_splat_u8x16 (k[i]);
373       r[0] = aes_dec_last_round_x2 (r[0], k2);
374       r[1] = aes_dec_last_round_x2 (r[1], k2);
375       r[2] = aes_dec_last_round_x2 (r[2], k2);
376       r[3] = aes_dec_last_round_x2 (r[3], k2);
377
378       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
379       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
380       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
381       dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
382       f = c[3];
383
384       n_blocks -= 8;
385       src += 4;
386       dst += 4;
387     }
388
389   if (n_blocks >= 6)
390     {
391       k2 = u8x32_splat_u8x16 (k[0]);
392       c[0] = src[0];
393       c[1] = src[1];
394       c[2] = src[2];
395
396       r[0] = c[0] ^ k2;
397       r[1] = c[1] ^ k2;
398       r[2] = c[2] ^ k2;
399
400       for (i = 1; i < rounds; i++)
401         {
402           k2 = u8x32_splat_u8x16 (k[i]);
403           r[0] = aes_dec_round_x2 (r[0], k2);
404           r[1] = aes_dec_round_x2 (r[1], k2);
405           r[2] = aes_dec_round_x2 (r[2], k2);
406         }
407
408       k2 = u8x32_splat_u8x16 (k[i]);
409       r[0] = aes_dec_last_round_x2 (r[0], k2);
410       r[1] = aes_dec_last_round_x2 (r[1], k2);
411       r[2] = aes_dec_last_round_x2 (r[2], k2);
412
413       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
414       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
415       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
416       f = c[2];
417
418       n_blocks -= 6;
419       src += 3;
420       dst += 3;
421     }
422   else if (n_blocks >= 4)
423     {
424       k2 = u8x32_splat_u8x16 (k[0]);
425       c[0] = src[0];
426       c[1] = src[1];
427
428       r[0] = c[0] ^ k2;
429       r[1] = c[1] ^ k2;
430
431       for (i = 1; i < rounds; i++)
432         {
433           k2 = u8x32_splat_u8x16 (k[i]);
434           r[0] = aes_dec_round_x2 (r[0], k2);
435           r[1] = aes_dec_round_x2 (r[1], k2);
436         }
437
438       k2 = u8x32_splat_u8x16 (k[i]);
439       r[0] = aes_dec_last_round_x2 (r[0], k2);
440       r[1] = aes_dec_last_round_x2 (r[1], k2);
441
442       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
443       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
444       f = c[1];
445
446       n_blocks -= 4;
447       src += 2;
448       dst += 2;
449     }
450   else if (n_blocks >= 2)
451     {
452       k2 = u8x32_splat_u8x16 (k[0]);
453       c[0] = src[0];
454       r[0] = c[0] ^ k2;
455
456       for (i = 1; i < rounds; i++)
457         r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
458
459       r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
460       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
461       f = c[0];
462
463       n_blocks -= 2;
464       src += 1;
465       dst += 1;
466     }
467
468   if (n_blocks > 0)
469     {
470       u8x16 rl = *(u8x16u *) src ^ k[0];
471       for (i = 1; i < rounds; i++)
472         rl = aes_dec_round_x1 (rl, k[i]);
473       rl = aes_dec_last_round_x1 (rl, k[i]);
474       *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
475     }
476 }
477 #endif
478 #endif
479
480 static_always_inline void
481 clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
482                          aes_key_size_t ks)
483 {
484   u8x16 e[15], d[15];
485   aes_key_expand (e, key, ks);
486   aes_key_enc_to_dec (e, d, ks);
487   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
488     {
489       ((u8x16 *) kd->decrypt_key)[i] = d[i];
490       ((u8x16 *) kd->encrypt_key)[i] = e[i];
491     }
492 }
493
494 static_always_inline void
495 clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
496 {
497   clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
498 }
499 static_always_inline void
500 clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
501 {
502   clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
503 }
504 static_always_inline void
505 clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
506 {
507   clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
508 }
509
510 static_always_inline void
511 clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
512                       uword len, const u8 *iv, aes_key_size_t ks,
513                       u8 *plaintext)
514 {
515   int rounds = AES_KEY_ROUNDS (ks);
516 #if defined(__VAES__) && defined(__AVX512F__)
517   aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
518                 (u8x16u *) iv, (int) len, rounds);
519 #elif defined(__VAES__)
520   aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
521                 (u8x16u *) iv, (int) len, rounds);
522 #else
523   aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
524                (u8x16u *) iv, (int) len, rounds);
525 #endif
526 }
527
528 static_always_inline void
529 clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530                          uword len, const u8 *iv, u8 *plaintext)
531 {
532   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
533 }
534
535 static_always_inline void
536 clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537                          uword len, const u8 *iv, u8 *plaintext)
538 {
539   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
540 }
541
542 static_always_inline void
543 clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
544                          uword len, const u8 *iv, u8 *plaintext)
545 {
546   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
547 }
548
549 #endif /* __crypto_aes_cbc_h__ */