1 /* SPDX-License-Identifier: Apache-2.0
2 * Copyright(c) 2023 Cisco Systems, Inc.
5 #ifndef __crypto_aes_cbc_h__
6 #define __crypto_aes_cbc_h__
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/crypto/aes.h>
14 const u8x16 encrypt_key[15];
15 const u8x16 decrypt_key[15];
18 static_always_inline void
19 clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20 const u8 *iv, aes_key_size_t ks, u8 *dst)
22 int rounds = AES_KEY_ROUNDS (ks);
23 u8x16 r, *k = (u8x16 *) kd->encrypt_key;
27 for (int i = 0; i < len; i += 16)
31 r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
32 for (j = 1; j < rounds; j++)
33 r = aes_enc_round (r, k[j]);
34 r = aes_enc_last_round (r, k[rounds]);
36 r ^= *(u8x16u *) (src + i);
37 for (j = 1; j < rounds - 1; j++)
38 r = vaesmcq_u8 (vaeseq_u8 (r, k[j]));
39 r = vaeseq_u8 (r, k[j]) ^ k[rounds];
41 *(u8x16u *) (dst + i) = r;
45 static_always_inline void
46 clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47 uword len, const u8 *iv, u8 *ciphertext)
49 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
52 static_always_inline void
53 clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54 uword len, const u8 *iv, u8 *ciphertext)
56 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
59 static_always_inline void
60 clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
61 uword len, const u8 *iv, u8 *ciphertext)
63 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
66 static_always_inline void __clib_unused
67 aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
86 for (int i = 1; i < rounds; i++)
88 r[0] = aes_dec_round (r[0], k[i]);
89 r[1] = aes_dec_round (r[1], k[i]);
90 r[2] = aes_dec_round (r[2], k[i]);
91 r[3] = aes_dec_round (r[3], k[i]);
94 r[0] = aes_dec_last_round (r[0], k[rounds]);
95 r[1] = aes_dec_last_round (r[1], k[rounds]);
96 r[2] = aes_dec_last_round (r[2], k[rounds]);
97 r[3] = aes_dec_last_round (r[3], k[rounds]);
99 for (int i = 0; i < rounds - 1; i++)
101 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
102 r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
103 r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
104 r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
106 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
107 r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
108 r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
109 r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
112 dst[1] = r[1] ^ c[0];
113 dst[2] = r[2] ^ c[1];
114 dst[3] = r[3] ^ c[2];
124 c[0] = r[0] = src[0];
127 for (int i = 1; i < rounds; i++)
128 r[0] = aes_dec_round (r[0], k[i]);
129 r[0] = aes_dec_last_round (r[0], k[rounds]);
131 c[0] = r[0] = src[0];
132 for (int i = 0; i < rounds - 1; i++)
133 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
134 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
146 #if defined(__VAES__) && defined(__AVX512F__)
148 static_always_inline u8x64
149 aes_block_load_x4 (u8 *src[], int i)
152 r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
153 r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
154 r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
155 r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
159 static_always_inline void
160 aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
162 aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
163 aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
164 aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
165 aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
168 static_always_inline u8x64
169 aes4_cbc_dec_permute (u8x64 a, u8x64 b)
171 return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
174 static_always_inline void
175 aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
176 aes_key_size_t rounds)
178 u8x64 f, k4, r[4], c[4] = {};
180 int i, n_blocks = count >> 4;
182 f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
184 while (n_blocks >= 16)
186 k4 = u8x64_splat_u8x16 (k[0]);
197 for (i = 1; i < rounds; i++)
199 k4 = u8x64_splat_u8x16 (k[i]);
200 r[0] = aes_dec_round_x4 (r[0], k4);
201 r[1] = aes_dec_round_x4 (r[1], k4);
202 r[2] = aes_dec_round_x4 (r[2], k4);
203 r[3] = aes_dec_round_x4 (r[3], k4);
206 k4 = u8x64_splat_u8x16 (k[i]);
207 r[0] = aes_dec_last_round_x4 (r[0], k4);
208 r[1] = aes_dec_last_round_x4 (r[1], k4);
209 r[2] = aes_dec_last_round_x4 (r[2], k4);
210 r[3] = aes_dec_last_round_x4 (r[3], k4);
212 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
213 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
214 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
215 dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
225 k4 = u8x64_splat_u8x16 (k[0]);
234 for (i = 1; i < rounds; i++)
236 k4 = u8x64_splat_u8x16 (k[i]);
237 r[0] = aes_dec_round_x4 (r[0], k4);
238 r[1] = aes_dec_round_x4 (r[1], k4);
239 r[2] = aes_dec_round_x4 (r[2], k4);
242 k4 = u8x64_splat_u8x16 (k[i]);
243 r[0] = aes_dec_last_round_x4 (r[0], k4);
244 r[1] = aes_dec_last_round_x4 (r[1], k4);
245 r[2] = aes_dec_last_round_x4 (r[2], k4);
247 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
248 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
249 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
256 else if (n_blocks >= 8)
258 k4 = u8x64_splat_u8x16 (k[0]);
265 for (i = 1; i < rounds; i++)
267 k4 = u8x64_splat_u8x16 (k[i]);
268 r[0] = aes_dec_round_x4 (r[0], k4);
269 r[1] = aes_dec_round_x4 (r[1], k4);
272 k4 = u8x64_splat_u8x16 (k[i]);
273 r[0] = aes_dec_last_round_x4 (r[0], k4);
274 r[1] = aes_dec_last_round_x4 (r[1], k4);
276 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
277 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
284 else if (n_blocks >= 4)
288 r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
290 for (i = 1; i < rounds; i++)
291 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
293 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
295 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
305 k4 = u8x64_splat_u8x16 (k[0]);
306 m = (1 << (n_blocks * 2)) - 1;
308 (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
309 f = aes4_cbc_dec_permute (f, c[0]);
311 for (i = 1; i < rounds; i++)
312 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
313 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
314 _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
317 #elif defined(__VAES__)
319 static_always_inline u8x32
320 aes_block_load_x2 (u8 *src[], int i)
323 r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
324 r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
328 static_always_inline void
329 aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
331 aes_block_store (dst[0] + i, u8x32_extract_lo (r));
332 aes_block_store (dst[1] + i, u8x32_extract_hi (r));
335 static_always_inline u8x32
336 aes2_cbc_dec_permute (u8x32 a, u8x32 b)
338 return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
341 static_always_inline void
342 aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
343 aes_key_size_t rounds)
345 u8x32 k2, f = {}, r[4], c[4] = {};
346 int i, n_blocks = count >> 4;
348 f = u8x32_insert_hi (f, *iv);
350 while (n_blocks >= 8)
352 k2 = u8x32_splat_u8x16 (k[0]);
363 for (i = 1; i < rounds; i++)
365 k2 = u8x32_splat_u8x16 (k[i]);
366 r[0] = aes_dec_round_x2 (r[0], k2);
367 r[1] = aes_dec_round_x2 (r[1], k2);
368 r[2] = aes_dec_round_x2 (r[2], k2);
369 r[3] = aes_dec_round_x2 (r[3], k2);
372 k2 = u8x32_splat_u8x16 (k[i]);
373 r[0] = aes_dec_last_round_x2 (r[0], k2);
374 r[1] = aes_dec_last_round_x2 (r[1], k2);
375 r[2] = aes_dec_last_round_x2 (r[2], k2);
376 r[3] = aes_dec_last_round_x2 (r[3], k2);
378 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
379 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
380 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
381 dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
391 k2 = u8x32_splat_u8x16 (k[0]);
400 for (i = 1; i < rounds; i++)
402 k2 = u8x32_splat_u8x16 (k[i]);
403 r[0] = aes_dec_round_x2 (r[0], k2);
404 r[1] = aes_dec_round_x2 (r[1], k2);
405 r[2] = aes_dec_round_x2 (r[2], k2);
408 k2 = u8x32_splat_u8x16 (k[i]);
409 r[0] = aes_dec_last_round_x2 (r[0], k2);
410 r[1] = aes_dec_last_round_x2 (r[1], k2);
411 r[2] = aes_dec_last_round_x2 (r[2], k2);
413 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
414 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
415 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
422 else if (n_blocks >= 4)
424 k2 = u8x32_splat_u8x16 (k[0]);
431 for (i = 1; i < rounds; i++)
433 k2 = u8x32_splat_u8x16 (k[i]);
434 r[0] = aes_dec_round_x2 (r[0], k2);
435 r[1] = aes_dec_round_x2 (r[1], k2);
438 k2 = u8x32_splat_u8x16 (k[i]);
439 r[0] = aes_dec_last_round_x2 (r[0], k2);
440 r[1] = aes_dec_last_round_x2 (r[1], k2);
442 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
443 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
450 else if (n_blocks >= 2)
452 k2 = u8x32_splat_u8x16 (k[0]);
456 for (i = 1; i < rounds; i++)
457 r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
459 r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
460 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
470 u8x16 rl = *(u8x16u *) src ^ k[0];
471 for (i = 1; i < rounds; i++)
472 rl = aes_dec_round (rl, k[i]);
473 rl = aes_dec_last_round (rl, k[i]);
474 *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
480 static_always_inline void
481 clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
485 aes_key_expand (e, key, ks);
486 aes_key_enc_to_dec (e, d, ks);
487 for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
489 ((u8x16 *) kd->decrypt_key)[i] = d[i];
490 ((u8x16 *) kd->encrypt_key)[i] = e[i];
494 static_always_inline void
495 clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
497 clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
499 static_always_inline void
500 clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
502 clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
504 static_always_inline void
505 clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
507 clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
510 static_always_inline void
511 clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
512 uword len, const u8 *iv, aes_key_size_t ks,
515 int rounds = AES_KEY_ROUNDS (ks);
516 #if defined(__VAES__) && defined(__AVX512F__)
517 aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
518 (u8x16u *) iv, (int) len, rounds);
519 #elif defined(__VAES__)
520 aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
521 (u8x16u *) iv, (int) len, rounds);
523 aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
524 (u8x16u *) iv, (int) len, rounds);
528 static_always_inline void
529 clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530 uword len, const u8 *iv, u8 *plaintext)
532 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
535 static_always_inline void
536 clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537 uword len, const u8 *iv, u8 *plaintext)
539 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
542 static_always_inline void
543 clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
544 uword len, const u8 *iv, u8 *plaintext)
546 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
549 #endif /* __crypto_aes_cbc_h__ */