vppinfra: small improvement and polishing of AES GCM code
[vpp.git] / src / vppinfra / crypto / aes_gcm.h
index 8a5f76c..3d1b220 100644 (file)
@@ -103,9 +103,15 @@ typedef struct
   aes_gcm_counter_t Y;
 
   /* ghash */
-  ghash_data_t gd;
+  ghash_ctx_t gd;
 } aes_gcm_ctx_t;
 
+static_always_inline u8x16
+aes_gcm_final_block (aes_gcm_ctx_t *ctx)
+{
+  return (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
+}
+
 static_always_inline void
 aes_gcm_ghash_mul_first (aes_gcm_ctx_t *ctx, aes_data_t data, u32 n_lanes)
 {
@@ -137,19 +143,18 @@ aes_gcm_ghash_mul_next (aes_gcm_ctx_t *ctx, aes_data_t data)
 }
 
 static_always_inline void
-aes_gcm_ghash_mul_bit_len (aes_gcm_ctx_t *ctx)
+aes_gcm_ghash_mul_final_block (aes_gcm_ctx_t *ctx)
 {
-  u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
 #if N_LANES == 4
   u8x64 h = u8x64_insert_u8x16 (u8x64_zero (), ctx->Hi[NUM_HI - 1], 0);
-  u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), r, 0);
+  u8x64 r4 = u8x64_insert_u8x16 (u8x64_zero (), aes_gcm_final_block (ctx), 0);
   ghash4_mul_next (&ctx->gd, r4, h);
 #elif N_LANES == 2
   u8x32 h = u8x32_insert_lo (u8x32_zero (), ctx->Hi[NUM_HI - 1]);
-  u8x32 r2 = u8x32_insert_lo (u8x32_zero (), r);
+  u8x32 r2 = u8x32_insert_lo (u8x32_zero (), aes_gcm_final_block (ctx));
   ghash2_mul_next (&ctx->gd, r2, h);
 #else
-  ghash_mul_next (&ctx->gd, r, ctx->Hi[NUM_HI - 1]);
+  ghash_mul_next (&ctx->gd, aes_gcm_final_block (ctx), ctx->Hi[NUM_HI - 1]);
 #endif
 }
 
@@ -178,7 +183,7 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
          aes_gcm_ghash_mul_first (ctx, d[0], 8 * N_LANES + 1);
          for (i = 1; i < 8; i++)
            aes_gcm_ghash_mul_next (ctx, d[i]);
-         aes_gcm_ghash_mul_bit_len (ctx);
+         aes_gcm_ghash_mul_final_block (ctx);
          aes_gcm_ghash_reduce (ctx);
          aes_gcm_ghash_reduce2 (ctx);
          aes_gcm_ghash_final (ctx);
@@ -243,16 +248,14 @@ aes_gcm_ghash (aes_gcm_ctx_t *ctx, u8 *data, u32 n_left)
        }
 
       if (ctx->operation == AES_GCM_OP_GMAC)
-       aes_gcm_ghash_mul_bit_len (ctx);
+       aes_gcm_ghash_mul_final_block (ctx);
       aes_gcm_ghash_reduce (ctx);
       aes_gcm_ghash_reduce2 (ctx);
       aes_gcm_ghash_final (ctx);
     }
   else if (ctx->operation == AES_GCM_OP_GMAC)
-    {
-      u8x16 r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
-      ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
-    }
+    ctx->T =
+      ghash_mul (aes_gcm_final_block (ctx) ^ ctx->T, ctx->Hi[NUM_HI - 1]);
 
 done:
   /* encrypt counter 0 E(Y0, k) */
@@ -267,6 +270,11 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
   const aes_gcm_expaned_key_t Ke0 = ctx->Ke[0];
   uword i = 0;
 
+  /* As counter is stored in network byte order for performance reasons we
+     are incrementing least significant byte only except in case where we
+     overlow. As we are processing four 128, 256 or 512-blocks in parallel
+     except the last round, overflow can happen only when n_blocks == 4 */
+
 #if N_LANES == 4
   const u32x16 ctr_inv_4444 = { 0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24,
                                0, 0, 0, 4 << 24, 0, 0, 0, 4 << 24 };
@@ -275,15 +283,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0,
   };
 
