NAT: session number limitation to avoid running out of memory crash (VPP-984)
[vpp.git] / src / plugins / nat / out2in.c
old mode 100644 (file)
new mode 100755 (executable)
index 6795006..e5426c1
@@ -87,7 +87,8 @@ vlib_node_registration_t snat_det_out2in_node;
 _(UNSUPPORTED_PROTOCOL, "Unsupported protocol")         \
 _(OUT2IN_PACKETS, "Good out2in packets processed")      \
 _(BAD_ICMP_TYPE, "unsupported ICMP type")               \
-_(NO_TRANSLATION, "No translation")
+_(NO_TRANSLATION, "No translation")                     \
+_(MAX_SESSIONS_EXCEEDED, "Maximum sessions exceeded")
 
 typedef enum {
 #define _(sym,str) SNAT_OUT2IN_ERROR_##sym,
@@ -139,6 +140,12 @@ create_session_for_static_mapping (snat_main_t *sm,
   dlist_elt_t * per_user_list_head_elt;
   ip4_header_t *ip0;
 
+  if (PREDICT_FALSE (maximum_sessions_exceeded(sm, thread_index)))
+    {
+      b0->error = node->errors[SNAT_OUT2IN_ERROR_MAX_SESSIONS_EXCEEDED];
+      return 0;
+    }
+
   ip0 = vlib_buffer_get_current (b0);
 
   user_key.addr = in2out.addr;
@@ -146,7 +153,8 @@ create_session_for_static_mapping (snat_main_t *sm,
   kv0.key = user_key.as_u64;
 
   /* Ever heard of the "user" = inside ip4 address before? */
-  if (clib_bihash_search_8_8 (&sm->user_hash, &kv0, &value0))
+  if (clib_bihash_search_8_8 (&sm->per_thread_data[thread_index].user_hash,
+                              &kv0, &value0))
     {
       /* no, make a new one */
       pool_get (sm->per_thread_data[thread_index].users, u);
@@ -166,7 +174,8 @@ create_session_for_static_mapping (snat_main_t *sm,
       kv0.value = u - sm->per_thread_data[thread_index].users;
 
       /* add user */
-      clib_bihash_add_del_8_8 (&sm->user_hash, &kv0, 1 /* is_add */);
+      clib_bihash_add_del_8_8 (&sm->per_thread_data[thread_index].user_hash,
+                               &kv0, 1 /* is_add */);
 
       /* add non-traslated packets worker lookup */
       kv0.value = thread_index;
@@ -211,13 +220,15 @@ create_session_for_static_mapping (snat_main_t *sm,
   /* Add to translation hashes */
   kv0.key = s->in2out.as_u64;
   kv0.value = s - sm->per_thread_data[thread_index].sessions;
-  if (clib_bihash_add_del_8_8 (&sm->in2out, &kv0, 1 /* is_add */))
+  if (clib_bihash_add_del_8_8 (&sm->per_thread_data[thread_index].in2out, &kv0,
+                               1 /* is_add */))
       clib_warning ("in2out key add failed");
 
   kv0.key = s->out2in.as_u64;
   kv0.value = s - sm->per_thread_data[thread_index].sessions;
 
-  if (clib_bihash_add_del_8_8 (&sm->out2in, &kv0, 1 /* is_add */))
+  if (clib_bihash_add_del_8_8 (&sm->per_thread_data[thread_index].out2in, &kv0,
+                               1 /* is_add */))
       clib_warning ("out2in key add failed");
 
   /* log NAT event */
@@ -325,7 +336,8 @@ u32 icmp_match_out2in_slow(snat_main_t *sm, vlib_node_runtime_t *node,
 
   kv0.key = key0.as_u64;
 
-  if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
+  if (clib_bihash_search_8_8 (&sm->per_thread_data[thread_index].out2in, &kv0,
+                              &value0))
     {
       /* Try to match static mapping by external address and port,
          destination address and port in packet */
@@ -615,14 +627,15 @@ static inline u32 icmp_out2in_slow_path (snat_main_t *sm,
   return next0;
 }
 
-static void
+static snat_session_t *
 snat_out2in_unknown_proto (snat_main_t *sm,
                            vlib_buffer_t * b,
                            ip4_header_t * ip,
                            u32 rx_fib_index,
                            u32 thread_index,
                            f64 now,
-                           vlib_main_t * vm)
+                           vlib_main_t * vm,
+                           vlib_node_runtime_t * node)
 {
   clib_bihash_kv_8_8_t kv, value;
   clib_bihash_kv_16_8_t s_kv, s_value;
@@ -630,7 +643,7 @@ snat_out2in_unknown_proto (snat_main_t *sm,
   snat_session_key_t m_key;
   u32 old_addr, new_addr;
   ip_csum_t sum;
-  snat_unk_proto_ses_key_t key;
+  nat_ed_ses_key_t key;
   snat_session_t * s;
   snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index];
   snat_user_key_t u_key;
@@ -643,24 +656,34 @@ snat_out2in_unknown_proto (snat_main_t *sm,
   key.r_addr = ip->src_address;
   key.fib_index = rx_fib_index;
   key.proto = ip->protocol;
-  key.rsvd[0] = key.rsvd[1] = key.rsvd[2] = 0;
+  key.rsvd = 0;
+  key.l_port = 0;
   s_kv.key[0] = key.as_u64[0];
   s_kv.key[1] = key.as_u64[1];
 
-  if (!clib_bihash_search_16_8 (&sm->out2in_unk_proto, &s_kv, &s_value))
+  if (!clib_bihash_search_16_8 (&sm->out2in_ed, &s_kv, &s_value))
     {
       s = pool_elt_at_index (tsm->sessions, s_value.value);
       new_addr = ip->dst_address.as_u32 = s->in2out.addr.as_u32;
     }
   else
     {
+      if (PREDICT_FALSE (maximum_sessions_exceeded(sm, thread_index)))
+        {
+          b->error = node->errors[SNAT_OUT2IN_ERROR_MAX_SESSIONS_EXCEEDED];
+          return 0;
+        }
+
       m_key.addr = ip->dst_address;
       m_key.port = 0;
       m_key.protocol = 0;
       m_key.fib_index = rx_fib_index;
       kv.key = m_key.as_u64;
       if (clib_bihash_search_8_8 (&sm->static_mapping_by_external, &kv, &value))
-        return;
+        {
+          b->error = node->errors[SNAT_OUT2IN_ERROR_NO_TRANSLATION];
+          return 0;
+        }
 
       m = pool_elt_at_index (sm->static_mappings, value.value);
 
@@ -671,7 +694,7 @@ snat_out2in_unknown_proto (snat_main_t *sm,
       kv.key = u_key.as_u64;
 
       /* Ever heard of the "user" = src ip4 address before? */
-      if (clib_bihash_search_8_8 (&sm->user_hash, &kv, &value))
+      if (clib_bihash_search_8_8 (&tsm->user_hash, &kv, &value))
         {
           /* no, make a new one */
           pool_get (tsm->users, u);
@@ -688,7 +711,7 @@ snat_out2in_unknown_proto (snat_main_t *sm,
           kv.value = u - tsm->users;
 
           /* add user */
-          clib_bihash_add_del_8_8 (&sm->user_hash, &kv, 1);
+          clib_bihash_add_del_8_8 (&tsm->user_hash, &kv, 1);
         }
       else
         {
@@ -721,14 +744,14 @@ snat_out2in_unknown_proto (snat_main_t *sm,
 
       /* Add to lookup tables */
       s_kv.value = s - tsm->sessions;
-      if (clib_bihash_add_del_16_8 (&sm->out2in_unk_proto, &s_kv, 1))
+      if (clib_bihash_add_del_16_8 (&sm->out2in_ed, &s_kv, 1))
         clib_warning ("out2in key add failed");
 
       key.l_addr = ip->dst_address;
       key.fib_index = m->fib_index;
       s_kv.key[0] = key.as_u64[0];
       s_kv.key[1] = key.as_u64[1];
-      if (clib_bihash_add_del_16_8 (&sm->in2out_unk_proto, &s_kv, 1))
+      if (clib_bihash_add_del_16_8 (&sm->in2out_ed, &s_kv, 1))
         clib_warning ("in2out key add failed");
    }
 
@@ -747,6 +770,162 @@ snat_out2in_unknown_proto (snat_main_t *sm,
   clib_dlist_remove (tsm->list_pool, s->per_user_index);
   clib_dlist_addtail (tsm->list_pool, s->per_user_list_head_index,
                       s->per_user_index);
+
+  return s;
+}
+
+static snat_session_t *
+snat_out2in_lb (snat_main_t *sm,
+                vlib_buffer_t * b,
+                ip4_header_t * ip,
+                u32 rx_fib_index,
+                u32 thread_index,
+                f64 now,
+                vlib_main_t * vm,
+                vlib_node_runtime_t * node)
+{
+  nat_ed_ses_key_t key;
+  clib_bihash_kv_16_8_t s_kv, s_value;
+  udp_header_t *udp = ip4_next_header (ip);
+  tcp_header_t *tcp = (tcp_header_t *) udp;
+  snat_session_t *s = 0;
+  snat_main_per_thread_data_t *tsm = &sm->per_thread_data[thread_index];
+  snat_session_key_t e_key, l_key;
+  clib_bihash_kv_8_8_t kv, value;
+  u32 old_addr, new_addr;
+  u32 proto = ip_proto_to_snat_proto (ip->protocol);
+  u16 new_port, old_port;
+  ip_csum_t sum;
+  snat_user_key_t u_key;
+  snat_user_t *u;
+  dlist_elt_t *head, *elt;
+
+  old_addr = ip->dst_address.as_u32;
+
+  key.l_addr = ip->dst_address;
+  key.r_addr = ip->src_address;
+  key.fib_index = rx_fib_index;
+  key.proto = ip->protocol;
+  key.rsvd = 0;
+  key.l_port = udp->dst_port;
+  s_kv.key[0] = key.as_u64[0];
+  s_kv.key[1] = key.as_u64[1];
+
+  if (!clib_bihash_search_16_8 (&sm->out2in_ed, &s_kv, &s_value))
+    {
+      s = pool_elt_at_index (tsm->sessions, s_value.value);
+    }
+  else
+    {
+      if (PREDICT_FALSE (maximum_sessions_exceeded(sm, thread_index)))
+        {
+          b->error = node->errors[SNAT_OUT2IN_ERROR_MAX_SESSIONS_EXCEEDED];
+          return 0;
+        }
+
+      e_key.addr = ip->dst_address;
+      e_key.port = udp->dst_port;
+      e_key.protocol = proto;
+      e_key.fib_index = rx_fib_index;
+      if (snat_static_mapping_match(sm, e_key, &l_key, 1, 0))
+        return 0;
+
+      u_key.addr = l_key.addr;
+      u_key.fib_index = l_key.fib_index;
+      kv.key = u_key.as_u64;
+
+      /* Ever heard of the "user" = src ip4 address before? */
+      if (clib_bihash_search_8_8 (&tsm->user_hash, &kv, &value))
+        {
+          /* no, make a new one */
+          pool_get (tsm->users, u);
+          memset (u, 0, sizeof (*u));
+          u->addr = l_key.addr;
+          u->fib_index = l_key.fib_index;
+
+          pool_get (tsm->list_pool, head);
+          u->sessions_per_user_list_head_index = head - tsm->list_pool;
+
+          clib_dlist_init (tsm->list_pool,
+                           u->sessions_per_user_list_head_index);
+
+          kv.value = u - tsm->users;
+
+          /* add user */
+          if (clib_bihash_add_del_8_8 (&tsm->user_hash, &kv, 1))
+            clib_warning ("user key add failed");
+        }
+      else
+        {
+          u = pool_elt_at_index (tsm->users, value.value);
+        }
+
+      /* Create a new session */
+      pool_get (tsm->sessions, s);
+      memset (s, 0, sizeof (*s));
+
+      s->ext_host_addr.as_u32 = ip->src_address.as_u32;
+      s->flags |= SNAT_SESSION_FLAG_STATIC_MAPPING;
+      s->flags |= SNAT_SESSION_FLAG_LOAD_BALANCING;
+      s->outside_address_index = ~0;
+      s->out2in = e_key;
+      s->in2out = l_key;
+      u->nstaticsessions++;
+
+      /* Create list elts */
+      pool_get (tsm->list_pool, elt);
+      clib_dlist_init (tsm->list_pool, elt - tsm->list_pool);
+      elt->value = s - tsm->sessions;
+      s->per_user_index = elt - tsm->list_pool;
+      s->per_user_list_head_index = u->sessions_per_user_list_head_index;
+      clib_dlist_addtail (tsm->list_pool, s->per_user_list_head_index,
+                          s->per_user_index);
+
+      /* Add to lookup tables */
+      s_kv.value = s - tsm->sessions;
+      if (clib_bihash_add_del_16_8 (&sm->out2in_ed, &s_kv, 1))
+        clib_warning ("out2in-ed key add failed");
+
+      key.l_addr = l_key.addr;
+      key.fib_index = l_key.fib_index;
+      key.l_port = l_key.port;
+      s_kv.key[0] = key.as_u64[0];
+      s_kv.key[1] = key.as_u64[1];
+      if (clib_bihash_add_del_16_8 (&sm->in2out_ed, &s_kv, 1))
+        clib_warning ("in2out-ed key add failed");
+    }
+
+  new_addr = ip->dst_address.as_u32 = s->in2out.addr.as_u32;
+
+  /* Update IP checksum */
+  sum = ip->checksum;
+  sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address);
+  ip->checksum = ip_csum_fold (sum);
+
+  if (PREDICT_TRUE(proto == SNAT_PROTOCOL_TCP))
+    {
+      old_port = tcp->dst_port;
+      tcp->dst_port = s->in2out.port;
+      new_port = tcp->dst_port;
+
+      sum = tcp->checksum;
+      sum = ip_csum_update (sum, old_addr, new_addr, ip4_header_t, dst_address);
+      sum = ip_csum_update (sum, old_port, new_port, ip4_header_t, length);
+      tcp->checksum = ip_csum_fold(sum);
+    }
+  else
+    {
+      udp->dst_port = s->in2out.port;
+      udp->checksum = 0;
+    }
+
+  vnet_buffer(b)->sw_if_index[VLIB_TX] = s->in2out.fib_index;
+
+  /* Accounting */
+  s->last_heard = now;
+  s->total_pkts++;
+  s->total_bytes += vlib_buffer_length_in_chain (vm, b);
+  return s;
 }
 
 static uword
@@ -845,8 +1024,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           if (PREDICT_FALSE (proto0 == ~0))
             {
-              snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0,
-                                        thread_index, now, vm);
+              s0 = snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0,
+                                             thread_index, now, vm, node);
+              if (!s0)
+                next0 = SNAT_OUT2IN_NEXT_DROP;
               goto trace0;
             }
 
@@ -865,7 +1046,8 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           kv0.key = key0.as_u64;
 
-          if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
+          if (clib_bihash_search_8_8 (&sm->per_thread_data[thread_index].out2in,
+                                      &kv0, &value0))
             {
               /* Try to match static mapping by external address and port,
                  destination address and port in packet */
@@ -888,14 +1070,27 @@ snat_out2in_node_fn (vlib_main_t * vm,
                                                      thread_index);
               if (!s0)
                 {
-                  b0->error = node->errors[SNAT_OUT2IN_ERROR_NO_TRANSLATION];
                   next0 = SNAT_OUT2IN_NEXT_DROP;
                   goto trace0;
                 }
             }
           else
-            s0 = pool_elt_at_index (sm->per_thread_data[thread_index].sessions,
-                                    value0.value);
+            {
+              if (PREDICT_FALSE (value0.value == ~0ULL))
+                {
+                  s0 = snat_out2in_lb(sm, b0, ip0, rx_fib_index0, thread_index,
+                                      now, vm, node);
+                  if (!s0)
+                    next0 = SNAT_OUT2IN_NEXT_DROP;
+                  goto trace0;
+                }
+              else
+                {
+                  s0 = pool_elt_at_index (
+                    sm->per_thread_data[thread_index].sessions,
+                    value0.value);
+                }
+            }
 
           old_addr0 = ip0->dst_address.as_u32;
           ip0->dst_address = s0->in2out.addr;
@@ -984,8 +1179,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           if (PREDICT_FALSE (proto1 == ~0))
             {
-              snat_out2in_unknown_proto(sm, b1, ip1, rx_fib_index1,
-                                        thread_index, now, vm);
+              s1 = snat_out2in_unknown_proto(sm, b1, ip1, rx_fib_index1,
+                                             thread_index, now, vm, node);
+              if (!s1)
+                next1 = SNAT_OUT2IN_NEXT_DROP;
               goto trace1;
             }
 
@@ -1004,7 +1201,8 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           kv1.key = key1.as_u64;
 
-          if (clib_bihash_search_8_8 (&sm->out2in, &kv1, &value1))
+          if (clib_bihash_search_8_8 (&sm->per_thread_data[thread_index].out2in,
+                                      &kv1, &value1))
             {
               /* Try to match static mapping by external address and port,
                  destination address and port in packet */
@@ -1027,14 +1225,27 @@ snat_out2in_node_fn (vlib_main_t * vm,
                                                      thread_index);
               if (!s1)
                 {
-                  b1->error = node->errors[SNAT_OUT2IN_ERROR_NO_TRANSLATION];
                   next1 = SNAT_OUT2IN_NEXT_DROP;
                   goto trace1;
                 }
             }
           else
-            s1 = pool_elt_at_index (sm->per_thread_data[thread_index].sessions,
-                                    value1.value);
+            {
+              if (PREDICT_FALSE (value1.value == ~0ULL))
+                {
+                  s1 = snat_out2in_lb(sm, b1, ip1, rx_fib_index1, thread_index,
+                                      now, vm, node);
+                  if (!s1)
+                    next1 = SNAT_OUT2IN_NEXT_DROP;
+                  goto trace1;
+                }
+              else
+                {
+                  s1 = pool_elt_at_index (
+                    sm->per_thread_data[thread_index].sessions,
+                    value1.value);
+                }
+            }
 
           old_addr1 = ip1->dst_address.as_u32;
           ip1->dst_address = s1->in2out.addr;
@@ -1149,8 +1360,10 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           if (PREDICT_FALSE (proto0 == ~0))
             {
-              snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0,
-                                        thread_index, now, vm);
+              s0 = snat_out2in_unknown_proto(sm, b0, ip0, rx_fib_index0,
+                                             thread_index, now, vm, node);
+              if (!s0)
+                next0 = SNAT_OUT2IN_NEXT_DROP;
               goto trace00;
             }
 
@@ -1179,7 +1392,8 @@ snat_out2in_node_fn (vlib_main_t * vm,
 
           kv0.key = key0.as_u64;
 
-          if (clib_bihash_search_8_8 (&sm->out2in, &kv0, &value0))
+          if (clib_bihash_search_8_8 (&sm->per_thread_data[thread_index].out2in,
+                                      &kv0, &value0))
             {
               /* Try to match static mapping by external address and port,
                  destination address and port in packet */
@@ -1203,14 +1417,27 @@ snat_out2in_node_fn (vlib_main_t * vm,
                                                      thread_index);
               if (!s0)
                 {
-                  b0->error = node->errors[SNAT_OUT2IN_ERROR_NO_TRANSLATION];
-                    next0 = SNAT_OUT2IN_NEXT_DROP;
+                  next0 = SNAT_OUT2IN_NEXT_DROP;
                   goto trace00;
                 }
             }
           else
-            s0 = pool_elt_at_index (sm->per_thread_data[thread_index].sessions,
-                                    value0.value);
+            {
+              if (PREDICT_FALSE (value0.value == ~0ULL))
+                {
+                  s0 = snat_out2in_lb(sm, b0, ip0, rx_fib_index0, thread_index,
+                                      now, vm, node);
+                  if (!s0)
+                    next0 = SNAT_OUT2IN_NEXT_DROP;
+                  goto trace00;
+                }
+              else
+                {
+                  s0 = pool_elt_at_index (
+                    sm->per_thread_data[thread_index].sessions,
+                    value0.value);
+                }
+            }
 
           old_addr0 = ip0->dst_address.as_u32;
           ip0->dst_address = s0->in2out.addr;