vppinfra: add support for precomputed SHA2 HMAC key and chained buffers 65/40665/3
authorDamjan Marion <damarion@cisco.com>
Tue, 9 Apr 2024 12:37:25 +0000 (12:37 +0000)
committerDamjan Marion <damarion@cisco.com>
Tue, 9 Apr 2024 21:09:21 +0000 (21:09 +0000)
Change-Id: Ic1fa3bd164e80c2ca1146be001870da0238a5f2e
Type: improvement
Signed-off-by: Damjan Marion <damarion@cisco.com>
src/vppinfra/crypto/sha2.h

index 5100615..69a24a2 100644 (file)
@@ -1,16 +1,5 @@
-/*
- * Copyright (c) 2019 Cisco and/or its affiliates.
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at:
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+/* SPDX-License-Identifier: Apache-2.0
+ * Copyright(c) 2024 Cisco Systems, Inc.
  */
 
 #ifndef included_sha2_h
@@ -18,12 +7,8 @@
 
 #include <vppinfra/clib.h>
 #include <vppinfra/vector.h>
+#include <vppinfra/string.h>
 
-#define SHA224_DIGEST_SIZE 28
-#define SHA224_BLOCK_SIZE  64
-
-#define SHA256_DIGEST_SIZE  32
-#define SHA256_BLOCK_SIZE   64
 #define SHA256_ROTR(x, y)   ((x >> y) | (x << (32 - y)))
 #define SHA256_CH(a, b, c)  ((a & b) ^ (~a & c))
 #define SHA256_MAJ(a, b, c) ((a & b) ^ (a & c) ^ (b & c))
     s[0] = t1 + t2;                                                           \
   }
 
-#define SHA512_224_DIGEST_SIZE 28
-#define SHA512_224_BLOCK_SIZE  128
-
-#define SHA512_256_DIGEST_SIZE 32
-#define SHA512_256_BLOCK_SIZE  128
-
-#define SHA384_DIGEST_SIZE 48
-#define SHA384_BLOCK_SIZE  128
-
-#define SHA512_DIGEST_SIZE  64
-#define SHA512_BLOCK_SIZE   128
 #define SHA512_ROTR(x, y)   ((x >> y) | (x << (64 - y)))
 #define SHA512_CH(a, b, c)  ((a & b) ^ (~a & c))
 #define SHA512_MAJ(a, b, c) ((a & b) ^ (a & c) ^ (b & c))
@@ -125,7 +99,7 @@ static const u32 sha256_h[8] = { 0x6a09e667, 0xbb67ae85, 0x3c6ef372,
                                 0xa54ff53a, 0x510e527f, 0x9b05688c,
                                 0x1f83d9ab, 0x5be0cd19 };
 
