vppinfra: native AES-CTR implementation
[vpp.git] / src / vppinfra / crypto / aes_ctr.h
1 /* SPDX-License-Identifier: Apache-2.0
2  * Copyright(c) 2024 Cisco Systems, Inc.
3  */
4
5 #ifndef __crypto_aes_ctr_h__
6 #define __crypto_aes_ctr_h__
7
8 #include <vppinfra/clib.h>
9 #include <vppinfra/vector.h>
10 #include <vppinfra/cache.h>
11 #include <vppinfra/string.h>
12 #include <vppinfra/crypto/aes.h>
13
14 typedef struct
15 {
16   const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1];
17 } aes_ctr_key_data_t;
18
19 typedef struct
20 {
21   const aes_expaned_key_t exp_key[AES_KEY_ROUNDS (AES_KEY_256) + 1];
22   aes_counter_t ctr;               /* counter (reflected) */
23   u8 keystream_bytes[N_AES_BYTES]; /* keystream leftovers */
24   u32 n_keystream_bytes;           /* number of keystream leftovers */
25 } aes_ctr_ctx_t;
26
27 static_always_inline aes_counter_t
28 aes_ctr_one_block (aes_ctr_ctx_t *ctx, aes_counter_t ctr, const u8 *src,
29                    u8 *dst, u32 n_parallel, u32 n_bytes, int rounds, int last)
30 {
31   u32 __clib_aligned (N_AES_BYTES)
32   inc[] = { N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0,
33             N_AES_LANES, 0, 0, 0, N_AES_LANES, 0, 0, 0 };
34   const aes_expaned_key_t *k = ctx->exp_key;
35   const aes_mem_t *sv = (aes_mem_t *) src;
36   aes_mem_t *dv = (aes_mem_t *) dst;
37   aes_data_t d[4], t[4];
38   u32 r;
39
40   n_bytes -= (n_parallel - 1) * N_AES_BYTES;
41
42   /* AES First Round */
43   for (int i = 0; i < n_parallel; i++)
44     {
45 #if N_AES_LANES == 4
46       t[i] = k[0].x4 ^ (u8x64) aes_reflect ((u8x64) ctr);
47 #elif N_AES_LANES == 2
48       t[i] = k[0].x2 ^ (u8x32) aes_reflect ((u8x32) ctr);
49 #else
50       t[i] = k[0].x1 ^ (u8x16) aes_reflect ((u8x16) ctr);
51 #endif
52       ctr += *(aes_counter_t *) inc;
53     }
54
55   /* Load Data */
56   for (int i = 0; i < n_parallel - last; i++)
57     d[i] = sv[i];
58
59   if (last)
60     d[n_parallel - 1] =
61       aes_load_partial ((u8 *) (sv + n_parallel - 1), n_bytes);
62
63   /* AES Intermediate Rounds */
64   for (r = 1; r < rounds; r++)
65     aes_enc_round (t, k + r, n_parallel);
66
67   /* AES Last Round */
68   aes_enc_last_round (t, d, k + r, n_parallel);
69
70   /* Store Data */
71   for (int i = 0; i < n_parallel - last; i++)
72     dv[i] = d[i];
73
74   if (last)
75     {
76       aes_store_partial (d[n_parallel - 1], dv + n_parallel - 1, n_bytes);
77       *(aes_data_t *) ctx->keystream_bytes = t[n_parallel - 1];
78       ctx->n_keystream_bytes = N_AES_BYTES - n_bytes;
79     }
80
81   return ctr;
82 }
83
84 static_always_inline void
85 clib_aes_ctr_init (aes_ctr_ctx_t *ctx, const aes_ctr_key_data_t *kd,
86                    const u8 *iv, aes_key_size_t ks)
87 {
88   u32x4 ctr = (u32x4) u8x16_reflect (*(u8x16u *) iv);
89 #if N_AES_LANES == 4
90   ctx->ctr = (aes_counter_t) u32x16_splat_u32x4 (ctr) +
91              (u32x16){ 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 };
92 #elif N_AES_LANES == 2
93   ctx->ctr = (aes_counter_t) u32x8_splat_u32x4 (ctr) +
94              (u32x8){ 0, 0, 0, 0, 1, 0, 0, 0 };
95 #else
96   ctx->ctr = ctr;
97 #endif
98   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
99     ((aes_expaned_key_t *) ctx->exp_key)[i] = kd->exp_key[i];
100   ctx->n_keystream_bytes = 0;
101 }
102
103 static_always_inline void
104 clib_aes_ctr_transform (aes_ctr_ctx_t *ctx, const u8 *src, u8 *dst,
105                         u32 n_bytes, aes_key_size_t ks)
106 {
107   int r = AES_KEY_ROUNDS (ks);
108   aes_counter_t ctr = ctx->ctr;
109
110   if (ctx->n_keystream_bytes)
111     {
112       u8 *ks = ctx->keystream_bytes + N_AES_BYTES - ctx->n_keystream_bytes;
113
114       if (ctx->n_keystream_bytes >= n_bytes)
115         {
116           for (int i = 0; i < n_bytes; i++)
117             dst[i] = src[i] ^ ks[i];
118           ctx->n_keystream_bytes -= n_bytes;
119           return;
120         }
121
122       for (int i = 0; i < ctx->n_keystream_bytes; i++)
123         dst++[0] = src++[0] ^ ks[i];
124
125       n_bytes -= ctx->n_keystream_bytes;
126       ctx->n_keystream_bytes = 0;
127     }
128
129   /* main loop */
130   for (int n = 4 * N_AES_BYTES; n_bytes >= n; n_bytes -= n, dst += n, src += n)
131     ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n, r, 0);
132
133   if (n_bytes)
134     {
135       if (n_bytes > 3 * N_AES_BYTES)
136         ctr = aes_ctr_one_block (ctx, ctr, src, dst, 4, n_bytes, r, 1);
137       else if (n_bytes > 2 * N_AES_BYTES)
138         ctr = aes_ctr_one_block (ctx, ctr, src, dst, 3, n_bytes, r, 1);
139       else if (n_bytes > N_AES_BYTES)
140         ctr = aes_ctr_one_block (ctx, ctr, src, dst, 2, n_bytes, r, 1);
141       else
142         ctr = aes_ctr_one_block (ctx, ctr, src, dst, 1, n_bytes, r, 1);
143     }
144   else
145     ctx->n_keystream_bytes = 0;
146
147   ctx->ctr = ctr;
148 }
149
150 static_always_inline void
151 clib_aes_ctr_key_expand (aes_ctr_key_data_t *kd, const u8 *key,
152                          aes_key_size_t ks)
153 {
154   u8x16 ek[AES_KEY_ROUNDS (AES_KEY_256) + 1];
155   aes_expaned_key_t *k = (aes_expaned_key_t *) kd->exp_key;
156
157   /* expand AES key */
158   aes_key_expand (ek, key, ks);
159   for (int i = 0; i < AES_KEY_ROUNDS (ks) + 1; i++)
160     k[i].lanes[0] = k[i].lanes[1] = k[i].lanes[2] = k[i].lanes[3] = ek[i];
161 }
162
163 static_always_inline void
164 clib_aes128_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
165                  const u8 *iv, u8 *dst)
166 {
167   aes_ctr_ctx_t ctx;
168   clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_128);
169   clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_128);
170 }
171
172 static_always_inline void
173 clib_aes192_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
174                  const u8 *iv, u8 *dst)
175 {
176   aes_ctr_ctx_t ctx;
177   clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_192);
178   clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_192);
179 }
180
181 static_always_inline void
182 clib_aes256_ctr (const aes_ctr_key_data_t *kd, const u8 *src, u32 n_bytes,
183                  const u8 *iv, u8 *dst)
184 {
185   aes_ctr_ctx_t ctx;
186   clib_aes_ctr_init (&ctx, kd, iv, AES_KEY_256);
187   clib_aes_ctr_transform (&ctx, src, dst, n_bytes, AES_KEY_256);
188 }
189
190 #endif /* __crypto_aes_ctr_h__ */