vppinfra: small improvement and polishing of AES GCM code 44/38544/2
authorDamjan Marion <damarion@cisco.com>
Thu, 23 Mar 2023 13:44:01 +0000 (13:44 +0000)
committerDave Wallace <dwallacelf@gmail.com>
Mon, 27 Mar 2023 10:09:50 +0000 (10:09 +0000)
Type: improvement
Change-Id: Ie9661792ec68d4ea3c62ee9eb31b455d3b2b0a42
Signed-off-by: Damjan Marion <damarion@cisco.com>
src/vppinfra/crypto/aes_gcm.h
src/vppinfra/crypto/ghash.h

index 8a5f76c..3d1b220 100644 (file)
@@ -103,9 +103,15 @@ typedef struct
   aes_gcm_counter_t Y;
 
   /* ghash */
-  ghash_data_t gd;
+  ghash_ctx_t gd;
 } aes_gcm_ctx_t;
 
+static_always_inline u8x16
+aes_gcm_final_block (aes_gcm_ctx_t *ctx)
+{
+  return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
+}
+
 static_always_inline void
 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
 {
@@ -137,19 +143,18 @@ aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
 }
 
 static_always_inline void
-aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx)
+aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
 {
-  u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
 #if N_LANES == 4
   u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
-  u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0);
+  u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
   ghash4_mul_next (&ctx->gd, r4, h);
 #elif N_LANES == 2
   u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
-  u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r);
+  u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
   ghash2_mul_next (&ctx->gd, r2, h);
 #else
-  ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]);
+  ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
 #endif
 }
 
@@ -178,7 +183,7 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
          aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1);
          for (i = 1; i < 8; i++)
            aes_gcm_ghash_mul_next (ctx, d[i]);
-         aes_gcm_ghash_mul_bit_len (ctx);
+         aes_gcm_ghash_mul_final_block (ctx);
          aes_gcm_ghash_reduce (ctx);
          aes_gcm_ghash_reduce2 (ctx);
          aes_gcm_ghash_final (ctx);
@@ -243,16 +248,14 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
        }
 
       if (ctx->operation == AES_GCM_OP_GMAC)
-       aes_gcm_ghash_mul_bit_len (ctx);
+       aes_gcm_ghash_mul_final_block (ctx);
       aes_gcm_ghash_reduce (ctx);
       aes_gcm_ghash_reduce2 (ctx);
       aes_gcm_ghash_final (ctx);
     }
   else if (ctx->operation == AES_GCM_OP_GMAC)
-    {
-      u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
-      ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
-    }
+    ctx->T =
+      ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
 
 done:
   /* encrypt counter 0 E(Y0, k) */
@@ -267,6 +270,11 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
   const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0];
   uword i = 0;
 
+  /* As counter is stored in network byte order for performance reasons we
+     are incrementing least significant byte only except in case where we
+     overlow. As we are processing four 128, 256 or 512-blocks in parallel
+     except the last round, overflow can happen only when n_blocks == 4 */
+
 #if N_LANES == 4
   const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
                                0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
@@ -275,15 +283,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
   };
 
-  /* As counter is stored in network byte order for performance reasons we
-     are incrementing least significant byte only except in case where we
-     overlow. As we are processing four 512-blocks in parallel except the
-     last round, overflow can happen only when n == 4 */
-
   if (n_blocks == 4)
     for (; i < 2; i++)
       {
-       r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+       r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
        ctx->Y += ctr_inv_4444;
       }
 
@@ -293,7 +296,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
          Yr += ctr_4444;
          ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr);
        }
@@ -302,7 +305,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_4444;
        }
     }
@@ -311,15 +314,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
   const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
   const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
 
-  /* As counter is stored in network byte order for performance reasons we
-     are incrementing least significant byte only except in case where we
-     overlow. As we are processing four 512-blocks in parallel except the
-     last round, overflow can happen only when n == 4 */
-
   if (n_blocks == 4)
     for (; i < 2; i++)
       {
-       r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+       r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
        ctx->Y += ctr_inv_22;
       }
 
