8d94f20e507c2acefec1c2eac6542955318d3ea1
[vpp.git] / vnet / vnet / ipsec / esp_decrypt.c
1 /*
2  * esp_decrypt.c : IPSec ESP decrypt node
3  *
4  * Copyright (c) 2015 Cisco and/or its affiliates.
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at:
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include <vnet/vnet.h>
19 #include <vnet/api_errno.h>
20 #include <vnet/ip/ip.h>
21
22 #include <vnet/ipsec/ipsec.h>
23 #include <vnet/ipsec/esp.h>
24
25 #define ESP_WINDOW_SIZE 64
26
27 #define foreach_esp_decrypt_next                \
28 _(DROP, "error-drop")                           \
29 _(IP4_INPUT, "ip4-input")                       \
30 _(IP6_INPUT, "ip6-input")
31
32 #define _(v, s) ESP_DECRYPT_NEXT_##v,
33 typedef enum {
34   foreach_esp_decrypt_next
35 #undef _
36   ESP_DECRYPT_N_NEXT,
37 } esp_decrypt_next_t;
38
39
40 #define foreach_esp_decrypt_error                   \
41  _(RX_PKTS, "ESP pkts received")                    \
42  _(NO_BUFFER, "No buffer (packed dropped)")         \
43  _(DECRYPTION_FAILED, "ESP decryption failed")      \
44  _(INTEG_ERROR, "Integrity check failed")           \
45  _(REPLAY, "SA replayed packet")
46
47
48 typedef enum {
49 #define _(sym,str) ESP_DECRYPT_ERROR_##sym,
50   foreach_esp_decrypt_error
51 #undef _
52   ESP_DECRYPT_N_ERROR,
53 } esp_decrypt_error_t;
54
55 static char * esp_decrypt_error_strings[] = {
56 #define _(sym,string) string,
57   foreach_esp_decrypt_error
58 #undef _
59 };
60
61 typedef struct {
62   ipsec_crypto_alg_t crypto_alg;
63   ipsec_integ_alg_t integ_alg;
64 } esp_decrypt_trace_t;
65
66 /* packet trace format function */
67 static u8 * format_esp_decrypt_trace (u8 * s, va_list * args)
68 {
69   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
70   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
71   esp_decrypt_trace_t * t = va_arg (*args, esp_decrypt_trace_t *);
72
73   s = format (s, "esp: crypto %U integrity %U",
74               format_ipsec_crypto_alg, t->crypto_alg,
75               format_ipsec_integ_alg, t->integ_alg);
76   return s;
77 }
78
79 always_inline void
80 esp_decrypt_aes_cbc(ipsec_crypto_alg_t alg,
81                     u8 * in,
82                     u8 * out,
83                     size_t in_len,
84                     u8 * key,
85                     u8 * iv)
86 {
87   esp_main_t * em = &esp_main;
88   u32 cpu_index = os_get_cpu_number();
89   EVP_CIPHER_CTX * ctx = &(em->per_thread_data[cpu_index].decrypt_ctx);
90   const EVP_CIPHER * cipher = NULL;
91   int out_len;
92
93   ASSERT(alg < IPSEC_CRYPTO_N_ALG);
94
95   if (PREDICT_FALSE(em->esp_crypto_algs[alg].type == 0))
96     return;
97
98   if (PREDICT_FALSE(alg != em->per_thread_data[cpu_index].last_decrypt_alg)) {
99     cipher = em->esp_crypto_algs[alg].type;
100     em->per_thread_data[cpu_index].last_decrypt_alg = alg;
101   }
102
103   EVP_DecryptInit_ex(ctx, cipher, NULL, key, iv);
104
105   EVP_DecryptUpdate(ctx, out, &out_len, in, in_len);
106   EVP_DecryptFinal_ex(ctx, out + out_len, &out_len);
107 }
108
109 always_inline int
110 esp_replay_check (ipsec_sa_t * sa, u32 seq)
111 {
112   u32 diff;
113
114   if (PREDICT_TRUE(seq > sa->last_seq))
115     return 0;
116
117   diff = sa->last_seq - seq;
118
119   if (ESP_WINDOW_SIZE > diff)
120     return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
121   else
122     return 1;
123
124   return 0;
125 }
126
127 always_inline int
128 esp_replay_check_esn (ipsec_sa_t * sa, u32 seq)
129 {
130   u32 tl = sa->last_seq;
131   u32 th = sa->last_seq_hi;
132   u32 diff = tl - seq;
133
134   if (PREDICT_TRUE(tl >= (ESP_WINDOW_SIZE - 1)))
135     {
136       if (seq >= (tl - ESP_WINDOW_SIZE + 1))
137         {
138           sa->seq_hi = th;
139           if (seq <= tl)
140             return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
141           else
142             return 0;
143         }
144       else
145         {
146           sa->seq_hi = th + 1;
147           return 0;
148         }
149     }
150   else
151     {
152       if (seq >= (tl - ESP_WINDOW_SIZE + 1))
153         {
154           sa->seq_hi = th - 1;
155           return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
156         }
157       else
158         {
159           sa->seq_hi = th;
160           if (seq <= tl)
161             return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
162           else
163             return 0;
164         }
165     }
166
167   return 0;
168 }
169
170 always_inline void
171 esp_replay_advance (ipsec_sa_t * sa, u32 seq)
172 {
173   u32 pos;
174
175   if (seq > sa->last_seq)
176     {
177       pos = seq - sa->last_seq;
178       if (pos < ESP_WINDOW_SIZE)
179         sa->replay_window = ((sa->replay_window) << pos) | 1;
180       else
181         sa->replay_window = 1;
182       sa->last_seq = seq;
183     }
184   else
185     {
186       pos = sa->last_seq - seq;
187       sa->replay_window |= (1ULL << pos);
188     }
189 }
190
191 always_inline void
192 esp_replay_advance_esn (ipsec_sa_t * sa, u32 seq)
193 {
194   int wrap = sa->seq_hi - sa->last_seq_hi;
195   u32 pos;
196
197   if (wrap == 0 && seq > sa->last_seq)
198     {
199       pos = seq - sa->last_seq;
200       if (pos < ESP_WINDOW_SIZE)
201         sa->replay_window = ((sa->replay_window) << pos) | 1;
202       else
203         sa->replay_window = 1;
204       sa->last_seq = seq;
205     }
206   else if (wrap > 0)
207     {
208       pos = ~seq + sa->last_seq + 1;
209       if (pos < ESP_WINDOW_SIZE)
210         sa->replay_window = ((sa->replay_window) << pos) | 1;
211       else
212         sa->replay_window = 1;
213       sa->last_seq = seq;
214       sa->last_seq_hi = sa->seq_hi;
215     }
216   else if (wrap < 0)
217     {
218       pos = ~seq + sa->last_seq + 1;
219       sa->replay_window |= (1ULL << pos);
220     }
221   else
222     {
223       pos = sa->last_seq - seq;
224       sa->replay_window |= (1ULL << pos);
225     }
226 }
227
228 static uword
229 esp_decrypt_node_fn (vlib_main_t * vm,
230                      vlib_node_runtime_t * node,
231                      vlib_frame_t * from_frame)
232 {
233   u32 n_left_from, *from, next_index, *to_next;
234   ipsec_main_t *im = &ipsec_main;
235   esp_main_t *em = &esp_main;
236   u32 * recycle = 0;
237   from = vlib_frame_vector_args (from_frame);
238   n_left_from = from_frame->n_vectors;
239   u32 cpu_index = os_get_cpu_number();
240
241   ipsec_alloc_empty_buffers(vm, im);
242
243   u32 * empty_buffers = im->empty_buffers[cpu_index];
244
245   if (PREDICT_FALSE(vec_len (empty_buffers) < n_left_from)){
246     vlib_node_increment_counter (vm, esp_decrypt_node.index,
247                                  ESP_DECRYPT_ERROR_NO_BUFFER, n_left_from);
248     goto free_buffers_and_exit;
249   }
250
251   next_index = node->cached_next_index;
252
253   while (n_left_from > 0)
254     {
255       u32 n_left_to_next;
256
257       vlib_get_next_frame (vm, node, next_index, to_next, n_left_to_next);
258
259       while (n_left_from > 0 && n_left_to_next > 0)
260         {
261           u32 i_bi0, o_bi0 = (u32) ~0, next0;
262           vlib_buffer_t * i_b0;
263           vlib_buffer_t * o_b0 = 0;
264           esp_header_t * esp0;
265           ipsec_sa_t * sa0;
266           u32 sa_index0 = ~0;
267           u32 seq;
268
269           i_bi0 = from[0];
270           from += 1;
271           n_left_from -= 1;
272           n_left_to_next -= 1;
273
274           next0 = ESP_DECRYPT_NEXT_DROP;
275
276           i_b0 = vlib_get_buffer (vm, i_bi0);
277           esp0 = vlib_buffer_get_current (i_b0);
278
279           sa_index0 = vnet_buffer(i_b0)->output_features.ipsec_sad_index;
280           sa0 = pool_elt_at_index (im->sad, sa_index0);
281
282           seq = clib_host_to_net_u32(esp0->seq);
283
284           /* anti-replay check */
285           if (sa0->use_anti_replay)
286             {
287               int rv = 0;
288
289               if (PREDICT_TRUE(sa0->use_esn))
290                 rv = esp_replay_check_esn(sa0, seq);
291               else
292                 rv = esp_replay_check(sa0, seq);
293
294               if (PREDICT_FALSE(rv))
295                 {
296                   clib_warning("anti-replay SPI %u seq %u", sa0->spi, seq);
297                   vlib_node_increment_counter (vm, esp_decrypt_node.index,
298                                                ESP_DECRYPT_ERROR_REPLAY, 1);
299                   o_bi0 = i_bi0;
300                   goto trace;
301                 }
302             }
303
304           if (PREDICT_TRUE(sa0->integ_alg != IPSEC_INTEG_ALG_NONE))
305             {
306               u8 sig[64];
307               int icv_size = em->esp_integ_algs[sa0->integ_alg].trunc_size;
308               memset(sig, 0, sizeof(sig));
309               u8 * icv = vlib_buffer_get_current (i_b0) + i_b0->current_length - icv_size;
310               i_b0->current_length -= icv_size;
311
312               hmac_calc(sa0->integ_alg, sa0->integ_key, sa0->integ_key_len,
313                         (u8 *) esp0, i_b0->current_length, sig, sa0->use_esn,
314                         sa0->seq_hi);
315
316               if (PREDICT_FALSE(memcmp(icv, sig, icv_size)))
317                 {
318                   vlib_node_increment_counter (vm, esp_decrypt_node.index,
319                                                ESP_DECRYPT_ERROR_INTEG_ERROR, 1);
320                   o_bi0 = i_bi0;
321                   goto trace;
322                 }
323             }
324
325           if (PREDICT_TRUE(sa0->use_anti_replay))
326             {
327               if (PREDICT_TRUE(sa0->use_esn))
328                 esp_replay_advance_esn(sa0, seq);
329               else
330                 esp_replay_advance(sa0, seq);
331              }
332
333           /* grab free buffer */
334           uword last_empty_buffer = vec_len (empty_buffers) - 1;
335           o_bi0 = empty_buffers[last_empty_buffer];
336           o_b0 = vlib_get_buffer (vm, o_bi0);
337           vlib_prefetch_buffer_with_index (vm, empty_buffers[last_empty_buffer-1], STORE);
338           _vec_len (empty_buffers) = last_empty_buffer;
339
340           /* add old buffer to the recycle list */
341           vec_add1(recycle, i_bi0);
342
343           if (sa0->crypto_alg >= IPSEC_CRYPTO_ALG_AES_CBC_128 &&
344               sa0->crypto_alg <= IPSEC_CRYPTO_ALG_AES_CBC_256) {
345             const int BLOCK_SIZE = 16;
346             const int IV_SIZE = 16;
347             esp_footer_t * f0;
348
349             int blocks = (i_b0->current_length - sizeof (esp_header_t) - IV_SIZE) / BLOCK_SIZE;
350
351             o_b0->current_data = sizeof(ethernet_header_t);
352
353             esp_decrypt_aes_cbc(sa0->crypto_alg,
354                                 esp0->data + IV_SIZE,
355                                 (u8 *) vlib_buffer_get_current (o_b0),
356                                 BLOCK_SIZE * blocks,
357                                 sa0->crypto_key,
358                                 esp0->data);
359
360             o_b0->current_length = (blocks * 16) - 2;
361             o_b0->flags = VLIB_BUFFER_TOTAL_LENGTH_VALID;
362             f0 = (esp_footer_t *) ((u8 *) vlib_buffer_get_current (o_b0) + o_b0->current_length);
363             o_b0->current_length -= f0->pad_length;
364             if (PREDICT_TRUE(f0->next_header == IP_PROTOCOL_IP_IN_IP))
365               next0 = ESP_DECRYPT_NEXT_IP4_INPUT;
366             else if (f0->next_header == IP_PROTOCOL_IPV6)
367               next0 = ESP_DECRYPT_NEXT_IP6_INPUT;
368             else
369               {
370                 clib_warning("next header: 0x%x", f0->next_header);
371                 vlib_node_increment_counter (vm, esp_decrypt_node.index,
372                                              ESP_DECRYPT_ERROR_DECRYPTION_FAILED,
373                                              1);
374                 o_b0 = 0;
375                 goto trace;
376               }
377
378             to_next[0] = o_bi0;
379             to_next += 1;
380
381             vnet_buffer (o_b0)->sw_if_index[VLIB_TX] = (u32)~0;
382           }
383
384 trace:
385           if (PREDICT_FALSE(i_b0->flags & VLIB_BUFFER_IS_TRACED)) {
386             if (o_b0) {
387               o_b0->flags |= VLIB_BUFFER_IS_TRACED;
388               o_b0->trace_index = i_b0->trace_index;
389             }
390             esp_decrypt_trace_t *tr = vlib_add_trace (vm, node, o_b0, sizeof (*tr));
391             tr->crypto_alg = sa0->crypto_alg;
392             tr->integ_alg = sa0->integ_alg;
393           }
394
395           vlib_validate_buffer_enqueue_x1 (vm, node, next_index, to_next,
396                                            n_left_to_next, o_bi0, next0);
397         }
398       vlib_put_next_frame (vm, node, next_index, n_left_to_next);
399     }
400   vlib_node_increment_counter (vm, esp_decrypt_node.index,
401                                ESP_DECRYPT_ERROR_RX_PKTS,
402                                from_frame->n_vectors);
403
404 free_buffers_and_exit:
405   vlib_buffer_free (vm, recycle, vec_len(recycle));
406   vec_free(recycle);
407   return from_frame->n_vectors;
408 }
409
410
411 VLIB_REGISTER_NODE (esp_decrypt_node) = {
412   .function = esp_decrypt_node_fn,
413   .name = "esp-decrypt",
414   .vector_size = sizeof (u32),
415   .format_trace = format_esp_decrypt_trace,
416   .type = VLIB_NODE_TYPE_INTERNAL,
417
418   .n_errors = ARRAY_LEN(esp_decrypt_error_strings),
419   .error_strings = esp_decrypt_error_strings,
420
421   .n_next_nodes = ESP_DECRYPT_N_NEXT,
422   .next_nodes = {
423 #define _(s,n) [ESP_DECRYPT_NEXT_##s] = n,
424     foreach_esp_decrypt_next
425 #undef _
426   },
427 };
428
429 VLIB_NODE_FUNCTION_MULTIARCH (esp_decrypt_node, esp_decrypt_node_fn)
430