vppinfra: SHA2-256 ARM ISA support 50/40450/2
authorDamjan Marion <damarion@cisco.com>
Sun, 3 Mar 2024 22:12:21 +0000 (22:12 +0000)
committerDamjan Marion <damarion@cisco.com>
Mon, 4 Mar 2024 14:07:25 +0000 (14:07 +0000)
Change-Id: I0fcda3e7afaab037bd12d0079d6639c6cbe8647e
Type: improvement
Signed-off-by: Damjan Marion <damarion@cisco.com>
src/vppinfra/crypto/sha2.h

index ce99fa3..5100615 100644 (file)
@@ -17,6 +17,7 @@
 #define included_sha2_h
 
 #include <vppinfra/clib.h>
+#include <vppinfra/vector.h>
 
 #define SHA224_DIGEST_SIZE 28
 #define SHA224_BLOCK_SIZE  64
     s[0] = t1 + t2;                                                           \
   }
 
+#if defined(__SHA__) && defined(__x86_64__)
+#define CLIB_SHA256_ISA_INTEL
+#define CLIB_SHA256_ISA
+#endif
+
+#ifdef __ARM_FEATURE_SHA2
+#define CLIB_SHA256_ISA_ARM
+#define CLIB_SHA256_ISA
+#endif
+
 static const u32 sha224_h[8] = { 0xc1059ed8, 0x367cd507, 0x3070dd17,
                                 0xf70e5939, 0xffc00b31, 0x68581511,
                                 0x64f98fa7, 0xbefa4fa4 };
@@ -201,7 +212,7 @@ typedef struct
   {
     u32 h32[8];
     u64 h64[8];
-#if defined(__SHA__) && defined(__x86_64__)
+#ifdef CLIB_SHA256_ISA
     u32x4 h32x4[2];
 #endif
   };
@@ -264,94 +275,132 @@ clib_sha2_init (clib_sha2_ctx_t *ctx, clib_sha2_type_t type)
       ctx->h64[i] = h64[i];
 }
 
-#if defined(__SHA__) && defined(__x86_64__)
+#ifdef CLIB_SHA256_ISA
 static inline void
-shani_sha256_cycle_w (u32x4 cw[], u8 a, u8 b, u8 c, u8 d)
+clib_sha256_vec_cycle_w (u32x4 w[], u8 i)
 {
-  cw[a] = (u32x4) _mm_sha256msg1_epu32 ((__m128i) cw[a], (__m128i) cw[b]);
-  cw[a] += (u32x4) _mm_alignr_epi8 ((__m128i) cw[d], (__m128i) cw[c], 4);
-  cw[a] = (u32x4) _mm_sha256msg2_epu32 ((__m128i) cw[a], (__m128i) cw[d]);
+  u8 j = (i + 1) % 4;
+  u8 k = (i + 2) % 4;
+  u8 l = (i + 3) % 4;
+#ifdef CLIB_SHA256_ISA_INTEL
+  w[i] = (u32x4) _mm_sha256msg1_epu32 ((__m128i) w[i], (__m128i) w[j]);
+  w[i] += (u32x4) _mm_alignr_epi8 ((__m128i) w[l], (__m128i) w[k], 4);
+  w[i] = (u32x4) _mm_sha256msg2_epu32 ((__m128i) w[i], (__m128i) w[l]);
+#elif defined(CLIB_SHA256_ISA_ARM)
+  w[i] = vsha256su1q_u32 (vsha256su0q_u32 (w[i], w[j]), w[k], w[l]);
+#endif
 }
 
 static inline void
