Tests: Refactor payload_to_info()
[vpp.git] / test / test_ip_ecmp.py
1 #!/usr/bin/env python
2
3 import unittest
4 import random
5 import socket
6 from ipaddress import IPv4Address, IPv6Address, AddressValueError
7
8
9 from framework import VppTestCase, VppTestRunner
10 from util import ppp
11
12 from scapy.packet import Raw
13 from scapy.layers.l2 import Ether
14 from scapy.layers.inet import IP, UDP
15 from scapy.layers.inet6 import IPv6
16
17 try:
18     text_type = unicode
19 except NameError:
20     text_type = str
21
22 #
23 # The number of packets to sent.
24 #
25 N_PKTS_IN_STREAM = 300
26
27
28 class TestECMP(VppTestCase):
29     """ Equal-cost multi-path routing Test Case """
30
31     @classmethod
32     def setUpClass(cls):
33         """
34         Perform standard class setup (defined by class method setUpClass in
35         class VppTestCase) before running the test case, set test case related
36         variables and configure VPP.
37         """
38         super(TestECMP, cls).setUpClass()
39
40         # create 4 pg interfaces
41         cls.create_pg_interfaces(range(4))
42
43         # packet sizes to test
44         cls.pg_if_packet_sizes = [64, 1500, 9018]
45
46         # setup interfaces
47         for i in cls.pg_interfaces:
48             i.admin_up()
49             i.generate_remote_hosts(5)
50             i.config_ip4()
51             i.resolve_arp()
52             i.configure_ipv4_neighbors()
53             i.config_ip6()
54             i.resolve_ndp()
55             i.configure_ipv6_neighbors()
56
57     @classmethod
58     def tearDownClass(cls):
59         if not cls.vpp_dead:
60             for i in cls.pg_interfaces:
61                 i.unconfig_ip4()
62                 i.unconfig_ip6()
63                 i.admin_down()
64
65         super(TestECMP, cls).tearDownClass()
66
67     def setUp(self):
68         super(TestECMP, self).setUp()
69         self.reset_packet_infos()
70
71     def tearDown(self):
72         """
73         Show various debug prints after each test.
74         """
75         super(TestECMP, self).tearDown()
76         if not self.vpp_dead:
77             self.logger.info(self.vapi.ppcli("show ip arp"))
78             self.logger.info(self.vapi.ppcli("show ip6 neighbors"))
79
80     def get_ip_address(self, ip_addr_start, ip_prefix_len):
81         """
82
83         :param str ip_addr_start: Starting IPv4 or IPv6 address.
84         :param int ip_prefix_len: IP address prefix length.
85         :return: Random IPv4 or IPv6 address from required range.
86         """
87         try:
88             ip_addr = IPv4Address(text_type(ip_addr_start))
89             ip_max_len = 32
90         except (AttributeError, AddressValueError):
91             ip_addr = IPv6Address(text_type(ip_addr_start))
92             ip_max_len = 128
93
94         return str(ip_addr +
95                    random.randint(0, 2 ** (ip_max_len - ip_prefix_len) - 2))
96
97     def create_stream(self, src_if, src_ip_start, dst_ip_start,
98                       ip_prefix_len, packet_sizes, ip_l=IP):
99         """Create input packet stream for defined interfaces.
100
101         :param VppInterface src_if: Source Interface for packet stream.
102         :param str src_ip_start: Starting source IPv4 or IPv6 address.
103         :param str dst_ip_start: Starting destination IPv4 or IPv6 address.
104         :param int ip_prefix_len: IP address prefix length.
105         :param list packet_sizes: packet size to test.
106         :param Scapy ip_l: Required IP layer - IP or IPv6. (Default is IP.)
107         """
108         pkts = []
109         for i in range(0, N_PKTS_IN_STREAM):
110             info = self.create_packet_info(src_if, src_if)
111             payload = self.info_to_payload(info)
112             src_ip = self.get_ip_address(src_ip_start, ip_prefix_len)
113             dst_ip = self.get_ip_address(dst_ip_start, ip_prefix_len)
114             p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
115                  ip_l(src=src_ip, dst=dst_ip) /
116                  UDP(sport=1234, dport=1234) /
117                  Raw(payload))
118             info.data = p.copy()
119             size = random.choice(packet_sizes)
120             self.extend_packet(p, size)
121             pkts.append(p)
122         return pkts
123
124     def verify_capture(self, rx_if, capture, ip_l=IP):
125         """Verify captured input packet stream for defined interface.
126
127         :param VppInterface rx_if: Interface to verify captured packet stream.
128         :param list capture: Captured packet stream.
129         :param Scapy ip_l: Required IP layer - IP or IPv6. (Default is IP.)
130         """
131         self.logger.info("Verifying capture on interface %s" % rx_if.name)
132
133         count = 0
134         host_counters = {}
135         for host_mac in rx_if._hosts_by_mac:
136             host_counters[host_mac] = 0
137
138         for packet in capture:
139             try:
140                 ip_received = packet[ip_l]
141                 payload_info = self.payload_to_info(packet[Raw])
142                 packet_index = payload_info.index
143                 ip_sent = self._packet_infos[packet_index].data[ip_l]
144                 self.logger.debug("Got packet on port %s: src=%u (id=%u)" %
145                                   (rx_if.name, payload_info.src, packet_index))
146                 # Check standard fields
147                 self.assertIn(packet.dst, rx_if._hosts_by_mac,
148                               "Destination MAC address %s shouldn't be routed "
149                               "via interface %s" % (packet.dst, rx_if.name))
150                 self.assertEqual(packet.src, rx_if.local_mac)
151                 self.assertEqual(ip_received.src, ip_sent.src)
152                 self.assertEqual(ip_received.dst, ip_sent.dst)
153                 host_counters[packet.dst] += 1
154                 self._packet_infos.pop(packet_index)
155
156             except:
157                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
158                 raise
159
160         # We expect packet routed via all host of pg interface
161         for host_mac in host_counters:
162             nr = host_counters[host_mac]
163             self.assertNotEqual(
164                 nr, 0, "No packet routed via host %s" % host_mac)
165             self.logger.info("%u packets routed via host %s of %s interface" %
166                              (nr, host_mac, rx_if.name))
167             count += nr
168         self.logger.info("Total amount of %u packets routed via %s interface" %
169                          (count, rx_if.name))
170
171         return count
172
173     def create_ip_routes(self, dst_ip_net, dst_prefix_len, is_ipv6=0):
174         """
175         Create IP routes for defined destination IP network.
176
177         :param str dst_ip_net: Destination IP network.
178         :param int dst_prefix_len: IP address prefix length.
179         :param int is_ipv6: 0 if an ip4 route, else ip6
180         """
181         af = socket.AF_INET if is_ipv6 == 0 else socket.AF_INET6
182         dst_ip = socket.inet_pton(af, dst_ip_net)
183
184         for pg_if in self.pg_interfaces[1:]:
185             for nh_host in pg_if.remote_hosts:
186                 nh_host_ip = nh_host.ip4 if is_ipv6 == 0 else nh_host.ip6
187                 next_hop_address = socket.inet_pton(af, nh_host_ip)
188                 next_hop_sw_if_index = pg_if.sw_if_index
189                 self.vapi.ip_add_del_route(
190                     dst_ip, dst_prefix_len, next_hop_address,
191                     next_hop_sw_if_index=next_hop_sw_if_index,
192                     is_ipv6=is_ipv6, is_multipath=1)
193                 self.logger.info("Route via %s on %s created" %
194                                  (nh_host_ip, pg_if.name))
195
196         self.logger.debug(self.vapi.ppcli("show ip fib"))
197         self.logger.debug(self.vapi.ppcli("show ip6 fib"))
198
199     def test_ip_ecmp(self):
200         """ IP equal-cost multi-path routing test """
201
202         src_ip_net = '16.0.0.1'
203         dst_ip_net = '32.0.0.1'
204         ip_prefix_len = 24
205
206         self.create_ip_routes(dst_ip_net, ip_prefix_len)
207
208         pkts = self.create_stream(self.pg0, src_ip_net, dst_ip_net,
209                                   ip_prefix_len, self.pg_if_packet_sizes)
210         self.pg0.add_stream(pkts)
211
212         self.pg_enable_capture(self.pg_interfaces)
213         self.pg_start()
214
215         # We expect packets on pg1, pg2 and pg3, but not on pg0
216         rx_count = 0
217         for pg_if in self.pg_interfaces[1:]:
218             capture = pg_if._get_capture(timeout=1)
219             self.assertNotEqual(
220                 len(capture), 0, msg="No packets captured on %s" % pg_if.name)
221             rx_count += self.verify_capture(pg_if, capture)
222         self.pg0.assert_nothing_captured(remark="IP packets forwarded on pg0")
223
224         # Check that all packets were forwarded via pg1, pg2 and pg3
225         self.assertEqual(rx_count, len(pkts))
226
227     def test_ip6_ecmp(self):
228         """ IPv6 equal-cost multi-path routing test """
229
230         src_ip_net = '3ffe:51::1'
231         dst_ip_net = '3ffe:71::1'
232         ip_prefix_len = 64
233
234         self.create_ip_routes(dst_ip_net, ip_prefix_len, is_ipv6=1)
235
236         pkts = self.create_stream(
237             self.pg0, src_ip_net, dst_ip_net,
238             ip_prefix_len, self.pg_if_packet_sizes, ip_l=IPv6)
239         self.pg0.add_stream(pkts)
240
241         self.pg_enable_capture(self.pg_interfaces)
242         self.pg_start()
243
244         # We expect packets on pg1, pg2 and pg3, but not on pg0
245         rx_count = 0
246         for pg_if in self.pg_interfaces[1:]:
247             capture = pg_if._get_capture(timeout=1)
248             self.assertNotEqual(
249                 len(capture), 0, msg="No packets captured on %s" % pg_if.name)
250             rx_count += self.verify_capture(pg_if, capture, ip_l=IPv6)
251         self.pg0.assert_nothing_captured(remark="IP packets forwarded on pg0")
252
253         # Check that all packets were forwarded via pg1, pg2 and pg3
254         self.assertEqual(rx_count, len(pkts))
255
256
257 if __name__ == '__main__':
258     unittest.main(testRunner=VppTestRunner)