aacbf8ae34dd3702e5bfcec58b26131885fac142
[vpp.git] / src / plugins / crypto_native / aes.h
1 /*
2  *------------------------------------------------------------------
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *------------------------------------------------------------------
16  */
17
18 #ifndef __aesni_h__
19 #define __aesni_h__
20
21 typedef enum
22 {
23   AES_KEY_128 = 0,
24   AES_KEY_192 = 1,
25   AES_KEY_256 = 2,
26 } aes_key_size_t;
27
28 #define AES_KEY_ROUNDS(x)               (10 + x * 2)
29 #define AES_KEY_BYTES(x)                (16 + x * 8)
30
31 #ifdef __x86_64__
32
33 static const u8x16 byte_mask_scale = {
34   0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
35 };
36
37 static_always_inline u8x16
38 aes_block_load (u8 * p)
39 {
40   return (u8x16) _mm_loadu_si128 ((__m128i *) p);
41 }
42
43 static_always_inline u8x16
44 aes_enc_round (u8x16 a, u8x16 k)
45 {
46   return (u8x16) _mm_aesenc_si128 ((__m128i) a, (__m128i) k);
47 }
48
49 static_always_inline u8x16
50 aes_enc_last_round (u8x16 a, u8x16 k)
51 {
52   return (u8x16) _mm_aesenclast_si128 ((__m128i) a, (__m128i) k);
53 }
54
55 static_always_inline u8x16
56 aes_dec_round (u8x16 a, u8x16 k)
57 {
58   return (u8x16) _mm_aesdec_si128 ((__m128i) a, (__m128i) k);
59 }
60
61 static_always_inline u8x16
62 aes_dec_last_round (u8x16 a, u8x16 k)
63 {
64   return (u8x16) _mm_aesdeclast_si128 ((__m128i) a, (__m128i) k);
65 }
66
67 static_always_inline void
68 aes_block_store (u8 * p, u8x16 r)
69 {
70   _mm_storeu_si128 ((__m128i *) p, (__m128i) r);
71 }
72
73 static_always_inline u8x16
74 aes_byte_mask (u8x16 x, u8 n_bytes)
75 {
76   return x & u8x16_is_greater (u8x16_splat (n_bytes), byte_mask_scale);
77 }
78
79 static_always_inline u8x16
80 aes_load_partial (u8x16u * p, int n_bytes)
81 {
82   ASSERT (n_bytes <= 16);
83 #ifdef __AVX512F__
84   __m128i zero = { };
85   return (u8x16) _mm_mask_loadu_epi8 (zero, (1 << n_bytes) - 1, p);
86 #else
87   return aes_byte_mask (CLIB_MEM_OVERFLOW_LOAD (*, p), n_bytes);
88 #endif
89 }
90
91 static_always_inline void
92 aes_store_partial (void *p, u8x16 r, int n_bytes)
93 {
94 #ifdef __AVX512F__
95   _mm_mask_storeu_epi8 (p, (1 << n_bytes) - 1, (__m128i) r);
96 #else
97   u8x16 mask = u8x16_is_greater (u8x16_splat (n_bytes), byte_mask_scale);
98   _mm_maskmoveu_si128 ((__m128i) r, (__m128i) mask, p);
99 #endif
100 }
101
102
103 static_always_inline u8x16
104 aes_encrypt_block (u8x16 block, const u8x16 * round_keys, aes_key_size_t ks)
105 {
106   int i;
107   block ^= round_keys[0];
108   for (i = 1; i < AES_KEY_ROUNDS (ks); i += 1)
109     block = aes_enc_round (block, round_keys[i]);
110   return aes_enc_last_round (block, round_keys[i]);
111 }
112
113 static_always_inline u8x16
114 aes_inv_mix_column (u8x16 a)
115 {
116   return (u8x16) _mm_aesimc_si128 ((__m128i) a);
117 }
118
119 #define aes_keygen_assist(a, b) \
120   (u8x16) _mm_aeskeygenassist_si128((__m128i) a, b)
121
122 /* AES-NI based AES key expansion based on code samples from
123    Intel(r) Advanced Encryption Standard (AES) New Instructions White Paper
124    (323641-001) */
125
126 static_always_inline void
127 aes128_key_assist (u8x16 * rk, u8x16 r)
128 {
129   u8x16 t = rk[-1];
130   t ^= u8x16_word_shift_left (t, 4);
131   t ^= u8x16_word_shift_left (t, 4);
132   t ^= u8x16_word_shift_left (t, 4);
133   rk[0] = t ^ (u8x16) u32x4_shuffle ((u32x4) r, 3, 3, 3, 3);
134 }
135
136 static_always_inline void
137 aes128_key_expand (u8x16 * rk, u8x16 const *k)
138 {
139   rk[0] = k[0];
140   aes128_key_assist (rk + 1, aes_keygen_assist (rk[0], 0x01));
141   aes128_key_assist (rk + 2, aes_keygen_assist (rk[1], 0x02));
142   aes128_key_assist (rk + 3, aes_keygen_assist (rk[2], 0x04));
143   aes128_key_assist (rk + 4, aes_keygen_assist (rk[3], 0x08));
144   aes128_key_assist (rk + 5, aes_keygen_assist (rk[4], 0x10));
145   aes128_key_assist (rk + 6, aes_keygen_assist (rk[5], 0x20));
146   aes128_key_assist (rk + 7, aes_keygen_assist (rk[6], 0x40));
147   aes128_key_assist (rk + 8, aes_keygen_assist (rk[7], 0x80));
148   aes128_key_assist (rk + 9, aes_keygen_assist (rk[8], 0x1b));
149   aes128_key_assist (rk + 10, aes_keygen_assist (rk[9], 0x36));
150 }
151
152 static_always_inline void
153 aes192_key_assist (u8x16 * r1, u8x16 * r2, u8x16 key_assist)
154 {
155   u8x16 t;
156   r1[0] ^= t = u8x16_word_shift_left (r1[0], 4);
157   r1[0] ^= t = u8x16_word_shift_left (t, 4);
158   r1[0] ^= u8x16_word_shift_left (t, 4);
159   r1[0] ^= (u8x16) _mm_shuffle_epi32 ((__m128i) key_assist, 0x55);
160   r2[0] ^= u8x16_word_shift_left (r2[0], 4);
161   r2[0] ^= (u8x16) _mm_shuffle_epi32 ((__m128i) r1[0], 0xff);
162 }
163
164 static_always_inline void
165 aes192_key_expand (u8x16 * rk, u8x16u const *k)
166 {
167   u8x16 r1, r2;
168
169   rk[0] = r1 = k[0];
170   /* *INDENT-OFF* */
171   rk[1] = r2 = (u8x16) (u64x2) { *(u64 *) (k + 1), 0 };
172   /* *INDENT-ON* */
173
174   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x1));
175   rk[1] = (u8x16) _mm_shuffle_pd ((__m128d) rk[1], (__m128d) r1, 0);
176   rk[2] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
177
178   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x2));
179   rk[3] = r1;
180   rk[4] = r2;
181
182   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x4));
183   rk[4] = (u8x16) _mm_shuffle_pd ((__m128d) rk[4], (__m128d) r1, 0);
184   rk[5] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
185
186   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x8));
187   rk[6] = r1;
188   rk[7] = r2;
189
190   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x10));
191   rk[7] = (u8x16) _mm_shuffle_pd ((__m128d) rk[7], (__m128d) r1, 0);
192   rk[8] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
193
194   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x20));
195   rk[9] = r1;
196   rk[10] = r2;
197
198   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x40));
199   rk[10] = (u8x16) _mm_shuffle_pd ((__m128d) rk[10], (__m128d) r1, 0);
200   rk[11] = (u8x16) _mm_shuffle_pd ((__m128d) r1, (__m128d) r2, 1);
201
202   aes192_key_assist (&r1, &r2, aes_keygen_assist (r2, 0x80));
203   rk[12] = r1;
204 }
205
206 static_always_inline void
207 aes256_key_assist (u8x16 * rk, int i, u8x16 key_assist)
208 {
209   u8x16 r, t;
210   rk += i;
211   r = rk[-2];
212   r ^= t = u8x16_word_shift_left (r, 4);
213   r ^= t = u8x16_word_shift_left (t, 4);
214   r ^= u8x16_word_shift_left (t, 4);
215   r ^= (u8x16) u32x4_shuffle ((u32x4) key_assist, 3, 3, 3, 3);
216   rk[0] = r;
217
218   if (i >= 14)
219     return;
220
221   key_assist = aes_keygen_assist (rk[0], 0x0);
222   r = rk[-1];
223   r ^= t = u8x16_word_shift_left (r, 4);
224   r ^= t = u8x16_word_shift_left (t, 4);
225   r ^= u8x16_word_shift_left (t, 4);
226   r ^= (u8x16) u32x4_shuffle ((u32x4) key_assist, 2, 2, 2, 2);
227   rk[1] = r;
228 }
229
230 static_always_inline void
231 aes256_key_expand (u8x16 * rk, u8x16u const *k)
232 {
233   rk[0] = k[0];
234   rk[1] = k[1];
235   aes256_key_assist (rk, 2, aes_keygen_assist (rk[1], 0x01));
236   aes256_key_assist (rk, 4, aes_keygen_assist (rk[3], 0x02));
237   aes256_key_assist (rk, 6, aes_keygen_assist (rk[5], 0x04));
238   aes256_key_assist (rk, 8, aes_keygen_assist (rk[7], 0x08));
239   aes256_key_assist (rk, 10, aes_keygen_assist (rk[9], 0x10));
240   aes256_key_assist (rk, 12, aes_keygen_assist (rk[11], 0x20));
241   aes256_key_assist (rk, 14, aes_keygen_assist (rk[13], 0x40));
242 }
243 #endif
244
245 #ifdef __aarch64__
246
247 static_always_inline u8x16
248 aes_inv_mix_column (u8x16 a)
249 {
250   return vaesimcq_u8 (a);
251 }
252
253 static const u8x16 aese_prep_mask1 =
254   { 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12 };
255 static const u8x16 aese_prep_mask2 =
256   { 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 };
257
258 static_always_inline void
259 aes128_key_expand_round_neon (u8x16 * rk, u32 rcon)
260 {
261   u8x16 r, t, last_round = rk[-1], z = { };
262   r = vqtbl1q_u8 (last_round, aese_prep_mask1);
263   r = vaeseq_u8 (r, z);
264   r ^= (u8x16) vdupq_n_u32 (rcon);
265   r ^= last_round;
266   r ^= t = vextq_u8 (z, last_round, 12);
267   r ^= t = vextq_u8 (z, t, 12);
268   r ^= vextq_u8 (z, t, 12);
269   rk[0] = r;
270 }
271
272 static_always_inline void
273 aes128_key_expand (u8x16 * rk, const u8x16 * k)
274 {
275   rk[0] = k[0];
276   aes128_key_expand_round_neon (rk + 1, 0x01);
277   aes128_key_expand_round_neon (rk + 2, 0x02);
278   aes128_key_expand_round_neon (rk + 3, 0x04);
279   aes128_key_expand_round_neon (rk + 4, 0x08);
280   aes128_key_expand_round_neon (rk + 5, 0x10);
281   aes128_key_expand_round_neon (rk + 6, 0x20);
282   aes128_key_expand_round_neon (rk + 7, 0x40);
283   aes128_key_expand_round_neon (rk + 8, 0x80);
284   aes128_key_expand_round_neon (rk + 9, 0x1b);
285   aes128_key_expand_round_neon (rk + 10, 0x36);
286 }
287
288 static_always_inline void
289 aes192_key_expand_round_neon (u8x8 * rk, u32 rcon)
290 {
291   u8x8 r, last_round = rk[-1], z = { };
292   u8x16 r2, z2 = { };
293
294   r2 = (u8x16) vdupq_lane_u64 ((uint64x1_t) last_round, 0);
295   r2 = vqtbl1q_u8 (r2, aese_prep_mask1);
296   r2 = vaeseq_u8 (r2, z2);
297   r2 ^= (u8x16) vdupq_n_u32 (rcon);
298
299   r = (u8x8) vdup_laneq_u64 ((u64x2) r2, 0);
300   r ^= rk[-3];
301   r ^= vext_u8 (z, rk[-3], 4);
302   rk[0] = r;
303
304   r = rk[-2] ^ vext_u8 (r, z, 4);
305   r ^= vext_u8 (z, r, 4);
306   rk[1] = r;
307
308   if (rcon == 0x80)
309     return;
310
311   r = rk[-1] ^ vext_u8 (r, z, 4);
312   r ^= vext_u8 (z, r, 4);
313   rk[2] = r;
314 }
315
316 static_always_inline void
317 aes192_key_expand (u8x16 * ek, const u8x16u * k)
318 {
319   u8x8 *rk = (u8x8 *) ek;
320   ek[0] = k[0];
321   rk[2] = *(u8x8u *) (k + 1);
322   aes192_key_expand_round_neon (rk + 3, 0x01);
323   aes192_key_expand_round_neon (rk + 6, 0x02);
324   aes192_key_expand_round_neon (rk + 9, 0x04);
325   aes192_key_expand_round_neon (rk + 12, 0x08);
326   aes192_key_expand_round_neon (rk + 15, 0x10);
327   aes192_key_expand_round_neon (rk + 18, 0x20);
328   aes192_key_expand_round_neon (rk + 21, 0x40);
329   aes192_key_expand_round_neon (rk + 24, 0x80);
330 }
331
332
333 static_always_inline void
334 aes256_key_expand_round_neon (u8x16 * rk, u32 rcon)
335 {
336   u8x16 r, t, z = { };
337
338   r = vqtbl1q_u8 (rk[-1], rcon ? aese_prep_mask1 : aese_prep_mask2);
339   r = vaeseq_u8 (r, z);
340   if (rcon)
341     r ^= (u8x16) vdupq_n_u32 (rcon);
342   r ^= rk[-2];
343   r ^= t = vextq_u8 (z, rk[-2], 12);
344   r ^= t = vextq_u8 (z, t, 12);
345   r ^= vextq_u8 (z, t, 12);
346   rk[0] = r;
347 }
348
349 static_always_inline void
350 aes256_key_expand (u8x16 * rk, u8x16 const *k)
351 {
352   rk[0] = k[0];
353   rk[1] = k[1];
354   aes256_key_expand_round_neon (rk + 2, 0x01);
355   aes256_key_expand_round_neon (rk + 3, 0);
356   aes256_key_expand_round_neon (rk + 4, 0x02);
357   aes256_key_expand_round_neon (rk + 5, 0);
358   aes256_key_expand_round_neon (rk + 6, 0x04);
359   aes256_key_expand_round_neon (rk + 7, 0);
360   aes256_key_expand_round_neon (rk + 8, 0x08);
361   aes256_key_expand_round_neon (rk + 9, 0);
362   aes256_key_expand_round_neon (rk + 10, 0x10);
363   aes256_key_expand_round_neon (rk + 11, 0);
364   aes256_key_expand_round_neon (rk + 12, 0x20);
365   aes256_key_expand_round_neon (rk + 13, 0);
366   aes256_key_expand_round_neon (rk + 14, 0x40);
367 }
368
369 #endif
370
371 static_always_inline void
372 aes_key_expand (u8x16 * key_schedule, u8 const *key, aes_key_size_t ks)
373 {
374   switch (ks)
375     {
376     case AES_KEY_128:
377       aes128_key_expand (key_schedule, (u8x16u const *) key);
378       break;
379     case AES_KEY_192:
380       aes192_key_expand (key_schedule, (u8x16u const *) key);
381       break;
382     case AES_KEY_256:
383       aes256_key_expand (key_schedule, (u8x16u const *) key);
384       break;
385     }
386 }
387
388 static_always_inline void
389 aes_key_enc_to_dec (u8x16 * ke, u8x16 * kd, aes_key_size_t ks)
390 {
391   int rounds = AES_KEY_ROUNDS (ks);
392
393   kd[rounds] = ke[0];
394   kd[0] = ke[rounds];
395
396   for (int i = 1; i < (rounds / 2); i++)
397     {
398       kd[rounds - i] = aes_inv_mix_column (ke[i]);
399       kd[i] = aes_inv_mix_column (ke[rounds - i]);
400     }
401
402   kd[rounds / 2] = aes_inv_mix_column (ke[rounds / 2]);
403 }
404
405 #endif /* __aesni_h__ */
406
407 /*
408  * fd.io coding-style-patch-verification: ON
409  *
410  * Local Variables:
411  * eval: (c-set-style "gnu")
412  * End:
413  */