-shani_sha256_4_rounds (u32x4 cw, u8 n, u32x4 s[])
+clib_sha256_vec_4_rounds (u32x4 w, u8 n, u32x4 s[])
 {
-  u32x4 r = *(u32x4 *) (sha256_k + 4 * n) + cw;
+#ifdef CLIB_SHA256_ISA_INTEL
+  u32x4 r = *(u32x4 *) (sha256_k + 4 * n) + w;
   s[0] = (u32x4) _mm_sha256rnds2_epu32 ((__m128i) s[0], (__m128i) s[1],
                                        (__m128i) r);
   r = (u32x4) u64x2_interleave_hi ((u64x2) r, (u64x2) r);
   s[1] = (u32x4) _mm_sha256rnds2_epu32 ((__m128i) s[1], (__m128i) s[0],
                                        (__m128i) r);
+#elif defined(CLIB_SHA256_ISA_ARM)
+  u32x4 r0, s0;
+  const u32x4u *k = (u32x4u *) sha256_k;
+
+  r0 = w + k[n];
+  s0 = s[0];
+  s[0] = vsha256hq_u32 (s[0], s[1], r0);
+  s[1] = vsha256h2q_u32 (s[1], s0, r0);
+#endif
+}
+#endif
+
+#if defined(CLIB_SHA256_ISA)
+static inline u32x4
+clib_sha256_vec_load (u32x4 r)
+{
+#if defined(CLIB_SHA256_ISA_INTEL)
+  return u32x4_byte_swap (r);
+#elif defined(CLIB_SHA256_ISA_ARM)
+  return vreinterpretq_u32_u8 (vrev32q_u8 (vreinterpretq_u8_u32 (r)));
+#endif
 }
 
 static inline void
