nat: pnat copy and clear byte instructions
[vpp.git] / src / plugins / nat / pnat / tests / pnat_test.c
similarity index 62%
rename from src/plugins/nat/pnat/pnat_test.c
rename to src/plugins/nat/pnat/tests/pnat_test.c
index 762b4bd..ab55e7e 100644 (file)
@@ -23,7 +23,7 @@
 #include <vppinfra/bihash_16_8.h>
 #include <vppinfra/bihash_template.c>
 #include <vnet/fib/ip4_fib.h>
-#include "pnat.h"
+#include "../pnat.h"
 #include <pnat/pnat.api_enum.h> /* For error counters */
 #include <arpa/inet.h>
 #include "pnat_test_stubs.h"
@@ -50,10 +50,10 @@ static u32 *buffer_init(u32 *vector, int count) {
     }
     return vector;
 }
-#define PNAT_TEST_DEBUG 0
 
 u32 *results_bi = 0; /* global vector of result buffers */
 u16 *results_next = 0;
+
 vlib_node_runtime_t *node;
 
 #define log_info(M, ...)                                                       \
@@ -61,13 +61,18 @@ vlib_node_runtime_t *node;
 #define log_error(M, ...)                                                      \
     fprintf(stderr, "\033[31;1m[ERROR] (%s:%d:) " M "\033[0m\n", __FILE__,     \
             __LINE__, ##__VA_ARGS__)
-#define test_assert(A, M, ...)                                                 \
+#define test_assert_log(A, M, ...)                                             \
     if (!(A)) {                                                                \
         log_error(M, ##__VA_ARGS__);                                           \
         assert(A);                                                             \
     } else {                                                                   \
         log_info(M, ##__VA_ARGS__);                                            \
     }
+#define test_assert(A, M, ...)                                                 \
+    if (!(A)) {                                                                \
+        log_error(M, ##__VA_ARGS__);                                           \
+        assert(A);                                                             \
+    }
 
 /*
  * Always return the frame of generated packets
@@ -113,53 +118,32 @@ vlib_buffer_t *test_vlib_get_buffer(u32 bi) {
 }
 
 /* Must be included here to allow the above functions to override */
-#include "pnat_node.h"
+#include "../pnat_node.h"
 
 /*** TESTS ***/
 
+typedef struct {
+    char *name;
+    int nsend;
+    char *send;
+    int nexpect;
+    char *expect;
+    u32 expect_next_index;
+} test_t;
+#include "test_packets.h"
+
+/* Rules */
 typedef struct {
     char *src;
     char *dst;
     u8 proto;
     u16 sport;
     u16 dport;
+    u8 from_offset;
+    u8 to_offset;
+    u8 clear_offset;
 } test_5tuple_t;
 
-typedef struct {
-    char *name;
-    test_5tuple_t send;
-    test_5tuple_t expect;
-    u32 expect_next_index;
-} test_t;
-
-test_t tests[] = {
-    {
-        .name = "da rewritten",
-        .send = {"1.1.1.1", "2.2.2.2", 17, 80, 6871},
-        .expect = {"1.1.1.1", "1.2.3.4", 17, 80, 6871},
-        .expect_next_index = NEXT_PASSTHROUGH,
-    },
-    {
-        .name = "unchanged",
-        .send = {"1.1.1.1", "2.2.2.2", 17, 80, 8080},
-        .expect = {"1.1.1.1", "2.2.2.2", 17, 80, 8080},
-        .expect_next_index = NEXT_PASSTHROUGH,
-    },
-    {
-        .name = "tcp da",
-        .send = {"1.1.1.1", "2.2.2.2", 6, 80, 6871},
-        .expect = {"1.1.1.1", "1.2.3.4", 6, 80, 6871},
-        .expect_next_index = NEXT_PASSTHROUGH,
-    },
-    {
-        .name = "tcp da ports",
-        .send = {"1.1.1.1", "2.2.2.2", 6, 80, 6872},
-        .expect = {"1.1.1.1", "1.2.3.4", 6, 53, 8000},
-        .expect_next_index = NEXT_PASSTHROUGH,
-    },
-};
-
-/* Rules */
 typedef struct {
     test_5tuple_t match;
     test_5tuple_t rewrite;
@@ -188,55 +172,55 @@ rule_t rules[] = {
         .rewrite = {.dst = "1.2.3.4", .sport = 53, .dport = 8000},
         .in = true,
     },
+    {
+        .match = {.dst = "2.2.2.2", .proto = 17, .dport = 6874},
+        .rewrite = {.from_offset = 15, .to_offset = 19},
+        .in = true,
+    },
+    {
+        .match = {.dst = "2.2.2.2", .proto = 17, .dport = 6875},
+        .rewrite = {.from_offset = 15, .to_offset = 50},
+        .in = true,
+    },
+    {
+        .match = {.dst = "2.2.2.2", .proto = 17, .dport = 6877},
+        .rewrite = {.dst = "1.2.3.4", .from_offset = 15, .to_offset = 35},
+        .in = true,
+    },
+    {
+        .match = {.dst = "2.2.2.2", .proto = 17, .dport = 6876},
+        .rewrite = {.clear_offset = 22},
+        .in = true,
+    },
 };
 
-static int fill_packets(vlib_main_t *vm, vlib_buffer_t *b,
-                        test_5tuple_t *test) {
+static int fill_packets(vlib_main_t *vm, vlib_buffer_t *b, int n, char *test) {
     b->flags |= VLIB_BUFFER_IS_TRACED;
 
     ip4_header_t *ip = (ip4_header_t *)vlib_buffer_get_current(b);
-    memset(ip, 0, sizeof(*ip));
-    ip->ip_version_and_header_length = 0x45;
-    ip->ttl = 64;
-    inet_pton(AF_INET, test->src, &ip->src_address.as_u32);
-    inet_pton(AF_INET, test->dst, &ip->dst_address.as_u32);
-    ip->protocol = test->proto;
-
-    if (test->proto == IP_PROTOCOL_UDP) {
+
+    memcpy(ip, test, n);
+
+    /* Do the work of SVR */
+    vnet_buffer(b)->ip.reass.l4_src_port = 0;
+    vnet_buffer(b)->ip.reass.l4_dst_port = 0;
+    b->current_length = n;
+
+    if (ip4_is_fragment(ip))
+        return 0;
+    if (ip->protocol == IP_PROTOCOL_UDP) {
         udp_header_t *udp = ip4_next_header(ip);
-        memset(udp, 0, sizeof(*udp));
-        udp->dst_port = htons(test->dport);
-        udp->src_port = htons(test->sport);
-        udp->length = htons(8);
         vnet_buffer(b)->ip.reass.l4_src_port = udp->src_port;
         vnet_buffer(b)->ip.reass.l4_dst_port = udp->dst_port;
-        b->current_length = 28;
-        ip->length = htons(b->current_length);
-        ip->checksum = ip4_header_checksum(ip);
-        udp->checksum = ip4_tcp_udp_compute_checksum(vm, b, ip);
-    } else if (test->proto == IP_PROTOCOL_TCP) {
+    } else if (ip->protocol == IP_PROTOCOL_TCP) {
         tcp_header_t *tcp = ip4_next_header(ip);
-        memset(tcp, 0, sizeof(*tcp));
-        tcp->dst_port = htons(test->dport);
-        tcp->src_port = htons(test->sport);
         vnet_buffer(b)->ip.reass.l4_src_port = tcp->src_port;
         vnet_buffer(b)->ip.reass.l4_dst_port = tcp->dst_port;
-        b->current_length = sizeof(ip4_header_t) + sizeof(tcp_header_t);
-        ip->length = htons(b->current_length);
-        ip->checksum = ip4_header_checksum(ip);
-        tcp->checksum = ip4_tcp_udp_compute_checksum(vm, b, ip);
-    } else {
-        b->current_length = sizeof(ip4_header_t);
-        ip->length = htons(b->current_length);
-        ip->checksum = ip4_header_checksum(ip);
-        vnet_buffer(b)->ip.reass.l4_src_port = 0;
-        vnet_buffer(b)->ip.reass.l4_dst_port = 0;
     }
-
     return 0;
 }
 
-static void ruleto5tuple(test_5tuple_t *r, pnat_5tuple_t *t) {
+static void ruletomatch(test_5tuple_t *r, pnat_match_tuple_t *t) {
     if (r->src) {
         inet_pton(AF_INET, r->src, &t->src);
         t->mask |= PNAT_SA;
@@ -256,12 +240,40 @@ static void ruleto5tuple(test_5tuple_t *r, pnat_5tuple_t *t) {
     t->proto = r->proto;
 }
 
+static void ruletorewrite(test_5tuple_t *r, pnat_rewrite_tuple_t *t) {
+    if (r->src) {
+        inet_pton(AF_INET, r->src, &t->src);
+        t->mask |= PNAT_SA;
+    }
+    if (r->dst) {
+        inet_pton(AF_INET, r->dst, &t->dst);
+        t->mask |= PNAT_DA;
+    }
+    if (r->dport) {
+        t->dport = r->dport;
+        t->mask |= PNAT_DPORT;
+    }
+    if (r->sport) {
+        t->sport = r->sport;
+        t->mask |= PNAT_SPORT;
+    }
+    if (r->to_offset || r->from_offset) {
+        t->to_offset = r->to_offset;
+        t->from_offset = r->from_offset;
+        t->mask |= PNAT_COPY_BYTE;
+    }
+    if (r->clear_offset) {
+        t->clear_offset = r->clear_offset;
+        t->mask |= PNAT_CLEAR_BYTE;
+    }
+}
+
 static void add_translation(rule_t *r) {
-    pnat_5tuple_t match = {0};
-    pnat_5tuple_t rewrite = {0};
+    pnat_match_tuple_t match = {0};
+    pnat_rewrite_tuple_t rewrite = {0};
 
-    ruleto5tuple(&r->match, &match);
-    ruleto5tuple(&r->rewrite, &rewrite);
+    ruletomatch(&r->match, &match);
+    ruletorewrite(&r->rewrite, &rewrite);
 
     int rv = pnat_binding_add(&match, &rewrite, &r->index);
     assert(rv == 0);
@@ -287,24 +299,47 @@ static void validate_packet(vlib_main_t *vm, char *name, u32 bi,
     ip4_header_t *expected_ip =
         (ip4_header_t *)vlib_buffer_get_current(expected_b);
 
-#if PNAT_TEST_DEBUG
-    clib_warning("Received packet: %U", format_ip4_header, ip, 20);
-    clib_warning("Expected packet: %U", format_ip4_header, expected_ip, 20);
-    tcp_header_t *tcp = ip4_next_header(ip);
-    clib_warning("IP: %U TCP: %U", format_ip4_header, ip, sizeof(*ip),
-                 format_tcp_header, tcp, sizeof(*tcp));
-    tcp = ip4_next_header(expected_ip);
-    clib_warning("IP: %U TCP: %U", format_ip4_header, expected_ip, sizeof(*ip),
-                 format_tcp_header, tcp, sizeof(*tcp));
-#endif
-
-    u32 flags = ip4_tcp_udp_validate_checksum(vm, b);
-    assert((flags & VNET_BUFFER_F_L4_CHECKSUM_CORRECT) != 0);
-    flags = ip4_tcp_udp_validate_checksum(vm, expected_b);
-    assert((flags & VNET_BUFFER_F_L4_CHECKSUM_CORRECT) != 0);
-    assert(b->current_length == expected_b->current_length);
-
-    test_assert(memcmp(ip, expected_ip, b->current_length) == 0, "%s", name);
+    if (ip->protocol == IP_PROTOCOL_UDP || ip->protocol == IP_PROTOCOL_TCP) {
+        u32 flags = ip4_tcp_udp_validate_checksum(vm, b);
+        test_assert((flags & VNET_BUFFER_F_L4_CHECKSUM_CORRECT) != 0, "%s",
+                    name);
+        flags = ip4_tcp_udp_validate_checksum(vm, expected_b);
+        test_assert((flags & VNET_BUFFER_F_L4_CHECKSUM_CORRECT) != 0, "%s",
+                    name);
+    }
+    test_assert(b->current_length == expected_b->current_length, "%s %d vs %d",
+                name, b->current_length, expected_b->current_length);
+
+    if (memcmp(ip, expected_ip, b->current_length) != 0) {
+        if (ip->protocol == IP_PROTOCOL_UDP) {
+            udp_header_t *udp = ip4_next_header(ip);
+            clib_warning("Received: IP: %U UDP: %U", format_ip4_header, ip,
+                         sizeof(*ip), format_udp_header, udp, sizeof(*udp));
+            udp = ip4_next_header(expected_ip);
+            clib_warning("%U", format_hexdump, ip, b->current_length);
+            clib_warning("Expected: IP: %U UDP: %U", format_ip4_header,
+                         expected_ip, sizeof(*ip), format_udp_header, udp,
+                         sizeof(*udp));
+            clib_warning("%U", format_hexdump, expected_ip,
+                         expected_b->current_length);
+        } else if (ip->protocol == IP_PROTOCOL_TCP) {
+            tcp_header_t *tcp = ip4_next_header(ip);
+            clib_warning("Received IP: %U TCP: %U", format_ip4_header, ip,
+                         sizeof(*ip), format_tcp_header, tcp, sizeof(*tcp));
+            tcp = ip4_next_header(expected_ip);
+            clib_warning("Expected IP: %U TCP: %U", format_ip4_header,
+                         expected_ip, sizeof(*ip), format_tcp_header, tcp,
+                         sizeof(*tcp));
+        } else {
+            clib_warning("Received: IP: %U", format_ip4_header, ip,
+                         sizeof(*ip));
+            clib_warning("Expected: IP: %U", format_ip4_header, expected_ip,
+                         sizeof(*ip));
+        }
+        test_assert_log(0, "%s", name);
+    } else {
+        test_assert_log(1, "%s", name);
+    }
 }
 
 extern vlib_node_registration_t pnat_input_node;
@@ -317,8 +352,9 @@ static void test_table(test_t *t, int no_tests) {
     /* Generate packet data */
     for (i = 0; i < no_tests; i++) {
         // create input buffer(s)
-        fill_packets(vm, (vlib_buffer_t *)&buffers[i], &t[i].send);
-        fill_packets(vm, (vlib_buffer_t *)&expected[i], &t[i].expect);
+        fill_packets(vm, (vlib_buffer_t *)&buffers[i], t[i].nsend, t[i].send);
+        fill_packets(vm, (vlib_buffer_t *)&expected[i], t[i].nexpect,
+                     t[i].expect);
     }
 
     /* send packets through graph node */
@@ -329,7 +365,7 @@ static void test_table(test_t *t, int no_tests) {
 
     /* verify tests */
     for (i = 0; i < no_tests; i++) {
-        assert(t[i].expect_next_index == results_next[i]);
+        test_assert(t[i].expect_next_index == results_next[i], "%s", t[i].name);
         validate_packet(vm, t[i].name, results_bi[i],
                         (vlib_buffer_t *)&expected[i]);
     }
@@ -337,7 +373,7 @@ static void test_table(test_t *t, int no_tests) {
     vec_free(results_bi);
 }
 
-static void test_performance(void) {
+void test_performance(void) {
     pnat_main_t *pm = &pnat_main;
     int i;
     vlib_main_t *vm = &vlib_global_main;
@@ -347,12 +383,13 @@ static void test_performance(void) {
     }
     assert(pool_elts(pm->translations) == sizeof(rules) / sizeof(rules[0]));
 
-    int no_tests = sizeof(tests) / sizeof(tests[0]);
+    int no_tests = sizeof(tests_packets) / sizeof(tests_packets[0]);
     /* Generate packet data */
     for (i = 0; i < VLIB_FRAME_SIZE; i++) {
         // create input buffer(s)
         fill_packets(vm, (vlib_buffer_t *)&buffers[i],
-                     &tests[i % no_tests].send);
+                     tests_packets[i % no_tests].nsend,
+                     tests_packets[i % no_tests].send);
         // fill_packets(vm, (vlib_buffer_t *)&expected[i], &tests[i %
         // no_tests].expect);
     }
@@ -382,7 +419,7 @@ static void test_performance(void) {
     assert(pool_elts(pm->interfaces) == 0);
 }
 
-static void test_packets(void) {
+void test_packets(void) {
     pnat_main_t *pm = &pnat_main;
     int i;
     for (i = 0; i < sizeof(rules) / sizeof(rules[0]); i++) {
@@ -390,7 +427,7 @@ static void test_packets(void) {
     }
     assert(pool_elts(pm->translations) == sizeof(rules) / sizeof(rules[0]));
 
-    test_table(tests, sizeof(tests) / sizeof(tests[0]));
+    test_table(tests_packets, sizeof(tests_packets) / sizeof(tests_packets[0]));
 
     for (i = 0; i < sizeof(rules) / sizeof(rules[0]); i++) {
         del_translation(&rules[i]);
@@ -398,6 +435,7 @@ static void test_packets(void) {
     assert(pool_elts(pm->translations) == 0);
     assert(pool_elts(pm->interfaces) == 0);
 }
+
 static void test_attach(void) {
     pnat_attachment_point_t attachment = PNAT_IP4_INPUT;
     u32 binding_index = 0;
@@ -408,8 +446,8 @@ static void test_attach(void) {
     rv = pnat_binding_detach(sw_if_index, attachment, 1234);
     test_assert(rv == -1, "binding_detach - nothing to detach");
 
-    pnat_5tuple_t match = {.mask = PNAT_SA};
-    pnat_5tuple_t rewrite = {.mask = PNAT_SA};
+    pnat_match_tuple_t match = {.mask = PNAT_SA};
+    pnat_rewrite_tuple_t rewrite = {.mask = PNAT_SA};
     rv = pnat_binding_add(&match, &rewrite, &binding_index);
     assert(rv == 0);
 
@@ -441,24 +479,17 @@ static void test_del_before_detach(void) {
     int rv = pnat_binding_del(binding_index);
     assert(rv == 0);
 
-    test_t test = {
-        .name = "hit missing rule",
-        .send = {"1.1.1.1", "123.123.123.123", 17, 80, 6871},
-        .expect = {"1.1.1.1", "123.123.123.123", 17, 80, 6871},
-        .expect_next_index = PNAT_NEXT_DROP,
-    };
-
-    test_table(&test, 1);
+    test_table(&tests_missing_rule[0], 1);
 
     /* For now if you have deleted before detach, can't find key */
     rv = pnat_binding_detach(sw_if_index, attachment, binding_index);
     test_assert(rv == -1, "binding_detach - failure");
 
     /* Re-add the rule and try again */
-    pnat_5tuple_t match = {0};
-    pnat_5tuple_t rewrite = {0};
-    ruleto5tuple(&rule.match, &match);
-    ruleto5tuple(&rule.rewrite, &rewrite);
+    pnat_match_tuple_t match = {0};
+    pnat_rewrite_tuple_t rewrite = {0};
+    ruletomatch(&rule.match, &match);
+    ruletorewrite(&rule.rewrite, &rewrite);
     rv = pnat_binding_add(&match, &rewrite, &binding_index);
     assert(rv == 0);
     rv = pnat_binding_detach(sw_if_index, attachment, binding_index);
@@ -467,11 +498,56 @@ static void test_del_before_detach(void) {
     assert(rv == 0);
 }
 
-static void test_api(void) {
+void test_api(void) {
     test_attach();
     test_del_before_detach();
 }
 
+void test_checksum(void) {
+    int i;
+    vlib_main_t *vm = &vlib_global_main;
+    pnat_main_t *pm = &pnat_main;
+
+    test_t test = {
+        .name = "checksum",
+        .nsend = 28,
+        .send =
+            (char[]){0x45, 0x00, 0x00, 0x1c, 0x00, 0x01, 0x00, 0x00, 0x40, 0x11,
+                     0x74, 0xcb, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02,
+                     0x00, 0x50, 0x1a, 0xd7, 0x00, 0x08, 0xde, 0xb1},
+    };
+
+    for (i = 0; i < sizeof(rules) / sizeof(rules[0]); i++) {
+        add_translation(&rules[i]);
+    }
+    assert(pool_elts(pm->translations) == sizeof(rules) / sizeof(rules[0]));
+
+    /* send packets through graph node */
+    vlib_frame_t frame = {.n_vectors = 1};
+    node->flags |= VLIB_NODE_FLAG_TRACE;
+
+    ip4_header_t *ip =
+        (ip4_header_t *)vlib_buffer_get_current((vlib_buffer_t *)&buffers[0]);
+
+    for (i = 0; i < 65535; i++) {
+
+        /* Get a buffer. Loop through 64K variations of it to check checksum */
+        memset(&buffers[0], 0, 2048);
+        fill_packets(vm, (vlib_buffer_t *)&buffers[0], test.nsend, test.send);
+
+        ip->src_address.as_u32 = i;
+        ip->checksum = 0;
+        ip->checksum = ip4_header_checksum(ip);
+        pnat_node_inline(vm, node, &frame, PNAT_IP4_INPUT, VLIB_RX);
+    }
+
+    test_assert_log(1, "%s", test.name);
+
+    for (i = 0; i < sizeof(rules) / sizeof(rules[0]); i++) {
+        del_translation(&rules[i]);
+    }
+}
+
 /*
  * Unit testing:
  * 1) Table of packets and expected outcomes. Run through
@@ -497,8 +573,15 @@ int main(int argc, char **argv) {
 
     /* Test API */
     test_api();
-
     test_packets();
-
+    test_checksum();
     test_performance();
 }
+
+/*
+ * NEW TESTS:
+ * - Chained buffers. Only do rewrite in first buffer
+ * - No interface. Can that really happen?
+ * - IP length shorter than buffer.
+ * - IP length longer than buffer.
+ */