ipsec: split ipsec nodes into ip4/ip6 nodes
[vpp.git] / src / vnet / ipsec / ah_decrypt.c
index abe2e6f..e3e0071 100644 (file)
 #include <vnet/ipsec/esp.h>
 #include <vnet/ipsec/ah.h>
 
-#define foreach_ah_decrypt_next                \
-_(DROP, "error-drop")                           \
-_(IP4_INPUT, "ip4-input")                       \
-_(IP6_INPUT, "ip6-input")                       \
-_(IPSEC_GRE_INPUT, "ipsec-gre-input")
+#define foreach_ah_decrypt_next \
+  _ (DROP, "error-drop")        \
+  _ (IP4_INPUT, "ip4-input")    \
+  _ (IP6_INPUT, "ip6-input")    \
+  _ (IPSEC_GRE_INPUT, "ipsec-gre-input")
 
 #define _(v, s) AH_DECRYPT_NEXT_##v,
 typedef enum
@@ -37,14 +37,11 @@ typedef enum
     AH_DECRYPT_N_NEXT,
 } ah_decrypt_next_t;
 
-
-#define foreach_ah_decrypt_error                   \
- _(RX_PKTS, "AH pkts received")                    \
- _(DECRYPTION_FAILED, "AH decryption failed")      \
- _(INTEG_ERROR, "Integrity check failed")           \
- _(REPLAY, "SA replayed packet")                    \
- _(NOT_IP, "Not IP packet (dropped)")
-
+#define foreach_ah_decrypt_error                \
+  _ (RX_PKTS, "AH pkts received")               \
+  _ (DECRYPTION_FAILED, "AH decryption failed") \
+  _ (INTEG_ERROR, "Integrity check failed")     \
+  _ (REPLAY, "SA replayed packet")
 
 typedef enum
 {
@@ -77,9 +74,10 @@ format_ah_decrypt_trace (u8 * s, va_list * args)
   return s;
 }
 