-static const u32 sha256_k[64] = {
+static const u32 clib_sha2_256_k[64] = {
   0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1,
   0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
   0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786,
@@ -159,7 +133,7 @@ static const u64 sha512_256_h[8] = { 0x22312194fc2bf72c, 0x9f555fa3c84c64c2,
                                     0x96283ee2a88effe3, 0xbe5e1e2553863992,
                                     0x2b0199fc2c85b8aa, 0x0eb72ddc81c52ca2 };
 
-static const u64 sha512_k[80] = {
+static const u64 clib_sha2_512_k[80] = {
   0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
   0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
   0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
@@ -199,80 +173,102 @@ typedef enum
   CLIB_SHA2_512_256,
 } clib_sha2_type_t;
 
-#define SHA2_MAX_BLOCK_SIZE  SHA512_BLOCK_SIZE
-#define SHA2_MAX_DIGEST_SIZE SHA512_DIGEST_SIZE
+#define CLIB_SHA2_256_BLOCK_SIZE 64
+#define CLIB_SHA2_512_BLOCK_SIZE 128
+#define SHA2_MAX_BLOCK_SIZE     CLIB_SHA2_512_BLOCK_SIZE
+#define SHA2_MAX_DIGEST_SIZE    64
 
-typedef struct
+static const struct
 {
-  u64 total_bytes;
-  u16 n_pending;
   u8 block_size;
   u8 digest_size;
-  union
-  {
-    u32 h32[8];
-    u64 h64[8];
+  const u32 *h32;
+  const u64 *h64;
+} clib_sha2_variants[] = {
+  [CLIB_SHA2_224] = {
+    .block_size = CLIB_SHA2_256_BLOCK_SIZE,
+    .digest_size = 28,
+    .h32 = sha224_h,
+  },
+  [CLIB_SHA2_256] = {
+    .block_size = CLIB_SHA2_256_BLOCK_SIZE,
+    .digest_size = 32,
+    .h32 = sha256_h,
+  },
+  [CLIB_SHA2_384] = {
+    .block_size = CLIB_SHA2_512_BLOCK_SIZE,
+    .digest_size = 48,
+    .h64 = sha384_h,
+  },
+  [CLIB_SHA2_512] = {
+    .block_size = CLIB_SHA2_512_BLOCK_SIZE,
+    .digest_size = 64,
+    .h64 = sha512_h,
+  },
+  [CLIB_SHA2_512_224] = {
+    .block_size = CLIB_SHA2_512_BLOCK_SIZE,
+    .digest_size = 28,
+    .h64 = sha512_224_h,
+  },
+  [CLIB_SHA2_512_256] = {
+    .block_size = CLIB_SHA2_512_BLOCK_SIZE,
+    .digest_size = 32,
+    .h64 = sha512_256_h,
+  },
+};
+
+typedef union
+{
+  u32 h32[8];
+  u64 h64[8];
 #ifdef CLIB_SHA256_ISA
-    u32x4 h32x4[2];
+  u32x4 h32x4[2];
 #endif
-  };
+} clib_sha2_h_t;
+
+typedef struct
+{
+  u64 total_bytes;
+  u16 n_pending;
+  clib_sha2_h_t h;
   union
   {
     u8 as_u8[SHA2_MAX_BLOCK_SIZE];
     u64 as_u64[SHA2_MAX_BLOCK_SIZE / sizeof (u64)];
     uword as_uword[SHA2_MAX_BLOCK_SIZE / sizeof (uword)];
   } pending;
+} clib_sha2_state_t;
+
+typedef struct
+{
+  clib_sha2_type_t type;
+  u8 block_size;
+  u8 digest_size;
+  clib_sha2_state_t state;
 } clib_sha2_ctx_t;
 
 static_always_inline void
-clib_sha2_init (clib_sha2_ctx_t *ctx, clib_sha2_type_t type)
+clib_sha2_state_init (clib_sha2_state_t *state, clib_sha2_type_t type)
 {
-  const u32 *h32 = 0;
-  const u64 *h64 = 0;
+  clib_sha2_state_t st = {};
 
-  ctx->total_bytes = 0;
-  ctx->n_pending = 0;
-
-  switch (type)
-    {
-    case CLIB_SHA2_224:
-      h32 = sha224_h;
-      ctx->block_size = SHA224_BLOCK_SIZE;
-      ctx->digest_size = SHA224_DIGEST_SIZE;
-      break;
-    case CLIB_SHA2_256:
-      h32 = sha256_h;
-      ctx->block_size = SHA256_BLOCK_SIZE;
-      ctx->digest_size = SHA256_DIGEST_SIZE;
-      break;
-    case CLIB_SHA2_384:
-      h64 = sha384_h;
-      ctx->block_size = SHA384_BLOCK_SIZE;
-      ctx->digest_size = SHA384_DIGEST_SIZE;
-      break;
-    case CLIB_SHA2_512:
-      h64 = sha512_h;
-      ctx->block_size = SHA512_BLOCK_SIZE;
-      ctx->digest_size = SHA512_DIGEST_SIZE;
-      break;
-    case CLIB_SHA2_512_224:
-      h64 = sha512_224_h;
-      ctx->block_size = SHA512_224_BLOCK_SIZE;
-      ctx->digest_size = SHA512_224_DIGEST_SIZE;
-      break;
-    case CLIB_SHA2_512_256:
-      h64 = sha512_256_h;
-      ctx->block_size = SHA512_256_BLOCK_SIZE;
-      ctx->digest_size = SHA512_256_DIGEST_SIZE;
-      break;
-    }
-  if (h32)
+  if (clib_sha2_variants[type].block_size == CLIB_SHA2_256_BLOCK_SIZE)
     for (int i = 0; i < 8; i++)
-      ctx->h32[i] = h32[i];
-
-  if (h64)
+      st.h.h32[i] = clib_sha2_variants[type].h32[i];
+  else
     for (int i = 0; i < 8; i++)
-      ctx->h64[i] = h64[i];
+      st.h.h64[i] = clib_sha2_variants[type].h64[i];
+
+  *state = st;
+}
+
+static_always_inline void
+clib_sha2_init (clib_sha2_ctx_t *ctx, clib_sha2_type_t type)
+{
+  clib_sha2_state_init (&ctx->state, type);
+  ctx->block_size = clib_sha2_variants[type].block_size;
+  ctx->digest_size = clib_sha2_variants[type].digest_size;
+  ctx->type = type;
 }
 
 #ifdef CLIB_SHA256_ISA
@@ -295,7 +291,7 @@ static inline void
 clib_sha256_vec_4_rounds (u32x4 w, u8 n, u32x4 s[])
 {
 #ifdef CLIB_SHA256_ISA_INTEL
-  u32x4 r = *(u32x4 *) (sha256_k + 4 * n) + w;
+  u32x4 r = *(u32x4 *) (clib_sha2_256_k + 4 * n) + w;
   s[0] = (u32x4) _mm_sha256rnds2_epu32 ((__m128i) s[0], (__m128i) s[1],
                                        (__m128i) r);
   r = (u32x4) u64x2_interleave_hi ((u64x2) r, (u64x2) r);
@@ -303,7 +299,7 @@ clib_sha256_vec_4_rounds (u32x4 w, u8 n, u32x4 s[])
                                        (__m128i) r);
 #elif defined(CLIB_SHA256_ISA_ARM)
   u32x4 r0, s0;
-  const u32x4u *k = (u32x4u *) sha256_k;
+  const u32x4u *k = (u32x4u *) clib_sha2_256_k;
 
   r0 = w + k[n];
   s0 = s[0];
@@ -338,14 +334,14 @@ clib_sha256_vec_shuffle (u32x4 d[2])
 #endif
 
 static inline void
-clib_sha256_block (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_blocks)
+clib_sha256_block (clib_sha2_state_t *st, const u8 *msg, uword n_blocks)
 {
 #if defined(CLIB_SHA256_ISA)
   u32x4 h[2];
   u32x4u *m = (u32x4u *) msg;
 
-  h[0] = ctx->h32x4[0];
-  h[1] = ctx->h32x4[1];
+  h[0] = st->h.h32x4[0];
+  h[1] = st->h.h32x4[1];
 
   clib_sha256_vec_shuffle (h);
 
@@ -399,159 +395,176 @@ clib_sha256_block (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_blocks)
 
   clib_sha256_vec_shuffle (h);
 
-  ctx->h32x4[0] = h[0];
-  ctx->h32x4[1] = h[1];
+  st->h.h32x4[0] = h[0];
+  st->h.h32x4[1] = h[1];
 #else
   u32 w[64], s[8], i;
+  clib_sha2_h_t h;
+
+  h = st->h;
 
-  while (n_blocks)
+  for (; n_blocks; msg += CLIB_SHA2_256_BLOCK_SIZE, n_blocks--)
     {
       for (i = 0; i < 8; i++)
-       s[i] = ctx->h32[i];
+       s[i] = h.h32[i];
 
       for (i = 0; i < 16; i++)
        {
-         w[i] = clib_net_to_host_u32 (*((u32 *) msg + i));
-         SHA256_TRANSFORM (s, w, i, sha256_k[i]);
+         w[i] = clib_net_to_host_u32 ((((u32u *) msg)[i]));
+         SHA256_TRANSFORM (s, w, i, clib_sha2_256_k[i]);
        }
 
       for (i = 16; i < 64; i++)
        {
          SHA256_MSG_SCHED (w, i);
-         SHA256_TRANSFORM (s, w, i, sha256_k[i]);
+         SHA256_TRANSFORM (s, w, i, clib_sha2_256_k[i]);
        }
 
       for (i = 0; i < 8; i++)
-       ctx->h32[i] += s[i];
-
-      /* next */
-      msg += SHA256_BLOCK_SIZE;
-      n_blocks--;
+       h.h32[i] += s[i];
     }
+
+  st->h = h;
 #endif
 }
 
 static_always_inline void
-clib_sha512_block (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_blocks)
+clib_sha512_block (clib_sha2_state_t *st, const u8 *msg, uword n_blocks)
 {
   u64 w[80], s[8], i;
+  clib_sha2_h_t h;
+
+  h = st->h;
 
-  while (n_blocks)
+  for (; n_blocks; msg += CLIB_SHA2_512_BLOCK_SIZE, n_blocks--)
     {
       for (i = 0; i < 8; i++)
-       s[i] = ctx->h64[i];
+       s[i] = h.h64[i];
 
       for (i = 0; i < 16; i++)
        {
-         w[i] = clib_net_to_host_u64 (*((u64 *) msg + i));
-         SHA512_TRANSFORM (s, w, i, sha512_k[i]);
+         w[i] = clib_net_to_host_u64 ((((u64u *) msg)[i]));
+         SHA512_TRANSFORM (s, w, i, clib_sha2_512_k[i]);
        }
 
       for (i = 16; i < 80; i++)
        {
          SHA512_MSG_SCHED (w, i);
-         SHA512_TRANSFORM (s, w, i, sha512_k[i]);
+         SHA512_TRANSFORM (s, w, i, clib_sha2_512_k[i]);
        }
 
       for (i = 0; i < 8; i++)
-       ctx->h64[i] += s[i];
-
-      /* next */
-      msg += SHA512_BLOCK_SIZE;
-      n_blocks--;
+       h.h64[i] += s[i];
     }
+
+  st->h = h;
 }
 
 static_always_inline void
-clib_sha2_update (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_bytes)
+clib_sha2_update_internal (clib_sha2_state_t *st, u8 block_size, const u8 *msg,
+                          uword n_bytes)
 {
   uword n_blocks;
-  if (ctx->n_pending)
+  if (st->n_pending)
     {
-      uword n_left = ctx->block_size - ctx->n_pending;
+      uword n_left = block_size - st->n_pending;
       if (n_bytes < n_left)
        {
-         clib_memcpy_fast (ctx->pending.as_u8 + ctx->n_pending, msg, n_bytes);
-         ctx->n_pending += n_bytes;
+         clib_memcpy_fast (st->pending.as_u8 + st->n_pending, msg, n_bytes);
+         st->n_pending += n_bytes;
          return;
        }
       else
        {
-         clib_memcpy_fast (ctx->pending.as_u8 + ctx->n_pending, msg, n_left);
-         if (ctx->block_size == SHA512_BLOCK_SIZE)
-           clib_sha512_block (ctx, ctx->pending.as_u8, 1);
+         clib_memcpy_fast (st->pending.as_u8 + st->n_pending, msg, n_left);
+         if (block_size == CLIB_SHA2_512_BLOCK_SIZE)
+           clib_sha512_block (st, st->pending.as_u8, 1);
          else
-           clib_sha256_block (ctx, ctx->pending.as_u8, 1);
-         ctx->n_pending = 0;
-         ctx->total_bytes += ctx->block_size;
+           clib_sha256_block (st, st->pending.as_u8, 1);
+         st->n_pending = 0;
+         st->total_bytes += block_size;
          n_bytes -= n_left;
          msg += n_left;
        }
     }
 
-  if ((n_blocks = n_bytes / ctx->block_size))
+  if ((n_blocks = n_bytes / block_size))
     {
-      if (ctx->block_size == SHA512_BLOCK_SIZE)
-       clib_sha512_block (ctx, msg, n_blocks);
+      if (block_size == CLIB_SHA2_512_BLOCK_SIZE)
+       clib_sha512_block (st, msg, n_blocks);
       else
-       clib_sha256_block (ctx, msg, n_blocks);
-      n_bytes -= n_blocks * ctx->block_size;
-      msg += n_blocks * ctx->block_size;
-      ctx->total_bytes += n_blocks * ctx->block_size;
+       clib_sha256_block (st, msg, n_blocks);
+      n_bytes -= n_blocks * block_size;
+      msg += n_blocks * block_size;
+      st->total_bytes += n_blocks * block_size;
     }
 
   if (n_bytes)
     {
-      clib_memset_u8 (ctx->pending.as_u8, 0, ctx->block_size);
-      clib_memcpy_fast (ctx->pending.as_u8, msg, n_bytes);
-      ctx->n_pending = n_bytes;
+      clib_memset_u8 (st->pending.as_u8, 0, block_size);
+      clib_memcpy_fast (st->pending.as_u8, msg, n_bytes);
+      st->n_pending = n_bytes;
     }
   else
-    ctx->n_pending = 0;
+    st->n_pending = 0;
 }
 
 static_always_inline void
-clib_sha2_final (clib_sha2_ctx_t *ctx, u8 *digest)
+clib_sha2_update (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_bytes)
+{
+  clib_sha2_update_internal (&ctx->state, ctx->block_size, msg, n_bytes);
+}
+
+static_always_inline void
+clib_sha2_final_internal (clib_sha2_state_t *st, u8 block_size, u8 digest_size,
+                         u8 *digest)
 {
   int i;
 
-  ctx->total_bytes += ctx->n_pending;
-  if (ctx->n_pending == 0)
+  st->total_bytes += st->n_pending;
+  if (st->n_pending == 0)
     {
-      clib_memset (ctx->pending.as_u8, 0, ctx->block_size);
-      ctx->pending.as_u8[0] = 0x80;
+      clib_memset (st->pending.as_u8, 0, block_size);
+      st->pending.as_u8[0] = 0x80;
     }
-  else if (ctx->n_pending + sizeof (u64) + sizeof (u8) > ctx->block_size)
+  else if (st->n_pending + sizeof (u64) + sizeof (u8) > block_size)
     {
-      ctx->pending.as_u8[ctx->n_pending] = 0x80;
-      if (ctx->block_size == SHA512_BLOCK_SIZE)
-       clib_sha512_block (ctx, ctx->pending.as_u8, 1);
+      st->pending.as_u8[st->n_pending] = 0x80;
+      if (block_size == CLIB_SHA2_512_BLOCK_SIZE)
+       clib_sha512_block (st, st->pending.as_u8, 1);
       else
-       clib_sha256_block (ctx, ctx->pending.as_u8, 1);
-      clib_memset (ctx->pending.as_u8, 0, ctx->block_size);
+       clib_sha256_block (st, st->pending.as_u8, 1);
+      clib_memset (st->pending.as_u8, 0, block_size);
     }
   else
-    ctx->pending.as_u8[ctx->n_pending] = 0x80;
+    st->pending.as_u8[st->n_pending] = 0x80;
 
-  ctx->pending.as_u64[ctx->block_size / 8 - 1] =
-    clib_net_to_host_u64 (ctx->total_bytes * 8);
-  if (ctx->block_size == SHA512_BLOCK_SIZE)
-    clib_sha512_block (ctx, ctx->pending.as_u8, 1);
-  else
-    clib_sha256_block (ctx, ctx->pending.as_u8, 1);
+  st->pending.as_u64[block_size / 8 - 1] =
+    clib_net_to_host_u64 (st->total_bytes * 8);
 
-  if (ctx->block_size == SHA512_BLOCK_SIZE)
+  if (block_size == CLIB_SHA2_512_BLOCK_SIZE)
     {
-      for (i = 0; i < ctx->digest_size / sizeof (u64); i++)
-       *((u64 *) digest + i) = clib_net_to_host_u64 (ctx->h64[i]);
+      clib_sha512_block (st, st->pending.as_u8, 1);
+      for (i = 0; i < digest_size / sizeof (u64); i++)
+       ((u64 *) digest)[i] = clib_net_to_host_u64 (st->h.h64[i]);
 
       /* sha512-224 case - write half of u64 */
-      if (i * sizeof (u64) < ctx->digest_size)
-       *((u32 *) digest + 2 * i) = clib_net_to_host_u32 (ctx->h64[i] >> 32);
+      if (i * sizeof (u64) < digest_size)
+       ((u32 *) digest)[2 * i] = clib_net_to_host_u32 (st->h.h64[i] >> 32);
     }
   else
-    for (i = 0; i < ctx->digest_size / sizeof (u32); i++)
-      *((u32 *) digest + i) = clib_net_to_host_u32 (ctx->h32[i]);
+    {
+      clib_sha256_block (st, st->pending.as_u8, 1);
+      for (i = 0; i < digest_size / sizeof (u32); i++)
+       *((u32 *) digest + i) = clib_net_to_host_u32 (st->h.h32[i]);
+    }
+}
+
+static_always_inline void
+clib_sha2_final (clib_sha2_ctx_t *ctx, u8 *digest)
+{
+  clib_sha2_final_internal (&ctx->state, ctx->block_size, ctx->digest_size,
+                           digest);
 }
 
 static_always_inline void
@@ -570,70 +583,133 @@ clib_sha2 (clib_sha2_type_t type, const u8 *msg, uword len, u8 *digest)
 #define clib_sha512_224(...) clib_sha2 (CLIB_SHA2_512_224, __VA_ARGS__)
 #define clib_sha512_256(...) clib_sha2 (CLIB_SHA2_512_256, __VA_ARGS__)
 
-static_always_inline void
-clib_hmac_sha2 (clib_sha2_type_t type, const u8 *key, uword key_len,
-               const u8 *msg, uword len, u8 *digest)
+/*
+ *  HMAC
+ */
+
+typedef struct
 {
-  clib_sha2_ctx_t _ctx, *ctx = &_ctx;
-  uword key_data[SHA2_MAX_BLOCK_SIZE / sizeof (uword)];
-  u8 i_digest[SHA2_MAX_DIGEST_SIZE];
-  int i, n_words;
+  clib_sha2_h_t ipad_h;
+  clib_sha2_h_t opad_h;
+} clib_sha2_hmac_key_data_t;
+
+typedef struct
+{
+  clib_sha2_type_t type;
+  u8 block_size;
+  u8 digest_size;
+  clib_sha2_state_t ipad_state;
+  clib_sha2_state_t opad_state;
+} clib_sha2_hmac_ctx_t;
 
-  clib_sha2_init (ctx, type);
-  n_words = ctx->block_size / sizeof (uword);
+static_always_inline void
+clib_sha2_hmac_key_data (clib_sha2_type_t type, const u8 *key, uword key_len,
+                        clib_sha2_hmac_key_data_t *kd)
+{
+  u8 block_size = clib_sha2_variants[type].block_size;
+  u8 data[SHA2_MAX_BLOCK_SIZE] = {};
+  u8 ikey[SHA2_MAX_BLOCK_SIZE];
+  u8 okey[SHA2_MAX_BLOCK_SIZE];
+  clib_sha2_state_t ipad_state;
+  clib_sha2_state_t opad_state;
 
   /* key */
-  if (key_len > ctx->block_size)
+  if (key_len > block_size)
     {
       /* key is longer than block, calculate hash of key */
-      clib_sha2_update (ctx, key, key_len);
-      for (i = (ctx->digest_size / sizeof (uword)) / 2; i < n_words; i++)
-       key_data[i] = 0;
-      clib_sha2_final (ctx, (u8 *) key_data);
-      clib_sha2_init (ctx, type);
+      clib_sha2_ctx_t ctx;
+      clib_sha2_init (&ctx, type);
+      clib_sha2_update (&ctx, key, key_len);
+      clib_sha2_final (&ctx, (u8 *) data);
     }
   else
+    clib_memcpy_fast (data, key, key_len);
+
+  for (int i = 0, w = 0; w < block_size; w += sizeof (uword), i++)
     {
-      for (i = 0; i < n_words; i++)
-       key_data[i] = 0;
-      clib_memcpy_fast (key_data, key, key_len);
+      ((uwordu *) ikey)[i] = ((uwordu *) data)[i] ^ 0x3636363636363636UL;
+      ((uwordu *) okey)[i] = ((uwordu *) data)[i] ^ 0x5c5c5c5c5c5c5c5cUL;
     }
 
-  /* ipad */
-  for (i = 0; i < n_words; i++)
-    ctx->pending.as_uword[i] = key_data[i] ^ (uword) 0x3636363636363636;
-  if (ctx->block_size == SHA512_BLOCK_SIZE)
-    clib_sha512_block (ctx, ctx->pending.as_u8, 1);
-  else
-    clib_sha256_block (ctx, ctx->pending.as_u8, 1);
-  ctx->total_bytes += ctx->block_size;
-
-  /* message */
-  clib_sha2_update (ctx, msg, len);
-  clib_sha2_final (ctx, i_digest);
-
-  /* opad */
-  clib_sha2_init (ctx, type);
-  for (i = 0; i < n_words; i++)
-    ctx->pending.as_uword[i] = key_data[i] ^ (uword) 0x5c5c5c5c5c5c5c5c;
-  if (ctx->block_size == SHA512_BLOCK_SIZE)
-    clib_sha512_block (ctx, ctx->pending.as_u8, 1);
+  clib_sha2_state_init (&ipad_state, type);
+  clib_sha2_state_init (&opad_state, type);
+
+  if (block_size == CLIB_SHA2_512_BLOCK_SIZE)
+    {
+      clib_sha512_block (&ipad_state, ikey, 1);
+      clib_sha512_block (&opad_state, okey, 1);
+    }
   else
-    clib_sha256_block (ctx, ctx->pending.as_u8, 1);
-  ctx->total_bytes += ctx->block_size;
+    {
+      clib_sha256_block (&ipad_state, ikey, 1);
+      clib_sha256_block (&opad_state, okey, 1);
+    }
+
+  kd->ipad_h = ipad_state.h;
+  kd->opad_h = opad_state.h;
+}
+
+static_always_inline void
+clib_sha2_hmac_init (clib_sha2_hmac_ctx_t *ctx, clib_sha2_type_t type,
+                    clib_sha2_hmac_key_data_t *kd)
+{
+  u8 block_size = clib_sha2_variants[type].block_size;
+  u8 digest_size = clib_sha2_variants[type].digest_size;
+
+  *ctx = (clib_sha2_hmac_ctx_t) {
+    .type = type,
+    .block_size = block_size,
+    .digest_size = digest_size,
+    .ipad_state = {
+      .h = kd->ipad_h,
+      .total_bytes = block_size,
+    },
+    .opad_state = {
+      .h = kd->opad_h,
+      .total_bytes = block_size,
+    },
+  };
+}
+
+static_always_inline void
+clib_sha2_hmac_update (clib_sha2_hmac_ctx_t *ctx, const u8 *msg, uword len)
+{
+  clib_sha2_update_internal (&ctx->ipad_state, ctx->block_size, msg, len);
+}
+
+static_always_inline void
+clib_sha2_hmac_final (clib_sha2_hmac_ctx_t *ctx, u8 *digest)
+{
+  u8 i_digest[SHA2_MAX_DIGEST_SIZE];
+
+  clib_sha2_final_internal (&ctx->ipad_state, ctx->block_size,
+                           ctx->digest_size, i_digest);
+  clib_sha2_update_internal (&ctx->opad_state, ctx->block_size, i_digest,
+                            ctx->digest_size);
+  clib_sha2_final_internal (&ctx->opad_state, ctx->block_size,
+                           ctx->digest_size, digest);
+}
+
+static_always_inline void
+clib_sha2_hmac (clib_sha2_type_t type, const u8 *key, uword key_len,
+               const u8 *msg, uword len, u8 *digest)
+{
+  clib_sha2_hmac_ctx_t _ctx, *ctx = &_ctx;
+  clib_sha2_hmac_key_data_t kd;
 
-  /* digest */
-  clib_sha2_update (ctx, i_digest, ctx->digest_size);
-  clib_sha2_final (ctx, digest);
+  clib_sha2_hmac_key_data (type, key, key_len, &kd);
+  clib_sha2_hmac_init (ctx, type, &kd);
+  clib_sha2_hmac_update (ctx, msg, len);
+  clib_sha2_hmac_final (ctx, digest);
 }
 
-#define clib_hmac_sha224(...) clib_hmac_sha2 (CLIB_SHA2_224, __VA_ARGS__)
-#define clib_hmac_sha256(...) clib_hmac_sha2 (CLIB_SHA2_256, __VA_ARGS__)
-#define clib_hmac_sha384(...) clib_hmac_sha2 (CLIB_SHA2_384, __VA_ARGS__)
-#define clib_hmac_sha512(...) clib_hmac_sha2 (CLIB_SHA2_512, __VA_ARGS__)
+#define clib_hmac_sha224(...) clib_sha2_hmac (CLIB_SHA2_224, __VA_ARGS__)
+#define clib_hmac_sha256(...) clib_sha2_hmac (CLIB_SHA2_256, __VA_ARGS__)
+#define clib_hmac_sha384(...) clib_sha2_hmac (CLIB_SHA2_384, __VA_ARGS__)
+#define clib_hmac_sha512(...) clib_sha2_hmac (CLIB_SHA2_512, __VA_ARGS__)
 #define clib_hmac_sha512_224(...)                                             \
-  clib_hmac_sha2 (CLIB_SHA2_512_224, __VA_ARGS__)
+  clib_sha2_hmac (CLIB_SHA2_512_224, __VA_ARGS__)
 #define clib_hmac_sha512_256(...)                                             \
-  clib_hmac_sha2 (CLIB_SHA2_512_256, __VA_ARGS__)
+  clib_sha2_hmac (CLIB_SHA2_512_256, __VA_ARGS__)
 
 #endif /* included_sha2_h */