1 /* SPDX-License-Identifier: Apache-2.0
2 * Copyright(c) 2023 Cisco Systems, Inc.
5 #ifndef __crypto_aes_cbc_h__
6 #define __crypto_aes_cbc_h__
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/crypto/aes.h>
14 const u8x16 encrypt_key[15];
15 const u8x16 decrypt_key[15];
18 static_always_inline void
19 clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20 const u8 *iv, aes_key_size_t ks, u8 *dst)
22 int rounds = AES_KEY_ROUNDS (ks);
23 u8x16 r, *k = (u8x16 *) kd->encrypt_key;
27 for (int i = 0; i < len; i += 16)
30 r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
31 for (j = 1; j < rounds; j++)
32 r = aes_enc_round_x1 (r, k[j]);
33 r = aes_enc_last_round_x1 (r, k[rounds]);
34 *(u8x16u *) (dst + i) = r;
38 static_always_inline void
39 clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
40 uword len, const u8 *iv, u8 *ciphertext)
42 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
45 static_always_inline void
46 clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47 uword len, const u8 *iv, u8 *ciphertext)
49 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
52 static_always_inline void
53 clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54 uword len, const u8 *iv, u8 *ciphertext)
56 clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
59 static_always_inline void __clib_unused
60 aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
79 for (int i = 1; i < rounds; i++)
81 r[0] = aes_dec_round_x1 (r[0], k[i]);
82 r[1] = aes_dec_round_x1 (r[1], k[i]);
83 r[2] = aes_dec_round_x1 (r[2], k[i]);
84 r[3] = aes_dec_round_x1 (r[3], k[i]);
87 r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
88 r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
89 r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
90 r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
92 for (int i = 0; i < rounds - 1; i++)
94 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
95 r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
96 r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
97 r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
99 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
100 r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
101 r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
102 r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
105 dst[1] = r[1] ^ c[0];
106 dst[2] = r[2] ^ c[1];
107 dst[3] = r[3] ^ c[2];
117 c[0] = r[0] = src[0];
120 for (int i = 1; i < rounds; i++)
121 r[0] = aes_dec_round_x1 (r[0], k[i]);
122 r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
124 c[0] = r[0] = src[0];
125 for (int i = 0; i < rounds - 1; i++)
126 r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
127 r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
139 #if defined(__VAES__) && defined(__AVX512F__)
141 static_always_inline u8x64
142 aes_block_load_x4 (u8 *src[], int i)
145 r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
146 r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
147 r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
148 r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
152 static_always_inline void
153 aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
155 aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
156 aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
157 aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
158 aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
161 static_always_inline u8x64
162 aes4_cbc_dec_permute (u8x64 a, u8x64 b)
164 return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
167 static_always_inline void
168 aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
169 aes_key_size_t rounds)
171 u8x64 f, k4, r[4], c[4] = {};
173 int i, n_blocks = count >> 4;
175 f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
177 while (n_blocks >= 16)
179 k4 = u8x64_splat_u8x16 (k[0]);
190 for (i = 1; i < rounds; i++)
192 k4 = u8x64_splat_u8x16 (k[i]);
193 r[0] = aes_dec_round_x4 (r[0], k4);
194 r[1] = aes_dec_round_x4 (r[1], k4);
195 r[2] = aes_dec_round_x4 (r[2], k4);
196 r[3] = aes_dec_round_x4 (r[3], k4);
199 k4 = u8x64_splat_u8x16 (k[i]);
200 r[0] = aes_dec_last_round_x4 (r[0], k4);
201 r[1] = aes_dec_last_round_x4 (r[1], k4);
202 r[2] = aes_dec_last_round_x4 (r[2], k4);
203 r[3] = aes_dec_last_round_x4 (r[3], k4);
205 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
206 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
207 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
208 dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
218 k4 = u8x64_splat_u8x16 (k[0]);
227 for (i = 1; i < rounds; i++)
229 k4 = u8x64_splat_u8x16 (k[i]);
230 r[0] = aes_dec_round_x4 (r[0], k4);
231 r[1] = aes_dec_round_x4 (r[1], k4);
232 r[2] = aes_dec_round_x4 (r[2], k4);
235 k4 = u8x64_splat_u8x16 (k[i]);
236 r[0] = aes_dec_last_round_x4 (r[0], k4);
237 r[1] = aes_dec_last_round_x4 (r[1], k4);
238 r[2] = aes_dec_last_round_x4 (r[2], k4);
240 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
241 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
242 dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
249 else if (n_blocks >= 8)
251 k4 = u8x64_splat_u8x16 (k[0]);
258 for (i = 1; i < rounds; i++)
260 k4 = u8x64_splat_u8x16 (k[i]);
261 r[0] = aes_dec_round_x4 (r[0], k4);
262 r[1] = aes_dec_round_x4 (r[1], k4);
265 k4 = u8x64_splat_u8x16 (k[i]);
266 r[0] = aes_dec_last_round_x4 (r[0], k4);
267 r[1] = aes_dec_last_round_x4 (r[1], k4);
269 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
270 dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
277 else if (n_blocks >= 4)
281 r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
283 for (i = 1; i < rounds; i++)
284 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
286 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
288 dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
298 k4 = u8x64_splat_u8x16 (k[0]);
299 m = (1 << (n_blocks * 2)) - 1;
301 (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
302 f = aes4_cbc_dec_permute (f, c[0]);
304 for (i = 1; i < rounds; i++)
305 r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
306 r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
307 _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
310 #elif defined(__VAES__)
312 static_always_inline u8x32
313 aes_block_load_x2 (u8 *src[], int i)
316 r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
317 r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
321 static_always_inline void
322 aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
324 aes_block_store (dst[0] + i, u8x32_extract_lo (r));
325 aes_block_store (dst[1] + i, u8x32_extract_hi (r));
328 static_always_inline u8x32
329 aes2_cbc_dec_permute (u8x32 a, u8x32 b)
331 return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
334 static_always_inline void
335 aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
336 aes_key_size_t rounds)
338 u8x32 k2, f = {}, r[4], c[4] = {};
339 int i, n_blocks = count >> 4;
341 f = u8x32_insert_hi (f, *iv);
343 while (n_blocks >= 8)
345 k2 = u8x32_splat_u8x16 (k[0]);
356 for (i = 1; i < rounds; i++)
358 k2 = u8x32_splat_u8x16 (k[i]);
359 r[0] = aes_dec_round_x2 (r[0], k2);
360 r[1] = aes_dec_round_x2 (r[1], k2);
361 r[2] = aes_dec_round_x2 (r[2], k2);
362 r[3] = aes_dec_round_x2 (r[3], k2);
365 k2 = u8x32_splat_u8x16 (k[i]);
366 r[0] = aes_dec_last_round_x2 (r[0], k2);
367 r[1] = aes_dec_last_round_x2 (r[1], k2);
368 r[2] = aes_dec_last_round_x2 (r[2], k2);
369 r[3] = aes_dec_last_round_x2 (r[3], k2);
371 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
372 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
373 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
374 dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
384 k2 = u8x32_splat_u8x16 (k[0]);
393 for (i = 1; i < rounds; i++)
395 k2 = u8x32_splat_u8x16 (k[i]);
396 r[0] = aes_dec_round_x2 (r[0], k2);
397 r[1] = aes_dec_round_x2 (r[1], k2);
398 r[2] = aes_dec_round_x2 (r[2], k2);
401 k2 = u8x32_splat_u8x16 (k[i]);
402 r[0] = aes_dec_last_round_x2 (r[0], k2);
403 r[1] = aes_dec_last_round_x2 (r[1], k2);
404 r[2] = aes_dec_last_round_x2 (r[2], k2);
406 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
407 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
408 dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
415 else if (n_blocks >= 4)
417 k2 = u8x32_splat_u8x16 (k[0]);
424 for (i = 1; i < rounds; i++)
426 k2 = u8x32_splat_u8x16 (k[i]);
427 r[0] = aes_dec_round_x2 (r[0], k2);
428 r[1] = aes_dec_round_x2 (r[1], k2);
431 k2 = u8x32_splat_u8x16 (k[i]);
432 r[0] = aes_dec_last_round_x2 (r[0], k2);
433 r[1] = aes_dec_last_round_x2 (r[1], k2);
435 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
436 dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
443 else if (n_blocks >= 2)
445 k2 = u8x32_splat_u8x16 (k[0]);
449 for (i = 1; i < rounds; i++)
450 r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
452 r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
453 dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
463 u8x16 rl = *(u8x16u *) src ^ k[0];
464 for (i = 1; i < rounds; i++)
465 rl = aes_dec_round_x1 (rl, k[i]);
466 rl = aes_dec_last_round_x1 (rl, k[i]);
467 *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
473 static_always_inline void
474 clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
478 aes_key_expand (e, key, ks);
479 aes_key_enc_to_dec (e, d, ks);
480 for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
482 ((u8x16 *) kd->decrypt_key)[i] = d[i];
483 ((u8x16 *) kd->encrypt_key)[i] = e[i];
487 static_always_inline void
488 clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
490 clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
492 static_always_inline void
493 clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
495 clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
497 static_always_inline void
498 clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
500 clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
503 static_always_inline void
504 clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
505 uword len, const u8 *iv, aes_key_size_t ks,
508 int rounds = AES_KEY_ROUNDS (ks);
509 #if defined(__VAES__) && defined(__AVX512F__)
510 aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
511 (u8x16u *) iv, (int) len, rounds);
512 #elif defined(__VAES__)
513 aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
514 (u8x16u *) iv, (int) len, rounds);
516 aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
517 (u8x16u *) iv, (int) len, rounds);
521 static_always_inline void
522 clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
523 uword len, const u8 *iv, u8 *plaintext)
525 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
528 static_always_inline void
529 clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530 uword len, const u8 *iv, u8 *plaintext)
532 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
535 static_always_inline void
536 clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537 uword len, const u8 *iv, u8 *plaintext)
539 clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
542 #endif /* __crypto_aes_cbc_h__ */