@@ -329,7 +327,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
          Yr += ctr_22;
          ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr);
        }
@@ -338,7 +336,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_22;
        }
     }
@@ -350,20 +348,20 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+         r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_1;
        }
       ctx->counter += n_blocks;
     }
   else
     {
-      r[i++] = Ke0.x1 ^ (u8x16) ctx->Y;
+      r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
       ctx->Y += ctr_inv_1;
       ctx->counter += 1;
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+         r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
          ctx->counter++;
          ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
        }
@@ -510,8 +508,7 @@ aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
 }
 
 static_always_inline void
-aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst,
-                    int with_ghash)
+aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
 {
   const aes_gcm_expaned_key_t *k = ctx->Ke;
   const aes_mem_t *sv = (aes_mem_t *) src;
@@ -680,7 +677,7 @@ aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
   aes_gcm_enc_ctr0_round (ctx, 8);
   aes_gcm_enc_ctr0_round (ctx, 9);
 
-  aes_gcm_ghash_mul_bit_len (ctx);
+  aes_gcm_ghash_mul_final_block (ctx);
   aes_gcm_ghash_reduce (ctx);
 
   for (i = 10; i < ctx->rounds; i++)
@@ -731,6 +728,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
        }
       return;
     }
+
   aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0);
 
   /* next */
@@ -739,7 +737,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
   src += 4 * N;
 
   for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N)
-    aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+    aes_gcm_calc_double (ctx, d, src, dst);
 
   if (n_left >= 4 * N)
     {
@@ -785,8 +783,11 @@ static_always_inline void
 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
 {
   aes_data_t d[4] = {};
+  ghash_ctx_t gd;
+
+  /* main encryption loop */
   for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N)
-    aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+    aes_gcm_calc_double (ctx, d, src, dst);
 
   if (n_left >= 4 * N)
     {
@@ -798,27 +799,48 @@ aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
       src += N * 4;
     }
 
-  if (n_left == 0)
-    goto done;
+  if (n_left)
+    {
+      ctx->last = 1;
 
-  ctx->last = 1;
+      if (n_left > 3 * N)
+       aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
+      else if (n_left > 2 * N)
+       aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
+      else if (n_left > N)
+       aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
+      else
+       aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+    }
 
-  if (n_left > 3 * N)
-    aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
-  else if (n_left > 2 * N)
-    aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
-  else if (n_left > N)
-    aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
-  else
-    aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+  /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
+   * (bit length) block */
 
-  u8x16 r;
-done:
-  r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
-  ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
+  aes_gcm_enc_ctr0_round (ctx, 0);
+  aes_gcm_enc_ctr0_round (ctx, 1);
 
-  /* encrypt counter 0 E(Y0, k) */
-  for (int i = 0; i < ctx->rounds + 1; i += 1)
+  ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
+                  ctx->Hi[NUM_HI - 1]);
+
+  aes_gcm_enc_ctr0_round (ctx, 2);
+  aes_gcm_enc_ctr0_round (ctx, 3);
+
+  ghash_reduce (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 4);
+  aes_gcm_enc_ctr0_round (ctx, 5);
+
+  ghash_reduce2 (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 6);
+  aes_gcm_enc_ctr0_round (ctx, 7);
+
+  ctx->T = ghash_final (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 8);
+  aes_gcm_enc_ctr0_round (ctx, 9);
+
+  for (int i = 10; i < ctx->rounds + 1; i += 1)
     aes_gcm_enc_ctr0_round (ctx, i);
 }
 
@@ -835,6 +857,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
                         .operation = op,
                         .data_bytes = data_bytes,
                         .aad_bytes = aad_bytes,
+                        .Ke = kd->Ke,
                         .Hi = kd->Hi },
                *ctx = &_ctx;
 
