861dae63b0ab5b952cefb2845718cdaaed449f2c
[vpp.git] / plugins / snat-plugin / snat / out2in.c
1 /*
2  * Copyright (c) 2016 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #include <vlib/vlib.h>
17 #include <vnet/vnet.h>
18 #include <vnet/pg/pg.h>
19
20 #include <vnet/ip/ip.h>
21 #include <vnet/ethernet/ethernet.h>
22 #include <snat/snat.h>
23
24 #include <vppinfra/hash.h>
25 #include <vppinfra/error.h>
26 #include <vppinfra/elog.h>
27
28 vlib_node_registration_t snat_out2in_node;
29
30 typedef struct {
31   u32 sw_if_index;
32   u32 next_index;
33   u32 session_index;
34 } snat_out2in_trace_t;
35
36 /* packet trace format function */
37 static u8 * format_snat_out2in_trace (u8 * s, va_list * args)
38 {
39   CLIB_UNUSED (vlib_main_t * vm) = va_arg (*args, vlib_main_t *);
40   CLIB_UNUSED (vlib_node_t * node) = va_arg (*args, vlib_node_t *);
41   snat_out2in_trace_t * t = va_arg (*args, snat_out2in_trace_t *);
42   
43   s = format (s, "SNAT_OUT2IN: sw_if_index %d, next index %d, session index %d",
44               t->sw_if_index, t->next_index, t->session_index);
45   return s;
46 }
47
48 vlib_node_registration_t snat_out2in_node;
49
50 #define foreach_snat_out2in_error                       \
51 _(UNSUPPORTED_PROTOCOL, "Unsupported protocol")         \
52 _(OUT2IN_PACKETS, "Good out2in packets processed")      \
53 _(BAD_ICMP_TYPE, "icmp type not echo-reply")            \
54 _(NO_TRANSLATION, "No translation")
55   
56 typedef enum {
57 #define _(sym,str) SNAT_OUT2IN_ERROR_##sym,
58   foreach_snat_out2in_error
59 #undef _
60   SNAT_OUT2IN_N_ERROR,
61 } snat_out2in_error_t;
62
63 static char * snat_out2in_error_strings[] = {
64 #define _(sym,string) string,
65   foreach_snat_out2in_error
66 #undef _
67 };
68
69 typedef enum {
70   SNAT_OUT2IN_NEXT_DROP,
71   SNAT_OUT2IN_N_NEXT,
72 } snat_out2in_next_t;
73
74 static inline u32 icmp_out2in_slow_path (snat_main_t *sm,
75                                          vlib_buffer_t * b0,
76                                          ip4_header_t * ip0,
77                                          icmp46_header_t * icmp0,
78                                          u32 sw_if_index0,
79                                          u32 rx_fib_index0,
80                                          vlib_node_runtime_t * node,
81                                          u32 next0, f64 now)
82 {
83   snat_session_key_t key0;
84   icmp_echo_header_t *echo0;
85   clib_bihash_kv_8_8_t kv0, value0;
86   snat_session_t * s0;
87   u32 new_addr0, old_addr0;
88   u16 old_id0, new_id0;
89   ip_csum_t sum0;
90   snat_runtime_t * rt = (snat_runtime_t *)node->runtime_data;
91
92   echo0 = (icmp_echo_header_t *)(icmp0+1);
93
94   key0.addr = ip0->dst_address;
95   key0.port = echo0->identifier;
96   key0.protocol = SNAT_PROTOCOL_ICMP;
97   key0.fib_index = rx_fib_index0;
98   
99   kv0.key = key0.as_u64;
100   
101   if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
102     {
103       ip4_address_t * first_int_addr;
104
105       if (PREDICT_FALSE(rt->cached_sw_if_index != sw_if_index0))
106         {
107           first_int_addr = 
108             ip4_interface_first_address (sm->ip4_main, sw_if_index0,
109                                          0 /* just want the address */);
110           rt->cached_sw_if_index = sw_if_index0;
111           rt->cached_ip4_address = first_int_addr->as_u32;
112         }
113       
114       /* Don't NAT packet aimed at the intfc address */
115       if (PREDICT_FALSE(ip0->dst_address.as_u32 ==
116                         rt->cached_ip4_address))
117         return next0;
118
119       b0->error = node->errors[SNAT_OUT2IN_ERROR_NO_TRANSLATION];
120       return SNAT_OUT2IN_NEXT_DROP;
121     }
122   else
123     s0 = pool_elt_at_index (sm->sessions, value0.value);
124
125   old_addr0 = ip0->dst_address.as_u32;
126   ip0->dst_address = s0->in2out.addr;
127   new_addr0 = ip0->dst_address.as_u32;
128   vnet_buffer(b0)->sw_if_index[VLIB_TX] = s0->in2out.fib_index;
129   
130   sum0 = ip0->checksum;
131   sum0 = ip_csum_update (sum0, old_addr0, new_addr0,
132                          ip4_header_t,
133                          dst_address /* changed member */);
134   ip0->checksum = ip_csum_fold (sum0);
135   
136   old_id0 = echo0->identifier;
137   new_id0 = s0->in2out.port;
138   echo0->identifier = new_id0;
139
140   sum0 = icmp0->checksum;
141   sum0 = ip_csum_update (sum0, old_id0, new_id0, icmp_echo_header_t,
142                          identifier);
143   icmp0->checksum = ip_csum_fold (sum0);
144
145   /* Accounting, per-user LRU list maintenance */
146   s0->last_heard = now;
147   s0->total_pkts++;
148   s0->total_bytes += vlib_buffer_length_in_chain (sm->vlib_main, b0);
149   clib_dlist_remove (sm->list_pool, s0->per_user_index);
150   clib_dlist_addtail (sm->list_pool, s0->per_user_list_head_index,
151                       s0->per_user_index);
152
153   return next0;
154 }
155
156 static uword
157 snat_out2in_node_fn (vlib_main_t * vm,
158                   vlib_node_runtime_t * node,
159                   vlib_frame_t * frame)
160 {
161   u32 n_left_from, * from, * to_next;
162   snat_out2in_next_t next_index;
163   u32 pkts_processed = 0;
164   snat_main_t * sm = &snat_main;
165   ip_lookup_main_t * lm = sm->ip4_lookup_main;
166   ip_config_main_t * cm = &lm->feature_config_mains[VNET_IP_RX_UNICAST_FEAT];
167   f64 now = vlib_time_now (vm);
168
169   from = vlib_frame_vector_args (frame);
170   n_left_from = frame->n_vectors;
171   next_index = node->cached_next_index;
172
173   while (n_left_from > 0)
174     {
175       u32 n_left_to_next;
176
177       vlib_get_next_frame (vm, node, next_index,
178                            to_next, n_left_to_next);
179
180       while (n_left_from >= 4 && n_left_to_next >= 2)
181         {
182           u32 bi0, bi1;
183           vlib_buffer_t * b0, * b1;
184           u32 next0 = SNAT_OUT2IN_NEXT_DROP;
185           u32 next1 = SNAT_OUT2IN_NEXT_DROP;
186           u32 sw_if_index0, sw_if_index1;
187           ip4_header_t * ip0, *ip1;
188           ip_csum_t sum0, sum1;
189           u32 new_addr0, old_addr0;
190           u16 new_port0, old_port0;
191           u32 new_addr1, old_addr1;
192           u16 new_port1, old_port1;
193           udp_header_t * udp0, * udp1;
194           tcp_header_t * tcp0, * tcp1;
195           icmp46_header_t * icmp0, * icmp1;
196           snat_session_key_t key0, key1;
197           u32 rx_fib_index0, rx_fib_index1;
198           u32 proto0, proto1;
199           snat_session_t * s0 = 0, * s1 = 0;
200           clib_bihash_kv_8_8_t kv0, kv1, value0, value1;
201           
202           /* Prefetch next iteration. */
203           {
204             vlib_buffer_t * p2, * p3;
205             
206             p2 = vlib_get_buffer (vm, from[2]);
207             p3 = vlib_get_buffer (vm, from[3]);
208             
209             vlib_prefetch_buffer_header (p2, LOAD);
210             vlib_prefetch_buffer_header (p3, LOAD);
211
212             CLIB_PREFETCH (p2->data, CLIB_CACHE_LINE_BYTES, STORE);
213             CLIB_PREFETCH (p3->data, CLIB_CACHE_LINE_BYTES, STORE);
214           }
215
216           /* speculatively enqueue b0 and b1 to the current next frame */
217           to_next[0] = bi0 = from[0];
218           to_next[1] = bi1 = from[1];
219           from += 2;
220           to_next += 2;
221           n_left_from -= 2;
222           n_left_to_next -= 2;
223
224           b0 = vlib_get_buffer (vm, bi0);
225           b1 = vlib_get_buffer (vm, bi1);
226             
227           ip0 = vlib_buffer_get_current (b0);
228           udp0 = ip4_next_header (ip0);
229           tcp0 = (tcp_header_t *) udp0;
230           icmp0 = (icmp46_header_t *) udp0;
231
232           sw_if_index0 = vnet_buffer(b0)->sw_if_index[VLIB_RX];
233           rx_fib_index0 = vec_elt (sm->ip4_main->fib_index_by_sw_if_index, 
234                                    sw_if_index0);
235
236           vnet_get_config_data (&cm->config_main,
237                                 &b0->current_config_index,
238                                 &next0,
239                                 0 /* sizeof config data */);
240           proto0 = ~0;
241           proto0 = (ip0->protocol == IP_PROTOCOL_UDP) 
242             ? SNAT_PROTOCOL_UDP : proto0;
243           proto0 = (ip0->protocol == IP_PROTOCOL_TCP) 
244             ? SNAT_PROTOCOL_TCP : proto0;
245           proto0 = (ip0->protocol == IP_PROTOCOL_ICMP) 
246             ? SNAT_PROTOCOL_ICMP : proto0;
247
248           if (PREDICT_FALSE (proto0 == ~0))
249               goto trace0;
250
251           if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
252             {
253               next0 = icmp_out2in_slow_path 
254                 (sm, b0, ip0, icmp0, sw_if_index0, rx_fib_index0, node, 
255                  next0, now);
256               goto trace0;
257             }
258
259           key0.addr = ip0->dst_address;
260           key0.port = udp0->dst_port;
261           key0.protocol = proto0;
262           key0.fib_index = rx_fib_index0;
263           
264           kv0.key = key0.as_u64;
265
266           if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
267             goto trace0;
268           else
269             s0 = pool_elt_at_index (sm->sessions, value0.value);
270
271           old_addr0 = ip0->dst_address.as_u32;
272           ip0->dst_address = s0->in2out.addr;
273           new_addr0 = ip0->dst_address.as_u32;
274           vnet_buffer(b0)->sw_if_index[VLIB_TX] = s0->out2in.fib_index;
275
276           sum0 = ip0->checksum;
277           sum0 = ip_csum_update (sum0, old_addr0, new_addr0,
278                                  ip4_header_t,
279                                  dst_address /* changed member */);
280           ip0->checksum = ip_csum_fold (sum0);
281
282           if (PREDICT_TRUE(proto0 == SNAT_PROTOCOL_TCP))
283             {
284               old_port0 = tcp0->ports.dst;
285               tcp0->ports.dst = s0->in2out.port;
286               new_port0 = tcp0->ports.dst;
287
288               sum0 = tcp0->checksum;
289               sum0 = ip_csum_update (sum0, old_addr0, new_addr0,
290                                      ip4_header_t,
291                                      dst_address /* changed member */);
292
293               sum0 = ip_csum_update (sum0, old_port0, new_port0,
294                                      ip4_header_t /* cheat */,
295                                      length /* changed member */);
296               tcp0->checksum = ip_csum_fold(sum0);
297             }
298           else
299             {
300               old_port0 = udp0->dst_port;
301               udp0->dst_port = s0->in2out.port;
302               udp0->checksum = 0;
303             }
304
305           /* Accounting, per-user LRU list maintenance */
306           s0->last_heard = now;
307           s0->total_pkts++;
308           s0->total_bytes += vlib_buffer_length_in_chain (vm, b0);
309           clib_dlist_remove (sm->list_pool, s0->per_user_index);
310           clib_dlist_addtail (sm->list_pool, s0->per_user_list_head_index,
311                               s0->per_user_index);
312         trace0:
313
314           if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE) 
315                             && (b0->flags & VLIB_BUFFER_IS_TRACED))) 
316             {
317               snat_out2in_trace_t *t = 
318                  vlib_add_trace (vm, node, b0, sizeof (*t));
319               t->sw_if_index = sw_if_index0;
320               t->next_index = next0;
321               t->session_index = ~0;
322               if (s0)
323                   t->session_index = s0 - sm->sessions;
324             }
325
326           pkts_processed += next0 != SNAT_OUT2IN_NEXT_DROP;
327
328
329           ip1 = vlib_buffer_get_current (b1);
330           udp1 = ip4_next_header (ip1);
331           tcp1 = (tcp_header_t *) udp1;
332           icmp1 = (icmp46_header_t *) udp1;
333
334           sw_if_index1 = vnet_buffer(b1)->sw_if_index[VLIB_RX];
335           rx_fib_index1 = vec_elt (sm->ip4_main->fib_index_by_sw_if_index, 
336                                    sw_if_index1);
337
338           vnet_get_config_data (&cm->config_main,
339                                 &b1->current_config_index,
340                                 &next1,
341                                 0 /* sizeof config data */);
342           proto1 = ~0;
343           proto1 = (ip1->protocol == IP_PROTOCOL_UDP) 
344             ? SNAT_PROTOCOL_UDP : proto1;
345           proto1 = (ip1->protocol == IP_PROTOCOL_TCP) 
346             ? SNAT_PROTOCOL_TCP : proto1;
347           proto1 = (ip1->protocol == IP_PROTOCOL_ICMP) 
348             ? SNAT_PROTOCOL_ICMP : proto1;
349
350           if (PREDICT_FALSE (proto1 == ~0))
351               goto trace1;
352
353           if (PREDICT_FALSE (proto1 == SNAT_PROTOCOL_ICMP))
354             {
355               next1 = icmp_out2in_slow_path 
356                 (sm, b1, ip1, icmp1, sw_if_index1, rx_fib_index1, node, 
357                  next1, now);
358               goto trace1;
359             }
360
361           key1.addr = ip1->dst_address;
362           key1.port = udp1->dst_port;
363           key1.protocol = proto1;
364           key1.fib_index = rx_fib_index1;
365           
366           kv1.key = key1.as_u64;
367
368           if (clib_bihash_search_8_8 (&sm->out2in, &kv1, &value1))
369             goto trace1;
370           else
371             s1 = pool_elt_at_index (sm->sessions, value1.value);
372
373           old_addr1 = ip1->dst_address.as_u32;
374           ip1->dst_address = s1->in2out.addr;
375           new_addr1 = ip1->dst_address.as_u32;
376           vnet_buffer(b1)->sw_if_index[VLIB_TX] = s1->out2in.fib_index;
377
378           sum1 = ip1->checksum;
379           sum1 = ip_csum_update (sum1, old_addr1, new_addr1,
380                                  ip4_header_t,
381                                  dst_address /* changed member */);
382           ip1->checksum = ip_csum_fold (sum1);
383
384           if (PREDICT_TRUE(proto1 == SNAT_PROTOCOL_TCP))
385             {
386               old_port1 = tcp1->ports.dst;
387               tcp1->ports.dst = s1->in2out.port;
388               new_port1 = tcp1->ports.dst;
389
390               sum1 = tcp1->checksum;
391               sum1 = ip_csum_update (sum1, old_addr1, new_addr1,
392                                      ip4_header_t,
393                                      dst_address /* changed member */);
394
395               sum1 = ip_csum_update (sum1, old_port1, new_port1,
396                                      ip4_header_t /* cheat */,
397                                      length /* changed member */);
398               tcp1->checksum = ip_csum_fold(sum1);
399             }
400           else
401             {
402               old_port1 = udp1->dst_port;
403               udp1->dst_port = s1->in2out.port;
404               udp1->checksum = 0;
405             }
406
407           /* Accounting, per-user LRU list maintenance */
408           s1->last_heard = now;
409           s1->total_pkts++;
410           s1->total_bytes += vlib_buffer_length_in_chain (vm, b1);
411           clib_dlist_remove (sm->list_pool, s1->per_user_index);
412           clib_dlist_addtail (sm->list_pool, s1->per_user_list_head_index,
413                               s1->per_user_index);
414         trace1:
415
416           if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE) 
417                             && (b1->flags & VLIB_BUFFER_IS_TRACED))) 
418             {
419               snat_out2in_trace_t *t = 
420                  vlib_add_trace (vm, node, b1, sizeof (*t));
421               t->sw_if_index = sw_if_index1;
422               t->next_index = next1;
423               t->session_index = ~0;
424               if (s1)
425                   t->session_index = s1 - sm->sessions;
426             }
427
428           pkts_processed += next0 != SNAT_OUT2IN_NEXT_DROP;
429           pkts_processed += next1 != SNAT_OUT2IN_NEXT_DROP;
430
431           /* verify speculative enqueues, maybe switch current next frame */
432           vlib_validate_buffer_enqueue_x2 (vm, node, next_index,
433                                            to_next, n_left_to_next,
434                                            bi0, bi1, next0, next1);
435         }
436
437       while (n_left_from > 0 && n_left_to_next > 0)
438         {
439           u32 bi0;
440           vlib_buffer_t * b0;
441           u32 next0 = SNAT_OUT2IN_NEXT_DROP;
442           u32 sw_if_index0;
443           ip4_header_t * ip0;
444           ip_csum_t sum0;
445           u32 new_addr0, old_addr0;
446           u16 new_port0, old_port0;
447           udp_header_t * udp0;
448           tcp_header_t * tcp0;
449           icmp46_header_t * icmp0;
450           snat_session_key_t key0;
451           u32 rx_fib_index0;
452           u32 proto0;
453           snat_session_t * s0 = 0;
454           clib_bihash_kv_8_8_t kv0, value0;
455           
456           /* speculatively enqueue b0 to the current next frame */
457           bi0 = from[0];
458           to_next[0] = bi0;
459           from += 1;
460           to_next += 1;
461           n_left_from -= 1;
462           n_left_to_next -= 1;
463
464           b0 = vlib_get_buffer (vm, bi0);
465
466           ip0 = vlib_buffer_get_current (b0);
467           udp0 = ip4_next_header (ip0);
468           tcp0 = (tcp_header_t *) udp0;
469           icmp0 = (icmp46_header_t *) udp0;
470
471           sw_if_index0 = vnet_buffer(b0)->sw_if_index[VLIB_RX];
472           rx_fib_index0 = vec_elt (sm->ip4_main->fib_index_by_sw_if_index, 
473                                    sw_if_index0);
474
475           vnet_get_config_data (&cm->config_main,
476                                 &b0->current_config_index,
477                                 &next0,
478                                 0 /* sizeof config data */);
479           proto0 = ~0;
480           proto0 = (ip0->protocol == IP_PROTOCOL_UDP) 
481             ? SNAT_PROTOCOL_UDP : proto0;
482           proto0 = (ip0->protocol == IP_PROTOCOL_TCP) 
483             ? SNAT_PROTOCOL_TCP : proto0;
484           proto0 = (ip0->protocol == IP_PROTOCOL_ICMP) 
485             ? SNAT_PROTOCOL_ICMP : proto0;
486
487           if (PREDICT_FALSE (proto0 == ~0))
488               goto trace00;
489
490           if (PREDICT_FALSE (proto0 == SNAT_PROTOCOL_ICMP))
491             {
492               next0 = icmp_out2in_slow_path 
493                 (sm, b0, ip0, icmp0, sw_if_index0, rx_fib_index0, node, 
494                  next0, now);
495               goto trace00;
496             }
497
498           key0.addr = ip0->dst_address;
499           key0.port = udp0->dst_port;
500           key0.protocol = proto0;
501           key0.fib_index = rx_fib_index0;
502           
503           kv0.key = key0.as_u64;
504
505           if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
506             goto trace00;
507           else
508             s0 = pool_elt_at_index (sm->sessions, value0.value);
509
510           old_addr0 = ip0->dst_address.as_u32;
511           ip0->dst_address = s0->in2out.addr;
512           new_addr0 = ip0->dst_address.as_u32;
513           vnet_buffer(b0)->sw_if_index[VLIB_TX] = s0->out2in.fib_index;
514
515           sum0 = ip0->checksum;
516           sum0 = ip_csum_update (sum0, old_addr0, new_addr0,
517                                  ip4_header_t,
518                                  dst_address /* changed member */);
519           ip0->checksum = ip_csum_fold (sum0);
520
521           if (PREDICT_TRUE(proto0 == SNAT_PROTOCOL_TCP))
522             {
523               old_port0 = tcp0->ports.dst;
524               tcp0->ports.dst = s0->in2out.port;
525               new_port0 = tcp0->ports.dst;
526
527               sum0 = tcp0->checksum;
528               sum0 = ip_csum_update (sum0, old_addr0, new_addr0,
529                                      ip4_header_t,
530                                      dst_address /* changed member */);
531
532               sum0 = ip_csum_update (sum0, old_port0, new_port0,
533                                      ip4_header_t /* cheat */,
534                                      length /* changed member */);
535               tcp0->checksum = ip_csum_fold(sum0);
536             }
537           else
538             {
539               old_port0 = udp0->dst_port;
540               udp0->dst_port = s0->in2out.port;
541               udp0->checksum = 0;
542             }
543
544           /* Accounting, per-user LRU list maintenance */
545           s0->last_heard = now;
546           s0->total_pkts++;
547           s0->total_bytes += vlib_buffer_length_in_chain (vm, b0);
548           clib_dlist_remove (sm->list_pool, s0->per_user_index);
549           clib_dlist_addtail (sm->list_pool, s0->per_user_list_head_index,
550                               s0->per_user_index);
551         trace00:
552
553           if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE) 
554                             && (b0->flags & VLIB_BUFFER_IS_TRACED))) 
555             {
556               snat_out2in_trace_t *t = 
557                  vlib_add_trace (vm, node, b0, sizeof (*t));
558               t->sw_if_index = sw_if_index0;
559               t->next_index = next0;
560               t->session_index = ~0;
561               if (s0)
562                   t->session_index = s0 - sm->sessions;
563             }
564
565           pkts_processed += next0 != SNAT_OUT2IN_NEXT_DROP;
566
567           /* verify speculative enqueue, maybe switch current next frame */
568           vlib_validate_buffer_enqueue_x1 (vm, node, next_index,
569                                            to_next, n_left_to_next,
570                                            bi0, next0);
571         }
572
573       vlib_put_next_frame (vm, node, next_index, n_left_to_next);
574     }
575
576   vlib_node_increment_counter (vm, snat_out2in_node.index, 
577                                SNAT_OUT2IN_ERROR_OUT2IN_PACKETS, 
578                                pkts_processed);
579   return frame->n_vectors;
580 }
581
582 VLIB_REGISTER_NODE (snat_out2in_node) = {
583   .function = snat_out2in_node_fn,
584   .name = "snat-out2in",
585   .vector_size = sizeof (u32),
586   .format_trace = format_snat_out2in_trace,
587   .type = VLIB_NODE_TYPE_INTERNAL,
588   
589   .n_errors = ARRAY_LEN(snat_out2in_error_strings),
590   .error_strings = snat_out2in_error_strings,
591
592   .runtime_data_bytes = sizeof (snat_runtime_t),
593   
594   .n_next_nodes = SNAT_OUT2IN_N_NEXT,
595
596   /* edit / add dispositions here */
597   .next_nodes = {
598     [SNAT_OUT2IN_NEXT_DROP] = "error-drop",
599   },
600 };
601 VLIB_NODE_FUNCTION_MULTIARCH (snat_out2in_node, snat_out2in_node_fn);