vppinfra: native AES-CTR implementation
[vpp.git] / src / vppinfra / crypto / aes.h
index a5e286e..0aa1541 100644 (file)
@@ -15,8 +15,8 @@
  *------------------------------------------------------------------
  */
 
-#ifndef __aesni_h__
-#define __aesni_h__
+#ifndef __aes_h__
+#define __aes_h__
 
 typedef enum
 {
@@ -35,7 +35,7 @@ aes_block_load (u8 * p)
 }
 
 static_always_inline u8x16
-aes_enc_round (u8x16 a, u8x16 k)
+aes_enc_round_x1 (u8x16 a, u8x16 k)
 {
 #if defined (__AES__)
   return (u8x16) _mm_aesenc_si128 ((__m128i) a, (__m128i) k);
@@ -97,7 +97,7 @@ aes_dec_last_round_x2 (u8x32 a, u8x32 k)
 #endif
 
 static_always_inline u8x16
-aes_enc_last_round (u8x16 a, u8x16 k)
+aes_enc_last_round_x1 (u8x16 a, u8x16 k)
 {
 #if defined (__AES__)
   return (u8x16) _mm_aesenclast_si128 ((__m128i) a, (__m128i) k);
@@ -109,13 +109,13 @@ aes_enc_last_round (u8x16 a, u8x16 k)
 #ifdef __x86_64__
 
 static_always_inline u8x16
-aes_dec_round (u8x16 a, u8x16 k)
+aes_dec_round_x1 (u8x16 a, u8x16 k)
 {
   return (u8x16) _mm_aesdec_si128 ((__m128i) a, (__m128i) k);
 }
 
 static_always_inline u8x16
-aes_dec_last_round (u8x16 a, u8x16 k)
+aes_dec_last_round_x1 (u8x16 a, u8x16 k)
 {
   return (u8x16) _mm_aesdeclast_si128 ((__m128i) a, (__m128i) k);
 }
@@ -133,8 +133,8 @@ aes_encrypt_block (u8x16 block, const u8x16 * round_keys, aes_key_size_t ks)
   int rounds = AES_KEY_ROUNDS (ks);
   block ^= round_keys[0];
   for (int i = 1; i < rounds; i += 1)
-    block = aes_enc_round (block, round_keys[i]);
-  return aes_enc_last_round (block, round_keys[rounds]);
+    block = aes_enc_round_x1 (block, round_keys[i]);
+  return aes_enc_last_round_x1 (block, round_keys[rounds]);
 }
 
 static_always_inline u8x16
@@ -427,13 +427,67 @@ aes_key_enc_to_dec (u8x16 * ke, u8x16 * kd, aes_key_size_t ks)
 
   kd[rounds / 2] = aes_inv_mix_column (ke[rounds / 2]);
 }
+#if defined(__VAES__) && defined(__AVX512F__)
+#define N_AES_LANES               4
+#define aes_load_partial(p, n)    u8x64_load_partial ((u8 *) (p), n)
+#define aes_store_partial(v, p, n) u8x64_store_partial (v, (u8 *) (p), n)
+#define aes_reflect(r)            u8x64_reflect_u8x16 (r)
+typedef u8x64 aes_data_t;
+typedef u8x64u aes_mem_t;
+typedef u32x16 aes_counter_t;
+#elif defined(__VAES__)
+#define N_AES_LANES               2
+#define aes_load_partial(p, n)    u8x32_load_partial ((u8 *) (p), n)
+#define aes_store_partial(v, p, n) u8x32_store_partial (v, (u8 *) (p), n)
+#define aes_reflect(r)            u8x32_reflect_u8x16 (r)
+typedef u8x32 aes_data_t;
+typedef u8x32u aes_mem_t;
+typedef u32x8 aes_counter_t;
+#else
+#define N_AES_LANES               1
+#define aes_load_partial(p, n)    u8x16_load_partial ((u8 *) (p), n)
+#define aes_store_partial(v, p, n) u8x16_store_partial (v, (u8 *) (p), n)
+#define aes_reflect(r)            u8x16_reflect (r)
+typedef u8x16 aes_data_t;
+typedef u8x16u aes_mem_t;
+typedef u32x4 aes_counter_t;
+#endif
 
-#endif /* __aesni_h__ */
+#define N_AES_BYTES (N_AES_LANES * 16)
 
-/*
- * fd.io coding-style-patch-verification: ON
- *
- * Local Variables:
- * eval: (c-set-style "gnu")
- * End:
- */
+typedef union
+{
+  u8x16 x1;
+  u8x32 x2;
+  u8x64 x4;
+  u8x16 lanes[4];
+} aes_expaned_key_t;
+
+static_always_inline void
+aes_enc_round (aes_data_t *r, const aes_expaned_key_t *ek, uword n_blocks)
+{
+  for (int i = 0; i < n_blocks; i++)
+#if N_AES_LANES == 4
+    r[i] = aes_enc_round_x4 (r[i], ek->x4);
+#elif N_AES_LANES == 2
+    r[i] = aes_enc_round_x2 (r[i], ek->x2);
+#else
+    r[i] = aes_enc_round_x1 (r[i], ek->x1);
+#endif
+}
+
+static_always_inline void
+aes_enc_last_round (aes_data_t *r, aes_data_t *d, const aes_expaned_key_t *ek,
+                   uword n_blocks)
+{
+  for (int i = 0; i < n_blocks; i++)
+#if N_AES_LANES == 4
+    d[i] ^= r[i] = aes_enc_last_round_x4 (r[i], ek->x4);
+#elif N_AES_LANES == 2
+    d[i] ^= r[i] = aes_enc_last_round_x2 (r[i], ek->x2);
+#else
+    d[i] ^= r[i] = aes_enc_last_round_x1 (r[i], ek->x1);
+#endif
+}
+
+#endif /* __aes_h__ */