-static uword
-ah_decrypt_node_fn (vlib_main_t * vm,
-                   vlib_node_runtime_t * node, vlib_frame_t * from_frame)
+always_inline uword
+ah_decrypt_inline (vlib_main_t * vm,
+                  vlib_node_runtime_t * node, vlib_frame_t * from_frame,
+                  int is_ip6)
 {
   u32 n_left_from, *from, next_index, *to_next;
   ipsec_main_t *im = &ipsec_main;
@@ -107,8 +105,6 @@ ah_decrypt_node_fn (vlib_main_t * vm,
          u32 seq;
          ip4_header_t *ih4 = 0, *oh4 = 0;
          ip6_header_t *ih6 = 0, *oh6 = 0;
-         u8 tunnel_mode = 1;
-         u8 transport_ip6 = 0;
          u8 ip_hdr_size = 0;
          u8 tos = 0;
          u8 ttl = 0;
@@ -133,12 +129,7 @@ ah_decrypt_node_fn (vlib_main_t * vm,
          sa_index0 = vnet_buffer (i_b0)->ipsec.sad_index;
          sa0 = pool_elt_at_index (im->sad, sa_index0);
 
-         if ((ih4->ip_version_and_header_length & 0xF0) == 0x40)
-           {
-             ip_hdr_size = ip4_header_bytes (ih4);
-             ah0 = (ah_header_t *) ((u8 *) ih4 + ip_hdr_size);
-           }
-         else if ((ih4->ip_version_and_header_length & 0xF0) == 0x60)
+         if (is_ip6)
            {
              ip6_ext_header_t *prev = NULL;
              ip6_ext_header_find_t (ih6, prev, ah0, IP_PROTOCOL_IPSEC_AH);
@@ -147,9 +138,8 @@ ah_decrypt_node_fn (vlib_main_t * vm,
            }
          else
            {
-             vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                                          AH_DECRYPT_ERROR_NOT_IP, 1);
-             goto trace;
+             ip_hdr_size = ip4_header_bytes (ih4);
+             ah0 = (ah_header_t *) ((u8 *) ih4 + ip_hdr_size);
            }
 
          seq = clib_host_to_net_u32 (ah0->seq_no);
@@ -167,8 +157,14 @@ ah_decrypt_node_fn (vlib_main_t * vm,
              if (PREDICT_FALSE (rv))
                {
                  clib_warning ("anti-replay SPI %u seq %u", sa0->spi, seq);
-                 vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                                              AH_DECRYPT_ERROR_REPLAY, 1);
+                 if (is_ip6)
+                   vlib_node_increment_counter (vm,
+                                                ah6_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_REPLAY, 1);
+                 else
+                   vlib_node_increment_counter (vm,
+                                                ah6_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_REPLAY, 1);
                  to_next[0] = i_bi0;
                  to_next += 1;
                  goto trace;
@@ -189,18 +185,7 @@ ah_decrypt_node_fn (vlib_main_t * vm,
              memcpy (digest, icv, icv_size);
              memset (icv, 0, icv_size);
 
-             if ((ih4->ip_version_and_header_length & 0xF0) == 0x40)
-               {
-                 tos = ih4->tos;
-                 ttl = ih4->ttl;
-                 ih4->tos = 0;
-                 ih4->ttl = 0;
-                 ih4->checksum = 0;
-                 ih4->flags_and_fragment_offset = 0;
-                 icv_padding_len =
-                   ah_calc_icv_padding_len (icv_size, 0 /* is_ipv6 */ );
-               }
-             else
+             if (is_ip6)
                {
                  ip_version_traffic_class_and_flow_label =
                    ih6->ip_version_traffic_class_and_flow_label;
@@ -211,15 +196,33 @@ ah_decrypt_node_fn (vlib_main_t * vm,
                  icv_padding_len =
                    ah_calc_icv_padding_len (icv_size, 1 /* is_ipv6 */ );
                }
+             else
+               {
+                 tos = ih4->tos;
+                 ttl = ih4->ttl;
+                 ih4->tos = 0;
+                 ih4->ttl = 0;
+                 ih4->checksum = 0;
+                 ih4->flags_and_fragment_offset = 0;
+                 icv_padding_len =
+                   ah_calc_icv_padding_len (icv_size, 0 /* is_ipv6 */ );
+               }
              hmac_calc (sa0->integ_alg, sa0->integ_key, sa0->integ_key_len,
                         (u8 *) ih4, i_b0->current_length, sig, sa0->use_esn,
                         sa0->seq_hi);
 
              if (PREDICT_FALSE (memcmp (digest, sig, icv_size)))
                {
-                 vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                                              AH_DECRYPT_ERROR_INTEG_ERROR,
-                                              1);
+                 if (is_ip6)
+                   vlib_node_increment_counter (vm,
+                                                ah6_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_INTEG_ERROR,
+                                                1);
+                 else
+                   vlib_node_increment_counter (vm,
+                                                ah4_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_INTEG_ERROR,
+                                                1);
                  to_next[0] = i_bi0;
                  to_next += 1;
                  goto trace;
@@ -241,30 +244,8 @@ ah_decrypt_node_fn (vlib_main_t * vm,
                               icv_padding_len);
          i_b0->flags |= VLIB_BUFFER_TOTAL_LENGTH_VALID;
 
-         /* transport mode */
-         if (PREDICT_FALSE (!sa0->is_tunnel && !sa0->is_tunnel_ip6))
-           {
-             tunnel_mode = 0;
-
-             if (PREDICT_TRUE
-                 ((ih4->ip_version_and_header_length & 0xF0) != 0x40))
-               {
-                 if (PREDICT_TRUE
-                     ((ih4->ip_version_and_header_length & 0xF0) == 0x60))
-                   transport_ip6 = 1;
-                 else
-                   {
-                     clib_warning ("next header: 0x%x", ah0->nexthdr);
-                     vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                                                  AH_DECRYPT_ERROR_NOT_IP,
-                                                  1);
-                     goto trace;
-                   }
-               }
-           }
-
-         if (PREDICT_TRUE (tunnel_mode))
-           {
+         if (PREDICT_TRUE (sa0->is_tunnel))
+           {                   /* tunnel mode */
              if (PREDICT_TRUE (ah0->nexthdr == IP_PROTOCOL_IP_IN_IP))
                next0 = AH_DECRYPT_NEXT_IP4_INPUT;
              else if (ah0->nexthdr == IP_PROTOCOL_IPV6)
@@ -272,16 +253,22 @@ ah_decrypt_node_fn (vlib_main_t * vm,
              else
                {
                  clib_warning ("next header: 0x%x", ah0->nexthdr);
-                 vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                                              AH_DECRYPT_ERROR_DECRYPTION_FAILED,
-                                              1);
+                 if (is_ip6)
+                   vlib_node_increment_counter (vm,
+                                                ah6_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_DECRYPTION_FAILED,
+                                                1);
+                 else
+                   vlib_node_increment_counter (vm,
+                                                ah4_decrypt_node.index,
+                                                AH_DECRYPT_ERROR_DECRYPTION_FAILED,
+                                                1);
                  goto trace;
                }
            }
-         /* transport mode */
          else
-           {
-             if (PREDICT_FALSE (transport_ip6))
+           {                   /* transport mode */
+             if (is_ip6)
                {
                  vlib_buffer_advance (i_b0, -sizeof (ip6_header_t));
                  oh6 = vlib_buffer_get_current (i_b0);
@@ -337,18 +324,58 @@ ah_decrypt_node_fn (vlib_main_t * vm,
        }
       vlib_put_next_frame (vm, node, next_index, n_left_to_next);
     }
-  vlib_node_increment_counter (vm, ah_decrypt_node.index,
-                              AH_DECRYPT_ERROR_RX_PKTS,
-                              from_frame->n_vectors);
+  if (is_ip6)
+    vlib_node_increment_counter (vm, ah6_decrypt_node.index,
+                                AH_DECRYPT_ERROR_RX_PKTS,
+                                from_frame->n_vectors);
+  else
+    vlib_node_increment_counter (vm, ah4_decrypt_node.index,
+                                AH_DECRYPT_ERROR_RX_PKTS,
+                                from_frame->n_vectors);
 
   return from_frame->n_vectors;
 }
 
+static uword
+ah4_decrypt_node_fn (vlib_main_t * vm,
+                    vlib_node_runtime_t * node, vlib_frame_t * from_frame)
+{
+  return ah_decrypt_inline (vm, node, from_frame, 0 /* is_ip6 */ );
+}
+
+/* *INDENT-OFF* */
+VLIB_REGISTER_NODE (ah4_decrypt_node) = {
+  .function = ah4_decrypt_node_fn,
+  .name = "ah4-decrypt",
+  .vector_size = sizeof (u32),
+  .format_trace = format_ah_decrypt_trace,
+  .type = VLIB_NODE_TYPE_INTERNAL,
+
+  .n_errors = ARRAY_LEN(ah_decrypt_error_strings),
+  .error_strings = ah_decrypt_error_strings,
+
+  .n_next_nodes = AH_DECRYPT_N_NEXT,
+  .next_nodes = {
+#define _(s,n) [AH_DECRYPT_NEXT_##s] = n,
+    foreach_ah_decrypt_next
+#undef _
+  },
+};
+/* *INDENT-ON* */
+
+VLIB_NODE_FUNCTION_MULTIARCH (ah4_decrypt_node, ah4_decrypt_node_fn);
+
+static uword
+ah6_decrypt_node_fn (vlib_main_t * vm,
+                    vlib_node_runtime_t * node, vlib_frame_t * from_frame)
+{
+  return ah_decrypt_inline (vm, node, from_frame, 1 /* is_ip6 */ );
+}
 
 /* *INDENT-OFF* */
-VLIB_REGISTER_NODE (ah_decrypt_node) = {
-  .function = ah_decrypt_node_fn,
-  .name = "ah-decrypt",
+VLIB_REGISTER_NODE (ah6_decrypt_node) = {
+  .function = ah6_decrypt_node_fn,
+  .name = "ah6-decrypt",
   .vector_size = sizeof (u32),
   .format_trace = format_ah_decrypt_trace,
   .type = VLIB_NODE_TYPE_INTERNAL,
@@ -365,7 +392,7 @@ VLIB_REGISTER_NODE (ah_decrypt_node) = {
 };
 /* *INDENT-ON* */
 
-VLIB_NODE_FUNCTION_MULTIARCH (ah_decrypt_node, ah_decrypt_node_fn)
+VLIB_NODE_FUNCTION_MULTIARCH (ah6_decrypt_node, ah6_decrypt_node_fn);
 /*
  * fd.io coding-style-patch-verification: ON
  *