1 /* SPDX-License-Identifier: Apache-2.0
2 * Copyright(c) 2023 Cisco Systems, Inc.
5 #ifndef __crypto_aes_gcm_h__
6 #define __crypto_aes_gcm_h__
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/cache.h>
11 #include <vppinfra/string.h>
12 #include <vppinfra/crypto/aes.h>
13 #include <vppinfra/crypto/ghash.h>
17 typedef u8x64u aes_ghash_t;
18 #define aes_gcm_splat(v) u8x64_splat (v)
19 #define aes_gcm_ghash_reduce(c) ghash4_reduce (&(c)->gd)
20 #define aes_gcm_ghash_reduce2(c) ghash4_reduce2 (&(c)->gd)
21 #define aes_gcm_ghash_final(c) (c)->T = ghash4_final (&(c)->gd)
22 #elif N_AES_LANES == 2
23 typedef u8x32u aes_ghash_t;
24 #define aes_gcm_splat(v) u8x32_splat (v)
25 #define aes_gcm_ghash_reduce(c) ghash2_reduce (&(c)->gd)
26 #define aes_gcm_ghash_reduce2(c) ghash2_reduce2 (&(c)->gd)
27 #define aes_gcm_ghash_final(c) (c)->T = ghash2_final (&(c)->gd)
29 typedef u8x16 aes_ghash_t;
30 #define aes_gcm_splat(v) u8x16_splat (v)
31 #define aes_gcm_ghash_reduce(c) ghash_reduce (&(c)->gd)
32 #define aes_gcm_ghash_reduce2(c) ghash_reduce2 (&(c)->gd)
33 #define aes_gcm_ghash_final(c) (c)->T = ghash_final (&(c)->gd)
38 AES_GCM_OP_UNKNONW = 0,
46 /* pre-calculated hash key values */
47 const u8x16 Hi[NUM_HI];
48 /* extracted AES key */
49 const aes_expaned_key_t Ke[AES_KEY_ROUNDS (AES_KEY_256) + 1];
54 aes_gcm_op_t operation;
64 const aes_ghash_t *next_Hi;
67 const aes_expaned_key_t *Ke;
78 static_always_inline u8x16
79 aes_gcm_final_block (aes_gcm_ctx_t *ctx)
81 return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
84 static_always_inline void
85 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
87 uword hash_offset = NUM_HI - n_lanes;
88 ctx->next_Hi = (aes_ghash_t *) (ctx->Hi + hash_offset);
91 tag4 = u8x64_insert_u8x16 (tag4, ctx->T, 0);
92 ghash4_mul_first (&ctx->gd, aes_reflect (data) ^ tag4, *ctx->next_Hi++);
93 #elif N_AES_LANES == 2
95 tag2 = u8x32_insert_lo (tag2, ctx->T);
96 ghash2_mul_first (&ctx->gd, aes_reflect (data) ^ tag2, *ctx->next_Hi++);
98 ghash_mul_first (&ctx->gd, aes_reflect (data) ^ ctx->T, *ctx->next_Hi++);
102 static_always_inline void
103 aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
106 ghash4_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
107 #elif N_AES_LANES == 2
108 ghash2_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
110 ghash_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
114 static_always_inline void
115 aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
118 u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
119 u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
120 ghash4_mul_next (&ctx->gd, r4, h);
121 #elif N_AES_LANES == 2
122 u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
123 u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
124 ghash2_mul_next (&ctx->gd, r2, h);
126 ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
130 static_always_inline void
131 aes_gcm_enc_ctr0_round (aes_gcm_ctx_t *ctx, int aes_round)
134 ctx->EY0 ^= ctx->Ke[0].x1;
135 else if (aes_round == ctx->rounds)
136 ctx->EY0 = aes_enc_last_round_x1 (ctx->EY0, ctx->Ke[aes_round].x1);
138 ctx->EY0 = aes_enc_round_x1 (ctx->EY0, ctx->Ke[aes_round].x1);
141 static_always_inline void
142 aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
146 const aes_mem_t *d = (aes_mem_t *) data;
148 for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, d += 8)
150 if (ctx->operation == AES_GCM_OP_GMAC && n_left == n)
152 aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_AES_LANES + 1);
153 for (i = 1; i < 8; i++)
154 aes_gcm_ghash_mul_next (ctx, d[i]);
155 aes_gcm_ghash_mul_final_block (ctx);
156 aes_gcm_ghash_reduce (ctx);
157 aes_gcm_ghash_reduce2 (ctx);
158 aes_gcm_ghash_final (ctx);
162 aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_AES_LANES);
163 for (i = 1; i < 8; i++)
164 aes_gcm_ghash_mul_next (ctx, d[i]);
165 aes_gcm_ghash_reduce (ctx);
166 aes_gcm_ghash_reduce2 (ctx);
167 aes_gcm_ghash_final (ctx);
172 int n_lanes = (n_left + 15) / 16;
174 if (ctx->operation == AES_GCM_OP_GMAC)
177 if (n_left < N_AES_BYTES)
179 clib_memcpy_fast (&r, d, n_left);
180 aes_gcm_ghash_mul_first (ctx, r, n_lanes);
184 aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
185 n_left -= N_AES_BYTES;
188 if (n_left >= 4 * N_AES_BYTES)
190 aes_gcm_ghash_mul_next (ctx, d[i]);
191 aes_gcm_ghash_mul_next (ctx, d[i + 1]);
192 aes_gcm_ghash_mul_next (ctx, d[i + 2]);
193 aes_gcm_ghash_mul_next (ctx, d[i + 3]);
194 n_left -= 4 * N_AES_BYTES;
197 if (n_left >= 2 * N_AES_BYTES)
199 aes_gcm_ghash_mul_next (ctx, d[i]);
200 aes_gcm_ghash_mul_next (ctx, d[i + 1]);
201 n_left -= 2 * N_AES_BYTES;
205 if (n_left >= N_AES_BYTES)
207 aes_gcm_ghash_mul_next (ctx, d[i]);
208 n_left -= N_AES_BYTES;
214 clib_memcpy_fast (&r, d + i, n_left);
215 aes_gcm_ghash_mul_next (ctx, r);
219 if (ctx->operation == AES_GCM_OP_GMAC)
220 aes_gcm_ghash_mul_final_block (ctx);
221 aes_gcm_ghash_reduce (ctx);
222 aes_gcm_ghash_reduce2 (ctx);
223 aes_gcm_ghash_final (ctx);
225 else if (ctx->operation == AES_GCM_OP_GMAC)
227 ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
230 /* encrypt counter 0 E(Y0, k) */
231 if (ctx->operation == AES_GCM_OP_GMAC)
232 for (int i = 0; i < ctx->rounds + 1; i += 1)
233 aes_gcm_enc_ctr0_round (ctx, i);
236 static_always_inline void
237 aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
239 const aes_expaned_key_t Ke0 = ctx->Ke[0];
242 /* As counter is stored in network byte order for performance reasons we
243 are incrementing least significant byte only except in case where we
244 overlow. As we are processing four 128, 256 or 512-blocks in parallel
245 except the last round, overflow can happen only when n_blocks == 4 */
248 const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
249 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
251 const u32x16 ctr_4444 = {
252 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
258 r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
259 ctx->Y += ctr_inv_4444;
262 if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 242))
264 u32x16 Yr = (u32x16) aes_reflect ((u8x64) ctx->Y);
266 for (; i < n_blocks; i++)
268 r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
270 ctx->Y = (u32x16) aes_reflect ((u8x64) Yr);
275 for (; i < n_blocks; i++)
277 r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
278 ctx->Y += ctr_inv_4444;
281 ctx->counter += n_blocks * 4;
282 #elif N_AES_LANES == 2
283 const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
284 const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
289 r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
290 ctx->Y += ctr_inv_22;
293 if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 250))
295 u32x8 Yr = (u32x8) aes_reflect ((u8x32) ctx->Y);
297 for (; i < n_blocks; i++)
299 r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
301 ctx->Y = (u32x8) aes_reflect ((u8x32) Yr);
306 for (; i < n_blocks; i++)
308 r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
309 ctx->Y += ctr_inv_22;
312 ctx->counter += n_blocks * 2;
314 const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 };
316 if (PREDICT_TRUE ((u8) ctx->counter < 0xfe) || n_blocks < 3)
318 for (; i < n_blocks; i++)
320 r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
323 ctx->counter += n_blocks;
327 r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
331 for (; i < n_blocks; i++)
333 r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
335 ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
341 static_always_inline void
342 aes_gcm_enc_last_round (aes_gcm_ctx_t *ctx, aes_data_t *r, aes_data_t *d,
343 const aes_expaned_key_t *Ke, uword n_blocks)
345 /* additional ronuds for AES-192 and AES-256 */
346 for (int i = 10; i < ctx->rounds; i++)
347 aes_enc_round (r, Ke + i, n_blocks);
349 aes_enc_last_round (r, d, Ke + ctx->rounds, n_blocks);
352 static_always_inline void
353 aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
354 u32 n_bytes, int with_ghash)
356 const aes_expaned_key_t *k = ctx->Ke;
357 const aes_mem_t *sv = (aes_mem_t *) src;
358 aes_mem_t *dv = (aes_mem_t *) dst;
359 uword ghash_blocks, gc = 1;
363 if (ctx->operation == AES_GCM_OP_ENCRYPT)
366 n_lanes = N_AES_LANES * 4;
371 n_lanes = n * N_AES_LANES;
374 n_lanes = (n_bytes + 15) / 16;
378 n_bytes -= (n - 1) * N_AES_BYTES;
380 /* AES rounds 0 and 1 */
381 aes_gcm_enc_first_round (ctx, r, n);
382 aes_enc_round (r, k + 1, n);
384 /* load data - decrypt round */
385 if (ctx->operation == AES_GCM_OP_DECRYPT)
387 for (i = 0; i < n - ctx->last; i++)
391 d[n - 1] = aes_load_partial ((u8 *) (sv + n - 1), n_bytes);
394 /* GHASH multiply block 0 */
396 aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
398 /* AES rounds 2 and 3 */
399 aes_enc_round (r, k + 2, n);
400 aes_enc_round (r, k + 3, n);
402 /* GHASH multiply block 1 */
403 if (with_ghash && gc++ < ghash_blocks)
404 aes_gcm_ghash_mul_next (ctx, (d[1]));
406 /* AES rounds 4 and 5 */
407 aes_enc_round (r, k + 4, n);
408 aes_enc_round (r, k + 5, n);
410 /* GHASH multiply block 2 */
411 if (with_ghash && gc++ < ghash_blocks)
412 aes_gcm_ghash_mul_next (ctx, (d[2]));
414 /* AES rounds 6 and 7 */
415 aes_enc_round (r, k + 6, n);
416 aes_enc_round (r, k + 7, n);
418 /* GHASH multiply block 3 */
419 if (with_ghash && gc++ < ghash_blocks)
420 aes_gcm_ghash_mul_next (ctx, (d[3]));
422 /* load 4 blocks of data - decrypt round */
423 if (ctx->operation == AES_GCM_OP_ENCRYPT)
425 for (i = 0; i < n - ctx->last; i++)
429 d[n - 1] = aes_load_partial (sv + n - 1, n_bytes);
432 /* AES rounds 8 and 9 */
433 aes_enc_round (r, k + 8, n);
434 aes_enc_round (r, k + 9, n);
436 /* AES last round(s) */
437 aes_gcm_enc_last_round (ctx, r, d, k, n);
440 for (i = 0; i < n - ctx->last; i++)
444 aes_store_partial (d[n - 1], dv + n - 1, n_bytes);
446 /* GHASH reduce 1st step */
447 aes_gcm_ghash_reduce (ctx);
449 /* GHASH reduce 2nd step */
451 aes_gcm_ghash_reduce2 (ctx);
453 /* GHASH final step */
455 aes_gcm_ghash_final (ctx);
458 static_always_inline void
459 aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
461 const aes_expaned_key_t *k = ctx->Ke;
462 const aes_mem_t *sv = (aes_mem_t *) src;
463 aes_mem_t *dv = (aes_mem_t *) dst;
466 /* AES rounds 0 and 1 */
467 aes_gcm_enc_first_round (ctx, r, 4);
468 aes_enc_round (r, k + 1, 4);
470 /* load 4 blocks of data - decrypt round */
471 if (ctx->operation == AES_GCM_OP_DECRYPT)
472 for (int i = 0; i < 4; i++)
475 /* GHASH multiply block 0 */
476 aes_gcm_ghash_mul_first (ctx, d[0], N_AES_LANES * 8);
478 /* AES rounds 2 and 3 */
479 aes_enc_round (r, k + 2, 4);
480 aes_enc_round (r, k + 3, 4);
482 /* GHASH multiply block 1 */
483 aes_gcm_ghash_mul_next (ctx, (d[1]));
485 /* AES rounds 4 and 5 */
486 aes_enc_round (r, k + 4, 4);
487 aes_enc_round (r, k + 5, 4);
489 /* GHASH multiply block 2 */
490 aes_gcm_ghash_mul_next (ctx, (d[2]));
492 /* AES rounds 6 and 7 */
493 aes_enc_round (r, k + 6, 4);
494 aes_enc_round (r, k + 7, 4);
496 /* GHASH multiply block 3 */
497 aes_gcm_ghash_mul_next (ctx, (d[3]));
499 /* AES rounds 8 and 9 */
500 aes_enc_round (r, k + 8, 4);
501 aes_enc_round (r, k + 9, 4);
503 /* load 4 blocks of data - encrypt round */
504 if (ctx->operation == AES_GCM_OP_ENCRYPT)
505 for (int i = 0; i < 4; i++)
508 /* AES last round(s) */
509 aes_gcm_enc_last_round (ctx, r, d, k, 4);
511 /* store 4 blocks of data */
512 for (int i = 0; i < 4; i++)
515 /* load next 4 blocks of data data - decrypt round */
516 if (ctx->operation == AES_GCM_OP_DECRYPT)
517 for (int i = 0; i < 4; i++)
520 /* GHASH multiply block 4 */
521 aes_gcm_ghash_mul_next (ctx, (d[0]));
523 /* AES rounds 0 and 1 */
524 aes_gcm_enc_first_round (ctx, r, 4);
525 aes_enc_round (r, k + 1, 4);
527 /* GHASH multiply block 5 */
528 aes_gcm_ghash_mul_next (ctx, (d[1]));
530 /* AES rounds 2 and 3 */
531 aes_enc_round (r, k + 2, 4);
532 aes_enc_round (r, k + 3, 4);
534 /* GHASH multiply block 6 */
535 aes_gcm_ghash_mul_next (ctx, (d[2]));
537 /* AES rounds 4 and 5 */
538 aes_enc_round (r, k + 4, 4);
539 aes_enc_round (r, k + 5, 4);
541 /* GHASH multiply block 7 */
542 aes_gcm_ghash_mul_next (ctx, (d[3]));
544 /* AES rounds 6 and 7 */
545 aes_enc_round (r, k + 6, 4);
546 aes_enc_round (r, k + 7, 4);
548 /* GHASH reduce 1st step */
549 aes_gcm_ghash_reduce (ctx);
551 /* AES rounds 8 and 9 */
552 aes_enc_round (r, k + 8, 4);
553 aes_enc_round (r, k + 9, 4);
555 /* GHASH reduce 2nd step */
556 aes_gcm_ghash_reduce2 (ctx);
558 /* load 4 blocks of data - encrypt round */
559 if (ctx->operation == AES_GCM_OP_ENCRYPT)
560 for (int i = 0; i < 4; i++)
563 /* AES last round(s) */
564 aes_gcm_enc_last_round (ctx, r, d, k, 4);
567 for (int i = 0; i < 4; i++)
570 /* GHASH final step */
571 aes_gcm_ghash_final (ctx);
574 static_always_inline void
575 aes_gcm_mask_bytes (aes_data_t *d, uword n_bytes)
582 .b = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
583 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
584 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
585 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 },
588 d[0] &= (aes_gcm_splat (n_bytes) > scale.r);
591 static_always_inline void
592 aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
595 int n_lanes = (N_AES_LANES == 1 ? n_blocks : (n_bytes + 15) / 16) + 1;
596 n_bytes -= (n_blocks - 1) * N_AES_BYTES;
599 aes_gcm_enc_ctr0_round (ctx, 0);
600 aes_gcm_enc_ctr0_round (ctx, 1);
602 if (n_bytes != N_AES_BYTES)
603 aes_gcm_mask_bytes (d + n_blocks - 1, n_bytes);
605 aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
607 aes_gcm_enc_ctr0_round (ctx, 2);
608 aes_gcm_enc_ctr0_round (ctx, 3);
611 aes_gcm_ghash_mul_next (ctx, d[1]);
613 aes_gcm_enc_ctr0_round (ctx, 4);
614 aes_gcm_enc_ctr0_round (ctx, 5);
617 aes_gcm_ghash_mul_next (ctx, d[2]);
619 aes_gcm_enc_ctr0_round (ctx, 6);
620 aes_gcm_enc_ctr0_round (ctx, 7);
623 aes_gcm_ghash_mul_next (ctx, d[3]);
625 aes_gcm_enc_ctr0_round (ctx, 8);
626 aes_gcm_enc_ctr0_round (ctx, 9);
628 aes_gcm_ghash_mul_final_block (ctx);
629 aes_gcm_ghash_reduce (ctx);
631 for (i = 10; i < ctx->rounds; i++)
632 aes_gcm_enc_ctr0_round (ctx, i);
634 aes_gcm_ghash_reduce2 (ctx);
636 aes_gcm_ghash_final (ctx);
638 aes_gcm_enc_ctr0_round (ctx, i);
641 static_always_inline void
642 aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
646 if (PREDICT_FALSE (n_left == 0))
649 for (i = 0; i < ctx->rounds + 1; i++)
650 aes_gcm_enc_ctr0_round (ctx, i);
654 if (n_left < 4 * N_AES_BYTES)
657 if (n_left > 3 * N_AES_BYTES)
659 aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 0);
660 aes_gcm_calc_last (ctx, d, 4, n_left);
662 else if (n_left > 2 * N_AES_BYTES)
664 aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 0);
665 aes_gcm_calc_last (ctx, d, 3, n_left);
667 else if (n_left > N_AES_BYTES)
669 aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 0);
670 aes_gcm_calc_last (ctx, d, 2, n_left);
674 aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 0);
675 aes_gcm_calc_last (ctx, d, 1, n_left);
680 aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 0);
683 n_left -= 4 * N_AES_BYTES;
684 dst += 4 * N_AES_BYTES;
685 src += 4 * N_AES_BYTES;
687 for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, src += n, dst += n)
688 aes_gcm_calc_double (ctx, d, src, dst);
690 if (n_left >= 4 * N_AES_BYTES)
692 aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 1);
695 n_left -= 4 * N_AES_BYTES;
696 dst += 4 * N_AES_BYTES;
697 src += 4 * N_AES_BYTES;
702 aes_gcm_calc_last (ctx, d, 4, 4 * N_AES_BYTES);
708 if (n_left > 3 * N_AES_BYTES)
710 aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
711 aes_gcm_calc_last (ctx, d, 4, n_left);
713 else if (n_left > 2 * N_AES_BYTES)
715 aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
716 aes_gcm_calc_last (ctx, d, 3, n_left);
718 else if (n_left > N_AES_BYTES)
720 aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
721 aes_gcm_calc_last (ctx, d, 2, n_left);
725 aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
726 aes_gcm_calc_last (ctx, d, 1, n_left);
730 static_always_inline void
731 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
733 aes_data_t d[4] = {};
736 /* main encryption loop */
737 for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, dst += n, src += n)
738 aes_gcm_calc_double (ctx, d, src, dst);
740 if (n_left >= 4 * N_AES_BYTES)
742 aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 1);
745 n_left -= 4 * N_AES_BYTES;
746 dst += N_AES_BYTES * 4;
747 src += N_AES_BYTES * 4;
754 if (n_left > 3 * N_AES_BYTES)
755 aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
756 else if (n_left > 2 * N_AES_BYTES)
757 aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
758 else if (n_left > N_AES_BYTES)
759 aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
761 aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
764 /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
765 * (bit length) block */
767 aes_gcm_enc_ctr0_round (ctx, 0);
768 aes_gcm_enc_ctr0_round (ctx, 1);
770 ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
771 ctx->Hi[NUM_HI - 1]);
773 aes_gcm_enc_ctr0_round (ctx, 2);
774 aes_gcm_enc_ctr0_round (ctx, 3);
778 aes_gcm_enc_ctr0_round (ctx, 4);
779 aes_gcm_enc_ctr0_round (ctx, 5);
783 aes_gcm_enc_ctr0_round (ctx, 6);
784 aes_gcm_enc_ctr0_round (ctx, 7);
786 ctx->T = ghash_final (&gd);
788 aes_gcm_enc_ctr0_round (ctx, 8);
789 aes_gcm_enc_ctr0_round (ctx, 9);
791 for (int i = 10; i < ctx->rounds + 1; i += 1)
792 aes_gcm_enc_ctr0_round (ctx, i);
795 static_always_inline int
796 aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
797 u32 data_bytes, u32 aad_bytes, u8 tag_len,
798 const aes_gcm_key_data_t *kd, int aes_rounds, aes_gcm_op_t op)
800 u8 *addt = (u8 *) aad;
803 aes_gcm_ctx_t _ctx = { .counter = 2,
804 .rounds = aes_rounds,
806 .data_bytes = data_bytes,
807 .aad_bytes = aad_bytes,
812 /* initalize counter */
813 Y0 = (u32x4) (u64x2){ *(u64u *) ivp, 0 };
814 Y0[2] = *(u32u *) (ivp + 8);
816 ctx->EY0 = (u8x16) Y0;
819 ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
820 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
822 #elif N_AES_LANES == 2
824 u32x8_splat_u32x4 (Y0) + (u32x8){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24 };
826 ctx->Y = Y0 + (u32x4){ 0, 0, 0, 1 << 24 };
829 /* calculate ghash for AAD */
830 aes_gcm_ghash (ctx, addt, aad_bytes);
832 /* ghash and encrypt/edcrypt */
833 if (op == AES_GCM_OP_ENCRYPT)
834 aes_gcm_enc (ctx, src, dst, data_bytes);
835 else if (op == AES_GCM_OP_DECRYPT)
836 aes_gcm_dec (ctx, src, dst, data_bytes);
839 ctx->T = u8x16_reflect (ctx->T) ^ ctx->EY0;
841 /* tag_len 16 -> 0 */
844 if (op == AES_GCM_OP_ENCRYPT || op == AES_GCM_OP_GMAC)
848 u8x16_store_partial (ctx->T, tag, tag_len);
850 ((u8x16u *) tag)[0] = ctx->T;
857 u16 mask = pow2_mask (tag_len);
858 u8x16 expected = u8x16_load_partial (tag, tag_len);
859 if ((u8x16_msb_mask (expected == ctx->T) & mask) == mask)
864 if (u8x16_is_equal (ctx->T, *(u8x16u *) tag))
871 static_always_inline void
872 clib_aes_gcm_key_expand (aes_gcm_key_data_t *kd, const u8 *key,
876 u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
877 aes_expaned_key_t *Ke = (aes_expaned_key_t *) kd->Ke;
880 aes_key_expand (ek, key, ks);
881 for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
882 Ke[i].lanes[0] = Ke[i].lanes[1] = Ke[i].lanes[2] = Ke[i].lanes[3] = ek[i];
884 /* pre-calculate H */
885 H = aes_encrypt_block (u8x16_zero (), ek, ks);
886 H = u8x16_reflect (H);
887 ghash_precompute (H, (u8x16 *) kd->Hi, ARRAY_LEN (kd->Hi));
890 static_always_inline void
891 clib_aes128_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
892 u32 data_bytes, const u8 *aad, u32 aad_bytes,
893 const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
895 aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
896 tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_ENCRYPT);
899 static_always_inline void
900 clib_aes256_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
901 u32 data_bytes, const u8 *aad, u32 aad_bytes,
902 const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
904 aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
905 tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_ENCRYPT);
908 static_always_inline int
909 clib_aes128_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
910 u32 data_bytes, const u8 *aad, u32 aad_bytes,
911 const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
913 return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
914 data_bytes, aad_bytes, tag_bytes, kd,
915 AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_DECRYPT);
918 static_always_inline int
919 clib_aes256_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
920 u32 data_bytes, const u8 *aad, u32 aad_bytes,
921 const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
923 return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
924 data_bytes, aad_bytes, tag_bytes, kd,
925 AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_DECRYPT);
928 static_always_inline void
929 clib_aes128_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
930 const u8 *iv, u32 tag_bytes, u8 *tag)
932 aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
933 AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_GMAC);
936 static_always_inline void
937 clib_aes256_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
938 const u8 *iv, u32 tag_bytes, u8 *tag)
940 aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
941 AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_GMAC);
944 #endif /* __crypto_aes_gcm_h__ */