-  /* As counter is stored in network byte order for performance reasons we
-     are incrementing least significant byte only except in case where we
-     overlow. As we are processing four 512-blocks in parallel except the
-     last round, overflow can happen only when n == 4 */
-
   if (n_blocks == 4)
     for (; i < 2; i++)
       {
-       r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+       r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
        ctx->Y += ctr_inv_4444;
       }
 
@@ -293,7 +296,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
          Yr += ctr_4444;
          ctx->Y = (u32x16) aes_gcm_reflect ((u8x64) Yr);
        }
@@ -302,7 +305,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x4 ^ (u8x64) ctx->Y;
+         r[i] = Ke0.x4 ^ (u8x64) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_4444;
        }
     }
@@ -311,15 +314,10 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
   const u32x8 ctr_inv_22 = { 0, 0, 0, 2 << 24, 0, 0, 0, 2 << 24 };
   const u32x8 ctr_22 = { 2, 0, 0, 0, 2, 0, 0, 0 };
 
-  /* As counter is stored in network byte order for performance reasons we
-     are incrementing least significant byte only except in case where we
-     overlow. As we are processing four 512-blocks in parallel except the
-     last round, overflow can happen only when n == 4 */
-
   if (n_blocks == 4)
     for (; i < 2; i++)
       {
-       r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+       r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
        ctx->Y += ctr_inv_22;
       }
 
@@ -329,7 +327,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
          Yr += ctr_22;
          ctx->Y = (u32x8) aes_gcm_reflect ((u8x32) Yr);
        }
@@ -338,7 +336,7 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x2 ^ (u8x32) ctx->Y;
+         r[i] = Ke0.x2 ^ (u8x32) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_22;
        }
     }
@@ -350,20 +348,20 @@ aes_gcm_enc_first_round (aes_gcm_ctx_t *ctx, aes_data_t *r, uword n_blocks)
     {
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+         r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
          ctx->Y += ctr_inv_1;
        }
       ctx->counter += n_blocks;
     }
   else
     {
-      r[i++] = Ke0.x1 ^ (u8x16) ctx->Y;
+      r[i++] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
       ctx->Y += ctr_inv_1;
       ctx->counter += 1;
 
       for (; i < n_blocks; i++)
        {
-         r[i] = Ke0.x1 ^ (u8x16) ctx->Y;
+         r[i] = Ke0.x1 ^ (u8x16) ctx->Y; /* Initial AES round */
          ctx->counter++;
          ctx->Y[3] = clib_host_to_net_u32 (ctx->counter);
        }
@@ -510,8 +508,7 @@ aes_gcm_calc (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst, u32 n,
 }
 
 static_always_inline void
-aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst,
-                    int with_ghash)
+aes_gcm_calc_double (aes_gcm_ctx_t *ctx, aes_data_t *d, const u8 *src, u8 *dst)
 {
   const aes_gcm_expaned_key_t *k = ctx->Ke;
   const aes_mem_t *sv = (aes_mem_t *) src;
@@ -680,7 +677,7 @@ aes_gcm_calc_last (aes_gcm_ctx_t *ctx, aes_data_t *d, int n_blocks,
   aes_gcm_enc_ctr0_round (ctx, 8);
   aes_gcm_enc_ctr0_round (ctx, 9);
 
-  aes_gcm_ghash_mul_bit_len (ctx);
+  aes_gcm_ghash_mul_final_block (ctx);
   aes_gcm_ghash_reduce (ctx);
 
   for (i = 10; i < ctx->rounds; i++)
@@ -731,6 +728,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
        }
       return;
     }
+
   aes_gcm_calc (ctx, d, src, dst, 4, 4 * N, /* with_ghash */ 0);
 
   /* next */
@@ -739,7 +737,7 @@ aes_gcm_enc (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, u32 n_left)
   src += 4 * N;
 
   for (; n_left >= 8 * N; n_left -= 8 * N, src += 8 * N, dst += 8 * N)
