Initial commit of vpp code.
[vpp.git] / vnet / vnet / ipsec / esp_decrypt.c
1 /*
2  * esp_decrypt.c : IPSec ESP decrypt node
3  *
4  * Copyright (c) 2015 Cisco and/or its affiliates.
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at:
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include <vnet/vnet.h>
19 #include <vnet/api_errno.h>
20 #include <vnet/ip/ip.h>
21
22 #include <vnet/ipsec/ipsec.h>
23 #include <vnet/ipsec/esp.h>
24
25 #define ESP_WINDOW_SIZE 64
26
27 #define foreach_esp_decrypt_next                \
28 _(DROP, "error-drop")                           \
29 _(IP4_INPUT, "ip4-input")                       \
30 _(IP6_INPUT, "ip6-input")
31
32 #define _(v, s) ESP_DECRYPT_NEXT_##v,
33 typedef enum {
34   foreach_esp_decrypt_next
35 #undef _
36   ESP_DECRYPT_N_NEXT,
37 } esp_decrypt_next_t;
38
39
40 #define foreach_esp_decrypt_error                   \
41  _(RX_PKTS, "ESP pkts received")                    \
42  _(NO_BUFFER, "No buffer (packed dropped)")         \
43  _(DECRYPTION_FAILED, "ESP decryption failed")      \
44  _(INTEG_ERROR, "Integrity check failed")           \
45  _(REPLAY, "SA replayed packet")
46
47
48 typedef enum {
49 #define _(sym,str) ESP_DECRYPT_ERROR_##sym,
50   foreach_esp_decrypt_error
51 #undef _
52   ESP_DECRYPT_N_ERROR,
53 } esp_decrypt_error_t;
54
55 static char * esp_decrypt_error_strings[] = {
56 #define _(sym,string) string,
57   foreach_esp_decrypt_error
58 #undef _
59 };
60
61 typedef struct {
62   ipsec_crypto_alg_t crypto_alg;
63   ipsec_integ_alg_t integ_alg;
64 } esp_decrypt_trace_t;
65
66 /* packet trace format function */
67 static u8 * format_esp_decrypt_trace (u8 * s, va_list * args)
68 {
69   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
70   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
71   esp_decrypt_trace_t * t = va_arg (*args, esp_decrypt_trace_t *);
72
73   s = format (s, "esp: crypto %U integrity %U",
74               format_ipsec_crypto_alg, t->crypto_alg,
75               format_ipsec_integ_alg, t->integ_alg);
76   return s;
77 }
78
79 always_inline void
80 esp_decrypt_aes_cbc(ipsec_crypto_alg_t alg,
81                     u8 * in,
82                     u8 * out,
83                     size_t in_len,
84                     u8 * key,
85                     u8 * iv)
86 {
87   esp_main_t * em = &esp_main;
88   EVP_CIPHER_CTX * ctx = &(em->decrypt_ctx);
89   const EVP_CIPHER * cipher = NULL;
90   int out_len;
91
92   ASSERT(alg < IPSEC_CRYPTO_N_ALG);
93
94   if (PREDICT_FALSE(em->esp_crypto_algs[alg].type == 0))
95     return;
96
97   if (PREDICT_FALSE(alg != em->last_decrytp_alg)) {
98     cipher = em->esp_crypto_algs[alg].type;
99     em->last_decrytp_alg = alg;
100   }
101
102   EVP_DecryptInit_ex(ctx, cipher, NULL, key, iv);
103
104   EVP_DecryptUpdate(ctx, out, &out_len, in, in_len);
105   EVP_DecryptFinal_ex(ctx, out + out_len, &out_len);
106 }
107
108 always_inline int
109 esp_replay_check (ipsec_sa_t * sa, u32 seq)
110 {
111   u32 diff;
112
113   if (PREDICT_TRUE(seq > sa->last_seq))
114     return 0;
115
116   diff = sa->last_seq - seq;
117
118   if (ESP_WINDOW_SIZE > diff)
119     return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
120   else
121     return 1;
122
123   return 0;
124 }
125
126 always_inline int
127 esp_replay_check_esn (ipsec_sa_t * sa, u32 seq)
128 {
129   u32 tl = sa->last_seq;
130   u32 th = sa->last_seq_hi;
131   u32 diff = tl - seq;
132
133   if (PREDICT_TRUE(tl >= (ESP_WINDOW_SIZE - 1)))
134     {
135       if (seq >= (tl - ESP_WINDOW_SIZE + 1))
136         {
137           sa->seq_hi = th;
138           if (seq <= tl)
139             return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
140           else
141             return 0;
142         }
143       else
144         {
145           sa->seq_hi = th + 1;
146           return 0;
147         }
148     }
149   else
150     {
151       if (seq >= (tl - ESP_WINDOW_SIZE + 1))
152         {
153           sa->seq_hi = th - 1;
154           return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
155         }
156       else
157         {
158           sa->seq_hi = th;
159           if (seq <= tl)
160             return (sa->replay_window & (1ULL << diff)) ? 1 : 0;
161           else
162             return 0;
163         }
164     }
165
166   return 0;
167 }
168
169 always_inline void
170 esp_replay_advance (ipsec_sa_t * sa, u32 seq)
171 {
172   u32 pos;
173
174   if (seq > sa->last_seq)
175     {
176       pos = seq - sa->last_seq;
177       if (pos < ESP_WINDOW_SIZE)
178         sa->replay_window = ((sa->replay_window) << pos) | 1;
179       else
180         sa->replay_window = 1;
181       sa->last_seq = seq;
182     }
183   else
184     {
185       pos = sa->last_seq - seq;
186       sa->replay_window |= (1ULL << pos);
187     }
188 }
189
190 always_inline void
191 esp_replay_advance_esn (ipsec_sa_t * sa, u32 seq)
192 {
193   int wrap = sa->seq_hi - sa->last_seq_hi;
194   u32 pos;
195
196   if (wrap == 0 && seq > sa->last_seq)
197     {
198       pos = seq - sa->last_seq;
199       if (pos < ESP_WINDOW_SIZE)
200         sa->replay_window = ((sa->replay_window) << pos) | 1;
201       else
202         sa->replay_window = 1;
203       sa->last_seq = seq;
204     }
205   else if (wrap > 0)
206     {
207       pos = ~seq + sa->last_seq + 1;
208       if (pos < ESP_WINDOW_SIZE)
209         sa->replay_window = ((sa->replay_window) << pos) | 1;
210       else
211         sa->replay_window = 1;
212       sa->last_seq = seq;
213       sa->last_seq_hi = sa->seq_hi;
214     }
215   else if (wrap < 0)
216     {
217       pos = ~seq + sa->last_seq + 1;
218       sa->replay_window |= (1ULL << pos);
219     }
220   else
221     {
222       pos = sa->last_seq - seq;
223       sa->replay_window |= (1ULL << pos);
224     }
225 }
226
227 static uword
228 esp_decrypt_node_fn (vlib_main_t * vm,
229                      vlib_node_runtime_t * node,
230                      vlib_frame_t * from_frame)
231 {
232   u32 n_left_from, *from, next_index, *to_next;
233   ipsec_main_t *im = &ipsec_main;
234   esp_main_t *em = &esp_main;
235   u32 * recycle = 0;
236   from = vlib_frame_vector_args (from_frame);
237   n_left_from = from_frame->n_vectors;
238
239   ipsec_alloc_empty_buffers(vm, im);
240
241   if (PREDICT_FALSE(vec_len (im->empty_buffers) < n_left_from)){
242     vlib_node_increment_counter (vm, esp_decrypt_node.index,
243                                  ESP_DECRYPT_ERROR_NO_BUFFER, n_left_from);
244     goto free_buffers_and_exit;
245   }
246
247   next_index = node->cached_next_index;
248
249   while (n_left_from > 0)
250     {
251       u32 n_left_to_next;
252
253       vlib_get_next_frame (vm, node, next_index, to_next, n_left_to_next);
254
255       while (n_left_from > 0 && n_left_to_next > 0)
256         {
257           u32 i_bi0, o_bi0 = (u32) ~0, next0;
258           vlib_buffer_t * i_b0;
259           vlib_buffer_t * o_b0 = 0;
260           esp_header_t * esp0;
261           ipsec_sa_t * sa0;
262           u32 sa_index0 = ~0;
263           u32 seq;
264
265           i_bi0 = from[0];
266           from += 1;
267           n_left_from -= 1;
268           n_left_to_next -= 1;
269
270           next0 = ESP_DECRYPT_NEXT_DROP;
271
272           i_b0 = vlib_get_buffer (vm, i_bi0);
273           esp0 = vlib_buffer_get_current (i_b0);
274
275           sa_index0 = vnet_buffer(i_b0)->output_features.ipsec_sad_index;
276           sa0 = pool_elt_at_index (im->sad, sa_index0);
277
278           seq = clib_host_to_net_u32(esp0->seq);
279
280           /* anti-replay check */
281           if (sa0->use_anti_replay)
282             {
283               int rv = 0;
284
285               if (PREDICT_TRUE(sa0->use_esn))
286                 rv = esp_replay_check_esn(sa0, seq);
287               else
288                 rv = esp_replay_check(sa0, seq);
289
290               if (PREDICT_FALSE(rv))
291                 {
292                   clib_warning("anti-replay SPI %u seq %u", sa0->spi, seq);
293                   vlib_node_increment_counter (vm, esp_decrypt_node.index,
294                                                ESP_DECRYPT_ERROR_REPLAY, 1);
295                   o_bi0 = i_bi0;
296                   goto trace;
297                 }
298             }
299
300           if (PREDICT_TRUE(sa0->integ_alg != IPSEC_INTEG_ALG_NONE))
301             {
302               u8 sig[64];
303               int icv_size = em->esp_integ_algs[sa0->integ_alg].trunc_size;
304               memset(sig, 0, sizeof(sig));
305               u8 * icv = vlib_buffer_get_current (i_b0) + i_b0->current_length - icv_size;
306               i_b0->current_length -= icv_size;
307
308               hmac_calc(sa0->integ_alg, sa0->integ_key, sa0->integ_key_len,
309                         (u8 *) esp0, i_b0->current_length, sig, sa0->use_esn,
310                         sa0->seq_hi);
311
312               if (PREDICT_FALSE(memcmp(icv, sig, icv_size)))
313                 {
314                   vlib_node_increment_counter (vm, esp_decrypt_node.index,
315                                                ESP_DECRYPT_ERROR_INTEG_ERROR, 1);
316                   o_bi0 = i_bi0;
317                   goto trace;
318                 }
319             }
320
321           if (PREDICT_TRUE(sa0->use_anti_replay))
322             {
323               if (PREDICT_TRUE(sa0->use_esn))
324                 esp_replay_advance_esn(sa0, seq);
325               else
326                 esp_replay_advance(sa0, seq);
327              }
328
329           /* grab free buffer */
330           uword last_empty_buffer = vec_len (im->empty_buffers) - 1;
331           o_bi0 = im->empty_buffers[last_empty_buffer];
332           o_b0 = vlib_get_buffer (vm, o_bi0);
333           vlib_prefetch_buffer_with_index (vm, im->empty_buffers[last_empty_buffer-1], STORE);
334           _vec_len (im->empty_buffers) = last_empty_buffer;
335
336           /* add old buffer to the recycle list */
337           vec_add1(recycle, i_bi0);
338
339           if (sa0->crypto_alg >= IPSEC_CRYPTO_ALG_AES_CBC_128 &&
340               sa0->crypto_alg <= IPSEC_CRYPTO_ALG_AES_CBC_256) {
341             const int BLOCK_SIZE = 16;
342             const int IV_SIZE = 16;
343             esp_footer_t * f0;
344
345             int blocks = (i_b0->current_length - sizeof (esp_header_t) - IV_SIZE) / BLOCK_SIZE;
346
347             o_b0->current_data = sizeof(ethernet_header_t);
348
349             esp_decrypt_aes_cbc(sa0->crypto_alg,
350                                 esp0->data + IV_SIZE,
351                                 (u8 *) vlib_buffer_get_current (o_b0),
352                                 BLOCK_SIZE * blocks,
353                                 sa0->crypto_key,
354                                 esp0->data);
355
356             o_b0->current_length = (blocks * 16) - 2;
357             o_b0->flags = VLIB_BUFFER_TOTAL_LENGTH_VALID;
358             f0 = (esp_footer_t *) ((u8 *) vlib_buffer_get_current (o_b0) + o_b0->current_length);
359             o_b0->current_length -= f0->pad_length;
360             if (PREDICT_TRUE(f0->next_header == IP_PROTOCOL_IP_IN_IP))
361               next0 = ESP_DECRYPT_NEXT_IP4_INPUT;
362             else if (f0->next_header == IP_PROTOCOL_IPV6)
363               next0 = ESP_DECRYPT_NEXT_IP6_INPUT;
364             else
365               {
366                 clib_warning("next header: 0x%x", f0->next_header);
367                 vlib_node_increment_counter (vm, esp_decrypt_node.index,
368                                              ESP_DECRYPT_ERROR_DECRYPTION_FAILED,
369                                              1);
370                 o_b0 = 0;
371                 goto trace;
372               }
373
374             to_next[0] = o_bi0;
375             to_next += 1;
376
377             vnet_buffer (o_b0)->sw_if_index[VLIB_TX] = (u32)~0;
378           }
379
380 trace:
381           if (PREDICT_FALSE(i_b0->flags & VLIB_BUFFER_IS_TRACED)) {
382             if (o_b0) {
383               o_b0->flags |= VLIB_BUFFER_IS_TRACED;
384               o_b0->trace_index = i_b0->trace_index;
385             }
386             esp_decrypt_trace_t *tr = vlib_add_trace (vm, node, o_b0, sizeof (*tr));
387             tr->crypto_alg = sa0->crypto_alg;
388             tr->integ_alg = sa0->integ_alg;
389           }
390
391           vlib_validate_buffer_enqueue_x1 (vm, node, next_index, to_next,
392                                            n_left_to_next, o_bi0, next0);
393         }
394       vlib_put_next_frame (vm, node, next_index, n_left_to_next);
395     }
396   vlib_node_increment_counter (vm, esp_decrypt_node.index,
397                                ESP_DECRYPT_ERROR_RX_PKTS,
398                                from_frame->n_vectors);
399
400 free_buffers_and_exit:
401   vlib_buffer_free (vm, recycle, vec_len(recycle));
402   vec_free(recycle);
403   return from_frame->n_vectors;
404 }
405
406
407 VLIB_REGISTER_NODE (esp_decrypt_node) = {
408   .function = esp_decrypt_node_fn,
409   .name = "esp-decrypt",
410   .vector_size = sizeof (u32),
411   .format_trace = format_esp_decrypt_trace,
412   .type = VLIB_NODE_TYPE_INTERNAL,
413
414   .n_errors = ARRAY_LEN(esp_decrypt_error_strings),
415   .error_strings = esp_decrypt_error_strings,
416
417   .n_next_nodes = ESP_DECRYPT_N_NEXT,
418   .next_nodes = {
419 #define _(s,n) [ESP_DECRYPT_NEXT_##s] = n,
420     foreach_esp_decrypt_next
421 #undef _
422   },
423 };
424