@@ -843,7 +866,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
   Y0[2] = *(u32u *) (ivp + 8);
   Y0[3] = 1 << 24;
   ctx->EY0 = (u8x16) Y0;
-  ctx->Ke = kd->Ke;
+
 #if N_LANES == 4
   ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
     0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
@@ -858,8 +881,6 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
   /* calculate ghash for AAD */
   aes_gcm_ghash (ctx, addt, aad_bytes);
 
-  clib_prefetch_load (tag);
-
   /* ghash and encrypt/edcrypt  */
   if (op == AES_GCM_OP_ENCRYPT)
     aes_gcm_enc (ctx, src, dst, data_bytes);
index bae8bad..66e3f6a 100644 (file)
@@ -89,7 +89,7 @@
  * u8x16 Hi[4];
  * ghash_precompute (H, Hi, 4);
  *
- * ghash_data_t _gd, *gd = &_gd;
+ * ghash_ctx_t _gd, *gd = &_gd;
  * ghash_mul_first (gd, GH ^ b0, Hi[3]);
  * ghash_mul_next (gd, b1, Hi[2]);
  * ghash_mul_next (gd, b2, Hi[1]);
@@ -154,7 +154,7 @@ typedef struct
   u8x32 hi2, lo2, mid2, tmp_lo2, tmp_hi2;
   u8x64 hi4, lo4, mid4, tmp_lo4, tmp_hi4;
   int pending;