-    aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+    aes_gcm_calc_double (ctx, d, src, dst);
 
   if (n_left >= 4 * N)
     {
@@ -785,8 +783,11 @@ static_always_inline void
 aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
 {
   aes_data_t d[4] = {};
+  ghash_ctx_t gd;
+
+  /* main encryption loop */
   for (; n_left >= 8 * N; n_left -= 8 * N, dst += 8 * N, src += 8 * N)
-    aes_gcm_calc_double (ctx, d, src, dst, /* with_ghash */ 1);
+    aes_gcm_calc_double (ctx, d, src, dst);
 
   if (n_left >= 4 * N)
     {
@@ -798,27 +799,48 @@ aes_gcm_dec (aes_gcm_ctx_t *ctx, const u8 *src, u8 *dst, uword n_left)
       src += N * 4;
     }
 
-  if (n_left == 0)
-    goto done;
+  if (n_left)
+    {
+      ctx->last = 1;
 
-  ctx->last = 1;
+      if (n_left > 3 * N)
+       aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
+      else if (n_left > 2 * N)
+       aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
+      else if (n_left > N)
+       aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
+      else
+       aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+    }
 
-  if (n_left > 3 * N)
-    aes_gcm_calc (ctx, d, src, dst, 4, n_left, /* with_ghash */ 1);
-  else if (n_left > 2 * N)
-    aes_gcm_calc (ctx, d, src, dst, 3, n_left, /* with_ghash */ 1);
-  else if (n_left > N)
-    aes_gcm_calc (ctx, d, src, dst, 2, n_left, /* with_ghash */ 1);
-  else
-    aes_gcm_calc (ctx, d, src, dst, 1, n_left, /* with_ghash */ 1);
+  /* interleaved counter 0 encryption E(Y0, k) and ghash of final GCM
+   * (bit length) block */
 
-  u8x16 r;
-done:
-  r = (u8x16) ((u64x2){ ctx->data_bytes, ctx->aad_bytes } << 3);
-  ctx->T = ghash_mul (r ^ ctx->T, ctx->Hi[NUM_HI - 1]);
+  aes_gcm_enc_ctr0_round (ctx, 0);
+  aes_gcm_enc_ctr0_round (ctx, 1);
 
-  /* encrypt counter 0 E(Y0, k) */
-  for (int i = 0; i < ctx->rounds + 1; i += 1)
+  ghash_mul_first (&gd, aes_gcm_final_block (ctx) ^ ctx->T,
+                  ctx->Hi[NUM_HI - 1]);
+
+  aes_gcm_enc_ctr0_round (ctx, 2);
+  aes_gcm_enc_ctr0_round (ctx, 3);
+
+  ghash_reduce (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 4);
+  aes_gcm_enc_ctr0_round (ctx, 5);
+
+  ghash_reduce2 (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 6);
+  aes_gcm_enc_ctr0_round (ctx, 7);
+
+  ctx->T = ghash_final (&gd);
+
+  aes_gcm_enc_ctr0_round (ctx, 8);
+  aes_gcm_enc_ctr0_round (ctx, 9);
+
+  for (int i = 10; i < ctx->rounds + 1; i += 1)
     aes_gcm_enc_ctr0_round (ctx, i);
 }
 
@@ -835,6 +857,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
                         .operation = op,
                         .data_bytes = data_bytes,
                         .aad_bytes = aad_bytes,
+                        .Ke = kd->Ke,
                         .Hi = kd->Hi },
                *ctx = &_ctx;
 
@@ -843,7 +866,7 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
   Y0[2] = *(u32u *) (ivp + 8);
   Y0[3] = 1 << 24;
   ctx->EY0 = (u8x16) Y0;
-  ctx->Ke = kd->Ke;
+
 #if N_LANES == 4
   ctx->Y = u32x16_splat_u32x4 (Y0) + (u32x16){
     0, 0, 0, 1 << 24, 0, 0, 0, 2 << 24, 0, 0, 0, 3 << 24, 0, 0, 0, 4 << 24,
@@ -858,8 +881,6 @@ aes_gcm (const u8 *src, u8 *dst, const u8 *aad, u8 *ivp, u8 *tag,
   /* calculate ghash for AAD */
   aes_gcm_ghash (ctx, addt, aad_bytes);
 
-  clib_prefetch_load (tag);
-
   /* ghash and encrypt/edcrypt  */
   if (op == AES_GCM_OP_ENCRYPT)
     aes_gcm_enc (ctx, src, dst, data_bytes);