vppinfra: small improvement and polishing of AES GCM code
[vpp.git] / src / vppinfra / crypto / aes_gcm.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2023 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_gcm_h__
6 #define __crypto_aes_gcm_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/cache.h>
11 #include <vppinfra/string.h>
12 #include <vppinfra/crypto/aes.h>
13 #include <vppinfra/crypto/ghash.h>
14
15 #define NUM_HI 36
16 #if defined(__VAES__) && defined(__AVX512F__)
17 typedef u8x64 aes_data_t;
18 typedef u8x64u aes_ghash_t;
19 typedef u8x64u aes_mem_t;
20 typedef u32x16 aes_gcm_counter_t;
21 #define N                              64
22 #define aes_gcm_load_partial(p, n)     u8x64_load_partial ((u8 *) (p), n)
23 #define aes_gcm_store_partial(v, p, n) u8x64_store_partial (v, (u8 *) (p), n)
24 #define aes_gcm_splat(v)               u8x64_splat (v)
25 #define aes_gcm_reflect(r)             u8x64_reflect_u8x16 (r)
26 #define aes_gcm_ghash_reduce(c)        ghash4_reduce (&(c)->gd)
27 #define aes_gcm_ghash_reduce2(c)       ghash4_reduce2 (&(c)->gd)
28 #define aes_gcm_ghash_final(c)         (c)->T = ghash4_final (&(c)->gd)
29 #elif defined(__VAES__)
30 typedef u8x32 aes_data_t;
31 typedef u8x32u aes_ghash_t;
32 typedef u8x32u aes_mem_t;
33 typedef u32x8 aes_gcm_counter_t;
34 #define N                              32
35 #define aes_gcm_load_partial(p, n)     u8x32_load_partial ((u8 *) (p), n)
36 #define aes_gcm_store_partial(v, p, n) u8x32_store_partial (v, (u8 *) (p), n)
37 #define aes_gcm_splat(v)               u8x32_splat (v)
38 #define aes_gcm_reflect(r)             u8x32_reflect_u8x16 (r)
39 #define aes_gcm_ghash_reduce(c)        ghash2_reduce (&(c)->gd)
40 #define aes_gcm_ghash_reduce2(c)       ghash2_reduce2 (&(c)->gd)
41 #define aes_gcm_ghash_final(c)         (c)->T = ghash2_final (&(c)->gd)
42 #else
43 typedef u8x16 aes_data_t;
44 typedef u8x16 aes_ghash_t;
45 typedef u8x16u aes_mem_t;
46 typedef u32x4 aes_gcm_counter_t;
47 #define N                              16
48 #define aes_gcm_load_partial(p, n)     u8x16_load_partial ((u8 *) (p), n)
49 #define aes_gcm_store_partial(v, p, n) u8x16_store_partial (v, (u8 *) (p), n)
50 #define aes_gcm_splat(v)               u8x16_splat (v)
51 #define aes_gcm_reflect(r)             u8x16_reflect (r)
52 #define aes_gcm_ghash_reduce(c)        ghash_reduce (&(c)->gd)
53 #define aes_gcm_ghash_reduce2(c)       ghash_reduce2 (&(c)->gd)
54 #define aes_gcm_ghash_final(c)         (c)->T = ghash_final (&(c)->gd)
55 #endif
56 #define N_LANES (N / 16)
57
58 typedef enum
59 {
60   AES_GCM_OP_UNKNONW = 0,
61   AES_GCM_OP_ENCRYPT,
62   AES_GCM_OP_DECRYPT,
63   AES_GCM_OP_GMAC
64 } aes_gcm_op_t;
65
66 typedef union
67 {
68   u8x16 x1;
69   u8x32 x2;
70   u8x64 x4;
71   u8x16 lanes[4];
72 } __clib_aligned (64)
73 aes_gcm_expaned_key_t;
74
75 typedef struct
76 {
77   /* pre-calculated hash key values */
78   const u8x16 Hi[NUM_HI];
79   /* extracted AES key */
80   const aes_gcm_expaned_key_t Ke[AES_KEY_ROUNDS (AES_KEY_256) + 1];
81 } aes_gcm_key_data_t;
82
83 typedef struct
84 {
85   aes_gcm_op_t operation;
86   int last;
87   u8 rounds;
88   uword data_bytes;
89   uword aad_bytes;
90
91   u8x16 T;
92
93   /* hash */
94   const u8x16 *Hi;
95   const aes_ghash_t *next_Hi;
96
97   /* expaded keys */
98   const aes_gcm_expaned_key_t *Ke;
99
100   /* counter */
101   u32 counter;
102   u8x16 EY0;
103   aes_gcm_counter_t Y;
104
105   /* ghash */
106   ghash_ctx_t gd;
107 } aes_gcm_ctx_t;
108
109 static_always_inline u8x16
110 aes_gcm_final_block (aes_gcm_ctx_t *ctx)
111 {
112   return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
113 }
114
115 static_always_inline void
116 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
117 {
118   uword hash_offset = NUM_HI - n_lanes;
119   ctx->next_Hi = (aes_ghash_t *) (ctx->Hi + hash_offset);
120 #if N_LANES == 4
121   u8x64 tag4 = {};
122   tag4 = u8x64_insert_u8x16 (tag4, ctx->T, 0);
123   ghash4_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag4, *ctx->next_Hi++);
124 #elif N_LANES == 2
125   u8x32 tag2 = {};
126   tag2 = u8x32_insert_lo (tag2, ctx->T);
127   ghash2_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ tag2, *ctx->next_Hi++);
128 #else
129   ghash_mul_first (&ctx->gd, aes_gcm_reflect (data) ^ ctx->T, *ctx->next_Hi++);
130 #endif
131 }
132
133 static_always_inline void
134 aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
135 {
136 #if N_LANES == 4
137   ghash4_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
138 #elif N_LANES == 2
139   ghash2_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
140 #else
141   ghash_mul_next (&ctx->gd, aes_gcm_reflect (data), *ctx->next_Hi++);
142 #endif
143 }
144
145 static_always_inline void
146 aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
147 {
148 #if N_LANES == 4
149   u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
150   u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
151   ghash4_mul_next (&ctx->gd, r4, h);
152 #elif N_LANES == 2
153   u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
154   u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
155   ghash2_mul_next (&ctx->gd, r2, h);
156 #else
157   ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
158 #endif
159 }
160
161 static_always_inline void
162 aes_gcm_enc_ctr0_round (aes_gcm_ctx_t *ctx, int aes_round)
163 {
164   if (aes_round == 0)
165     ctx->EY0 ^= ctx->Ke[0].x1;
166   else if (aes_round == ctx->rounds)
167     ctx->EY0 = aes_enc_last_round (ctx->EY0, ctx->Ke[aes_round].x1);
168   else
169     ctx->EY0 = aes_enc_round (ctx->EY0, ctx->Ke[aes_round].x1);
170 }
171
172 static_always_inline void
173 aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
174 {
175   uword i;
176   aes_data_t r = {};
177   const aes_mem_t *d = (aes_mem_t *) data;
178
179   for (; n_left >= 8 * N; n_left -= 8 * N, d += 8)
180     {
181       if (ctx->operation == AES_GCM_OP_GMAC && n_left == N * 8)
182         {
183           aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1);
184           for (i = 1; i < 8; i++)
185             aes_gcm_ghash_mul_next (ctx, d[i]);
186           aes_gcm_ghash_mul_final_block (ctx);
187           aes_gcm_ghash_reduce (ctx);
188           aes_gcm_ghash_reduce2 (ctx);
189           aes_gcm_ghash_final (ctx);
190           goto done;
191         }
192
193       aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES);
194       for (i = 1; i < 8; i++)
195         aes_gcm_ghash_mul_next (ctx, d[i]);
196       aes_gcm_ghash_reduce (ctx);
197       aes_gcm_ghash_reduce2 (ctx);
198       aes_gcm_ghash_final (ctx);
199     }
200
201   if (n_left > 0)
202     {
203       int n_lanes = (n_left + 15) / 16;
204
205       if (ctx->operation == AES_GCM_OP_GMAC)
206         n_lanes++;
207
208       if (n_left < N)
209         {
210           clib_memcpy_fast (&r, d, n_left);
211           aes_gcm_ghash_mul_first (ctx, r, n_lanes);
212         }
213       else
214         {
215           aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
216           n_left -= N;
217           i = 1;
218
219           if (n_left >= 4 * N)
220             {
221               aes_gcm_ghash_mul_next (ctx, d[i]);
222               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
223               aes_gcm_ghash_mul_next (ctx, d[i + 2]);
224               aes_gcm_ghash_mul_next (ctx, d[i + 3]);
225               n_left -= 4 * N;
226               i += 4;
227             }
228           if (n_left >= 2 * N)
229             {
230               aes_gcm_ghash_mul_next (ctx, d[i]);
231               aes_gcm_ghash_mul_next (ctx, d[i + 1]);
232               n_left -= 2 * N;
233               i += 2;
234             }
235
236           if (n_left >= N)
237             {
238               aes_gcm_ghash_mul_next (ctx, d[i]);
239               n_left -= N;
240               i += 1;
241             }
242
243           if (n_left)
244             {
245               clib_memcpy_fast (&r, d + i, n_left);
246               aes_gcm_ghash_mul_next (ctx, r);
247             }
248         }
249
250       if (ctx->operation == AES_GCM_OP_GMAC)
251         aes_gcm_ghash_mul_final_block (ctx);
252       aes_gcm_ghash_reduce (ctx);
253       aes_gcm_ghash_reduce2 (ctx);
254       aes_gcm_ghash_final (ctx);
255     }
256   else if (ctx->operation == AES_GCM_OP_GMAC)
257     ctx->T =
258       ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
259
260 done:
261   /* encrypt counter 0 E(Y0, k) */
262   if (ctx->operation == AES_GCM_OP_GMAC)
263     for (int i = 0; i < ctx->rounds + 1; i += 1)
264       aes_gcm_enc_ctr0_round (ctx, i);
265 }
266
267 static_always_inline void
268 aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
269 {
270   const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0];
271   uword i = 0;
272
273   /* As counter is stored in network byte order for performance reasons we
274      are incrementing least significant byte only except in case where we
275      overlow. As we are processing four 128, 256 or 512-blocks in parallel
276      except the last round, overflow can happen only when n_blocks == 4 */
277
278 #if N_LANES == 4
279   const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
280                                 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
281
282   const u32x16 ctr_4444 = {
283     4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
284   };
285
286   if (n_blocks == 4)
287     for (; i < 2; i++)
288       {
289         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
290         ctx->Y += ctr_inv_4444;
291       }
292
293   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 242))
294     {
295       u32x16 Yr = (u32x16) aes_gcm_reflect ((u8x64) ctx->Y);
296
297       for (; i < n_blocks; i++)
298         {
299           r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
300           Yr += ctr_4444;
301           ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr);
302         }
303     }
304   else
305     {
306       for (; i < n_blocks; i++)
307         {
308           r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
309           ctx->Y += ctr_inv_4444;
310         }
311     }
312   ctx->counter += n_blocks * 4;
313 #elif N_LANES == 2
314   const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
315   const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
316
317   if (n_blocks == 4)
318     for (; i < 2; i++)
319       {
320         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
321         ctx->Y += ctr_inv_22;
322       }
323
324   if (n_blocks == 4 && PREDICT_FALSE ((u8) ctx->counter == 250))
325     {
326       u32x8 Yr = (u32x8) aes_gcm_reflect ((u8x32) ctx->Y);
327
328       for (; i < n_blocks; i++)
329         {
330           r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
331           Yr += ctr_22;
332           ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr);
333         }
334     }
335   else
336     {
337       for (; i < n_blocks; i++)
338         {
339           r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
340           ctx->Y += ctr_inv_22;
341         }
342     }
343   ctx->counter += n_blocks * 2;
344 #else
345   const u32x4 ctr_inv_1 = { 0, 0, 0, 1 << 24 };
346
347   if (PREDICT_TRUE ((u8) ctx->counter < 0xfe) || n_blocks < 3)
348     {
349       for (; i < n_blocks; i++)
350         {
351           r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
352           ctx->Y += ctr_inv_1;
353         }
354       ctx->counter += n_blocks;
355     }
356   else
357     {
358       r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
359       ctx->Y += ctr_inv_1;
360       ctx->counter += 1;
361
362       for (; i < n_blocks; i++)
363         {
364           r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
365           ctx->counter++;
366           ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
367         }
368     }
369 #endif
370 }
371
372 static_always_inline void
373 aes_gcm_enc_round (aes_data_t *r, const aes_gcm_expaned_key_t *Ke,
374                    uword n_blocks)
375 {
376   for (int i = 0; i < n_blocks; i++)
377 #if N_LANES == 4
378     r[i] = aes_enc_round_x4 (r[i], Ke->x4);
379 #elif N_LANES == 2
380     r[i] = aes_enc_round_x2 (r[i], Ke->x2);
381 #else
382     r[i] = aes_enc_round (r[i], Ke->x1);
383 #endif
384 }
385
386 static_always_inline void
387 aes_gcm_enc_last_round (aes_gcm_ctx_t *ctx, aes_data_t *r, aes_data_t *d,
388                         const aes_gcm_expaned_key_t *Ke, uword n_blocks)
389 {
390   /* additional ronuds for AES-192 and AES-256 */
391   for (int i = 10; i < ctx->rounds; i++)
392     aes_gcm_enc_round (r, Ke + i, n_blocks);
393
394   for (int i = 0; i < n_blocks; i++)
395 #if N_LANES == 4
396     d[i] ^= aes_enc_last_round_x4 (r[i], Ke[ctx->rounds].x4);
397 #elif N_LANES == 2
398     d[i] ^= aes_enc_last_round_x2 (r[i], Ke[ctx->rounds].x2);
399 #else
400     d[i] ^= aes_enc_last_round (r[i], Ke[ctx->rounds].x1);
401 #endif
402 }
403
404 static_always_inline void
405 aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
406               u32 n_bytes, int with_ghash)
407 {
408   const aes_gcm_expaned_key_t *k = ctx->Ke;
409   const aes_mem_t *sv = (aes_mem_t *) src;
410   aes_mem_t *dv = (aes_mem_t *) dst;
411   uword ghash_blocks, gc = 1;
412   aes_data_t r[4];
413   u32 i, n_lanes;
414
415   if (ctx->operation == AES_GCM_OP_ENCRYPT)
416     {
417       ghash_blocks = 4;
418       n_lanes = N_LANES * 4;
419     }
420   else
421     {
422       ghash_blocks = n;
423       n_lanes = n * N_LANES;
424 #if N_LANES != 1
425       if (ctx->last)
426         n_lanes = (n_bytes + 15) / 16;
427 #endif
428     }
429
430   n_bytes -= (n - 1) * N;
431
432   /* AES rounds 0 and 1 */
433   aes_gcm_enc_first_round (ctx, r, n);
434   aes_gcm_enc_round (r, k + 1, n);
435
436   /* load data - decrypt round */
437   if (ctx->operation == AES_GCM_OP_DECRYPT)
438     {
439       for (i = 0; i < n - ctx->last; i++)
440         d[i] = sv[i];
441
442       if (ctx->last)
443         d[n - 1] = aes_gcm_load_partial ((u8 *) (sv + n - 1), n_bytes);
444     }
445
446   /* GHASH multiply block 0 */
447   if (with_ghash)
448     aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
449
450   /* AES rounds 2 and 3 */
451   aes_gcm_enc_round (r, k + 2, n);
452   aes_gcm_enc_round (r, k + 3, n);
453
454   /* GHASH multiply block 1 */
455   if (with_ghash && gc++ < ghash_blocks)
456     aes_gcm_ghash_mul_next (ctx, (d[1]));
457
458   /* AES rounds 4 and 5 */
459   aes_gcm_enc_round (r, k + 4, n);
460   aes_gcm_enc_round (r, k + 5, n);
461
462   /* GHASH multiply block 2 */
463   if (with_ghash && gc++ < ghash_blocks)
464     aes_gcm_ghash_mul_next (ctx, (d[2]));
465
466   /* AES rounds 6 and 7 */
467   aes_gcm_enc_round (r, k + 6, n);
468   aes_gcm_enc_round (r, k + 7, n);
469
470   /* GHASH multiply block 3 */
471   if (with_ghash && gc++ < ghash_blocks)
472     aes_gcm_ghash_mul_next (ctx, (d[3]));
473
474   /* load 4 blocks of data - decrypt round */
475   if (ctx->operation == AES_GCM_OP_ENCRYPT)
476     {
477       for (i = 0; i < n - ctx->last; i++)
478         d[i] = sv[i];
479
480       if (ctx->last)
481         d[n - 1] = aes_gcm_load_partial (sv + n - 1, n_bytes);
482     }
483
484   /* AES rounds 8 and 9 */
485   aes_gcm_enc_round (r, k + 8, n);
486   aes_gcm_enc_round (r, k + 9, n);
487
488   /* AES last round(s) */
489   aes_gcm_enc_last_round (ctx, r, d, k, n);
490
491   /* store data */
492   for (i = 0; i < n - ctx->last; i++)
493     dv[i] = d[i];
494
495   if (ctx->last)
496     aes_gcm_store_partial (d[n - 1], dv + n - 1, n_bytes);
497
498   /* GHASH reduce 1st step */
499   aes_gcm_ghash_reduce (ctx);
500
501   /* GHASH reduce 2nd step */
502   if (with_ghash)
503     aes_gcm_ghash_reduce2 (ctx);
504
505   /* GHASH final step */
506   if (with_ghash)
507     aes_gcm_ghash_final (ctx);
508 }
509
510 static_always_inline void
511 aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
512 {
513   const aes_gcm_expaned_key_t *k = ctx->Ke;
514   const aes_mem_t *sv = (aes_mem_t *) src;
515   aes_mem_t *dv = (aes_mem_t *) dst;
516   aes_data_t r[4];
517
518   /* AES rounds 0 and 1 */
519   aes_gcm_enc_first_round (ctx, r, 4);
520   aes_gcm_enc_round (r, k + 1, 4);
521
522   /* load 4 blocks of data - decrypt round */
523   if (ctx->operation == AES_GCM_OP_DECRYPT)
524     for (int i = 0; i < 4; i++)
525       d[i] = sv[i];
526
527   /* GHASH multiply block 0 */
528   aes_gcm_ghash_mul_first (ctx, d[0], N_LANES * 8);
529
530   /* AES rounds 2 and 3 */
531   aes_gcm_enc_round (r, k + 2, 4);
532   aes_gcm_enc_round (r, k + 3, 4);
533
534   /* GHASH multiply block 1 */
535   aes_gcm_ghash_mul_next (ctx, (d[1]));
536
537   /* AES rounds 4 and 5 */
538   aes_gcm_enc_round (r, k + 4, 4);
539   aes_gcm_enc_round (r, k + 5, 4);
540
541   /* GHASH multiply block 2 */
542   aes_gcm_ghash_mul_next (ctx, (d[2]));
543
544   /* AES rounds 6 and 7 */
545   aes_gcm_enc_round (r, k + 6, 4);
546   aes_gcm_enc_round (r, k + 7, 4);
547
548   /* GHASH multiply block 3 */
549   aes_gcm_ghash_mul_next (ctx, (d[3]));
550
551   /* AES rounds 8 and 9 */
552   aes_gcm_enc_round (r, k + 8, 4);
553   aes_gcm_enc_round (r, k + 9, 4);
554
555   /* load 4 blocks of data - encrypt round */
556   if (ctx->operation == AES_GCM_OP_ENCRYPT)
557     for (int i = 0; i < 4; i++)
558       d[i] = sv[i];
559
560   /* AES last round(s) */
561   aes_gcm_enc_last_round (ctx, r, d, k, 4);
562
563   /* store 4 blocks of data */
564   for (int i = 0; i < 4; i++)
565     dv[i] = d[i];
566
567   /* load next 4 blocks of data data - decrypt round */
568   if (ctx->operation == AES_GCM_OP_DECRYPT)
569     for (int i = 0; i < 4; i++)
570       d[i] = sv[i + 4];
571
572   /* GHASH multiply block 4 */
573   aes_gcm_ghash_mul_next (ctx, (d[0]));
574
575   /* AES rounds 0 and 1 */
576   aes_gcm_enc_first_round (ctx, r, 4);
577   aes_gcm_enc_round (r, k + 1, 4);
578
579   /* GHASH multiply block 5 */
580   aes_gcm_ghash_mul_next (ctx, (d[1]));
581
582   /* AES rounds 2 and 3 */
583   aes_gcm_enc_round (r, k + 2, 4);
584   aes_gcm_enc_round (r, k + 3, 4);
585
586   /* GHASH multiply block 6 */
587   aes_gcm_ghash_mul_next (ctx, (d[2]));
588
589   /* AES rounds 4 and 5 */
590   aes_gcm_enc_round (r, k + 4, 4);
591   aes_gcm_enc_round (r, k + 5, 4);
592
593   /* GHASH multiply block 7 */
594   aes_gcm_ghash_mul_next (ctx, (d[3]));
595
596   /* AES rounds 6 and 7 */
597   aes_gcm_enc_round (r, k + 6, 4);
598   aes_gcm_enc_round (r, k + 7, 4);
599
600   /* GHASH reduce 1st step */
601   aes_gcm_ghash_reduce (ctx);
602
603   /* AES rounds 8 and 9 */
604   aes_gcm_enc_round (r, k + 8, 4);
605   aes_gcm_enc_round (r, k + 9, 4);
606
607   /* GHASH reduce 2nd step */
608   aes_gcm_ghash_reduce2 (ctx);
609
610   /* load 4 blocks of data - encrypt round */
611   if (ctx->operation == AES_GCM_OP_ENCRYPT)
612     for (int i = 0; i < 4; i++)
613       d[i] = sv[i + 4];
614
615   /* AES last round(s) */
616   aes_gcm_enc_last_round (ctx, r, d, k, 4);
617
618   /* store data */
619   for (int i = 0; i < 4; i++)
620     dv[i + 4] = d[i];
621
622   /* GHASH final step */
623   aes_gcm_ghash_final (ctx);
624 }
625
626 static_always_inline void
627 aes_gcm_mask_bytes (aes_data_t *d, uword n_bytes)
628 {
629   const union
630   {
631     u8 b[64];
632     aes_data_t r;
633   } scale = {
634     .b = { 0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
635            16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
636            32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
637            48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63 },
638   };
639
640   d[0] &= (aes_gcm_splat (n_bytes) > scale.r);
641 }
642
643 static_always_inline void
644 aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
645                    u32 n_bytes)
646 {
647   int n_lanes = (N_LANES == 1 ? n_blocks : (n_bytes + 15) / 16) + 1;
648   n_bytes -= (n_blocks - 1) * N;
649   int i;
650
651   aes_gcm_enc_ctr0_round (ctx, 0);
652   aes_gcm_enc_ctr0_round (ctx, 1);
653
654   if (n_bytes != N)
655     aes_gcm_mask_bytes (d + n_blocks - 1, n_bytes);
656
657   aes_gcm_ghash_mul_first (ctx, d[0], n_lanes);
658
659   aes_gcm_enc_ctr0_round (ctx, 2);
660   aes_gcm_enc_ctr0_round (ctx, 3);
661
662   if (n_blocks > 1)
663     aes_gcm_ghash_mul_next (ctx, d[1]);
664
665   aes_gcm_enc_ctr0_round (ctx, 4);
666   aes_gcm_enc_ctr0_round (ctx, 5);
667
668   if (n_blocks > 2)
669     aes_gcm_ghash_mul_next (ctx, d[2]);
670
671   aes_gcm_enc_ctr0_round (ctx, 6);
672   aes_gcm_enc_ctr0_round (ctx, 7);
673
674   if (n_blocks > 3)
675     aes_gcm_ghash_mul_next (ctx, d[3]);
676
677   aes_gcm_enc_ctr0_round (ctx, 8);
678   aes_gcm_enc_ctr0_round (ctx, 9);
679
680   aes_gcm_ghash_mul_final_block (ctx);
681   aes_gcm_ghash_reduce (ctx);
682
683   for (i = 10; i < ctx->rounds; i++)
684     aes_gcm_enc_ctr0_round (ctx, i);
685
686   aes_gcm_ghash_reduce2 (ctx);
687
688   aes_gcm_ghash_final (ctx);
689
690   aes_gcm_enc_ctr0_round (ctx, i);
691 }
692
693 static_always_inline void
694 aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
695 {
696   aes_data_t d[4];
697
698   if (PREDICT_FALSE (n_left == 0))
699     {
700       int i;
701       for (i = 0; i < ctx->rounds + 1; i++)
702         aes_gcm_enc_ctr0_round (ctx, i);
703       return;
704     }
705
706   if (n_left < 4 * N)
707     {
708       ctx->last = 1;
709       if (n_left > 3 * N)
710         {
711           aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 0);
712           aes_gcm_calc_last (ctx, d, 4, n_left);
713         }
714       else if (n_left > 2 * N)
715         {
716           aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 0);
717           aes_gcm_calc_last (ctx, d, 3, n_left);
718         }
719       else if (n_left > N)
720         {
721           aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 0);
722           aes_gcm_calc_last (ctx, d, 2, n_left);
723         }
724       else
725         {
726           aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 0);
727           aes_gcm_calc_last (ctx, d, 1, n_left);
728         }
729       return;
730     }
731
732   aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0);
733
734   /* next */
735   n_left -= 4 * N;
736   dst += 4 * N;
737   src += 4 * N;
738
739   for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N)
740     aes_gcm_calc_double (ctx, d, src, dst);
741
742   if (n_left >= 4 * N)
743     {
744       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1);
745
746       /* next */
747       n_left -= 4 * N;
748       dst += 4 * N;
749       src += 4 * N;
750     }
751
752   if (n_left == 0)
753     {
754       aes_gcm_calc_last (ctx, d, 4, 4 * N);
755       return;
756     }
757
758   ctx->last = 1;
759
760   if (n_left > 3 * N)
761     {
762       aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
763       aes_gcm_calc_last (ctx, d, 4, n_left);
764     }
765   else if (n_left > 2 * N)
766     {
767       aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
768       aes_gcm_calc_last (ctx, d, 3, n_left);
769     }
770   else if (n_left > N)
771     {
772       aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
773       aes_gcm_calc_last (ctx, d, 2, n_left);
774     }
775   else
776     {
777       aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
778       aes_gcm_calc_last (ctx, d, 1, n_left);
779     }
780 }
781
782 static_always_inline void
783 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
784 {
785   aes_data_t d[4] = {};
786   ghash_ctx_t gd;
787
788   /* main encryption loop */
789   for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N)
790     aes_gcm_calc_double (ctx, d, src, dst);
791
792   if (n_left >= 4 * N)
793     {
794       aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 1);
795
796       /* next */
797       n_left -= 4 * N;
798       dst += N * 4;
799       src += N * 4;
800     }
801
802   if (n_left)
803     {
804       ctx->last = 1;
805
806       if (n_left > 3 * N)
807         aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
808       else if (n_left > 2 * N)
809         aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
810       else if (n_left > N)
811         aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
812       else
813         aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
814     }
815
816   /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
817    * (bit length) block */
818
819   aes_gcm_enc_ctr0_round (ctx, 0);
820   aes_gcm_enc_ctr0_round (ctx, 1);
821
822   ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
823                    ctx->Hi[NUM_HI - 1]);
824
825   aes_gcm_enc_ctr0_round (ctx, 2);
826   aes_gcm_enc_ctr0_round (ctx, 3);
827
828   ghash_reduce (&gd);
829
830   aes_gcm_enc_ctr0_round (ctx, 4);
831   aes_gcm_enc_ctr0_round (ctx, 5);
832
833   ghash_reduce2 (&gd);
834
835   aes_gcm_enc_ctr0_round (ctx, 6);
836   aes_gcm_enc_ctr0_round (ctx, 7);
837
838   ctx->T = ghash_final (&gd);
839
840   aes_gcm_enc_ctr0_round (ctx, 8);
841   aes_gcm_enc_ctr0_round (ctx, 9);
842
843   for (int i = 10; i < ctx->rounds + 1; i += 1)
844     aes_gcm_enc_ctr0_round (ctx, i);
845 }
846
847 static_always_inline int
848 aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
849          u32 data_bytes, u32 aad_bytes, u8 tag_len,
850          const aes_gcm_key_data_t *kd, int aes_rounds, aes_gcm_op_t op)
851 {
852   u8 *addt = (u8 *) aad;
853   u32x4 Y0;
854
855   aes_gcm_ctx_t _ctx = { .counter = 2,
856                          .rounds = aes_rounds,
857                          .operation = op,
858                          .data_bytes = data_bytes,
859                          .aad_bytes = aad_bytes,
860                          .Ke = kd->Ke,
861                          .Hi = kd->Hi },
862                 *ctx = &_ctx;
863
864   /* initalize counter */
865   Y0 = (u32x4) (u64x2){ *(u64u *) ivp, 0 };
866   Y0[2] = *(u32u *) (ivp + 8);
867   Y0[3] = 1 << 24;
868   ctx->EY0 = (u8x16) Y0;
869
870 #if N_LANES == 4
871   ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
872     0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
873   };
874 #elif N_LANES == 2
875   ctx->Y =
876     u32x8_splat_u32x4 (Y0) + (u32x8){ 0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24 };
877 #else
878   ctx->Y = Y0 + (u32x4){ 0, 0, 0, 1 << 24 };
879 #endif
880
881   /* calculate ghash for AAD */
882   aes_gcm_ghash (ctx, addt, aad_bytes);
883
884   /* ghash and encrypt/edcrypt  */
885   if (op == AES_GCM_OP_ENCRYPT)
886     aes_gcm_enc (ctx, src, dst, data_bytes);
887   else if (op == AES_GCM_OP_DECRYPT)
888     aes_gcm_dec (ctx, src, dst, data_bytes);
889
890   /* final tag is */
891   ctx->T = u8x16_reflect (ctx->T) ^ ctx->EY0;
892
893   /* tag_len 16 -> 0 */
894   tag_len &= 0xf;
895
896   if (op == AES_GCM_OP_ENCRYPT || op == AES_GCM_OP_GMAC)
897     {
898       /* store tag */
899       if (tag_len)
900         u8x16_store_partial (ctx->T, tag, tag_len);
901       else
902         ((u8x16u *) tag)[0] = ctx->T;
903     }
904   else
905     {
906       /* check tag */
907       if (tag_len)
908         {
909           u16 mask = pow2_mask (tag_len);
910           u8x16 expected = u8x16_load_partial (tag, tag_len);
911           if ((u8x16_msb_mask (expected == ctx->T) & mask) == mask)
912             return 1;
913         }
914       else
915         {
916           if (u8x16_is_equal (ctx->T, *(u8x16u *) tag))
917             return 1;
918         }
919     }
920   return 0;
921 }
922
923 static_always_inline void
924 clib_aes_gcm_key_expand (aes_gcm_key_data_t *kd, const u8 *key,
925                          aes_key_size_t ks)
926 {
927   u8x16 H;
928   u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
929   aes_gcm_expaned_key_t *Ke = (aes_gcm_expaned_key_t *) kd->Ke;
930
931   /* expand AES key */
932   aes_key_expand (ek, key, ks);
933   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
934     Ke[i].lanes[0] = Ke[i].lanes[1] = Ke[i].lanes[2] = Ke[i].lanes[3] = ek[i];
935
936   /* pre-calculate H */
937   H = aes_encrypt_block (u8x16_zero (), ek, ks);
938   H = u8x16_reflect (H);
939   ghash_precompute (H, (u8x16 *) kd->Hi, ARRAY_LEN (kd->Hi));
940 }
941
942 static_always_inline void
943 clib_aes128_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
944                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
945                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
946 {
947   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
948            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_ENCRYPT);
949 }
950
951 static_always_inline void
952 clib_aes256_gcm_enc (const aes_gcm_key_data_t *kd, const u8 *plaintext,
953                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
954                      const u8 *iv, u32 tag_bytes, u8 *cyphertext, u8 *tag)
955 {
956   aes_gcm (plaintext, cyphertext, aad, (u8 *) iv, tag, data_bytes, aad_bytes,
957            tag_bytes, kd, AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_ENCRYPT);
958 }
959
960 static_always_inline int
961 clib_aes128_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
962                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
963                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
964 {
965   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
966                   data_bytes, aad_bytes, tag_bytes, kd,
967                   AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_DECRYPT);
968 }
969
970 static_always_inline int
971 clib_aes256_gcm_dec (const aes_gcm_key_data_t *kd, const u8 *cyphertext,
972                      u32 data_bytes, const u8 *aad, u32 aad_bytes,
973                      const u8 *iv, const u8 *tag, u32 tag_bytes, u8 *plaintext)
974 {
975   return aes_gcm (cyphertext, plaintext, aad, (u8 *) iv, (u8 *) tag,
976                   data_bytes, aad_bytes, tag_bytes, kd,
977                   AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_DECRYPT);
978 }
979
980 static_always_inline void
981 clib_aes128_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
982                   const u8 *iv, u32 tag_bytes, u8 *tag)
983 {
984   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
985            AES_KEY_ROUNDS (AES_KEY_128), AES_GCM_OP_GMAC);
986 }
987
988 static_always_inline void
989 clib_aes256_gmac (const aes_gcm_key_data_t *kd, const u8 *data, u32 data_bytes,
990                   const u8 *iv, u32 tag_bytes, u8 *tag)
991 {
992   aes_gcm (0, 0, data, (u8 *) iv, tag, 0, data_bytes, tag_bytes, kd,
993            AES_KEY_ROUNDS (AES_KEY_256), AES_GCM_OP_GMAC);
994 }
995
996 #endif /* __crypto_aes_gcm_h__ */