crypto-native: refactor GCM code to use generic types
[vpp.git] / src / plugins / crypto_native / aes.h
1 /*
2  *------------------------------------------------------------------
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *------------------------------------------------------------------
16  */
17
18 #ifndef __aesni_h__
19 #define __aesni_h__
20
21 typedef enum
22 {
23   AES_KEY_128 = 0,
24   AES_KEY_192 = 1,
25   AES_KEY_256 = 2,
26 } aes_key_size_t;
27
28 #define AES_KEY_ROUNDS(x)               (10 + x * 2)
29 #define AES_KEY_BYTES(x)                (16 + x * 8)
30
31 #ifdef __x86_64__
32
33 static const u8x16 byte_mask_scale = {
34   0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
35 };
36
37 static_always_inline u8x16
38 aes_block_load (u8 * p)
39 {
40   return (u8x16) _mm_loadu_si128 ((__m128i *) p);
41 }
42
43 static_always_inline u8x16
44 aes_enc_round (u8x16 a, u8x16 k)
45 {
46   return (u8x16) _mm_aesenc_si128 ((__m128i) a, (__m128i) k);
47 }
48
49 static_always_inline u8x16
50 aes_enc_last_round (u8x16 a, u8x16 k)
51 {
52   return (u8x16) _mm_aesenclast_si128 ((__m128i) a, (__m128i) k);
53 }
54
55 static_always_inline u8x16
56 aes_dec_round (u8x16 a, u8x16 k)
57 {
58   return (u8x16) _mm_aesdec_si128 ((__m128i) a, (__m128i) k);
59 }
60
61 static_always_inline u8x16
62 aes_dec_last_round (u8x16 a, u8x16 k)
63 {
64   return (u8x16) _mm_aesdeclast_si128 ((__m128i) a, (__m128i) k);
65 }
66
67 static_always_inline void
68 aes_block_store (u8 * p, u8x16 r)
69 {
70   _mm_storeu_si128 ((__m128i *) p, (__m128i) r);
71 }
72
73 static_always_inline u8x16
74 aes_byte_mask (u8x16 x, u8 n_bytes)
75 {
76   u8x16 mask = u8x16_is_greater (u8x16_splat (n_bytes), byte_mask_scale);
77   __m128i zero = { };
78
79   return (u8x16) _mm_blendv_epi8 (zero, (__m128i) x, (__m128i) mask);
80 }
81
82 static_always_inline u8x16
83 aes_load_partial (u8x16u * p, int n_bytes)
84 {
85   ASSERT (n_bytes <= 16);
86 #ifdef __AVX512F__
87   __m128i zero = { };
88   return (u8x16) _mm_mask_loadu_epi8 (zero, (1 << n_bytes) - 1, p);
89 #else
90   return aes_byte_mask (CLIB_MEM_OVERFLOW_LOAD (*, p), n_bytes);
91 #endif
92 }
93
94 static_always_inline void
95 aes_store_partial (void *p, u8x16 r, int n_bytes)
96 {
97 #ifdef __AVX512F__
98   _mm_mask_storeu_epi8 (p, (1 << n_bytes) - 1, (__m128i) r);
99 #else
100   u8x16 mask = u8x16_is_greater (u8x16_splat (n_bytes), byte_mask_scale);
101   _mm_maskmoveu_si128 ((__m128i) r, (__m128i) mask, p);
102 #endif
103 }
104
105
106 static_always_inline u8x16
107 aes_encrypt_block (u8x16 block, const u8x16 * round_keys, aes_key_size_t ks)
108 {
109   int i;
110   block ^= round_keys[0];
111   for (i = 1; i < AES_KEY_ROUNDS (ks); i += 1)
112     block = aes_enc_round (block, round_keys[i]);
113   return aes_enc_last_round (block, round_keys[i]);
114 }
115
116 static_always_inline u8x16
117 aes_inv_mix_column (u8x16 a)
118 {
119   return (u8x16) _mm_aesimc_si128 ((__m128i) a);
120 }
121
122 #define aes_keygen_assist(a, b) \
123   (u8x16) _mm_aeskeygenassist_si128((__m128i) a, b)
124
125 /* AES-NI based AES key expansion based on code samples from
126    Intel(r) Advanced Encryption Standard (AES) New Instructions White Paper
127    (323641-001) */
128
129 static_always_inline void
130 aes128_key_assist (u8x16 * rk, u8x16 r)
131 {
132   u8x16 t = rk[-1];
133   t ^= u8x16_word_shift_left (t, 4);
134   t ^= u8x16_word_shift_left (t, 4);
135   t ^= u8x16_word_shift_left (t, 4);
136   rk[0] = t ^ (u8x16) u32x4_shuffle ((u32x4) r, 3, 3, 3, 3);
137 }
138
139 static_always_inline void
140 aes128_key_expand (u8x16 * rk, u8x16 const *k)
141 {
142   rk[0] = k[0];
143   aes128_key_assist (rk + 1, aes_keygen_assist (rk[0], 0x01));
144   aes128_key_assist (rk + 2, aes_keygen_assist (rk[1], 0x02));
145   aes128_key_assist (rk + 3, aes_keygen_assist (rk[2], 0x04));
146   aes128_key_assist (rk + 4, aes_keygen_assist (rk[3], 0x08));
147   aes128_key_assist (rk + 5, aes_keygen_assist (rk[4], 0x10));
148   aes128_key_assist (rk + 6, aes_keygen_assist (rk[5], 0x20));
149   aes128_key_assist (rk + 7, aes_keygen_assist (rk[6], 0x40));
150   aes128_key_assist (rk + 8, aes_keygen_assist (rk[7], 0x80));
151   aes128_key_assist (rk + 9, aes_keygen_assist (rk[8], 0x1b));
152   aes128_key_assist (rk + 10, aes_keygen_assist (rk[9], 0x36));
153 }
154
155 static_always_inline void
156 aes192_key_assist (u8x16 * r1, u8x16 * r2, u8x16 key_assist)
157 {
158   u8x16 t;
159   r1[0] ^= t = u8x16_word_shift_left (r1[0], 4);
160   r1[0] ^= t = u8x16_word_shift_left (t, 4);
161   r1[0] ^= u8x16_word_shift_left (t, 4);
162   r1[0] ^= (u8x16) _mm_shuffle_epi32 ((__m128i) key_assist, 0x55);
163   r2[0] ^= u8x16_word_shift_left (r2[0], 4);
164   r2[0] ^= (u8x16) _mm_shuffle_epi32 ((__m128i) r1[0], 0xff);
165 }
166
167 static_always_inline void
168 aes192_key_expand (u8x16 * rk, u8x16u const *k)
169 {
170   u8x16 r1, r2;
171
172   rk[0] = r1 = k[0];
173   /* *INDENT-OFF* */
174   rk[1] = r2 = (u8x16) (u64x2) { *(u64 *) (k + 1), 0 };
175   /* *INDENT-ON* */
176
177   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x1));
178   rk[1] = (u8x16) _mm_shuffle_pd ((__m128d) rk[1], (__m128d) r1, 0);
179   rk[2] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
180
181   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x2));
182   rk[3] = r1;
183   rk[4] = r2;
184
185   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x4));
186   rk[4] = (u8x16) _mm_shuffle_pd ((__m128d) rk[4], (__m128d) r1, 0);
187   rk[5] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
188
189   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x8));
190   rk[6] = r1;
191   rk[7] = r2;
192
193   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x10));
194   rk[7] = (u8x16) _mm_shuffle_pd ((__m128d) rk[7], (__m128d) r1, 0);
195   rk[8] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
196
197   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x20));
198   rk[9] = r1;
199   rk[10] = r2;
200
201   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x40));
202   rk[10] = (u8x16) _mm_shuffle_pd ((__m128d) rk[10], (__m128d) r1, 0);
203   rk[11] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
204
205   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x80));
206   rk[12] = r1;
207 }
208
209 static_always_inline void
210 aes256_key_assist (u8x16 * rk, int i, u8x16 key_assist)
211 {
212   u8x16 r, t;
213   rk += i;
214   r = rk[-2];
215   r ^= t = u8x16_word_shift_left (r, 4);
216   r ^= t = u8x16_word_shift_left (t, 4);
217   r ^= u8x16_word_shift_left (t, 4);
218   r ^= (u8x16) u32x4_shuffle ((u32x4) key_assist, 3, 3, 3, 3);
219   rk[0] = r;
220
221   if (i >= 14)
222     return;
223
224   key_assist = aes_keygen_assist (rk[0], 0x0);
225   r = rk[-1];
226   r ^= t = u8x16_word_shift_left (r, 4);
227   r ^= t = u8x16_word_shift_left (t, 4);
228   r ^= u8x16_word_shift_left (t, 4);
229   r ^= (u8x16) u32x4_shuffle ((u32x4) key_assist, 2, 2, 2, 2);
230   rk[1] = r;
231 }
232
233 static_always_inline void
234 aes256_key_expand (u8x16 * rk, u8x16u const *k)
235 {
236   rk[0] = k[0];
237   rk[1] = k[1];
238   aes256_key_assist (rk, 2, aes_keygen_assist (rk[1], 0x01));
239   aes256_key_assist (rk, 4, aes_keygen_assist (rk[3], 0x02));
240   aes256_key_assist (rk, 6, aes_keygen_assist (rk[5], 0x04));
241   aes256_key_assist (rk, 8, aes_keygen_assist (rk[7], 0x08));
242   aes256_key_assist (rk, 10, aes_keygen_assist (rk[9], 0x10));
243   aes256_key_assist (rk, 12, aes_keygen_assist (rk[11], 0x20));
244   aes256_key_assist (rk, 14, aes_keygen_assist (rk[13], 0x40));
245 }
246 #endif
247
248 #ifdef __aarch64__
249
250 static_always_inline u8x16
251 aes_inv_mix_column (u8x16 a)
252 {
253   return vaesimcq_u8 (a);
254 }
255
256 static const u8x16 aese_prep_mask1 =
257   { 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12 };
258 static const u8x16 aese_prep_mask2 =
259   { 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 };
260
261 static inline void
262 aes128_key_expand_round_neon (u8x16 * rk, u32 rcon)
263 {
264   u8x16 r, t, last_round = rk[-1], z = { };
265   r = vqtbl1q_u8 (last_round, aese_prep_mask1);
266   r = vaeseq_u8 (r, z);
267   r ^= (u8x16) vdupq_n_u32 (rcon);
268   r ^= last_round;
269   r ^= t = vextq_u8 (z, last_round, 12);
270   r ^= t = vextq_u8 (z, t, 12);
271   r ^= vextq_u8 (z, t, 12);
272   rk[0] = r;
273 }
274
275 void
276 aes128_key_expand (u8x16 * rk, const u8x16 * k)
277 {
278   rk[0] = k[0];
279   aes128_key_expand_round_neon (rk + 1, 0x01);
280   aes128_key_expand_round_neon (rk + 2, 0x02);
281   aes128_key_expand_round_neon (rk + 3, 0x04);
282   aes128_key_expand_round_neon (rk + 4, 0x08);
283   aes128_key_expand_round_neon (rk + 5, 0x10);
284   aes128_key_expand_round_neon (rk + 6, 0x20);
285   aes128_key_expand_round_neon (rk + 7, 0x40);
286   aes128_key_expand_round_neon (rk + 8, 0x80);
287   aes128_key_expand_round_neon (rk + 9, 0x1b);
288   aes128_key_expand_round_neon (rk + 10, 0x36);
289 }
290
291 static inline void
292 aes192_key_expand_round_neon (u8x8 * rk, u32 rcon)
293 {
294   u8x8 r, last_round = rk[-1], z = { };
295   u8x16 r2, z2 = { };
296
297   r2 = (u8x16) vdupq_lane_u64 ((uint64x1_t) last_round, 0);
298   r2 = vqtbl1q_u8 (r2, aese_prep_mask1);
299   r2 = vaeseq_u8 (r2, z2);
300   r2 ^= (u8x16) vdupq_n_u32 (rcon);
301
302   r = (u8x8) vdup_laneq_u64 ((u64x2) r2, 0);
303   r ^= rk[-3];
304   r ^= vext_u8 (z, rk[-3], 4);
305   rk[0] = r;
306
307   r = rk[-2] ^ vext_u8 (r, z, 4);
308   r ^= vext_u8 (z, r, 4);
309   rk[1] = r;
310
311   if (rcon == 0x80)
312     return;
313
314   r = rk[-1] ^ vext_u8 (r, z, 4);
315   r ^= vext_u8 (z, r, 4);
316   rk[2] = r;
317 }
318
319 void
320 aes192_key_expand (u8x16 * ek, const u8x16u * k)
321 {
322   u8x8 *rk = (u8x8 *) ek;
323   ek[0] = k[0];
324   rk[2] = *(u8x8u *) (k + 1);
325   aes192_key_expand_round_neon (rk + 3, 0x01);
326   aes192_key_expand_round_neon (rk + 6, 0x02);
327   aes192_key_expand_round_neon (rk + 9, 0x04);
328   aes192_key_expand_round_neon (rk + 12, 0x08);
329   aes192_key_expand_round_neon (rk + 15, 0x10);
330   aes192_key_expand_round_neon (rk + 18, 0x20);
331   aes192_key_expand_round_neon (rk + 21, 0x40);
332   aes192_key_expand_round_neon (rk + 24, 0x80);
333 }
334
335
336 static inline void
337 aes256_key_expand_round_neon (u8x16 * rk, u32 rcon)
338 {
339   u8x16 r, t, z = { };
340
341   r = vqtbl1q_u8 (rk[-1], rcon ? aese_prep_mask1 : aese_prep_mask2);
342   r = vaeseq_u8 (r, z);
343   if (rcon)
344     r ^= (u8x16) vdupq_n_u32 (rcon);
345   r ^= rk[-2];
346   r ^= t = vextq_u8 (z, rk[-2], 12);
347   r ^= t = vextq_u8 (z, t, 12);
348   r ^= vextq_u8 (z, t, 12);
349   rk[0] = r;
350 }
351
352 void
353 aes256_key_expand (u8x16 * rk, u8x16 const *k)
354 {
355   rk[0] = k[0];
356   rk[1] = k[1];
357   aes256_key_expand_round_neon (rk + 2, 0x01);
358   aes256_key_expand_round_neon (rk + 3, 0);
359   aes256_key_expand_round_neon (rk + 4, 0x02);
360   aes256_key_expand_round_neon (rk + 5, 0);
361   aes256_key_expand_round_neon (rk + 6, 0x04);
362   aes256_key_expand_round_neon (rk + 7, 0);
363   aes256_key_expand_round_neon (rk + 8, 0x08);
364   aes256_key_expand_round_neon (rk + 9, 0);
365   aes256_key_expand_round_neon (rk + 10, 0x10);
366   aes256_key_expand_round_neon (rk + 11, 0);
367   aes256_key_expand_round_neon (rk + 12, 0x20);
368   aes256_key_expand_round_neon (rk + 13, 0);
369   aes256_key_expand_round_neon (rk + 14, 0x40);
370 }
371
372 #endif
373
374 static_always_inline void
375 aes_key_expand (u8x16 * key_schedule, u8 const *key, aes_key_size_t ks)
376 {
377   switch (ks)
378     {
379     case AES_KEY_128:
380       aes128_key_expand (key_schedule, (u8x16u const *) key);
381       break;
382     case AES_KEY_192:
383       aes192_key_expand (key_schedule, (u8x16u const *) key);
384       break;
385     case AES_KEY_256:
386       aes256_key_expand (key_schedule, (u8x16u const *) key);
387       break;
388     }
389 }
390
391 static_always_inline void
392 aes_key_enc_to_dec (u8x16 * ke, u8x16 * kd, aes_key_size_t ks)
393 {
394   int rounds = AES_KEY_ROUNDS (ks);
395
396   kd[rounds] = ke[0];
397   kd[0] = ke[rounds];
398
399   for (int i = 1; i < (rounds / 2); i++)
400     {
401       kd[rounds - i] = aes_inv_mix_column (ke[i]);
402       kd[i] = aes_inv_mix_column (ke[rounds - i]);
403     }
404
405   kd[rounds / 2] = aes_inv_mix_column (ke[rounds / 2]);
406 }
407
408 #endif /* __aesni_h__ */
409
410 /*
411  * fd.io coding-style-patch-verification: ON
412  *
413  * Local Variables:
414  * eval: (c-set-style "gnu")
415  * End:
416  */