http_static: fix memory hss_session using after be freed
[vpp.git] / src / vppinfra / crypto / aes_cbc.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_cbc_h__
6 #define __crypto_aes_cbc_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/crypto/aes.h>
11
12 typedef struct
13 {
14   const u8x16 encrypt_key[15];
15   const u8x16 decrypt_key[15];
16 } aes_cbc_key_data_t;
17
18 static_always_inline void
19 clib_aes_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *src, uword len,
20                       const u8 *iv, aes_key_size_t ks, u8 *dst)
21 {
22   int rounds = AES_KEY_ROUNDS (ks);
23   u8x16 r, *k = (u8x16 *) kd->encrypt_key;
24
25   r = *(u8x16u *) iv;
26
27   for (int i = 0; i < len; i += 16)
28     {
29       int j;
30       r = u8x16_xor3 (r, *(u8x16u *) (src + i), k[0]);
31       for (j = 1; j < rounds; j++)
32         r = aes_enc_round_x1 (r, k[j]);
33       r = aes_enc_last_round_x1 (r, k[rounds]);
34       *(u8x16u *) (dst + i) = r;
35     }
36 }
37
38 static_always_inline void
39 clib_aes128_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
40                          uword len, const u8 *iv, u8 *ciphertext)
41 {
42   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_128, ciphertext);
43 }
44
45 static_always_inline void
46 clib_aes192_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
47                          uword len, const u8 *iv, u8 *ciphertext)
48 {
49   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_192, ciphertext);
50 }
51
52 static_always_inline void
53 clib_aes256_cbc_encrypt (const aes_cbc_key_data_t *kd, const u8 *plaintext,
54                          uword len, const u8 *iv, u8 *ciphertext)
55 {
56   clib_aes_cbc_encrypt (kd, plaintext, len, iv, AES_KEY_256, ciphertext);
57 }
58
59 static_always_inline void __clib_unused
60 aes_cbc_dec (const u8x16 *k, u8x16u *src, u8x16u *dst, u8x16u *iv, int count,
61              int rounds)
62 {
63   u8x16 r[4], c[4], f;
64
65   f = iv[0];
66   while (count >= 64)
67     {
68       c[0] = r[0] = src[0];
69       c[1] = r[1] = src[1];
70       c[2] = r[2] = src[2];
71       c[3] = r[3] = src[3];
72
73 #if __x86_64__
74       r[0] ^= k[0];
75       r[1] ^= k[0];
76       r[2] ^= k[0];
77       r[3] ^= k[0];
78
79       for (int i = 1; i < rounds; i++)
80         {
81           r[0] = aes_dec_round_x1 (r[0], k[i]);
82           r[1] = aes_dec_round_x1 (r[1], k[i]);
83           r[2] = aes_dec_round_x1 (r[2], k[i]);
84           r[3] = aes_dec_round_x1 (r[3], k[i]);
85         }
86
87       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
88       r[1] = aes_dec_last_round_x1 (r[1], k[rounds]);
89       r[2] = aes_dec_last_round_x1 (r[2], k[rounds]);
90       r[3] = aes_dec_last_round_x1 (r[3], k[rounds]);
91 #else
92       for (int i = 0; i < rounds - 1; i++)
93         {
94           r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
95           r[1] = vaesimcq_u8 (vaesdq_u8 (r[1], k[i]));
96           r[2] = vaesimcq_u8 (vaesdq_u8 (r[2], k[i]));
97           r[3] = vaesimcq_u8 (vaesdq_u8 (r[3], k[i]));
98         }
99       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
100       r[1] = vaesdq_u8 (r[1], k[rounds - 1]) ^ k[rounds];
101       r[2] = vaesdq_u8 (r[2], k[rounds - 1]) ^ k[rounds];
102       r[3] = vaesdq_u8 (r[3], k[rounds - 1]) ^ k[rounds];
103 #endif
104       dst[0] = r[0] ^ f;
105       dst[1] = r[1] ^ c[0];
106       dst[2] = r[2] ^ c[1];
107       dst[3] = r[3] ^ c[2];
108       f = c[3];
109
110       count -= 64;
111       src += 4;
112       dst += 4;
113     }
114
115   while (count > 0)
116     {
117       c[0] = r[0] = src[0];
118 #if __x86_64__
119       r[0] ^= k[0];
120       for (int i = 1; i < rounds; i++)
121         r[0] = aes_dec_round_x1 (r[0], k[i]);
122       r[0] = aes_dec_last_round_x1 (r[0], k[rounds]);
123 #else
124       c[0] = r[0] = src[0];
125       for (int i = 0; i < rounds - 1; i++)
126         r[0] = vaesimcq_u8 (vaesdq_u8 (r[0], k[i]));
127       r[0] = vaesdq_u8 (r[0], k[rounds - 1]) ^ k[rounds];
128 #endif
129       dst[0] = r[0] ^ f;
130       f = c[0];
131
132       count -= 16;
133       src += 1;
134       dst += 1;
135     }
136 }
137
138 #if __x86_64__
139 #if defined(__VAES__) && defined(__AVX512F__)
140
141 static_always_inline u8x64
142 aes_block_load_x4 (u8 *src[], int i)
143 {
144   u8x64 r = {};
145   r = u8x64_insert_u8x16 (r, aes_block_load (src[0] + i), 0);
146   r = u8x64_insert_u8x16 (r, aes_block_load (src[1] + i), 1);
147   r = u8x64_insert_u8x16 (r, aes_block_load (src[2] + i), 2);
148   r = u8x64_insert_u8x16 (r, aes_block_load (src[3] + i), 3);
149   return r;
150 }
151
152 static_always_inline void
153 aes_block_store_x4 (u8 *dst[], int i, u8x64 r)
154 {
155   aes_block_store (dst[0] + i, u8x64_extract_u8x16 (r, 0));
156   aes_block_store (dst[1] + i, u8x64_extract_u8x16 (r, 1));
157   aes_block_store (dst[2] + i, u8x64_extract_u8x16 (r, 2));
158   aes_block_store (dst[3] + i, u8x64_extract_u8x16 (r, 3));
159 }
160
161 static_always_inline u8x64
162 aes4_cbc_dec_permute (u8x64 a, u8x64 b)
163 {
164   return (u8x64) u64x8_shuffle2 (a, b, 6, 7, 8, 9, 10, 11, 12, 13);
165 }
166
167 static_always_inline void
168 aes4_cbc_dec (const u8x16 *k, u8x64u *src, u8x64u *dst, u8x16u *iv, int count,
169               aes_key_size_t rounds)
170 {
171   u8x64 f, k4, r[4], c[4] = {};
172   __mmask8 m;
173   int i, n_blocks = count >> 4;
174
175   f = u8x64_insert_u8x16 (u8x64_zero (), *iv, 3);
176
177   while (n_blocks >= 16)
178     {
179       k4 = u8x64_splat_u8x16 (k[0]);
180       c[0] = src[0];
181       c[1] = src[1];
182       c[2] = src[2];
183       c[3] = src[3];
184
185       r[0] = c[0] ^ k4;
186       r[1] = c[1] ^ k4;
187       r[2] = c[2] ^ k4;
188       r[3] = c[3] ^ k4;
189
190       for (i = 1; i < rounds; i++)
191         {
192           k4 = u8x64_splat_u8x16 (k[i]);
193           r[0] = aes_dec_round_x4 (r[0], k4);
194           r[1] = aes_dec_round_x4 (r[1], k4);
195           r[2] = aes_dec_round_x4 (r[2], k4);
196           r[3] = aes_dec_round_x4 (r[3], k4);
197         }
198
199       k4 = u8x64_splat_u8x16 (k[i]);
200       r[0] = aes_dec_last_round_x4 (r[0], k4);
201       r[1] = aes_dec_last_round_x4 (r[1], k4);
202       r[2] = aes_dec_last_round_x4 (r[2], k4);
203       r[3] = aes_dec_last_round_x4 (r[3], k4);
204
205       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
206       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
207       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
208       dst[3] = r[3] ^= aes4_cbc_dec_permute (c[2], c[3]);
209       f = c[3];
210
211       n_blocks -= 16;
212       src += 4;
213       dst += 4;
214     }
215
216   if (n_blocks >= 12)
217     {
218       k4 = u8x64_splat_u8x16 (k[0]);
219       c[0] = src[0];
220       c[1] = src[1];
221       c[2] = src[2];
222
223       r[0] = c[0] ^ k4;
224       r[1] = c[1] ^ k4;
225       r[2] = c[2] ^ k4;
226
227       for (i = 1; i < rounds; i++)
228         {
229           k4 = u8x64_splat_u8x16 (k[i]);
230           r[0] = aes_dec_round_x4 (r[0], k4);
231           r[1] = aes_dec_round_x4 (r[1], k4);
232           r[2] = aes_dec_round_x4 (r[2], k4);
233         }
234
235       k4 = u8x64_splat_u8x16 (k[i]);
236       r[0] = aes_dec_last_round_x4 (r[0], k4);
237       r[1] = aes_dec_last_round_x4 (r[1], k4);
238       r[2] = aes_dec_last_round_x4 (r[2], k4);
239
240       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
241       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
242       dst[2] = r[2] ^= aes4_cbc_dec_permute (c[1], c[2]);
243       f = c[2];
244
245       n_blocks -= 12;
246       src += 3;
247       dst += 3;
248     }
249   else if (n_blocks >= 8)
250     {
251       k4 = u8x64_splat_u8x16 (k[0]);
252       c[0] = src[0];
253       c[1] = src[1];
254
255       r[0] = c[0] ^ k4;
256       r[1] = c[1] ^ k4;
257
258       for (i = 1; i < rounds; i++)
259         {
260           k4 = u8x64_splat_u8x16 (k[i]);
261           r[0] = aes_dec_round_x4 (r[0], k4);
262           r[1] = aes_dec_round_x4 (r[1], k4);
263         }
264
265       k4 = u8x64_splat_u8x16 (k[i]);
266       r[0] = aes_dec_last_round_x4 (r[0], k4);
267       r[1] = aes_dec_last_round_x4 (r[1], k4);
268
269       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
270       dst[1] = r[1] ^= aes4_cbc_dec_permute (c[0], c[1]);
271       f = c[1];
272
273       n_blocks -= 8;
274       src += 2;
275       dst += 2;
276     }
277   else if (n_blocks >= 4)
278     {
279       c[0] = src[0];
280
281       r[0] = c[0] ^ u8x64_splat_u8x16 (k[0]);
282
283       for (i = 1; i < rounds; i++)
284         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
285
286       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
287
288       dst[0] = r[0] ^= aes4_cbc_dec_permute (f, c[0]);
289       f = c[0];
290
291       n_blocks -= 4;
292       src += 1;
293       dst += 1;
294     }
295
296   if (n_blocks > 0)
297     {
298       k4 = u8x64_splat_u8x16 (k[0]);
299       m = (1 << (n_blocks * 2)) - 1;
300       c[0] =
301         (u8x64) _mm512_mask_loadu_epi64 ((__m512i) c[0], m, (__m512i *) src);
302       f = aes4_cbc_dec_permute (f, c[0]);
303       r[0] = c[0] ^ k4;
304       for (i = 1; i < rounds; i++)
305         r[0] = aes_dec_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
306       r[0] = aes_dec_last_round_x4 (r[0], u8x64_splat_u8x16 (k[i]));
307       _mm512_mask_storeu_epi64 ((__m512i *) dst, m, (__m512i) (r[0] ^ f));
308     }
309 }
310 #elif defined(__VAES__)
311
312 static_always_inline u8x32
313 aes_block_load_x2 (u8 *src[], int i)
314 {
315   u8x32 r = {};
316   r = u8x32_insert_lo (r, aes_block_load (src[0] + i));
317   r = u8x32_insert_hi (r, aes_block_load (src[1] + i));
318   return r;
319 }
320
321 static_always_inline void
322 aes_block_store_x2 (u8 *dst[], int i, u8x32 r)
323 {
324   aes_block_store (dst[0] + i, u8x32_extract_lo (r));
325   aes_block_store (dst[1] + i, u8x32_extract_hi (r));
326 }
327
328 static_always_inline u8x32
329 aes2_cbc_dec_permute (u8x32 a, u8x32 b)
330 {
331   return (u8x32) u64x4_shuffle2 ((u64x4) a, (u64x4) b, 2, 3, 4, 5);
332 }
333
334 static_always_inline void
335 aes2_cbc_dec (const u8x16 *k, u8x32u *src, u8x32u *dst, u8x16u *iv, int count,
336               aes_key_size_t rounds)
337 {
338   u8x32 k2, f = {}, r[4], c[4] = {};
339   int i, n_blocks = count >> 4;
340
341   f = u8x32_insert_hi (f, *iv);
342
343   while (n_blocks >= 8)
344     {
345       k2 = u8x32_splat_u8x16 (k[0]);
346       c[0] = src[0];
347       c[1] = src[1];
348       c[2] = src[2];
349       c[3] = src[3];
350
351       r[0] = c[0] ^ k2;
352       r[1] = c[1] ^ k2;
353       r[2] = c[2] ^ k2;
354       r[3] = c[3] ^ k2;
355
356       for (i = 1; i < rounds; i++)
357         {
358           k2 = u8x32_splat_u8x16 (k[i]);
359           r[0] = aes_dec_round_x2 (r[0], k2);
360           r[1] = aes_dec_round_x2 (r[1], k2);
361           r[2] = aes_dec_round_x2 (r[2], k2);
362           r[3] = aes_dec_round_x2 (r[3], k2);
363         }
364
365       k2 = u8x32_splat_u8x16 (k[i]);
366       r[0] = aes_dec_last_round_x2 (r[0], k2);
367       r[1] = aes_dec_last_round_x2 (r[1], k2);
368       r[2] = aes_dec_last_round_x2 (r[2], k2);
369       r[3] = aes_dec_last_round_x2 (r[3], k2);
370
371       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
372       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
373       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
374       dst[3] = r[3] ^= aes2_cbc_dec_permute (c[2], c[3]);
375       f = c[3];
376
377       n_blocks -= 8;
378       src += 4;
379       dst += 4;
380     }
381
382   if (n_blocks >= 6)
383     {
384       k2 = u8x32_splat_u8x16 (k[0]);
385       c[0] = src[0];
386       c[1] = src[1];
387       c[2] = src[2];
388
389       r[0] = c[0] ^ k2;
390       r[1] = c[1] ^ k2;
391       r[2] = c[2] ^ k2;
392
393       for (i = 1; i < rounds; i++)
394         {
395           k2 = u8x32_splat_u8x16 (k[i]);
396           r[0] = aes_dec_round_x2 (r[0], k2);
397           r[1] = aes_dec_round_x2 (r[1], k2);
398           r[2] = aes_dec_round_x2 (r[2], k2);
399         }
400
401       k2 = u8x32_splat_u8x16 (k[i]);
402       r[0] = aes_dec_last_round_x2 (r[0], k2);
403       r[1] = aes_dec_last_round_x2 (r[1], k2);
404       r[2] = aes_dec_last_round_x2 (r[2], k2);
405
406       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
407       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
408       dst[2] = r[2] ^= aes2_cbc_dec_permute (c[1], c[2]);
409       f = c[2];
410
411       n_blocks -= 6;
412       src += 3;
413       dst += 3;
414     }
415   else if (n_blocks >= 4)
416     {
417       k2 = u8x32_splat_u8x16 (k[0]);
418       c[0] = src[0];
419       c[1] = src[1];
420
421       r[0] = c[0] ^ k2;
422       r[1] = c[1] ^ k2;
423
424       for (i = 1; i < rounds; i++)
425         {
426           k2 = u8x32_splat_u8x16 (k[i]);
427           r[0] = aes_dec_round_x2 (r[0], k2);
428           r[1] = aes_dec_round_x2 (r[1], k2);
429         }
430
431       k2 = u8x32_splat_u8x16 (k[i]);
432       r[0] = aes_dec_last_round_x2 (r[0], k2);
433       r[1] = aes_dec_last_round_x2 (r[1], k2);
434
435       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
436       dst[1] = r[1] ^= aes2_cbc_dec_permute (c[0], c[1]);
437       f = c[1];
438
439       n_blocks -= 4;
440       src += 2;
441       dst += 2;
442     }
443   else if (n_blocks >= 2)
444     {
445       k2 = u8x32_splat_u8x16 (k[0]);
446       c[0] = src[0];
447       r[0] = c[0] ^ k2;
448
449       for (i = 1; i < rounds; i++)
450         r[0] = aes_dec_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
451
452       r[0] = aes_dec_last_round_x2 (r[0], u8x32_splat_u8x16 (k[i]));
453       dst[0] = r[0] ^= aes2_cbc_dec_permute (f, c[0]);
454       f = c[0];
455
456       n_blocks -= 2;
457       src += 1;
458       dst += 1;
459     }
460
461   if (n_blocks > 0)
462     {
463       u8x16 rl = *(u8x16u *) src ^ k[0];
464       for (i = 1; i < rounds; i++)
465         rl = aes_dec_round_x1 (rl, k[i]);
466       rl = aes_dec_last_round_x1 (rl, k[i]);
467       *(u8x16u *) dst = rl ^ u8x32_extract_hi (f);
468     }
469 }
470 #endif
471 #endif
472
473 static_always_inline void
474 clib_aes_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key,
475                          aes_key_size_t ks)
476 {
477   u8x16 e[15], d[15];
478   aes_key_expand (e, key, ks);
479   aes_key_enc_to_dec (e, d, ks);
480   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
481     {
482       ((u8x16 *) kd->decrypt_key)[i] = d[i];
483       ((u8x16 *) kd->encrypt_key)[i] = e[i];
484     }
485 }
486
487 static_always_inline void
488 clib_aes128_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
489 {
490   clib_aes_cbc_key_expand (kd, key, AES_KEY_128);
491 }
492 static_always_inline void
493 clib_aes192_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
494 {
495   clib_aes_cbc_key_expand (kd, key, AES_KEY_192);
496 }
497 static_always_inline void
498 clib_aes256_cbc_key_expand (aes_cbc_key_data_t *kd, const u8 *key)
499 {
500   clib_aes_cbc_key_expand (kd, key, AES_KEY_256);
501 }
502
503 static_always_inline void
504 clib_aes_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
505                       uword len, const u8 *iv, aes_key_size_t ks,
506                       u8 *plaintext)
507 {
508   int rounds = AES_KEY_ROUNDS (ks);
509 #if defined(__VAES__) && defined(__AVX512F__)
510   aes4_cbc_dec (kd->decrypt_key, (u8x64u *) ciphertext, (u8x64u *) plaintext,
511                 (u8x16u *) iv, (int) len, rounds);
512 #elif defined(__VAES__)
513   aes2_cbc_dec (kd->decrypt_key, (u8x32u *) ciphertext, (u8x32u *) plaintext,
514                 (u8x16u *) iv, (int) len, rounds);
515 #else
516   aes_cbc_dec (kd->decrypt_key, (u8x16u *) ciphertext, (u8x16u *) plaintext,
517                (u8x16u *) iv, (int) len, rounds);
518 #endif
519 }
520
521 static_always_inline void
522 clib_aes128_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
523                          uword len, const u8 *iv, u8 *plaintext)
524 {
525   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_128, plaintext);
526 }
527
528 static_always_inline void
529 clib_aes192_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
530                          uword len, const u8 *iv, u8 *plaintext)
531 {
532   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_192, plaintext);
533 }
534
535 static_always_inline void
536 clib_aes256_cbc_decrypt (const aes_cbc_key_data_t *kd, const u8 *ciphertext,
537                          uword len, const u8 *iv, u8 *plaintext)
538 {
539   clib_aes_cbc_decrypt (kd, ciphertext, len, iv, AES_KEY_256, plaintext);
540 }
541
542 #if __GNUC__ > 4 && !__clang__ && CLIB_DEBUG == 0
543 #pragma GCC optimize("O3")
544 #endif
545
546 #if defined(__VAES__) && defined(__AVX512F__)
547 #define u8xN              u8x64
548 #define u32xN             u32x16
549 #define u32xN_min_scalar  u32x16_min_scalar
550 #define u32xN_is_all_zero u32x16_is_all_zero
551 #define u32xN_splat       u32x16_splat
552 #elif defined(__VAES__)
553 #define u8xN              u8x32
554 #define u32xN             u32x8
555 #define u32xN_min_scalar  u32x8_min_scalar
556 #define u32xN_is_all_zero u32x8_is_all_zero
557 #define u32xN_splat       u32x8_splat
558 #else
559 #define u8xN              u8x16
560 #define u32xN             u32x4
561 #define u32xN_min_scalar  u32x4_min_scalar
562 #define u32xN_is_all_zero u32x4_is_all_zero
563 #define u32xN_splat       u32x4_splat
564 #endif
565
566 static_always_inline u32
567 clib_aes_cbc_encrypt_multi (aes_cbc_key_data_t **key_data,
568                             const uword *key_indices, u8 **plaintext,
569                             const uword *oplen, u8 **iv, aes_key_size_t ks,
570                             u8 **ciphertext, uword n_ops)
571 {
572   int rounds = AES_KEY_ROUNDS (ks);
573   u8 placeholder[8192];
574   u32 i, j, count, n_left = n_ops;
575   u32xN placeholder_mask = {};
576   u32xN len = {};
577   u32 key_index[4 * N_AES_LANES];
578   u8 *src[4 * N_AES_LANES] = {};
579   u8 *dst[4 * N_AES_LANES] = {};
580   u8xN r[4] = {};
581   u8xN k[15][4] = {};
582
583   for (i = 0; i < 4 * N_AES_LANES; i++)
584     key_index[i] = ~0;
585
586 more:
587   for (i = 0; i < 4 * N_AES_LANES; i++)
588     if (len[i] == 0)
589       {
590         if (n_left == 0)
591           {
592             /* no more work to enqueue, so we are enqueueing placeholder buffer
593              */
594             src[i] = dst[i] = placeholder;
595             len[i] = sizeof (placeholder);
596             placeholder_mask[i] = 0;
597           }
598         else
599           {
600             u8x16 t = aes_block_load (iv[0]);
601             ((u8x16 *) r)[i] = t;
602
603             src[i] = plaintext[0];
604             dst[i] = ciphertext[0];
605             len[i] = oplen[0];
606             placeholder_mask[i] = ~0;
607             if (key_index[i] != key_indices[0])
608               {
609                 aes_cbc_key_data_t *kd;
610                 key_index[i] = key_indices[0];
611                 kd = key_data[key_index[i]];
612                 for (j = 0; j < rounds + 1; j++)
613                   ((u8x16 *) k[j])[i] = kd->encrypt_key[j];
614               }
615             n_left--;
616             iv++;
617             ciphertext++;
618             plaintext++;
619             key_indices++;
620             oplen++;
621           }
622       }
623
624   count = u32xN_min_scalar (len);
625
626   ASSERT (count % 16 == 0);
627
628   for (i = 0; i < count; i += 16)
629     {
630 #if defined(__VAES__) && defined(__AVX512F__)
631       r[0] = u8x64_xor3 (r[0], aes_block_load_x4 (src, i), k[0][0]);
632       r[1] = u8x64_xor3 (r[1], aes_block_load_x4 (src + 4, i), k[0][1]);
633       r[2] = u8x64_xor3 (r[2], aes_block_load_x4 (src + 8, i), k[0][2]);
634       r[3] = u8x64_xor3 (r[3], aes_block_load_x4 (src + 12, i), k[0][3]);
635
636       for (j = 1; j < rounds; j++)
637         {
638           r[0] = aes_enc_round_x4 (r[0], k[j][0]);
639           r[1] = aes_enc_round_x4 (r[1], k[j][1]);
640           r[2] = aes_enc_round_x4 (r[2], k[j][2]);
641           r[3] = aes_enc_round_x4 (r[3], k[j][3]);
642         }
643       r[0] = aes_enc_last_round_x4 (r[0], k[j][0]);
644       r[1] = aes_enc_last_round_x4 (r[1], k[j][1]);
645       r[2] = aes_enc_last_round_x4 (r[2], k[j][2]);
646       r[3] = aes_enc_last_round_x4 (r[3], k[j][3]);
647
648       aes_block_store_x4 (dst, i, r[0]);
649       aes_block_store_x4 (dst + 4, i, r[1]);
650       aes_block_store_x4 (dst + 8, i, r[2]);
651       aes_block_store_x4 (dst + 12, i, r[3]);
652 #elif defined(__VAES__)
653       r[0] = u8x32_xor3 (r[0], aes_block_load_x2 (src, i), k[0][0]);
654       r[1] = u8x32_xor3 (r[1], aes_block_load_x2 (src + 2, i), k[0][1]);
655       r[2] = u8x32_xor3 (r[2], aes_block_load_x2 (src + 4, i), k[0][2]);
656       r[3] = u8x32_xor3 (r[3], aes_block_load_x2 (src + 6, i), k[0][3]);
657
658       for (j = 1; j < rounds; j++)
659         {
660           r[0] = aes_enc_round_x2 (r[0], k[j][0]);
661           r[1] = aes_enc_round_x2 (r[1], k[j][1]);
662           r[2] = aes_enc_round_x2 (r[2], k[j][2]);
663           r[3] = aes_enc_round_x2 (r[3], k[j][3]);
664         }
665       r[0] = aes_enc_last_round_x2 (r[0], k[j][0]);
666       r[1] = aes_enc_last_round_x2 (r[1], k[j][1]);
667       r[2] = aes_enc_last_round_x2 (r[2], k[j][2]);
668       r[3] = aes_enc_last_round_x2 (r[3], k[j][3]);
669
670       aes_block_store_x2 (dst, i, r[0]);
671       aes_block_store_x2 (dst + 2, i, r[1]);
672       aes_block_store_x2 (dst + 4, i, r[2]);
673       aes_block_store_x2 (dst + 6, i, r[3]);
674 #else
675 #if __x86_64__
676       r[0] = u8x16_xor3 (r[0], aes_block_load (src[0] + i), k[0][0]);
677       r[1] = u8x16_xor3 (r[1], aes_block_load (src[1] + i), k[0][1]);
678       r[2] = u8x16_xor3 (r[2], aes_block_load (src[2] + i), k[0][2]);
679       r[3] = u8x16_xor3 (r[3], aes_block_load (src[3] + i), k[0][3]);
680
681       for (j = 1; j < rounds; j++)
682         {
683           r[0] = aes_enc_round_x1 (r[0], k[j][0]);
684           r[1] = aes_enc_round_x1 (r[1], k[j][1]);
685           r[2] = aes_enc_round_x1 (r[2], k[j][2]);
686           r[3] = aes_enc_round_x1 (r[3], k[j][3]);
687         }
688
689       r[0] = aes_enc_last_round_x1 (r[0], k[j][0]);
690       r[1] = aes_enc_last_round_x1 (r[1], k[j][1]);
691       r[2] = aes_enc_last_round_x1 (r[2], k[j][2]);
692       r[3] = aes_enc_last_round_x1 (r[3], k[j][3]);
693
694       aes_block_store (dst[0] + i, r[0]);
695       aes_block_store (dst[1] + i, r[1]);
696       aes_block_store (dst[2] + i, r[2]);
697       aes_block_store (dst[3] + i, r[3]);
698 #else
699       r[0] ^= aes_block_load (src[0] + i);
700       r[1] ^= aes_block_load (src[1] + i);
701       r[2] ^= aes_block_load (src[2] + i);
702       r[3] ^= aes_block_load (src[3] + i);
703       for (j = 0; j < rounds - 1; j++)
704         {
705           r[0] = vaesmcq_u8 (vaeseq_u8 (r[0], k[j][0]));
706           r[1] = vaesmcq_u8 (vaeseq_u8 (r[1], k[j][1]));
707           r[2] = vaesmcq_u8 (vaeseq_u8 (r[2], k[j][2]));
708           r[3] = vaesmcq_u8 (vaeseq_u8 (r[3], k[j][3]));
709         }
710       r[0] = vaeseq_u8 (r[0], k[j][0]) ^ k[rounds][0];
711       r[1] = vaeseq_u8 (r[1], k[j][1]) ^ k[rounds][1];
712       r[2] = vaeseq_u8 (r[2], k[j][2]) ^ k[rounds][2];
713       r[3] = vaeseq_u8 (r[3], k[j][3]) ^ k[rounds][3];
714       aes_block_store (dst[0] + i, r[0]);
715       aes_block_store (dst[1] + i, r[1]);
716       aes_block_store (dst[2] + i, r[2]);
717       aes_block_store (dst[3] + i, r[3]);
718 #endif
719 #endif
720     }
721
722   len -= u32xN_splat (count);
723
724   for (i = 0; i < 4 * N_AES_LANES; i++)
725     {
726       src[i] += count;
727       dst[i] += count;
728     }
729
730   if (n_left > 0)
731     goto more;
732
733   if (!u32xN_is_all_zero (len & placeholder_mask))
734     goto more;
735
736   return n_ops;
737 }
738
739 #undef u8xN
740 #undef u32xN
741 #undef u32xN_min_scalar
742 #undef u32xN_is_all_zero
743 #undef u32xN_splat
744
745 #endif /* __crypto_aes_cbc_h__ */