-shani_sha256_shuffle (u32x4 d[2], u32x4 s[2])
+clib_sha256_vec_shuffle (u32x4 d[2])
 {
+#if defined(CLIB_SHA256_ISA_INTEL)
   /* {0, 1, 2, 3}, {4, 5, 6, 7} -> {7, 6, 3, 2}, {5, 4, 1, 0} */
-  d[0] = (u32x4) _mm_shuffle_ps ((__m128) s[1], (__m128) s[0], 0xbb);
-  d[1] = (u32x4) _mm_shuffle_ps ((__m128) s[1], (__m128) s[0], 0x11);
+  u32x4 r;
+  r = (u32x4) _mm_shuffle_ps ((__m128) d[1], (__m128) d[0], 0xbb);
+  d[1] = (u32x4) _mm_shuffle_ps ((__m128) d[1], (__m128) d[0], 0x11);
+  d[0] = r;
+#endif
 }
 #endif
 
 static inline void
 clib_sha256_block (clib_sha2_ctx_t *ctx, const u8 *msg, uword n_blocks)
 {
-#if defined(__SHA__) && defined(__x86_64__)
-  u32x4 h[2], s[2], w[4];
+#if defined(CLIB_SHA256_ISA)
+  u32x4 h[2];
+  u32x4u *m = (u32x4u *) msg;
 
-  shani_sha256_shuffle (h, ctx->h32x4);
+  h[0] = ctx->h32x4[0];
+  h[1] = ctx->h32x4[1];
 
-  while (n_blocks)
+  clib_sha256_vec_shuffle (h);
+
+  for (; n_blocks; m += 4, n_blocks--)
     {
-      w[0] = u32x4_byte_swap (u32x4_load_unaligned ((u8 *) msg + 0));
-      w[1] = u32x4_byte_swap (u32x4_load_unaligned ((u8 *) msg + 16));
-      w[2] = u32x4_byte_swap (u32x4_load_unaligned ((u8 *) msg + 32));
-      w[3] = u32x4_byte_swap (u32x4_load_unaligned ((u8 *) msg + 48));
+      u32x4 s[2], w[4];
 
       s[0] = h[0];
       s[1] = h[1];
 
-      shani_sha256_4_rounds (w[0], 0, s);
-      shani_sha256_4_rounds (w[1], 1, s);
-      shani_sha256_4_rounds (w[2], 2, s);
-      shani_sha256_4_rounds (w[3], 3, s);
-
-      shani_sha256_cycle_w (w, 0, 1, 2, 3);
-      shani_sha256_4_rounds (w[0], 4, s);
-      shani_sha256_cycle_w (w, 1, 2, 3, 0);
-      shani_sha256_4_rounds (w[1], 5, s);
-      shani_sha256_cycle_w (w, 2, 3, 0, 1);
-      shani_sha256_4_rounds (w[2], 6, s);
-      shani_sha256_cycle_w (w, 3, 0, 1, 2);
-      shani_sha256_4_rounds (w[3], 7, s);
-
-      shani_sha256_cycle_w (w, 0, 1, 2, 3);
-      shani_sha256_4_rounds (w[0], 8, s);
-      shani_sha256_cycle_w (w, 1, 2, 3, 0);
-      shani_sha256_4_rounds (w[1], 9, s);
-      shani_sha256_cycle_w (w, 2, 3, 0, 1);
-      shani_sha256_4_rounds (w[2], 10, s);
-      shani_sha256_cycle_w (w, 3, 0, 1, 2);
-      shani_sha256_4_rounds (w[3], 11, s);
-
-      shani_sha256_cycle_w (w, 0, 1, 2, 3);
-      shani_sha256_4_rounds (w[0], 12, s);
-      shani_sha256_cycle_w (w, 1, 2, 3, 0);
-      shani_sha256_4_rounds (w[1], 13, s);
-      shani_sha256_cycle_w (w, 2, 3, 0, 1);
-      shani_sha256_4_rounds (w[2], 14, s);
-      shani_sha256_cycle_w (w, 3, 0, 1, 2);
-      shani_sha256_4_rounds (w[3], 15, s);
+      w[0] = clib_sha256_vec_load (m[0]);
+      w[1] = clib_sha256_vec_load (m[1]);
+      w[2] = clib_sha256_vec_load (m[2]);
+      w[3] = clib_sha256_vec_load (m[3]);
+
+      clib_sha256_vec_4_rounds (w[0], 0, s);
+      clib_sha256_vec_4_rounds (w[1], 1, s);
+      clib_sha256_vec_4_rounds (w[2], 2, s);
+      clib_sha256_vec_4_rounds (w[3], 3, s);
+
+      clib_sha256_vec_cycle_w (w, 0);
+      clib_sha256_vec_4_rounds (w[0], 4, s);
+      clib_sha256_vec_cycle_w (w, 1);
+      clib_sha256_vec_4_rounds (w[1], 5, s);
+      clib_sha256_vec_cycle_w (w, 2);
+      clib_sha256_vec_4_rounds (w[2], 6, s);
+      clib_sha256_vec_cycle_w (w, 3);
+      clib_sha256_vec_4_rounds (w[3], 7, s);
+
+      clib_sha256_vec_cycle_w (w, 0);
+      clib_sha256_vec_4_rounds (w[0], 8, s);
+      clib_sha256_vec_cycle_w (w, 1);
+      clib_sha256_vec_4_rounds (w[1], 9, s);
+      clib_sha256_vec_cycle_w (w, 2);
+      clib_sha256_vec_4_rounds (w[2], 10, s);
+      clib_sha256_vec_cycle_w (w, 3);
+      clib_sha256_vec_4_rounds (w[3], 11, s);
+
+      clib_sha256_vec_cycle_w (w, 0);
+      clib_sha256_vec_4_rounds (w[0], 12, s);
+      clib_sha256_vec_cycle_w (w, 1);
+      clib_sha256_vec_4_rounds (w[1], 13, s);
+      clib_sha256_vec_cycle_w (w, 2);
+      clib_sha256_vec_4_rounds (w[2], 14, s);
+      clib_sha256_vec_cycle_w (w, 3);
+      clib_sha256_vec_4_rounds (w[3], 15, s);
 
       h[0] += s[0];
       h[1] += s[1];
-
-      /* next */
-      msg += SHA256_BLOCK_SIZE;
-      n_blocks--;
     }
 
-  shani_sha256_shuffle (ctx->h32x4, h);
+  clib_sha256_vec_shuffle (h);
+
+  ctx->h32x4[0] = h[0];
+  ctx->h32x4[1] = h[1];
 #else
   u32 w[64], s[8], i;