nat: pnat copy and clear byte instructions
[vpp.git] / src / plugins / nat / pnat / pnat_node.h
1 /*
2  * Copyright (c) 2021 Cisco and/or its affiliates.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at:
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #ifndef included_pnat_node_h
17 #define included_pnat_node_h
18
19 #include "pnat.h"
20 #include <pnat/pnat.api_enum.h>
21 #include <vnet/feature/feature.h>
22 #include <vnet/udp/udp_packet.h>
23 #include <vnet/ip/format.h>
24
25 /* PNAT next-nodes */
26 typedef enum { PNAT_NEXT_DROP, PNAT_N_NEXT } pnat_next_t;
27
28 u8 *format_pnat_match_tuple(u8 *s, va_list *args);
29 u8 *format_pnat_rewrite_tuple(u8 *s, va_list *args);
30 static inline u8 *format_pnat_trace(u8 *s, va_list *args) {
31     CLIB_UNUSED(vlib_main_t * vm) = va_arg(*args, vlib_main_t *);
32     CLIB_UNUSED(vlib_node_t * node) = va_arg(*args, vlib_node_t *);
33     pnat_trace_t *t = va_arg(*args, pnat_trace_t *);
34
35     s = format(s, "pnat: index %d\n", t->pool_index);
36     if (t->pool_index != ~0) {
37         s = format(s, "        match: %U\n", format_pnat_match_tuple,
38                    &t->match);
39         s = format(s, "        rewrite: %U", format_pnat_rewrite_tuple,
40                    &t->rewrite);
41     }
42     return s;
43 }
44
45 /*
46  * Given a packet and rewrite instructions from a translation modify packet.
47  */
48 // TODO: Generalize to write with mask
49 static u32 pnat_rewrite_ip4(u32 pool_index, ip4_header_t *ip) {
50     pnat_main_t *pm = &pnat_main;
51     if (pool_is_free_index(pm->translations, pool_index))
52         return PNAT_ERROR_REWRITE;
53     pnat_translation_t *t = pool_elt_at_index(pm->translations, pool_index);
54
55     ip_csum_t csumd = 0;
56
57     if (t->instructions & PNAT_INSTR_DESTINATION_ADDRESS) {
58         csumd = ip_csum_sub_even(csumd, ip->dst_address.as_u32);
59         csumd = ip_csum_add_even(csumd, t->post_da.as_u32);
60         ip->dst_address = t->post_da;
61     }
62     if (t->instructions & PNAT_INSTR_SOURCE_ADDRESS) {
63         csumd = ip_csum_sub_even(csumd, ip->src_address.as_u32);
64         csumd = ip_csum_add_even(csumd, t->post_sa.as_u32);
65         ip->src_address = t->post_sa;
66     }
67
68     ip_csum_t csum = ip->checksum;
69     csum = ip_csum_sub_even(csum, csumd);
70     ip->checksum = ip_csum_fold(csum);
71     if (ip->checksum == 0xffff)
72         ip->checksum = 0;
73     ASSERT(ip->checksum == ip4_header_checksum(ip));
74
75     u16 plen = clib_net_to_host_u16(ip->length);
76
77     /* Nothing more to do if this is a fragment. */
78     if (ip4_is_fragment(ip))
79         return PNAT_ERROR_NONE;
80
81     /* L4 ports */
82     if (ip->protocol == IP_PROTOCOL_TCP) {
83         /* Assume IP4 header is 20 bytes */
84         if (plen < sizeof(ip4_header_t) + sizeof(tcp_header_t))
85             return PNAT_ERROR_TOOSHORT;
86
87         tcp_header_t *tcp = ip4_next_header(ip);
88         ip_csum_t l4csum = tcp->checksum;
89         if (t->instructions & PNAT_INSTR_DESTINATION_PORT) {
90             l4csum = ip_csum_sub_even(l4csum, tcp->dst_port);
91             l4csum = ip_csum_add_even(l4csum, clib_net_to_host_u16(t->post_dp));
92             tcp->dst_port = clib_net_to_host_u16(t->post_dp);
93         }
94         if (t->instructions & PNAT_INSTR_SOURCE_PORT) {
95             l4csum = ip_csum_sub_even(l4csum, tcp->src_port);
96             l4csum = ip_csum_add_even(l4csum, clib_net_to_host_u16(t->post_sp));
97             tcp->src_port = clib_net_to_host_u16(t->post_sp);
98         }
99         l4csum = ip_csum_sub_even(l4csum, csumd);
100         tcp->checksum = ip_csum_fold(l4csum);
101     } else if (ip->protocol == IP_PROTOCOL_UDP) {
102         if (plen < sizeof(ip4_header_t) + sizeof(udp_header_t))
103             return PNAT_ERROR_TOOSHORT;
104         udp_header_t *udp = ip4_next_header(ip);
105         ip_csum_t l4csum = udp->checksum;
106         if (t->instructions & PNAT_INSTR_DESTINATION_PORT) {
107             l4csum = ip_csum_sub_even(l4csum, udp->dst_port);
108             l4csum = ip_csum_add_even(l4csum, clib_net_to_host_u16(t->post_dp));
109             udp->dst_port = clib_net_to_host_u16(t->post_dp);
110         }
111         if (t->instructions & PNAT_INSTR_SOURCE_PORT) {
112             l4csum = ip_csum_sub_even(l4csum, udp->src_port);
113             l4csum = ip_csum_add_even(l4csum, clib_net_to_host_u16(t->post_sp));
114             udp->src_port = clib_net_to_host_u16(t->post_sp);
115         }
116         if (udp->checksum) {
117             l4csum = ip_csum_sub_even(l4csum, csumd);
118             udp->checksum = ip_csum_fold(l4csum);
119         }
120     }
121     if (t->instructions & PNAT_INSTR_COPY_BYTE) {
122         /* Copy byte from somewhere in packet to elsewhere */
123
124         if (t->to_offset >= plen || t->from_offset > plen) {
125             return PNAT_ERROR_TOOSHORT;
126         }
127         u8 *p = (u8 *)ip;
128         p[t->to_offset] = p[t->from_offset];
129         ip->checksum = ip4_header_checksum(ip);
130         // TODO: L4 checksum
131     }
132     if (t->instructions & PNAT_INSTR_CLEAR_BYTE) {
133         /* Clear byte at offset */
134         u8 *p = (u8 *)ip;
135         p[t->clear_offset] = 0;
136         ip->checksum = ip4_header_checksum(ip);
137         // TODO: L4 checksum
138     }
139
140     return PNAT_ERROR_NONE;
141 }
142
143 /*
144  * Lookup the packet tuple in the flow cache, given the lookup mask.
145  * If a binding is found, rewrite the packet according to instructions,
146  * otherwise follow configured default action (forward, punt or drop)
147  */
148 // TODO: Make use of SVR configurable
149 static_always_inline uword pnat_node_inline(vlib_main_t *vm,
150                                             vlib_node_runtime_t *node,
151                                             vlib_frame_t *frame,
152                                             pnat_attachment_point_t attachment,
153                                             int dir) {
154     pnat_main_t *pm = &pnat_main;
155     u32 n_left_from, *from;
156     u16 nexts[VLIB_FRAME_SIZE] = {0}, *next = nexts;
157     u32 pool_indicies[VLIB_FRAME_SIZE], *pi = pool_indicies;
158     vlib_buffer_t *bufs[VLIB_FRAME_SIZE], **b = bufs;
159     clib_bihash_kv_16_8_t kv, value;
160     ip4_header_t *ip0;
161
162     from = vlib_frame_vector_args(frame);
163     n_left_from = frame->n_vectors;
164     vlib_get_buffers(vm, from, b, n_left_from);
165     pnat_interface_t *interface;
166
167     /* Stage 1: build vector of flow hash (based on lookup mask) */
168     while (n_left_from > 0) {
169         u32 sw_if_index0 = vnet_buffer(b[0])->sw_if_index[dir];
170         u16 sport0 = vnet_buffer(b[0])->ip.reass.l4_src_port;
171         u16 dport0 = vnet_buffer(b[0])->ip.reass.l4_dst_port;
172         u32 iph_offset =
173             dir == VLIB_TX ? vnet_buffer(b[0])->ip.save_rewrite_length : 0;
174         ip0 = (ip4_header_t *)(vlib_buffer_get_current(b[0]) + iph_offset);
175         interface = pnat_interface_by_sw_if_index(sw_if_index0);
176         ASSERT(interface);
177         pnat_mask_fast_t mask = interface->lookup_mask_fast[attachment];
178         pnat_calc_key(sw_if_index0, attachment, ip0->src_address,
179                       ip0->dst_address, ip0->protocol, sport0, dport0, mask,
180                       &kv);
181         /* By default pass packet to next node in the feature chain */
182         vnet_feature_next_u16(next, b[0]);
183
184         if (clib_bihash_search_16_8(&pm->flowhash, &kv, &value) == 0) {
185             /* Cache hit */
186             *pi = value.value;
187             u32 errno0 = pnat_rewrite_ip4(value.value, ip0);
188             if (PREDICT_FALSE(errno0)) {
189                 next[0] = PNAT_NEXT_DROP;
190                 b[0]->error = node->errors[errno0];
191             }
192         } else {
193             /* Cache miss */
194             *pi = ~0;
195         }
196
197         /*next: */
198         next += 1;
199         n_left_from -= 1;
200         b += 1;
201         pi += 1;
202     }
203
204     /* Packet trace */
205     if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE))) {
206         u32 i;
207         b = bufs;
208         pi = pool_indicies;
209         for (i = 0; i < frame->n_vectors; i++) {
210             if (b[0]->flags & VLIB_BUFFER_IS_TRACED) {
211                 pnat_trace_t *t = vlib_add_trace(vm, node, b[0], sizeof(*t));
212                 if (*pi != ~0) {
213                     if (!pool_is_free_index(pm->translations, *pi)) {
214                         pnat_translation_t *tr =
215                             pool_elt_at_index(pm->translations, *pi);
216                         t->match = tr->match;
217                         t->rewrite = tr->rewrite;
218                     }
219                 }
220                 t->pool_index = *pi;
221                 b += 1;
222                 pi += 1;
223             } else
224                 break;
225         }
226     }
227
228     vlib_buffer_enqueue_to_next(vm, node, from, nexts, frame->n_vectors);
229
230     return frame->n_vectors;
231 }
232 #endif