MAP: Convert from DPO to input feature.
[vpp.git] / test / test_ip_ecmp.py
1 #!/usr/bin/env python
2
3 import unittest
4 import random
5 import socket
6 from ipaddress import IPv4Address, IPv6Address, AddressValueError
7
8 from framework import VppTestCase, VppTestRunner
9 from util import ppp
10
11 from scapy.packet import Raw
12 from scapy.layers.l2 import Ether
13 from scapy.layers.inet import IP, UDP
14 from scapy.layers.inet6 import IPv6
15
16 #
17 # The number of packets to sent.
18 #
19 N_PKTS_IN_STREAM = 300
20
21
22 class TestECMP(VppTestCase):
23     """ Equal-cost multi-path routing Test Case """
24
25     @classmethod
26     def setUpClass(cls):
27         """
28         Perform standard class setup (defined by class method setUpClass in
29         class VppTestCase) before running the test case, set test case related
30         variables and configure VPP.
31         """
32         super(TestECMP, cls).setUpClass()
33
34         # create 4 pg interfaces
35         cls.create_pg_interfaces(range(4))
36
37         # packet sizes to test
38         cls.pg_if_packet_sizes = [64, 1500, 9018]
39
40         # setup interfaces
41         for i in cls.pg_interfaces:
42             i.admin_up()
43             i.generate_remote_hosts(5)
44             i.config_ip4()
45             i.resolve_arp()
46             i.configure_ipv4_neighbors()
47             i.config_ip6()
48             i.resolve_ndp()
49             i.configure_ipv6_neighbors()
50
51     @classmethod
52     def tearDownClass(cls):
53         if not cls.vpp_dead:
54             for i in cls.pg_interfaces:
55                 i.unconfig_ip4()
56                 i.unconfig_ip6()
57                 i.admin_down()
58
59         super(TestECMP, cls).tearDownClass()
60
61     def setUp(self):
62         super(TestECMP, self).setUp()
63         self.reset_packet_infos()
64
65     def tearDown(self):
66         """
67         Show various debug prints after each test.
68         """
69         super(TestECMP, self).tearDown()
70         if not self.vpp_dead:
71             self.logger.info(self.vapi.ppcli("show ip arp"))
72             self.logger.info(self.vapi.ppcli("show ip6 neighbors"))
73
74     def get_ip_address(self, ip_addr_start, ip_prefix_len):
75         """
76
77         :param str ip_addr_start: Starting IPv4 or IPv6 address.
78         :param int ip_prefix_len: IP address prefix length.
79         :return: Random IPv4 or IPv6 address from required range.
80         """
81         try:
82             ip_addr = IPv4Address(unicode(ip_addr_start))
83             ip_max_len = 32
84         except (AttributeError, AddressValueError):
85             ip_addr = IPv6Address(unicode(ip_addr_start))
86             ip_max_len = 128
87
88         return str(ip_addr +
89                    random.randint(0, 2 ** (ip_max_len - ip_prefix_len) - 2))
90
91     def create_stream(self, src_if, src_ip_start, dst_ip_start,
92                       ip_prefix_len, packet_sizes, ip_l=IP):
93         """Create input packet stream for defined interfaces.
94
95         :param VppInterface src_if: Source Interface for packet stream.
96         :param str src_ip_start: Starting source IPv4 or IPv6 address.
97         :param str dst_ip_start: Starting destination IPv4 or IPv6 address.
98         :param int ip_prefix_len: IP address prefix length.
99         :param list packet_sizes: packet size to test.
100         :param Scapy ip_l: Required IP layer - IP or IPv6. (Default is IP.)
101         """
102         pkts = []
103         for i in range(0, N_PKTS_IN_STREAM):
104             info = self.create_packet_info(src_if, src_if)
105             payload = self.info_to_payload(info)
106             src_ip = self.get_ip_address(src_ip_start, ip_prefix_len)
107             dst_ip = self.get_ip_address(dst_ip_start, ip_prefix_len)
108             p = (Ether(dst=src_if.local_mac, src=src_if.remote_mac) /
109                  ip_l(src=src_ip, dst=dst_ip) /
110                  UDP(sport=1234, dport=1234) /
111                  Raw(payload))
112             info.data = p.copy()
113             size = random.choice(packet_sizes)
114             self.extend_packet(p, size)
115             pkts.append(p)
116         return pkts
117
118     def verify_capture(self, rx_if, capture, ip_l=IP):
119         """Verify captured input packet stream for defined interface.
120
121         :param VppInterface rx_if: Interface to verify captured packet stream.
122         :param list capture: Captured packet stream.
123         :param Scapy ip_l: Required IP layer - IP or IPv6. (Default is IP.)
124         """
125         self.logger.info("Verifying capture on interface %s" % rx_if.name)
126
127         count = 0
128         host_counters = {}
129         for host_mac in rx_if._hosts_by_mac:
130             host_counters[host_mac] = 0
131
132         for packet in capture:
133             try:
134                 ip_received = packet[ip_l]
135                 payload_info = self.payload_to_info(str(packet[Raw]))
136                 packet_index = payload_info.index
137                 ip_sent = self._packet_infos[packet_index].data[ip_l]
138                 self.logger.debug("Got packet on port %s: src=%u (id=%u)" %
139                                   (rx_if.name, payload_info.src, packet_index))
140                 # Check standard fields
141                 self.assertIn(packet.dst, rx_if._hosts_by_mac,
142                               "Destination MAC address %s shouldn't be routed "
143                               "via interface %s" % (packet.dst, rx_if.name))
144                 self.assertEqual(packet.src, rx_if.local_mac)
145                 self.assertEqual(ip_received.src, ip_sent.src)
146                 self.assertEqual(ip_received.dst, ip_sent.dst)
147                 host_counters[packet.dst] += 1
148                 self._packet_infos.pop(packet_index)
149
150             except:
151                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
152                 raise
153
154         # We expect packet routed via all host of pg interface
155         for host_mac in host_counters:
156             nr = host_counters[host_mac]
157             self.assertNotEqual(
158                 nr, 0, "No packet routed via host %s" % host_mac)
159             self.logger.info("%u packets routed via host %s of %s interface" %
160                              (nr, host_mac, rx_if.name))
161             count += nr
162         self.logger.info("Total amount of %u packets routed via %s interface" %
163                          (count, rx_if.name))
164
165         return count
166
167     def create_ip_routes(self, dst_ip_net, dst_prefix_len, is_ipv6=0):
168         """
169         Create IP routes for defined destination IP network.
170
171         :param str dst_ip_net: Destination IP network.
172         :param int dst_prefix_len: IP address prefix length.
173         :param int is_ipv6: 0 if an ip4 route, else ip6
174         """
175         af = socket.AF_INET if is_ipv6 == 0 else socket.AF_INET6
176         dst_ip = socket.inet_pton(af, dst_ip_net)
177
178         for pg_if in self.pg_interfaces[1:]:
179             for nh_host in pg_if.remote_hosts:
180                 nh_host_ip = nh_host.ip4 if is_ipv6 == 0 else nh_host.ip6
181                 next_hop_address = socket.inet_pton(af, nh_host_ip)
182                 next_hop_sw_if_index = pg_if.sw_if_index
183                 self.vapi.ip_add_del_route(
184                     dst_ip, dst_prefix_len, next_hop_address,
185                     next_hop_sw_if_index=next_hop_sw_if_index,
186                     is_ipv6=is_ipv6, is_multipath=1)
187                 self.logger.info("Route via %s on %s created" %
188                                  (nh_host_ip, pg_if.name))
189
190         self.logger.debug(self.vapi.ppcli("show ip fib"))
191         self.logger.debug(self.vapi.ppcli("show ip6 fib"))
192
193     def test_ip_ecmp(self):
194         """ IP equal-cost multi-path routing test """
195
196         src_ip_net = '16.0.0.1'
197         dst_ip_net = '32.0.0.1'
198         ip_prefix_len = 24
199
200         self.create_ip_routes(dst_ip_net, ip_prefix_len)
201
202         pkts = self.create_stream(self.pg0, src_ip_net, dst_ip_net,
203                                   ip_prefix_len, self.pg_if_packet_sizes)
204         self.pg0.add_stream(pkts)
205
206         self.pg_enable_capture(self.pg_interfaces)
207         self.pg_start()
208
209         # We expect packets on pg1, pg2 and pg3, but not on pg0
210         rx_count = 0
211         for pg_if in self.pg_interfaces[1:]:
212             capture = pg_if._get_capture(timeout=1)
213             self.assertNotEqual(
214                 len(capture), 0, msg="No packets captured on %s" % pg_if.name)
215             rx_count += self.verify_capture(pg_if, capture)
216         self.pg0.assert_nothing_captured(remark="IP packets forwarded on pg0")
217
218         # Check that all packets were forwarded via pg1, pg2 and pg3
219         self.assertEqual(rx_count, len(pkts))
220
221     def test_ip6_ecmp(self):
222         """ IPv6 equal-cost multi-path routing test """
223
224         src_ip_net = '3ffe:51::1'
225         dst_ip_net = '3ffe:71::1'
226         ip_prefix_len = 64
227
228         self.create_ip_routes(dst_ip_net, ip_prefix_len, is_ipv6=1)
229
230         pkts = self.create_stream(
231             self.pg0, src_ip_net, dst_ip_net,
232             ip_prefix_len, self.pg_if_packet_sizes, ip_l=IPv6)
233         self.pg0.add_stream(pkts)
234
235         self.pg_enable_capture(self.pg_interfaces)
236         self.pg_start()
237
238         # We expect packets on pg1, pg2 and pg3, but not on pg0
239         rx_count = 0
240         for pg_if in self.pg_interfaces[1:]:
241             capture = pg_if._get_capture(timeout=1)
242             self.assertNotEqual(
243                 len(capture), 0, msg="No packets captured on %s" % pg_if.name)
244             rx_count += self.verify_capture(pg_if, capture, ip_l=IPv6)
245         self.pg0.assert_nothing_captured(remark="IP packets forwarded on pg0")
246
247         # Check that all packets were forwarded via pg1, pg2 and pg3
248         self.assertEqual(rx_count, len(pkts))
249
250
251 if __name__ == '__main__':
252     unittest.main(testRunner=VppTestRunner)