vppinfra: AES-CBC and AES-GCM refactor and optimizations
[vpp.git] / src / vppinfra / crypto / aes_gcm.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_gcm_h__
6 #define __crypto_aes_gcm_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/cache.h>
11 #include <vppinfra/string.h>
12 #include <vppinfra/crypto/aes.h>
13 #include <vppinfra/crypto/ghash.h>
14
15 #define NUM_HI 36
16 #if defined(__VAES__) && defined(__AVX512F__)
17 typedef u8x64 aes_data_t;
18 typedef u8x64u aes_ghash_t;
19 typedef u8x64u aes_mem_t;
20 typedef u32x16 aes_gcm_counter_t;
21 #define N                              64
22 #define aes_gcm_load_partial(p, n)     u8x64_load_partial ((u8 *) (p), n)
23 #define aes_gcm_store_partial(v, p, n) u8x64_store_partial (v, (u8 *) (p), n)
24 #define aes_gcm_splat(v)               u8x64_splat (v)
25 #define aes_gcm_reflect(r)             u8x64_reflect_u8x16 (r)
26 #define aes_gcm_ghash_reduce(c)        ghash4_reduce (&(c)->gd)
27 #define aes_gcm_ghash_reduce2(c)       ghash4_reduce2 (&(c)->gd)
28 #define aes_gcm_ghash_final(c)         (c)->T = ghash4_final (&(c)->gd)
29 #elif defined(__VAES__)
30 typedef u8x32 aes_data_t;
31 typedef u8x32u aes_ghash_t;
32 typedef u8x32u aes_mem_t;
33 typedef u32x8 aes_gcm_counter_t;
34 #define N                              32
35 #define aes_gcm_load_partial(p, n)     u8x32_load_partial ((u8 *) (p), n)
36 #define aes_gcm_store_partial(v, p, n) u8x32_store_partial (v, (u8 *) (p), n)
37 #define aes_gcm_splat(v)               u8x32_splat (v)
38 #define aes_gcm_reflect(r)             u8x32_reflect_u8x16 (r)
39 #define aes_gcm_ghash_reduce(c)        ghash2_reduce (&(c)->gd)
40 #define aes_gcm_ghash_reduce2(c)       ghash2_reduce2 (&(c)->gd)
41 #define aes_gcm_ghash_final(c)         (c)->T = ghash2_final (&(c)->gd)
42 #else
43 typedef u8x16 aes_data_t;
44 typedef u8x16 aes_ghash_t;
45 typedef u8x16u aes_mem_t;
46 typedef u32x4 aes_gcm_counter_t;
47 #define N                              16
48 #define aes_gcm_load_partial(p, n)     u8x16_load_partial ((u8 *) (p), n)
49 #define aes_gcm_store_partial(v, p, n) u8x16_store_partial (v, (u8 *) (p), n)
50 #define aes_gcm_splat(v)               u8x16_splat (v)
51 #define aes_gcm_reflect(r)             u8x16_reflect (r)
52 #define aes_gcm_ghash_reduce(c)        ghash_reduce (&(c)->gd)
53 #define aes_gcm_ghash_reduce2(c)       ghash_reduce2 (&(c)->gd)
54 #define aes_gcm_ghash_final(c)         (c)->T = ghash_final (&(c)->gd)
55 #endif
56 #define N_LANES (N / 16)
57
58 typedef enum
59 {
60   AES_GCM_OP_UNKNONW = 0,
61   AES_GCM_OP_ENCRYPT,
62   AES_GCM_OP_DECRYPT,
63   AES_GCM_OP_GMAC
64 } aes_gcm_op_t;
65
66 typedef union
67 {
68   u8x16 x1;
69   u8x32 x2;
70   u8x64 x4;
71   u8x16 lanes[4];
72 } __clib_aligned (64)
73 aes_gcm_expaned_key_t;
74
75 typedef struct
76 {
77   /* pre-calculated hash key values */
78   const u8x16 Hi[NUM_HI];
79   /* extracted AES key */
80   const aes_gcm_expaned_key_t Ke[AES_KEY_ROUNDS (AES_KEY_256) + 1];
81 } aes_gcm_key_data_t;
82
83 typedef struct
84 {
85   aes_gcm_op_t operation;
86   int last;
87   u8 rounds;
88   uword data_bytes;
89   uword aad_bytes;
90
91   u8x16 T;
92
93   /* hash */
94   const u8x16 *Hi;
95   const aes_ghash_t *next_Hi;
96
97   /* expaded keys */
98   const aes_gcm_expaned_key_t *Ke;
99
100   /* counter */
101   u32 counter;
102   u8x16 EY0;
103   aes_gcm_counter_t Y;
104
105   /* ghash */
106   ghash_data_t gd;
107 } aes_gcm_ctx_t;
108
109 static_always_inline void
110 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
111 {
112   uword hash_offset = NUM_HI - n_lanes;
113   ctx->next_Hi = (aes_ghash_t *) (ctx->Hi + hash_offset);
114 #if N_LANES == 4
115   u8x64 tag4 = {};
116   tag4 = u8x64_insert_u8x16 (tag4, ctx->T, 0);
117   ghash4_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag4, *ctx->next_Hi++);
118 #elif N_LANES == 2
119   u8x32 tag2 = {};
120   tag2 = u8x32_insert_lo (tag2, ctx->T);
121   ghash2_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag2, *ctx->next_Hi++);
122 #else
123   ghash_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ ctx->T, *ctx->next_Hi++);
124 #endif
125 }
126
127 static_always_inline void
128 aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
129 {
130 #if N_LANES == 4
131   ghash4_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
132 #elif N_LANES == 2
133   ghash2_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
134 #else
135   ghash_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
136 #endif
137 }
138
139 static_always_inline void
140 aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx)
141 {
142   u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
143 #if N_LANES == 4
144   u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
145   u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0);
146   ghash4_mul_next (&ctx->gd, r4, h);
147 #elif N_LANES == 2
148   u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
149   u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r);
150   ghash2_mul_next (&ctx->gd, r2, h);
151 #else
152   ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]);
153 #endif
154 }
155
156 static_always_inline void
157 aes_gcm_enc_ctr0_round (aes_gcm_ctx_t *ctx, int aes_round)
158 {
159   if (aes_round == 0)
160     ctx->EY0 ^= ctx->Ke[0].x1;
161   else if (aes_round == ctx->rounds)
162     ctx->EY0 = aes_enc_last_round (ctx->EY0, ctx->Ke[aes_round].x1);
163   else
164     ctx->EY0 = aes_enc_round (ctx->EY0, ctx->Ke[aes_round].x1);
165 }
166
167 static_always_inline void
168 aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
169 {
170   uword i;
171   aes_data_t r = {};
172   const aes_mem_t *d = (aes_mem_t *) data;
173
174   for (; n_left >= 8 * N; n_left -= 8 * N, d += 8)
175     {
176       if (ctx->operation == AES_GCM_OP_GMAC && n_left == N * 8)
177         {
178           aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1);
179           for (i = 1; i < 8; i++)
180             aes_gcm_ghash_mul_next (ctx, d[i]);
181           aes_gcm_ghash_mul_bit_len (ctx);
182           aes_gcm_ghash_reduce (ctx);
183           aes_gcm_ghash_reduce2 (ctx);
184           aes_gcm_ghash_final (ctx);
185           goto done;
186         }
187
188       aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES);
189       for (i = 1; i < 8; i++)
190         aes_gcm_ghash_mul_next (ctx, d[i]);
191       aes_gcm_ghash_reduce (ctx);
192       aes_gcm_ghash_reduce2 (ctx);
193       aes_gcm_ghash_final (ctx);
194     }
195
196   if (n_left > 0)
197     {
198       int n_lanes = (n_left + 15) / 16;
199
200       if (ctx->operation == AES_GCM_OP_GMAC)
201         n_lanes++;
202
203       if (n_left < N)
204         {
205           clib_memcpy_fast (&r, d, n_left);
206           aes_gcm_ghash_mul_first (ctx, r, n_lanes);
207         }
208       else
209         {
210           aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
211           n_left -= N;
212           i = 1;
213
214           if (n_left >= 4 * N)
215             {
216               aes_gcm_ghash_mul_next (ctx, d[i]);
217               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
218               aes_gcm_ghash_mul_next (ctx, d[i + 2]);
219               aes_gcm_ghash_mul_next (ctx, d[i + 3]);
220               n_left -= 4 * N;
221               i += 4;
222             }
223           if (n_left >= 2 * N)
224             {
225               aes_gcm_ghash_mul_next (ctx, d[i]);
226               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
227               n_left -= 2 * N;
228               i += 2;
229             }
230
231           if (n_left >= N)
232             {
233               aes_gcm_ghash_mul_next (ctx, d[i]);
234               n_left -= N;
235               i += 1;
236             }
237
238           if (n_left)
239             {
240               clib_memcpy_fast (&r, d + i, n_left);
241               aes_gcm_ghash_mul_next (ctx, r);
242             }
243         }
244
245       if (ctx->operation == AES_GCM_OP_GMAC)
246         aes_gcm_ghash_mul_bit_len (ctx);
247       aes_gcm_ghash_reduce (ctx);
248       aes_gcm_ghash_reduce2 (ctx);
249       aes_gcm_ghash_final (ctx);
250     }
251   else if (ctx->operation == AES_GCM_OP_GMAC)
252     {
253       u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
254       ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
255     }
256
257 done:
258   /* encrypt counter 0 E(Y0, k) */
259   if (ctx->operation == AES_GCM_OP_GMAC)
260     for (int i = 0; i < ctx->rounds + 1; i += 1)
261       aes_gcm_enc_ctr0_round (ctx, i);
262 }
263
264 static_always_inline void
265 aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
266 {
267   const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0];
268   uword i = 0;
269
270 #if N_LANES == 4
271   const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
272                                 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
273
274   const u32x16 ctr_4444 = {
275     4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
276   };
277
278   /* As counter is stored in network byte order for performance reasons we
279      are incrementing least significant byte only except in case where we
280      overlow. As we are processing four 512-blocks in parallel except the
281      last round, overflow can happen only when n == 4 */
282
283   if (n_blocks == 4)
284     for (; i < 2; i++)
285       {
286         r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
287         ctx->Y += ctr_inv_4444;
288       }
289
290   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 242))
291     {
292       u32x16 Yr = (u32x16) aes_gcm_reflect ((u8x64) ctx->Y);
293
294       for (; i < n_blocks; i++)
295         {
296           r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
297           Yr += ctr_4444;
298           ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr);
299         }
300     }
301   else
302     {
303       for (; i < n_blocks; i++)
304         {
305           r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
306           ctx->Y += ctr_inv_4444;
307         }
308     }
309   ctx->counter += n_blocks * 4;
310 #elif N_LANES == 2
311   const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
312   const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
313
314   /* As counter is stored in network byte order for performance reasons we
315      are incrementing least significant byte only except in case where we
316      overlow. As we are processing four 512-blocks in parallel except the
317      last round, overflow can happen only when n == 4 */
318
319   if (n_blocks == 4)
320     for (; i < 2; i++)
321       {
322         r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
323         ctx->Y += ctr_inv_22;
324       }
325
326   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 250))
327     {
328       u32x8 Yr = (u32x8) aes_gcm_reflect ((u8x32) ctx->Y);
329
330       for (; i < n_blocks; i++)
331         {
332           r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
333           Yr += ctr_22;
334           ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr);
335         }
336     }
337   else
338     {
339       for (; i < n_blocks; i++)
340         {
341           r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
342           ctx->Y += ctr_inv_22;
343         }
344     }
345   ctx->counter += n_blocks * 2;
346 #else
347   const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 };
348
349   if (PREDICT_TRUE ((u8) ctx->counter < 0xfe) || n_blocks < 3)
350     {
351       for (; i < n_blocks; i++)
352         {
353           r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
354           ctx->Y += ctr_inv_1;
355         }
356       ctx->counter += n_blocks;
357     }
358   else
359     {
360       r[i++] = Ke0.x1 ^ (u8x16) ctx->Y;
361       ctx->Y += ctr_inv_1;
362       ctx->counter += 1;
363
364       for (; i < n_blocks; i++)
365         {
366           r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
367           ctx->counter++;
368           ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
369         }
370     }
371 #endif
372 }
373
374 static_always_inline void
375 aes_gcm_enc_round (aes_data_t *r, const aes_gcm_expaned_key_t *Ke,
376                    uword n_blocks)
377 {
378   for (int i = 0; i < n_blocks; i++)
379 #if N_LANES == 4
380     r[i] = aes_enc_round_x4 (r[i], Ke->x4);
381 #elif N_LANES == 2
382     r[i] = aes_enc_round_x2 (r[i], Ke->x2);
383 #else
384     r[i] = aes_enc_round (r[i], Ke->x1);
385 #endif
386 }
387
388 static_always_inline void
389 aes_gcm_enc_last_round (aes_gcm_ctx_t *ctx, aes_data_t *r, aes_data_t *d,
390                         const aes_gcm_expaned_key_t *Ke, uword n_blocks)
391 {
392   /* additional ronuds for AES-192 and AES-256 */
393   for (int i = 10; i < ctx->rounds; i++)
394     aes_gcm_enc_round (r, Ke + i, n_blocks);
395
396   for (int i = 0; i < n_blocks; i++)
397 #if N_LANES == 4
398     d[i] ^= aes_enc_last_round_x4 (r[i], Ke[ctx->rounds].x4);
399 #elif N_LANES == 2
400     d[i] ^= aes_enc_last_round_x2 (r[i], Ke[ctx->rounds].x2);
401 #else
402     d[i] ^= aes_enc_last_round (r[i], Ke[ctx->rounds].x1);
403 #endif
404 }
405
406 static_always_inline void
407 aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
408               u32 n_bytes, int with_ghash)
409 {
410   const aes_gcm_expaned_key_t *k = ctx->Ke;
411   const aes_mem_t *sv = (aes_mem_t *) src;
412   aes_mem_t *dv = (aes_mem_t *) dst;
413   uword ghash_blocks, gc = 1;
414   aes_data_t r[4];
415   u32 i, n_lanes;
416
417   if (ctx->operation == AES_GCM_OP_ENCRYPT)
418     {
419       ghash_blocks = 4;
420       n_lanes = N_LANES * 4;
421     }
422   else
423     {
424       ghash_blocks = n;
425       n_lanes = n * N_LANES;
426 #if N_LANES != 1
427       if (ctx->last)
428         n_lanes = (n_bytes + 15) / 16;
429 #endif
430     }
431
432   n_bytes -= (n - 1) * N;
433
434   /* AES rounds 0 and 1 */
435   aes_gcm_enc_first_round (ctx, r, n);
436   aes_gcm_enc_round (r, k + 1, n);
437
438   /* load data - decrypt round */
439   if (ctx->operation == AES_GCM_OP_DECRYPT)
440     {
441       for (i = 0; i < n - ctx->last; i++)
442         d[i] = sv[i];
443
444       if (ctx->last)
445         d[n - 1] = aes_gcm_load_partial ((u8 *) (sv + n - 1), n_bytes);
446     }
447
448   /* GHASH multiply block 0 */
449   if (with_ghash)
450     aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
451
452   /* AES rounds 2 and 3 */
453   aes_gcm_enc_round (r, k + 2, n);
454   aes_gcm_enc_round (r, k + 3, n);
455
456   /* GHASH multiply block 1 */
457   if (with_ghash && gc++ < ghash_blocks)
458     aes_gcm_ghash_mul_next (ctx, (d[1]));
459
460   /* AES rounds 4 and 5 */
461   aes_gcm_enc_round (r, k + 4, n);
462   aes_gcm_enc_round (r, k + 5, n);
463
464   /* GHASH multiply block 2 */
465   if (with_ghash && gc++ < ghash_blocks)
466     aes_gcm_ghash_mul_next (ctx, (d[2]));
467
468   /* AES rounds 6 and 7 */
469   aes_gcm_enc_round (r, k + 6, n);
470   aes_gcm_enc_round (r, k + 7, n);
471
472   /* GHASH multiply block 3 */
473   if (with_ghash && gc++ < ghash_blocks)
474     aes_gcm_ghash_mul_next (ctx, (d[3]));
475
476   /* load 4 blocks of data - decrypt round */
477   if (ctx->operation == AES_GCM_OP_ENCRYPT)
478     {
479       for (i = 0; i < n - ctx->last; i++)
480         d[i] = sv[i];
481
482       if (ctx->last)
483         d[n - 1] = aes_gcm_load_partial (sv + n - 1, n_bytes);
484     }
485
486   /* AES rounds 8 and 9 */
487   aes_gcm_enc_round (r, k + 8, n);
488   aes_gcm_enc_round (r, k + 9, n);
489
490   /* AES last round(s) */
491   aes_gcm_enc_last_round (ctx, r, d, k, n);
492
493   /* store data */
494   for (i = 0; i < n - ctx->last; i++)
495     dv[i] = d[i];
496
497   if (ctx->last)
498     aes_gcm_store_partial (d[n - 1], dv + n - 1, n_bytes);
499
500   /* GHASH reduce 1st step */
501   aes_gcm_ghash_reduce (ctx);
502
503   /* GHASH reduce 2nd step */
504   if (with_ghash)
505     aes_gcm_ghash_reduce2 (ctx);
506
507   /* GHASH final step */
508   if (with_ghash)
509     aes_gcm_ghash_final (ctx);
510 }
511
512 static_always_inline void
513 aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst,
514                      int with_ghash)
515 {
516   const aes_gcm_expaned_key_t *k = ctx->Ke;
517   const aes_mem_t *sv = (aes_mem_t *) src;
518   aes_mem_t *dv = (aes_mem_t *) dst;
519   aes_data_t r[4];
520
521   /* AES rounds 0 and 1 */
522   aes_gcm_enc_first_round (ctx, r, 4);
523   aes_gcm_enc_round (r, k + 1, 4);
524
525   /* load 4 blocks of data - decrypt round */
526   if (ctx->operation == AES_GCM_OP_DECRYPT)
527     for (int i = 0; i < 4; i++)
528       d[i] = sv[i];
529
530   /* GHASH multiply block 0 */
531   aes_gcm_ghash_mul_first (ctx, d[0], N_LANES * 8);
532
533   /* AES rounds 2 and 3 */
534   aes_gcm_enc_round (r, k + 2, 4);
535   aes_gcm_enc_round (r, k + 3, 4);
536
537   /* GHASH multiply block 1 */
538   aes_gcm_ghash_mul_next (ctx, (d[1]));
539
540   /* AES rounds 4 and 5 */
541   aes_gcm_enc_round (r, k + 4, 4);
542   aes_gcm_enc_round (r, k + 5, 4);
543
544   /* GHASH multiply block 2 */
545   aes_gcm_ghash_mul_next (ctx, (d[2]));
546
547   /* AES rounds 6 and 7 */
548   aes_gcm_enc_round (r, k + 6, 4);
549   aes_gcm_enc_round (r, k + 7, 4);
550
551   /* GHASH multiply block 3 */
552   aes_gcm_ghash_mul_next (ctx, (d[3]));
553
554   /* AES rounds 8 and 9 */
555   aes_gcm_enc_round (r, k + 8, 4);
556   aes_gcm_enc_round (r, k + 9, 4);
557
558   /* load 4 blocks of data - encrypt round */
559   if (ctx->operation == AES_GCM_OP_ENCRYPT)
560     for (int i = 0; i < 4; i++)
561       d[i] = sv[i];
562
563   /* AES last round(s) */
564   aes_gcm_enc_last_round (ctx, r, d, k, 4);
565
566   /* store 4 blocks of data */
567   for (int i = 0; i < 4; i++)
568     dv[i] = d[i];
569
570   /* load next 4 blocks of data data - decrypt round */
571   if (ctx->operation == AES_GCM_OP_DECRYPT)
572     for (int i = 0; i < 4; i++)
573       d[i] = sv[i + 4];
574
575   /* GHASH multiply block 4 */
576   aes_gcm_ghash_mul_next (ctx, (d[0]));
577
578   /* AES rounds 0 and 1 */
579   aes_gcm_enc_first_round (ctx, r, 4);
580   aes_gcm_enc_round (r, k + 1, 4);
581
582   /* GHASH multiply block 5 */
583   aes_gcm_ghash_mul_next (ctx, (d[1]));
584
585   /* AES rounds 2 and 3 */
586   aes_gcm_enc_round (r, k + 2, 4);
587   aes_gcm_enc_round (r, k + 3, 4);
588
589   /* GHASH multiply block 6 */
590   aes_gcm_ghash_mul_next (ctx, (d[2]));
591
592   /* AES rounds 4 and 5 */
593   aes_gcm_enc_round (r, k + 4, 4);
594   aes_gcm_enc_round (r, k + 5, 4);
595
596   /* GHASH multiply block 7 */
597   aes_gcm_ghash_mul_next (ctx, (d[3]));
598
599   /* AES rounds 6 and 7 */
600   aes_gcm_enc_round (r, k + 6, 4);
601   aes_gcm_enc_round (r, k + 7, 4);
602
603   /* GHASH reduce 1st step */
604   aes_gcm_ghash_reduce (ctx);
605
606   /* AES rounds 8 and 9 */
607   aes_gcm_enc_round (r, k + 8, 4);
608   aes_gcm_enc_round (r, k + 9, 4);
609
610   /* GHASH reduce 2nd step */
611   aes_gcm_ghash_reduce2 (ctx);
612
613   /* load 4 blocks of data - encrypt round */
614   if (ctx->operation == AES_GCM_OP_ENCRYPT)
615     for (int i = 0; i < 4; i++)
616       d[i] = sv[i + 4];
617
618   /* AES last round(s) */
619   aes_gcm_enc_last_round (ctx, r, d, k, 4);
620
621   /* store data */
622   for (int i = 0; i < 4; i++)
623     dv[i + 4] = d[i];
624
625   /* GHASH final step */
626   aes_gcm_ghash_final (ctx);
627 }
628
629 static_always_inline void
630 aes_gcm_mask_bytes (aes_data_t *d, uword n_bytes)
631 {
632   const union
633   {
634     u8 b[64];
635     aes_data_t r;
636   } scale = {
637     .b = { 0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
638            16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
639            32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
640            48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 },
641   };
642
643   d[0] &= (aes_gcm_splat (n_bytes) > scale.r);
644 }
645
646 static_always_inline void
647 aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
648                    u32 n_bytes)
649 {
650   int n_lanes = (N_LANES == 1 ? n_blocks : (n_bytes + 15) / 16) + 1;
651   n_bytes -= (n_blocks - 1) * N;
652   int i;
653
654   aes_gcm_enc_ctr0_round (ctx, 0);
655   aes_gcm_enc_ctr0_round (ctx, 1);
656
657   if (n_bytes != N)
658     aes_gcm_mask_bytes (d + n_blocks - 1, n_bytes);
659
660   aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
661
662   aes_gcm_enc_ctr0_round (ctx, 2);
663   aes_gcm_enc_ctr0_round (ctx, 3);
664
665   if (n_blocks > 1)
666     aes_gcm_ghash_mul_next (ctx, d[1]);
667
668   aes_gcm_enc_ctr0_round (ctx, 4);
669   aes_gcm_enc_ctr0_round (ctx, 5);
670
671   if (n_blocks > 2)
672     aes_gcm_ghash_mul_next (ctx, d[2]);
673
674   aes_gcm_enc_ctr0_round (ctx, 6);
675   aes_gcm_enc_ctr0_round (ctx, 7);
676
677   if (n_blocks > 3)
678     aes_gcm_ghash_mul_next (ctx, d[3]);
679
680   aes_gcm_enc_ctr0_round (ctx, 8);
681   aes_gcm_enc_ctr0_round (ctx, 9);
682
683   aes_gcm_ghash_mul_bit_len (ctx);
684   aes_gcm_ghash_reduce (ctx);
685
686   for (i = 10; i < ctx->rounds; i++)
687     aes_gcm_enc_ctr0_round (ctx, i);
688
689   aes_gcm_ghash_reduce2 (ctx);
690
691   aes_gcm_ghash_final (ctx);
692
693   aes_gcm_enc_ctr0_round (ctx, i);
694 }
695
696 static_always_inline void
697 aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
698 {
699   aes_data_t d[4];
700
701   if (PREDICT_FALSE (n_left == 0))
702     {
703       int i;
704       for (i = 0; i < ctx->rounds + 1; i++)
705         aes_gcm_enc_ctr0_round (ctx, i);
706       return;
707     }
708
709   if (n_left < 4 * N)
710     {
711       ctx->last = 1;
712       if (n_left > 3 * N)
713         {
714           aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 0);
715           aes_gcm_calc_last (ctx, d, 4, n_left);
716         }
717       else if (n_left > 2 * N)
718         {
719           aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 0);
720           aes_gcm_calc_last (ctx, d, 3, n_left);
721         }
722       else if (n_left > N)
723         {
724           aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 0);
725           aes_gcm_calc_last (ctx, d, 2, n_left);
726         }
727       else
728         {
729           aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 0);
730           aes_gcm_calc_last (ctx, d, 1, n_left);
731         }
732       return;
733     }
734   aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0);
735
736   /* next */
737   n_left -= 4 * N;
738   dst += 4 * N;
739   src += 4 * N;
740
741   for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N)
742     aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
743
744   if (n_left >= 4 * N)
745     {
746       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1);
747
748       /* next */
749       n_left -= 4 * N;
750       dst += 4 * N;
751       src += 4 * N;
752     }
753
754   if (n_left == 0)
755     {
756       aes_gcm_calc_last (ctx, d, 4, 4 * N);
757       return;
758     }
759
760   ctx->last = 1;
761
762   if (n_left > 3 * N)
763     {
764       aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
765       aes_gcm_calc_last (ctx, d, 4, n_left);
766     }
767   else if (n_left > 2 * N)
768     {
769       aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
770       aes_gcm_calc_last (ctx, d, 3, n_left);
771     }
772   else if (n_left > N)
773     {
774       aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
775       aes_gcm_calc_last (ctx, d, 2, n_left);
776     }
777   else
778     {
779       aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
780       aes_gcm_calc_last (ctx, d, 1, n_left);
781     }
782 }
783
784 static_always_inline void
785 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
786 {
787   aes_data_t d[4] = {};
788   for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N)
789     aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
790
791   if (n_left >= 4 * N)
792     {
793       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1);
794
795       /* next */
796       n_left -= 4 * N;
797       dst += N * 4;
798       src += N * 4;
799     }
800
801   if (n_left == 0)
802     goto done;
803
804   ctx->last = 1;
805
806   if (n_left > 3 * N)
807     aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
808   else if (n_left > 2 * N)
809     aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
810   else if (n_left > N)
811     aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
812   else
813     aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
814
815   u8x16 r;
816 done:
817   r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
818   ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
819
820   /* encrypt counter 0 E(Y0, k) */
821   for (int i = 0; i < ctx->rounds + 1; i += 1)
822     aes_gcm_enc_ctr0_round (ctx, i);
823 }
824
825 static_always_inline int
826 aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
827          u32 data_bytes, u32 aad_bytes, u8 tag_len,
828          const aes_gcm_key_data_t *kd, int aes_rounds, aes_gcm_op_t op)
829 {
830   u8 *addt = (u8 *) aad;
831   u32x4 Y0;
832
833   aes_gcm_ctx_t _ctx = { .counter = 2,
834                          .rounds = aes_rounds,
835                          .operation = op,
836                          .data_bytes = data_bytes,
837                          .aad_bytes = aad_bytes,
838                          .Hi = kd->Hi },
839                 *ctx = &_ctx;
840
841   /* initalize counter */
842   Y0 = (u32x4) (u64x2){ *(u64u *) ivp, 0 };
843   Y0[2] = *(u32u *) (ivp + 8);
844   Y0[3] = 1 << 24;
845   ctx->EY0 = (u8x16) Y0;
846   ctx->Ke = kd->Ke;
847 #if N_LANES == 4
848   ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
849     0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
850   };
851 #elif N_LANES == 2
852   ctx->Y =
853     u32x8_splat_u32x4 (Y0) + (u32x8){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24 };
854 #else
855   ctx->Y = Y0 + (u32x4){ 0, 0, 0, 1 << 24 };
856 #endif
857
858   /* calculate ghash for AAD */
859   aes_gcm_ghash (ctx, addt, aad_bytes);
860
861   clib_prefetch_load (tag);
862
863   /* ghash and encrypt/edcrypt  */
864   if (op == AES_GCM_OP_ENCRYPT)
865     aes_gcm_enc (ctx, src, dst, data_bytes);
866   else if (op == AES_GCM_OP_DECRYPT)
867     aes_gcm_dec (ctx, src, dst, data_bytes);
868
869   /* final tag is */
870   ctx->T = u8x16_reflect (ctx->T) ^ ctx->EY0;
871
872   /* tag_len 16 -> 0 */
873   tag_len &= 0xf;
874
875   if (op == AES_GCM_OP_ENCRYPT || op == AES_GCM_OP_GMAC)
876     {
877       /* store tag */
878       if (tag_len)
879         u8x16_store_partial (ctx->T, tag, tag_len);
880       else
881         ((u8x16u *) tag)[0] = ctx->T;
882     }
883   else
884     {
885       /* check tag */
886       if (tag_len)
887         {
888           u16 mask = pow2_mask (tag_len);
889           u8x16 expected = u8x16_load_partial (tag, tag_len);
890           if ((u8x16_msb_mask (expected == ctx->T) & mask) == mask)
891             return 1;
892         }
893       else
894         {
895           if (u8x16_is_equal (ctx->T, *(u8x16u *) tag))
896             return 1;
897         }
898     }
899   return 0;
900 }
901
902 static_always_inline void
903 clib_aes_gcm_key_expand (aes_gcm_key_data_t *kd, const u8 *key,
904                          aes_key_size_t ks)
905 {
906   u8x16 H;
907   u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
908   aes_gcm_expaned_key_t *Ke = (aes_gcm_expaned_key_t *) kd->Ke;
909
910   /* expand AES key */
911   aes_key_expand (ek, key, ks);
912   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
913     Ke[i].lanes[0] = Ke[i].lanes[1] = Ke[i].lanes[2] = Ke[i].lanes[3] = ek[i];
914
915   /* pre-calculate H */
916   H = aes_encrypt_block (u8x16_zero (), ek, ks);
917   H = u8x16_reflect (H);
918   ghash_precompute (H, (u8x16 *) kd->Hi, ARRAY_LEN (kd->Hi));
919 }
920
921 static_always_inline void
922 clib_aes128_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
923                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
924                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
925 {
926   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
927            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_ENCRYPT);
928 }
929
930 static_always_inline void
931 clib_aes256_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
932                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
933                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
934 {
935   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
936            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_ENCRYPT);
937 }
938
939 static_always_inline int
940 clib_aes128_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
941                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
942                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
943 {
944   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
945                   data_bytes, aad_bytes, tag_bytes, kd,
946                   AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_DECRYPT);
947 }
948
949 static_always_inline int
950 clib_aes256_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
951                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
952                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
953 {
954   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
955                   data_bytes, aad_bytes, tag_bytes, kd,
956                   AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_DECRYPT);
957 }
958
959 static_always_inline void
960 clib_aes128_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
961                   const u8 *iv, u32 tag_bytes, u8 *tag)
962 {
963   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
964            AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_GMAC);
965 }
966
967 static_always_inline void
968 clib_aes256_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
969                   const u8 *iv, u32 tag_bytes, u8 *tag)
970 {
971   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
972            AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_GMAC);
973 }
974
975 #endif /* __crypto_aes_gcm_h__ */