ipsec: support UDP encap/decap for NAT traversal
[vpp.git] / src / vnet / ipsec / ipsec_input.c
index 9aa5654..08269d0 100644 (file)
@@ -216,7 +216,9 @@ ipsec_input_ip4_node_fn (vlib_main_t * vm,
 
          ip0 = vlib_buffer_get_current (b0);
 
-         if (PREDICT_TRUE (ip0->protocol == IP_PROTOCOL_IPSEC_ESP))
+         if (PREDICT_TRUE
+             (ip0->protocol == IP_PROTOCOL_IPSEC_ESP
+              || ip0->protocol == IP_PROTOCOL_UDP))
            {
 #if 0
              clib_warning
@@ -228,6 +230,13 @@ ipsec_input_ip4_node_fn (vlib_main_t * vm,
 #endif
 
              esp0 = (esp_header_t *) ((u8 *) ip0 + ip4_header_bytes (ip0));
+             if (PREDICT_FALSE (ip0->protocol == IP_PROTOCOL_UDP))
+               {
+                 esp0 =
+                   (esp_header_t *) ((u8 *) esp0 + sizeof (udp_header_t));
+               }
+             /* FIXME TODO missing check whether there is enough data inside
+              * IP/UDP to contain ESP header & stuff ? */
              p0 = ipsec_input_protect_policy_match (spd0,
                                                     clib_net_to_host_u32
                                                     (ip0->src_address.
@@ -245,7 +254,7 @@ ipsec_input_ip4_node_fn (vlib_main_t * vm,
                  vnet_buffer (b0)->ipsec.sad_index = p0->sa_index;
                  vnet_buffer (b0)->ipsec.flags = 0;
                  next0 = im->esp_decrypt_next_index;
-                 vlib_buffer_advance (b0, ip4_header_bytes (ip0));
+                 vlib_buffer_advance (b0, ((u8 *) esp0 - (u8 *) ip0));
                  goto trace0;
                }
 
@@ -255,7 +264,8 @@ ipsec_input_ip4_node_fn (vlib_main_t * vm,
                {
                  ipsec_input_trace_t *tr =
                    vlib_add_trace (vm, node, b0, sizeof (*tr));
-                 if (ip0->protocol == IP_PROTOCOL_IPSEC_ESP)
+                 if (ip0->protocol == IP_PROTOCOL_IPSEC_ESP ||
+                     ip0->protocol == IP_PROTOCOL_UDP)
                    {
                      if (p0)
                        tr->sa_id = p0->sa_id;