vppinfra: native AES-CTR implementation
[vpp.git] / src / vppinfra / crypto / aes_gcm.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_gcm_h__
6 #define __crypto_aes_gcm_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/cache.h>
11 #include <vppinfra/string.h>
12 #include <vppinfra/crypto/aes.h>
13 #include <vppinfra/crypto/ghash.h>
14
15 #define NUM_HI 36
16 #if N_AES_LANES == 4
17 typedef u8x64u aes_ghash_t;
18 #define aes_gcm_splat(v)               u8x64_splat (v)
19 #define aes_gcm_ghash_reduce(c)        ghash4_reduce (&(c)->gd)
20 #define aes_gcm_ghash_reduce2(c)       ghash4_reduce2 (&(c)->gd)
21 #define aes_gcm_ghash_final(c)         (c)->T = ghash4_final (&(c)->gd)
22 #elif N_AES_LANES == 2
23 typedef u8x32u aes_ghash_t;
24 #define aes_gcm_splat(v)               u8x32_splat (v)
25 #define aes_gcm_ghash_reduce(c)        ghash2_reduce (&(c)->gd)
26 #define aes_gcm_ghash_reduce2(c)       ghash2_reduce2 (&(c)->gd)
27 #define aes_gcm_ghash_final(c)         (c)->T = ghash2_final (&(c)->gd)
28 #else
29 typedef u8x16 aes_ghash_t;
30 #define aes_gcm_splat(v)               u8x16_splat (v)
31 #define aes_gcm_ghash_reduce(c)        ghash_reduce (&(c)->gd)
32 #define aes_gcm_ghash_reduce2(c)       ghash_reduce2 (&(c)->gd)
33 #define aes_gcm_ghash_final(c)         (c)->T = ghash_final (&(c)->gd)
34 #endif
35
36 typedef enum
37 {
38   AES_GCM_OP_UNKNONW = 0,
39   AES_GCM_OP_ENCRYPT,
40   AES_GCM_OP_DECRYPT,
41   AES_GCM_OP_GMAC
42 } aes_gcm_op_t;
43
44 typedef struct
45 {
46   /* pre-calculated hash key values */
47   const u8x16 Hi[NUM_HI];
48   /* extracted AES key */
49   const aes_expaned_key_t Ke[AES_KEY_ROUNDS (AES_KEY_256) + 1];
50 } aes_gcm_key_data_t;
51
52 typedef struct
53 {
54   aes_gcm_op_t operation;
55   int last;
56   u8 rounds;
57   uword data_bytes;
58   uword aad_bytes;
59
60   u8x16 T;
61
62   /* hash */
63   const u8x16 *Hi;
64   const aes_ghash_t *next_Hi;
65
66   /* expaded keys */
67   const aes_expaned_key_t *Ke;
68
69   /* counter */
70   u32 counter;
71   u8x16 EY0;
72   aes_counter_t Y;
73
74   /* ghash */
75   ghash_ctx_t gd;
76 } aes_gcm_ctx_t;
77
78 static_always_inline u8x16
79 aes_gcm_final_block (aes_gcm_ctx_t *ctx)
80 {
81   return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
82 }
83
84 static_always_inline void
85 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
86 {
87   uword hash_offset = NUM_HI - n_lanes;
88   ctx->next_Hi = (aes_ghash_t *) (ctx->Hi + hash_offset);
89 #if N_AES_LANES == 4
90   u8x64 tag4 = {};
91   tag4 = u8x64_insert_u8x16 (tag4, ctx->T, 0);
92   ghash4_mul_first (&ctx->gd, aes_reflect (data) ^ tag4, *ctx->next_Hi++);
93 #elif N_AES_LANES == 2
94   u8x32 tag2 = {};
95   tag2 = u8x32_insert_lo (tag2, ctx->T);
96   ghash2_mul_first (&ctx->gd, aes_reflect (data) ^ tag2, *ctx->next_Hi++);
97 #else
98   ghash_mul_first (&ctx->gd, aes_reflect (data) ^ ctx->T, *ctx->next_Hi++);
99 #endif
100 }
101
102 static_always_inline void
103 aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
104 {
105 #if N_AES_LANES == 4
106   ghash4_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
107 #elif N_AES_LANES == 2
108   ghash2_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
109 #else
110   ghash_mul_next (&ctx->gd, aes_reflect (data), *ctx->next_Hi++);
111 #endif
112 }
113
114 static_always_inline void
115 aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
116 {
117 #if N_AES_LANES == 4
118   u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
119   u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
120   ghash4_mul_next (&ctx->gd, r4, h);
121 #elif N_AES_LANES == 2
122   u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
123   u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
124   ghash2_mul_next (&ctx->gd, r2, h);
125 #else
126   ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
127 #endif
128 }
129
130 static_always_inline void
131 aes_gcm_enc_ctr0_round (aes_gcm_ctx_t *ctx, int aes_round)
132 {
133   if (aes_round == 0)
134     ctx->EY0 ^= ctx->Ke[0].x1;
135   else if (aes_round == ctx->rounds)
136     ctx->EY0 = aes_enc_last_round_x1 (ctx->EY0, ctx->Ke[aes_round].x1);
137   else
138     ctx->EY0 = aes_enc_round_x1 (ctx->EY0, ctx->Ke[aes_round].x1);
139 }
140
141 static_always_inline void
142 aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
143 {
144   uword i;
145   aes_data_t r = {};
146   const aes_mem_t *d = (aes_mem_t *) data;
147
148   for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, d += 8)
149     {
150       if (ctx->operation == AES_GCM_OP_GMAC && n_left == n)
151         {
152           aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_AES_LANES + 1);
153           for (i = 1; i < 8; i++)
154             aes_gcm_ghash_mul_next (ctx, d[i]);
155           aes_gcm_ghash_mul_final_block (ctx);
156           aes_gcm_ghash_reduce (ctx);
157           aes_gcm_ghash_reduce2 (ctx);
158           aes_gcm_ghash_final (ctx);
159           goto done;
160         }
161
162       aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_AES_LANES);
163       for (i = 1; i < 8; i++)
164         aes_gcm_ghash_mul_next (ctx, d[i]);
165       aes_gcm_ghash_reduce (ctx);
166       aes_gcm_ghash_reduce2 (ctx);
167       aes_gcm_ghash_final (ctx);
168     }
169
170   if (n_left > 0)
171     {
172       int n_lanes = (n_left + 15) / 16;
173
174       if (ctx->operation == AES_GCM_OP_GMAC)
175         n_lanes++;
176
177       if (n_left < N_AES_BYTES)
178         {
179           clib_memcpy_fast (&r, d, n_left);
180           aes_gcm_ghash_mul_first (ctx, r, n_lanes);
181         }
182       else
183         {
184           aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
185           n_left -= N_AES_BYTES;
186           i = 1;
187
188           if (n_left >= 4 * N_AES_BYTES)
189             {
190               aes_gcm_ghash_mul_next (ctx, d[i]);
191               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
192               aes_gcm_ghash_mul_next (ctx, d[i + 2]);
193               aes_gcm_ghash_mul_next (ctx, d[i + 3]);
194               n_left -= 4 * N_AES_BYTES;
195               i += 4;
196             }
197           if (n_left >= 2 * N_AES_BYTES)
198             {
199               aes_gcm_ghash_mul_next (ctx, d[i]);
200               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
201               n_left -= 2 * N_AES_BYTES;
202               i += 2;
203             }
204
205           if (n_left >= N_AES_BYTES)
206             {
207               aes_gcm_ghash_mul_next (ctx, d[i]);
208               n_left -= N_AES_BYTES;
209               i += 1;
210             }
211
212           if (n_left)
213             {
214               clib_memcpy_fast (&r, d + i, n_left);
215               aes_gcm_ghash_mul_next (ctx, r);
216             }
217         }
218
219       if (ctx->operation == AES_GCM_OP_GMAC)
220         aes_gcm_ghash_mul_final_block (ctx);
221       aes_gcm_ghash_reduce (ctx);
222       aes_gcm_ghash_reduce2 (ctx);
223       aes_gcm_ghash_final (ctx);
224     }
225   else if (ctx->operation == AES_GCM_OP_GMAC)
226     ctx->T =
227       ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
228
229 done:
230   /* encrypt counter 0 E(Y0, k) */
231   if (ctx->operation == AES_GCM_OP_GMAC)
232     for (int i = 0; i < ctx->rounds + 1; i += 1)
233       aes_gcm_enc_ctr0_round (ctx, i);
234 }
235
236 static_always_inline void
237 aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
238 {
239   const aes_expaned_key_t Ke0 = ctx->Ke[0];
240   uword i = 0;
241
242   /* As counter is stored in network byte order for performance reasons we
243      are incrementing least significant byte only except in case where we
244      overlow. As we are processing four 128, 256 or 512-blocks in parallel
245      except the last round, overflow can happen only when n_blocks == 4 */
246
247 #if N_AES_LANES == 4
248   const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
249                                 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
250
251   const u32x16 ctr_4444 = {
252     4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
253   };
254
255   if (n_blocks == 4)
256     for (; i < 2; i++)
257       {
258         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
259         ctx->Y += ctr_inv_4444;
260       }
261
262   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 242))
263     {
264       u32x16 Yr = (u32x16) aes_reflect ((u8x64) ctx->Y);
265
266       for (; i < n_blocks; i++)
267         {
268           r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
269           Yr += ctr_4444;
270           ctx->Y = (u32x16) aes_reflect ((u8x64) Yr);
271         }
272     }
273   else
274     {
275       for (; i < n_blocks; i++)
276         {
277           r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
278           ctx->Y += ctr_inv_4444;
279         }
280     }
281   ctx->counter += n_blocks * 4;
282 #elif N_AES_LANES == 2
283   const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
284   const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
285
286   if (n_blocks == 4)
287     for (; i < 2; i++)
288       {
289         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
290         ctx->Y += ctr_inv_22;
291       }
292
293   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 250))
294     {
295       u32x8 Yr = (u32x8) aes_reflect ((u8x32) ctx->Y);
296
297       for (; i < n_blocks; i++)
298         {
299           r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
300           Yr += ctr_22;
301           ctx->Y = (u32x8) aes_reflect ((u8x32) Yr);
302         }
303     }
304   else
305     {
306       for (; i < n_blocks; i++)
307         {
308           r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
309           ctx->Y += ctr_inv_22;
310         }
311     }
312   ctx->counter += n_blocks * 2;
313 #else
314   const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 };
315
316   if (PREDICT_TRUE ((u8) ctx->counter < 0xfe) || n_blocks < 3)
317     {
318       for (; i < n_blocks; i++)
319         {
320           r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
321           ctx->Y += ctr_inv_1;
322         }
323       ctx->counter += n_blocks;
324     }
325   else
326     {
327       r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
328       ctx->Y += ctr_inv_1;
329       ctx->counter += 1;
330
331       for (; i < n_blocks; i++)
332         {
333           r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
334           ctx->counter++;
335           ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
336         }
337     }
338 #endif
339 }
340
341 static_always_inline void
342 aes_gcm_enc_last_round (aes_gcm_ctx_t *ctx, aes_data_t *r, aes_data_t *d,
343                         const aes_expaned_key_t *Ke, uword n_blocks)
344 {
345   /* additional ronuds for AES-192 and AES-256 */
346   for (int i = 10; i < ctx->rounds; i++)
347     aes_enc_round (r, Ke + i, n_blocks);
348
349   aes_enc_last_round (r, d, Ke + ctx->rounds, n_blocks);
350 }
351
352 static_always_inline void
353 aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
354               u32 n_bytes, int with_ghash)
355 {
356   const aes_expaned_key_t *k = ctx->Ke;
357   const aes_mem_t *sv = (aes_mem_t *) src;
358   aes_mem_t *dv = (aes_mem_t *) dst;
359   uword ghash_blocks, gc = 1;
360   aes_data_t r[4];
361   u32 i, n_lanes;
362
363   if (ctx->operation == AES_GCM_OP_ENCRYPT)
364     {
365       ghash_blocks = 4;
366       n_lanes = N_AES_LANES * 4;
367     }
368   else
369     {
370       ghash_blocks = n;
371       n_lanes = n * N_AES_LANES;
372 #if N_AES_LANES != 1
373       if (ctx->last)
374         n_lanes = (n_bytes + 15) / 16;
375 #endif
376     }
377
378   n_bytes -= (n - 1) * N_AES_BYTES;
379
380   /* AES rounds 0 and 1 */
381   aes_gcm_enc_first_round (ctx, r, n);
382   aes_enc_round (r, k + 1, n);
383
384   /* load data - decrypt round */
385   if (ctx->operation == AES_GCM_OP_DECRYPT)
386     {
387       for (i = 0; i < n - ctx->last; i++)
388         d[i] = sv[i];
389
390       if (ctx->last)
391         d[n - 1] = aes_load_partial ((u8 *) (sv + n - 1), n_bytes);
392     }
393
394   /* GHASH multiply block 0 */
395   if (with_ghash)
396     aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
397
398   /* AES rounds 2 and 3 */
399   aes_enc_round (r, k + 2, n);
400   aes_enc_round (r, k + 3, n);
401
402   /* GHASH multiply block 1 */
403   if (with_ghash && gc++ < ghash_blocks)
404     aes_gcm_ghash_mul_next (ctx, (d[1]));
405
406   /* AES rounds 4 and 5 */
407   aes_enc_round (r, k + 4, n);
408   aes_enc_round (r, k + 5, n);
409
410   /* GHASH multiply block 2 */
411   if (with_ghash && gc++ < ghash_blocks)
412     aes_gcm_ghash_mul_next (ctx, (d[2]));
413
414   /* AES rounds 6 and 7 */
415   aes_enc_round (r, k + 6, n);
416   aes_enc_round (r, k + 7, n);
417
418   /* GHASH multiply block 3 */
419   if (with_ghash && gc++ < ghash_blocks)
420     aes_gcm_ghash_mul_next (ctx, (d[3]));
421
422   /* load 4 blocks of data - decrypt round */
423   if (ctx->operation == AES_GCM_OP_ENCRYPT)
424     {
425       for (i = 0; i < n - ctx->last; i++)
426         d[i] = sv[i];
427
428       if (ctx->last)
429         d[n - 1] = aes_load_partial (sv + n - 1, n_bytes);
430     }
431
432   /* AES rounds 8 and 9 */
433   aes_enc_round (r, k + 8, n);
434   aes_enc_round (r, k + 9, n);
435
436   /* AES last round(s) */
437   aes_gcm_enc_last_round (ctx, r, d, k, n);
438
439   /* store data */
440   for (i = 0; i < n - ctx->last; i++)
441     dv[i] = d[i];
442
443   if (ctx->last)
444     aes_store_partial (d[n - 1], dv + n - 1, n_bytes);
445
446   /* GHASH reduce 1st step */
447   aes_gcm_ghash_reduce (ctx);
448
449   /* GHASH reduce 2nd step */
450   if (with_ghash)
451     aes_gcm_ghash_reduce2 (ctx);
452
453   /* GHASH final step */
454   if (with_ghash)
455     aes_gcm_ghash_final (ctx);
456 }
457
458 static_always_inline void
459 aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
460 {
461   const aes_expaned_key_t *k = ctx->Ke;
462   const aes_mem_t *sv = (aes_mem_t *) src;
463   aes_mem_t *dv = (aes_mem_t *) dst;
464   aes_data_t r[4];
465
466   /* AES rounds 0 and 1 */
467   aes_gcm_enc_first_round (ctx, r, 4);
468   aes_enc_round (r, k + 1, 4);
469
470   /* load 4 blocks of data - decrypt round */
471   if (ctx->operation == AES_GCM_OP_DECRYPT)
472     for (int i = 0; i < 4; i++)
473       d[i] = sv[i];
474
475   /* GHASH multiply block 0 */
476   aes_gcm_ghash_mul_first (ctx, d[0], N_AES_LANES * 8);
477
478   /* AES rounds 2 and 3 */
479   aes_enc_round (r, k + 2, 4);
480   aes_enc_round (r, k + 3, 4);
481
482   /* GHASH multiply block 1 */
483   aes_gcm_ghash_mul_next (ctx, (d[1]));
484
485   /* AES rounds 4 and 5 */
486   aes_enc_round (r, k + 4, 4);
487   aes_enc_round (r, k + 5, 4);
488
489   /* GHASH multiply block 2 */
490   aes_gcm_ghash_mul_next (ctx, (d[2]));
491
492   /* AES rounds 6 and 7 */
493   aes_enc_round (r, k + 6, 4);
494   aes_enc_round (r, k + 7, 4);
495
496   /* GHASH multiply block 3 */
497   aes_gcm_ghash_mul_next (ctx, (d[3]));
498
499   /* AES rounds 8 and 9 */
500   aes_enc_round (r, k + 8, 4);
501   aes_enc_round (r, k + 9, 4);
502
503   /* load 4 blocks of data - encrypt round */
504   if (ctx->operation == AES_GCM_OP_ENCRYPT)
505     for (int i = 0; i < 4; i++)
506       d[i] = sv[i];
507
508   /* AES last round(s) */
509   aes_gcm_enc_last_round (ctx, r, d, k, 4);
510
511   /* store 4 blocks of data */
512   for (int i = 0; i < 4; i++)
513     dv[i] = d[i];
514
515   /* load next 4 blocks of data data - decrypt round */
516   if (ctx->operation == AES_GCM_OP_DECRYPT)
517     for (int i = 0; i < 4; i++)
518       d[i] = sv[i + 4];
519
520   /* GHASH multiply block 4 */
521   aes_gcm_ghash_mul_next (ctx, (d[0]));
522
523   /* AES rounds 0 and 1 */
524   aes_gcm_enc_first_round (ctx, r, 4);
525   aes_enc_round (r, k + 1, 4);
526
527   /* GHASH multiply block 5 */
528   aes_gcm_ghash_mul_next (ctx, (d[1]));
529
530   /* AES rounds 2 and 3 */
531   aes_enc_round (r, k + 2, 4);
532   aes_enc_round (r, k + 3, 4);
533
534   /* GHASH multiply block 6 */
535   aes_gcm_ghash_mul_next (ctx, (d[2]));
536
537   /* AES rounds 4 and 5 */
538   aes_enc_round (r, k + 4, 4);
539   aes_enc_round (r, k + 5, 4);
540
541   /* GHASH multiply block 7 */
542   aes_gcm_ghash_mul_next (ctx, (d[3]));
543
544   /* AES rounds 6 and 7 */
545   aes_enc_round (r, k + 6, 4);
546   aes_enc_round (r, k + 7, 4);
547
548   /* GHASH reduce 1st step */
549   aes_gcm_ghash_reduce (ctx);
550
551   /* AES rounds 8 and 9 */
552   aes_enc_round (r, k + 8, 4);
553   aes_enc_round (r, k + 9, 4);
554
555   /* GHASH reduce 2nd step */
556   aes_gcm_ghash_reduce2 (ctx);
557
558   /* load 4 blocks of data - encrypt round */
559   if (ctx->operation == AES_GCM_OP_ENCRYPT)
560     for (int i = 0; i < 4; i++)
561       d[i] = sv[i + 4];
562
563   /* AES last round(s) */
564   aes_gcm_enc_last_round (ctx, r, d, k, 4);
565
566   /* store data */
567   for (int i = 0; i < 4; i++)
568     dv[i + 4] = d[i];
569
570   /* GHASH final step */
571   aes_gcm_ghash_final (ctx);
572 }
573
574 static_always_inline void
575 aes_gcm_mask_bytes (aes_data_t *d, uword n_bytes)
576 {
577   const union
578   {
579     u8 b[64];
580     aes_data_t r;
581   } scale = {
582     .b = { 0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
583            16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
584            32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
585            48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 },
586   };
587
588   d[0] &= (aes_gcm_splat (n_bytes) > scale.r);
589 }
590
591 static_always_inline void
592 aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
593                    u32 n_bytes)
594 {
595   int n_lanes = (N_AES_LANES == 1 ? n_blocks : (n_bytes + 15) / 16) + 1;
596   n_bytes -= (n_blocks - 1) * N_AES_BYTES;
597   int i;
598
599   aes_gcm_enc_ctr0_round (ctx, 0);
600   aes_gcm_enc_ctr0_round (ctx, 1);
601
602   if (n_bytes != N_AES_BYTES)
603     aes_gcm_mask_bytes (d + n_blocks - 1, n_bytes);
604
605   aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
606
607   aes_gcm_enc_ctr0_round (ctx, 2);
608   aes_gcm_enc_ctr0_round (ctx, 3);
609
610   if (n_blocks > 1)
611     aes_gcm_ghash_mul_next (ctx, d[1]);
612
613   aes_gcm_enc_ctr0_round (ctx, 4);
614   aes_gcm_enc_ctr0_round (ctx, 5);
615
616   if (n_blocks > 2)
617     aes_gcm_ghash_mul_next (ctx, d[2]);
618
619   aes_gcm_enc_ctr0_round (ctx, 6);
620   aes_gcm_enc_ctr0_round (ctx, 7);
621
622   if (n_blocks > 3)
623     aes_gcm_ghash_mul_next (ctx, d[3]);
624
625   aes_gcm_enc_ctr0_round (ctx, 8);
626   aes_gcm_enc_ctr0_round (ctx, 9);
627
628   aes_gcm_ghash_mul_final_block (ctx);
629   aes_gcm_ghash_reduce (ctx);
630
631   for (i = 10; i < ctx->rounds; i++)
632     aes_gcm_enc_ctr0_round (ctx, i);
633
634   aes_gcm_ghash_reduce2 (ctx);
635
636   aes_gcm_ghash_final (ctx);
637
638   aes_gcm_enc_ctr0_round (ctx, i);
639 }
640
641 static_always_inline void
642 aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
643 {
644   aes_data_t d[4];
645
646   if (PREDICT_FALSE (n_left == 0))
647     {
648       int i;
649       for (i = 0; i < ctx->rounds + 1; i++)
650         aes_gcm_enc_ctr0_round (ctx, i);
651       return;
652     }
653
654   if (n_left < 4 * N_AES_BYTES)
655     {
656       ctx->last = 1;
657       if (n_left > 3 * N_AES_BYTES)
658         {
659           aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 0);
660           aes_gcm_calc_last (ctx, d, 4, n_left);
661         }
662       else if (n_left > 2 * N_AES_BYTES)
663         {
664           aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 0);
665           aes_gcm_calc_last (ctx, d, 3, n_left);
666         }
667       else if (n_left > N_AES_BYTES)
668         {
669           aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 0);
670           aes_gcm_calc_last (ctx, d, 2, n_left);
671         }
672       else
673         {
674           aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 0);
675           aes_gcm_calc_last (ctx, d, 1, n_left);
676         }
677       return;
678     }
679
680   aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 0);
681
682   /* next */
683   n_left -= 4 * N_AES_BYTES;
684   dst += 4 * N_AES_BYTES;
685   src += 4 * N_AES_BYTES;
686
687   for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, src += n, dst += n)
688     aes_gcm_calc_double (ctx, d, src, dst);
689
690   if (n_left >= 4 * N_AES_BYTES)
691     {
692       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 1);
693
694       /* next */
695       n_left -= 4 * N_AES_BYTES;
696       dst += 4 * N_AES_BYTES;
697       src += 4 * N_AES_BYTES;
698     }
699
700   if (n_left == 0)
701     {
702       aes_gcm_calc_last (ctx, d, 4, 4 * N_AES_BYTES);
703       return;
704     }
705
706   ctx->last = 1;
707
708   if (n_left > 3 * N_AES_BYTES)
709     {
710       aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
711       aes_gcm_calc_last (ctx, d, 4, n_left);
712     }
713   else if (n_left > 2 * N_AES_BYTES)
714     {
715       aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
716       aes_gcm_calc_last (ctx, d, 3, n_left);
717     }
718   else if (n_left > N_AES_BYTES)
719     {
720       aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
721       aes_gcm_calc_last (ctx, d, 2, n_left);
722     }
723   else
724     {
725       aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
726       aes_gcm_calc_last (ctx, d, 1, n_left);
727     }
728 }
729
730 static_always_inline void
731 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
732 {
733   aes_data_t d[4] = {};
734   ghash_ctx_t gd;
735
736   /* main encryption loop */
737   for (int n = 8 * N_AES_BYTES; n_left >= n; n_left -= n, dst += n, src += n)
738     aes_gcm_calc_double (ctx, d, src, dst);
739
740   if (n_left >= 4 * N_AES_BYTES)
741     {
742       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N_AES_BYTES, /* with_ghash */ 1);
743
744       /* next */
745       n_left -= 4 * N_AES_BYTES;
746       dst += N_AES_BYTES * 4;
747       src += N_AES_BYTES * 4;
748     }
749
750   if (n_left)
751     {
752       ctx->last = 1;
753
754       if (n_left > 3 * N_AES_BYTES)
755         aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
756       else if (n_left > 2 * N_AES_BYTES)
757         aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
758       else if (n_left > N_AES_BYTES)
759         aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
760       else
761         aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
762     }
763
764   /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
765    * (bit length) block */
766
767   aes_gcm_enc_ctr0_round (ctx, 0);
768   aes_gcm_enc_ctr0_round (ctx, 1);
769
770   ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
771                    ctx->Hi[NUM_HI - 1]);
772
773   aes_gcm_enc_ctr0_round (ctx, 2);
774   aes_gcm_enc_ctr0_round (ctx, 3);
775
776   ghash_reduce (&gd);
777
778   aes_gcm_enc_ctr0_round (ctx, 4);
779   aes_gcm_enc_ctr0_round (ctx, 5);
780
781   ghash_reduce2 (&gd);
782
783   aes_gcm_enc_ctr0_round (ctx, 6);
784   aes_gcm_enc_ctr0_round (ctx, 7);
785
786   ctx->T = ghash_final (&gd);
787
788   aes_gcm_enc_ctr0_round (ctx, 8);
789   aes_gcm_enc_ctr0_round (ctx, 9);
790
791   for (int i = 10; i < ctx->rounds + 1; i += 1)
792     aes_gcm_enc_ctr0_round (ctx, i);
793 }
794
795 static_always_inline int
796 aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
797          u32 data_bytes, u32 aad_bytes, u8 tag_len,
798          const aes_gcm_key_data_t *kd, int aes_rounds, aes_gcm_op_t op)
799 {
800   u8 *addt = (u8 *) aad;
801   u32x4 Y0;
802
803   aes_gcm_ctx_t _ctx = { .counter = 2,
804                          .rounds = aes_rounds,
805                          .operation = op,
806                          .data_bytes = data_bytes,
807                          .aad_bytes = aad_bytes,
808                          .Ke = kd->Ke,
809                          .Hi = kd->Hi },
810                 *ctx = &_ctx;
811
812   /* initalize counter */
813   Y0 = (u32x4) (u64x2){ *(u64u *) ivp, 0 };
814   Y0[2] = *(u32u *) (ivp + 8);
815   Y0[3] = 1 << 24;
816   ctx->EY0 = (u8x16) Y0;
817
818 #if N_AES_LANES == 4
819   ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
820     0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
821   };
822 #elif N_AES_LANES == 2
823   ctx->Y =
824     u32x8_splat_u32x4 (Y0) + (u32x8){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24 };
825 #else
826   ctx->Y = Y0 + (u32x4){ 0, 0, 0, 1 << 24 };
827 #endif
828
829   /* calculate ghash for AAD */
830   aes_gcm_ghash (ctx, addt, aad_bytes);
831
832   /* ghash and encrypt/edcrypt  */
833   if (op == AES_GCM_OP_ENCRYPT)
834     aes_gcm_enc (ctx, src, dst, data_bytes);
835   else if (op == AES_GCM_OP_DECRYPT)
836     aes_gcm_dec (ctx, src, dst, data_bytes);
837
838   /* final tag is */
839   ctx->T = u8x16_reflect (ctx->T) ^ ctx->EY0;
840
841   /* tag_len 16 -> 0 */
842   tag_len &= 0xf;
843
844   if (op == AES_GCM_OP_ENCRYPT || op == AES_GCM_OP_GMAC)
845     {
846       /* store tag */
847       if (tag_len)
848         u8x16_store_partial (ctx->T, tag, tag_len);
849       else
850         ((u8x16u *) tag)[0] = ctx->T;
851     }
852   else
853     {
854       /* check tag */
855       if (tag_len)
856         {
857           u16 mask = pow2_mask (tag_len);
858           u8x16 expected = u8x16_load_partial (tag, tag_len);
859           if ((u8x16_msb_mask (expected == ctx->T) & mask) == mask)
860             return 1;
861         }
862       else
863         {
864           if (u8x16_is_equal (ctx->T, *(u8x16u *) tag))
865             return 1;
866         }
867     }
868   return 0;
869 }
870
871 static_always_inline void
872 clib_aes_gcm_key_expand (aes_gcm_key_data_t *kd, const u8 *key,
873                          aes_key_size_t ks)
874 {
875   u8x16 H;
876   u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
877   aes_expaned_key_t *Ke = (aes_expaned_key_t *) kd->Ke;
878
879   /* expand AES key */
880   aes_key_expand (ek, key, ks);
881   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
882     Ke[i].lanes[0] = Ke[i].lanes[1] = Ke[i].lanes[2] = Ke[i].lanes[3] = ek[i];
883
884   /* pre-calculate H */
885   H = aes_encrypt_block (u8x16_zero (), ek, ks);
886   H = u8x16_reflect (H);
887   ghash_precompute (H, (u8x16 *) kd->Hi, ARRAY_LEN (kd->Hi));
888 }
889
890 static_always_inline void
891 clib_aes128_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
892                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
893                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
894 {
895   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
896            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_ENCRYPT);
897 }
898
899 static_always_inline void
900 clib_aes256_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
901                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
902                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
903 {
904   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
905            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_ENCRYPT);
906 }
907
908 static_always_inline int
909 clib_aes128_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
910                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
911                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
912 {
913   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
914                   data_bytes, aad_bytes, tag_bytes, kd,
915                   AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_DECRYPT);
916 }
917
918 static_always_inline int
919 clib_aes256_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
920                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
921                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
922 {
923   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
924                   data_bytes, aad_bytes, tag_bytes, kd,
925                   AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_DECRYPT);
926 }
927
928 static_always_inline void
929 clib_aes128_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
930                   const u8 *iv, u32 tag_bytes, u8 *tag)
931 {
932   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
933            AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_GMAC);
934 }
935
936 static_always_inline void
937 clib_aes256_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
938                   const u8 *iv, u32 tag_bytes, u8 *tag)
939 {
940   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
941            AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_GMAC);
942 }
943
944 #endif /* __crypto_aes_gcm_h__ */