-} ghash_data_t;
+} ghash_ctx_t;
 
 static const u8x16 ghash_poly = {
   0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
@@ -167,7 +167,7 @@ static const u8x16 ghash_poly2 = {
 };
 
 static_always_inline void
-ghash_mul_first (ghash_data_t * gd, u8x16 a, u8x16 b)
+ghash_mul_first (ghash_ctx_t *gd, u8x16 a, u8x16 b)
 {
   /* a1 * b1 */
   gd->hi = gmul_hi_hi (a, b);
@@ -182,7 +182,7 @@ ghash_mul_first (ghash_data_t * gd, u8x16 a, u8x16 b)
 }
 
 static_always_inline void
-ghash_mul_next (ghash_data_t * gd, u8x16 a, u8x16 b)
+ghash_mul_next (ghash_ctx_t *gd, u8x16 a, u8x16 b)
 {
   /* a1 * b1 */
   u8x16 hi = gmul_hi_hi (a, b);
@@ -211,7 +211,7 @@ ghash_mul_next (ghash_data_t * gd, u8x16 a, u8x16 b)
 }
 
 static_always_inline void
-ghash_reduce (ghash_data_t * gd)
+ghash_reduce (ghash_ctx_t *gd)
 {
   u8x16 r;
 
@@ -236,14 +236,14 @@ ghash_reduce (ghash_data_t * gd)
 }
 
 static_always_inline void
-ghash_reduce2 (ghash_data_t * gd)
+ghash_reduce2 (ghash_ctx_t *gd)
 {
   gd->tmp_lo = gmul_lo_lo (ghash_poly2, gd->lo);
   gd->tmp_hi = gmul_lo_hi (ghash_poly2, gd->lo);
 }
 
 static_always_inline u8x16
-ghash_final (ghash_data_t * gd)
+ghash_final (ghash_ctx_t *gd)
 {
   return u8x16_xor3 (gd->hi, u8x16_word_shift_right (gd->tmp_lo, 4),
                     u8x16_word_shift_left (gd->tmp_hi, 4));
@@ -252,7 +252,7 @@ ghash_final (ghash_data_t * gd)
 static_always_inline u8x16
 ghash_mul (u8x16 a, u8x16 b)
 {
-  ghash_data_t _gd, *gd = &_gd;
+  ghash_ctx_t _gd, *gd = &_gd;
   ghash_mul_first (gd, a, b);
   ghash_reduce (gd);
   ghash_reduce2 (gd);
@@ -297,7 +297,7 @@ gmul4_hi_hi (u8x64 a, u8x64 b)
 }
 
 static_always_inline void
-ghash4_mul_first (ghash_data_t *gd, u8x64 a, u8x64 b)
+ghash4_mul_first (ghash_ctx_t *gd, u8x64 a, u8x64 b)
 {
   gd->hi4 = gmul4_hi_hi (a, b);
   gd->lo4 = gmul4_lo_lo (a, b);
@@ -306,7 +306,7 @@ ghash4_mul_first (ghash_data_t *gd, u8x64 a, u8x64 b)
 }
 
 static_always_inline void
-ghash4_mul_next (ghash_data_t *gd, u8x64 a, u8x64 b)
+ghash4_mul_next (ghash_ctx_t *gd, u8x64 a, u8x64 b)
 {
   u8x64 hi = gmul4_hi_hi (a, b);
   u8x64 lo = gmul4_lo_lo (a, b);
@@ -329,7 +329,7 @@ ghash4_mul_next (ghash_data_t *gd, u8x64 a, u8x64 b)
 }
 
 static_always_inline void
-ghash4_reduce (ghash_data_t *gd)
+ghash4_reduce (ghash_ctx_t *gd)
 {
   u8x64 r;
 
@@ -356,14 +356,14 @@ ghash4_reduce (ghash_data_t *gd)
 }
 
 static_always_inline void
-ghash4_reduce2 (ghash_data_t *gd)
+ghash4_reduce2 (ghash_ctx_t *gd)
 {
   gd->tmp_lo4 = gmul4_lo_lo (ghash4_poly2, gd->lo4);
   gd->tmp_hi4 = gmul4_lo_hi (ghash4_poly2, gd->lo4);
 }
 
 static_always_inline u8x16
-ghash4_final (ghash_data_t *gd)
+ghash4_final (ghash_ctx_t *gd)
 {
   u8x64 r;
   u8x32 t;
@@ -410,7 +410,7 @@ gmul2_hi_hi (u8x32 a, u8x32 b)
 }
 
 static_always_inline void
-ghash2_mul_first (ghash_data_t *gd, u8x32 a, u8x32 b)
+ghash2_mul_first (ghash_ctx_t *gd, u8x32 a, u8x32 b)
 {
   gd->hi2 = gmul2_hi_hi (a, b);
   gd->lo2 = gmul2_lo_lo (a, b);
@@ -419,7 +419,7 @@ ghash2_mul_first (ghash_data_t *gd, u8x32 a, u8x32 b)
 }
 
 static_always_inline void
-ghash2_mul_next (ghash_data_t *gd, u8x32 a, u8x32 b)
+ghash2_mul_next (ghash_ctx_t *gd, u8x32 a, u8x32 b)
 {
   u8x32 hi = gmul2_hi_hi (a, b);
   u8x32 lo = gmul2_lo_lo (a, b);
@@ -442,7 +442,7 @@ ghash2_mul_next (ghash_data_t *gd, u8x32 a, u8x32 b)
 }
 
 static_always_inline void
-ghash2_reduce (ghash_data_t *gd)
+ghash2_reduce (ghash_ctx_t *gd)
 {
   u8x32 r;
 
@@ -469,14 +469,14 @@ ghash2_reduce (ghash_data_t *gd)
 }
 
 static_always_inline void
-ghash2_reduce2 (ghash_data_t *gd)
+ghash2_reduce2 (ghash_ctx_t *gd)
 {
   gd->tmp_lo2 = gmul2_lo_lo (ghash2_poly2, gd->lo2);
   gd->tmp_hi2 = gmul2_lo_hi (ghash2_poly2, gd->lo2);
 }
 
 static_always_inline u8x16
-ghash2_final (ghash_data_t *gd)
+ghash2_final (ghash_ctx_t *gd)
 {
   u8x32 r;