VPP-1522: harden reassembly code
[vpp.git] / test / test_reassembly.py
1 #!/usr/bin/env python
2
3 import six
4 import unittest
5 from random import shuffle
6
7 from framework import VppTestCase, VppTestRunner, is_skip_aarch64_set,\
8     is_platform_aarch64
9
10 from scapy.packet import Raw
11 from scapy.layers.l2 import Ether, GRE
12 from scapy.layers.inet import IP, UDP, ICMP
13 from util import ppp, fragment_rfc791, fragment_rfc8200
14 from scapy.layers.inet6 import IPv6, IPv6ExtHdrFragment, ICMPv6ParamProblem,\
15     ICMPv6TimeExceeded
16 from vpp_gre_interface import VppGreInterface, VppGre6Interface
17 from vpp_ip import DpoProto
18 from vpp_ip_route import VppIpRoute, VppRoutePath
19
20 test_packet_count = 257
21
22
23 class TestIPv4Reassembly(VppTestCase):
24     """ IPv4 Reassembly """
25
26     @classmethod
27     def setUpClass(cls):
28         super(TestIPv4Reassembly, cls).setUpClass()
29
30         cls.create_pg_interfaces([0, 1])
31         cls.src_if = cls.pg0
32         cls.dst_if = cls.pg1
33
34         # setup all interfaces
35         for i in cls.pg_interfaces:
36             i.admin_up()
37             i.config_ip4()
38             i.resolve_arp()
39
40         # packet sizes
41         cls.packet_sizes = [64, 512, 1518, 9018]
42         cls.padding = " abcdefghijklmn"
43         cls.create_stream(cls.packet_sizes)
44         cls.create_fragments()
45
46     def setUp(self):
47         """ Test setup - force timeout on existing reassemblies """
48         super(TestIPv4Reassembly, self).setUp()
49         self.vapi.ip_reassembly_enable_disable(
50             sw_if_index=self.src_if.sw_if_index, enable_ip4=True)
51         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
52                                     expire_walk_interval_ms=10)
53         self.sleep(.25)
54         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
55                                     expire_walk_interval_ms=10000)
56
57     def tearDown(self):
58         super(TestIPv4Reassembly, self).tearDown()
59         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
60
61     @classmethod
62     def create_stream(cls, packet_sizes, packet_count=test_packet_count):
63         """Create input packet stream for defined interface.
64
65         :param list packet_sizes: Required packet sizes.
66         """
67         for i in range(0, packet_count):
68             info = cls.create_packet_info(cls.src_if, cls.src_if)
69             payload = cls.info_to_payload(info)
70             p = (Ether(dst=cls.src_if.local_mac, src=cls.src_if.remote_mac) /
71                  IP(id=info.index, src=cls.src_if.remote_ip4,
72                     dst=cls.dst_if.remote_ip4) /
73                  UDP(sport=1234, dport=5678) /
74                  Raw(payload))
75             size = packet_sizes[(i // 2) % len(packet_sizes)]
76             cls.extend_packet(p, size, cls.padding)
77             info.data = p
78
79     @classmethod
80     def create_fragments(cls):
81         infos = cls._packet_infos
82         cls.pkt_infos = []
83         for index, info in six.iteritems(infos):
84             p = info.data
85             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
86             fragments_400 = fragment_rfc791(p, 400)
87             fragments_300 = fragment_rfc791(p, 300)
88             fragments_200 = [
89                 x for f in fragments_400 for x in fragment_rfc791(f, 200)]
90             cls.pkt_infos.append(
91                 (index, fragments_400, fragments_300, fragments_200))
92         cls.fragments_400 = [
93             x for (_, frags, _, _) in cls.pkt_infos for x in frags]
94         cls.fragments_300 = [
95             x for (_, _, frags, _) in cls.pkt_infos for x in frags]
96         cls.fragments_200 = [
97             x for (_, _, _, frags) in cls.pkt_infos for x in frags]
98         cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, "
99                          "%s 300-byte fragments and %s 200-byte fragments" %
100                          (len(infos), len(cls.fragments_400),
101                              len(cls.fragments_300), len(cls.fragments_200)))
102
103     def verify_capture(self, capture, dropped_packet_indexes=[]):
104         """Verify captured packet stream.
105
106         :param list capture: Captured packet stream.
107         """
108         info = None
109         seen = set()
110         for packet in capture:
111             try:
112                 self.logger.debug(ppp("Got packet:", packet))
113                 ip = packet[IP]
114                 udp = packet[UDP]
115                 payload_info = self.payload_to_info(str(packet[Raw]))
116                 packet_index = payload_info.index
117                 self.assertTrue(
118                     packet_index not in dropped_packet_indexes,
119                     ppp("Packet received, but should be dropped:", packet))
120                 if packet_index in seen:
121                     raise Exception(ppp("Duplicate packet received", packet))
122                 seen.add(packet_index)
123                 self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
124                 info = self._packet_infos[packet_index]
125                 self.assertTrue(info is not None)
126                 self.assertEqual(packet_index, info.index)
127                 saved_packet = info.data
128                 self.assertEqual(ip.src, saved_packet[IP].src)
129                 self.assertEqual(ip.dst, saved_packet[IP].dst)
130                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
131             except Exception:
132                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
133                 raise
134         for index in self._packet_infos:
135             self.assertTrue(index in seen or index in dropped_packet_indexes,
136                             "Packet with packet_index %d not received" % index)
137
138     def test_reassembly(self):
139         """ basic reassembly """
140
141         self.pg_enable_capture()
142         self.src_if.add_stream(self.fragments_200)
143         self.pg_start()
144
145         packets = self.dst_if.get_capture(len(self.pkt_infos))
146         self.verify_capture(packets)
147         self.src_if.assert_nothing_captured()
148
149         # run it all again to verify correctness
150         self.pg_enable_capture()
151         self.src_if.add_stream(self.fragments_200)
152         self.pg_start()
153
154         packets = self.dst_if.get_capture(len(self.pkt_infos))
155         self.verify_capture(packets)
156         self.src_if.assert_nothing_captured()
157
158     def test_reversed(self):
159         """ reverse order reassembly """
160
161         fragments = list(self.fragments_200)
162         fragments.reverse()
163
164         self.pg_enable_capture()
165         self.src_if.add_stream(fragments)
166         self.pg_start()
167
168         packets = self.dst_if.get_capture(len(self.packet_infos))
169         self.verify_capture(packets)
170         self.src_if.assert_nothing_captured()
171
172         # run it all again to verify correctness
173         self.pg_enable_capture()
174         self.src_if.add_stream(fragments)
175         self.pg_start()
176
177         packets = self.dst_if.get_capture(len(self.packet_infos))
178         self.verify_capture(packets)
179         self.src_if.assert_nothing_captured()
180
181     def test_5737(self):
182         """ fragment length + ip header size > 65535 """
183         raw = ('E\x00\x00\x88,\xf8\x1f\xfe@\x01\x98\x00\xc0\xa8\n-\xc0\xa8\n'
184                '\x01\x08\x00\xf0J\xed\xcb\xf1\xf5Test-group: IPv4.IPv4.ipv4-'
185                'message.Ethernet-Payload.IPv4-Packet.IPv4-Header.Fragment-Of'
186                'fset; Test-case: 5737')
187
188         malformed_packet = (Ether(dst=self.src_if.local_mac,
189                                   src=self.src_if.remote_mac) /
190                             IP(raw))
191         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
192              IP(id=1000, src=self.src_if.remote_ip4,
193                 dst=self.dst_if.remote_ip4) /
194              UDP(sport=1234, dport=5678) /
195              Raw("X" * 1000))
196         valid_fragments = fragment_rfc791(p, 400)
197
198         self.pg_enable_capture()
199         self.src_if.add_stream([malformed_packet] + valid_fragments)
200         self.pg_start()
201
202         self.dst_if.get_capture(1)
203         self.assert_packet_counter_equal(
204             "/err/ip4-reassembly-feature/malformed packets", 1)
205
206     @unittest.skipIf(is_skip_aarch64_set() and is_platform_aarch64(),
207                      "test doesn't work on aarch64")
208     def test_random(self):
209         """ random order reassembly """
210
211         fragments = list(self.fragments_200)
212         shuffle(fragments)
213
214         self.pg_enable_capture()
215         self.src_if.add_stream(fragments)
216         self.pg_start()
217
218         packets = self.dst_if.get_capture(len(self.packet_infos))
219         self.verify_capture(packets)
220         self.src_if.assert_nothing_captured()
221
222         # run it all again to verify correctness
223         self.pg_enable_capture()
224         self.src_if.add_stream(fragments)
225         self.pg_start()
226
227         packets = self.dst_if.get_capture(len(self.packet_infos))
228         self.verify_capture(packets)
229         self.src_if.assert_nothing_captured()
230
231     def test_duplicates(self):
232         """ duplicate fragments """
233
234         fragments = [
235             x for (_, frags, _, _) in self.pkt_infos
236             for x in frags
237             for _ in range(0, min(2, len(frags)))
238         ]
239
240         self.pg_enable_capture()
241         self.src_if.add_stream(fragments)
242         self.pg_start()
243
244         packets = self.dst_if.get_capture(len(self.pkt_infos))
245         self.verify_capture(packets)
246         self.src_if.assert_nothing_captured()
247
248     def test_overlap1(self):
249         """ overlapping fragments case #1 """
250
251         fragments = []
252         for _, _, frags_300, frags_200 in self.pkt_infos:
253             if len(frags_300) == 1:
254                 fragments.extend(frags_300)
255             else:
256                 for i, j in zip(frags_200, frags_300):
257                     fragments.extend(i)
258                     fragments.extend(j)
259
260         self.pg_enable_capture()
261         self.src_if.add_stream(fragments)
262         self.pg_start()
263
264         packets = self.dst_if.get_capture(len(self.pkt_infos))
265         self.verify_capture(packets)
266         self.src_if.assert_nothing_captured()
267
268         # run it all to verify correctness
269         self.pg_enable_capture()
270         self.src_if.add_stream(fragments)
271         self.pg_start()
272
273         packets = self.dst_if.get_capture(len(self.pkt_infos))
274         self.verify_capture(packets)
275         self.src_if.assert_nothing_captured()
276
277     def test_overlap2(self):
278         """ overlapping fragments case #2 """
279
280         fragments = []
281         for _, _, frags_300, frags_200 in self.pkt_infos:
282             if len(frags_300) == 1:
283                 fragments.extend(frags_300)
284             else:
285                 # care must be taken here so that there are no fragments
286                 # received by vpp after reassembly is finished, otherwise
287                 # new reassemblies will be started and packet generator will
288                 # freak out when it detects unfreed buffers
289                 zipped = zip(frags_300, frags_200)
290                 for i, j in zipped[:-1]:
291                     fragments.extend(i)
292                     fragments.extend(j)
293                 fragments.append(zipped[-1][0])
294
295         self.pg_enable_capture()
296         self.src_if.add_stream(fragments)
297         self.pg_start()
298
299         packets = self.dst_if.get_capture(len(self.pkt_infos))
300         self.verify_capture(packets)
301         self.src_if.assert_nothing_captured()
302
303         # run it all to verify correctness
304         self.pg_enable_capture()
305         self.src_if.add_stream(fragments)
306         self.pg_start()
307
308         packets = self.dst_if.get_capture(len(self.pkt_infos))
309         self.verify_capture(packets)
310         self.src_if.assert_nothing_captured()
311
312     def test_timeout_inline(self):
313         """ timeout (inline) """
314
315         dropped_packet_indexes = set(
316             index for (index, frags, _, _) in self.pkt_infos if len(frags) > 1
317         )
318
319         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
320                                     expire_walk_interval_ms=10000)
321
322         self.pg_enable_capture()
323         self.src_if.add_stream(self.fragments_400)
324         self.pg_start()
325
326         packets = self.dst_if.get_capture(
327             len(self.pkt_infos) - len(dropped_packet_indexes))
328         self.verify_capture(packets, dropped_packet_indexes)
329         self.src_if.assert_nothing_captured()
330
331     def test_timeout_cleanup(self):
332         """ timeout (cleanup) """
333
334         # whole packets + fragmented packets sans last fragment
335         fragments = [
336             x for (_, frags_400, _, _) in self.pkt_infos
337             for x in frags_400[:-1 if len(frags_400) > 1 else None]
338         ]
339
340         # last fragments for fragmented packets
341         fragments2 = [frags_400[-1]
342                       for (_, frags_400, _, _) in self.pkt_infos
343                       if len(frags_400) > 1]
344
345         dropped_packet_indexes = set(
346             index for (index, frags_400, _, _) in self.pkt_infos
347             if len(frags_400) > 1)
348
349         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
350                                     expire_walk_interval_ms=50)
351
352         self.pg_enable_capture()
353         self.src_if.add_stream(fragments)
354         self.pg_start()
355
356         self.sleep(.25, "wait before sending rest of fragments")
357
358         self.src_if.add_stream(fragments2)
359         self.pg_start()
360
361         packets = self.dst_if.get_capture(
362             len(self.pkt_infos) - len(dropped_packet_indexes))
363         self.verify_capture(packets, dropped_packet_indexes)
364         self.src_if.assert_nothing_captured()
365
366     def test_disabled(self):
367         """ reassembly disabled """
368
369         dropped_packet_indexes = set(
370             index for (index, frags_400, _, _) in self.pkt_infos
371             if len(frags_400) > 1)
372
373         self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0,
374                                     expire_walk_interval_ms=10000)
375
376         self.pg_enable_capture()
377         self.src_if.add_stream(self.fragments_400)
378         self.pg_start()
379
380         packets = self.dst_if.get_capture(
381             len(self.pkt_infos) - len(dropped_packet_indexes))
382         self.verify_capture(packets, dropped_packet_indexes)
383         self.src_if.assert_nothing_captured()
384
385
386 class TestIPv6Reassembly(VppTestCase):
387     """ IPv6 Reassembly """
388
389     @classmethod
390     def setUpClass(cls):
391         super(TestIPv6Reassembly, cls).setUpClass()
392
393         cls.create_pg_interfaces([0, 1])
394         cls.src_if = cls.pg0
395         cls.dst_if = cls.pg1
396
397         # setup all interfaces
398         for i in cls.pg_interfaces:
399             i.admin_up()
400             i.config_ip6()
401             i.resolve_ndp()
402
403         # packet sizes
404         cls.packet_sizes = [64, 512, 1518, 9018]
405         cls.padding = " abcdefghijklmn"
406         cls.create_stream(cls.packet_sizes)
407         cls.create_fragments()
408
409     def setUp(self):
410         """ Test setup - force timeout on existing reassemblies """
411         super(TestIPv6Reassembly, self).setUp()
412         self.vapi.ip_reassembly_enable_disable(
413             sw_if_index=self.src_if.sw_if_index, enable_ip6=True)
414         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
415                                     expire_walk_interval_ms=10, is_ip6=1)
416         self.sleep(.25)
417         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
418                                     expire_walk_interval_ms=10000, is_ip6=1)
419         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
420
421     def tearDown(self):
422         super(TestIPv6Reassembly, self).tearDown()
423         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
424
425     @classmethod
426     def create_stream(cls, packet_sizes, packet_count=test_packet_count):
427         """Create input packet stream for defined interface.
428
429         :param list packet_sizes: Required packet sizes.
430         """
431         for i in range(0, packet_count):
432             info = cls.create_packet_info(cls.src_if, cls.src_if)
433             payload = cls.info_to_payload(info)
434             p = (Ether(dst=cls.src_if.local_mac, src=cls.src_if.remote_mac) /
435                  IPv6(src=cls.src_if.remote_ip6,
436                       dst=cls.dst_if.remote_ip6) /
437                  UDP(sport=1234, dport=5678) /
438                  Raw(payload))
439             size = packet_sizes[(i // 2) % len(packet_sizes)]
440             cls.extend_packet(p, size, cls.padding)
441             info.data = p
442
443     @classmethod
444     def create_fragments(cls):
445         infos = cls._packet_infos
446         cls.pkt_infos = []
447         for index, info in six.iteritems(infos):
448             p = info.data
449             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
450             fragments_400 = fragment_rfc8200(p, info.index, 400)
451             fragments_300 = fragment_rfc8200(p, info.index, 300)
452             cls.pkt_infos.append((index, fragments_400, fragments_300))
453         cls.fragments_400 = [
454             x for _, frags, _ in cls.pkt_infos for x in frags]
455         cls.fragments_300 = [
456             x for _, _, frags in cls.pkt_infos for x in frags]
457         cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, "
458                          "and %s 300-byte fragments" %
459                          (len(infos), len(cls.fragments_400),
460                              len(cls.fragments_300)))
461
462     def verify_capture(self, capture, dropped_packet_indexes=[]):
463         """Verify captured packet strea .
464
465         :param list capture: Captured packet stream.
466         """
467         info = None
468         seen = set()
469         for packet in capture:
470             try:
471                 self.logger.debug(ppp("Got packet:", packet))
472                 ip = packet[IPv6]
473                 udp = packet[UDP]
474                 payload_info = self.payload_to_info(str(packet[Raw]))
475                 packet_index = payload_info.index
476                 self.assertTrue(
477                     packet_index not in dropped_packet_indexes,
478                     ppp("Packet received, but should be dropped:", packet))
479                 if packet_index in seen:
480                     raise Exception(ppp("Duplicate packet received", packet))
481                 seen.add(packet_index)
482                 self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
483                 info = self._packet_infos[packet_index]
484                 self.assertTrue(info is not None)
485                 self.assertEqual(packet_index, info.index)
486                 saved_packet = info.data
487                 self.assertEqual(ip.src, saved_packet[IPv6].src)
488                 self.assertEqual(ip.dst, saved_packet[IPv6].dst)
489                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
490             except Exception:
491                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
492                 raise
493         for index in self._packet_infos:
494             self.assertTrue(index in seen or index in dropped_packet_indexes,
495                             "Packet with packet_index %d not received" % index)
496
497     def test_reassembly(self):
498         """ basic reassembly """
499
500         self.pg_enable_capture()
501         self.src_if.add_stream(self.fragments_400)
502         self.pg_start()
503
504         packets = self.dst_if.get_capture(len(self.pkt_infos))
505         self.verify_capture(packets)
506         self.src_if.assert_nothing_captured()
507
508         # run it all again to verify correctness
509         self.pg_enable_capture()
510         self.src_if.add_stream(self.fragments_400)
511         self.pg_start()
512
513         packets = self.dst_if.get_capture(len(self.pkt_infos))
514         self.verify_capture(packets)
515         self.src_if.assert_nothing_captured()
516
517     def test_reversed(self):
518         """ reverse order reassembly """
519
520         fragments = list(self.fragments_400)
521         fragments.reverse()
522
523         self.pg_enable_capture()
524         self.src_if.add_stream(fragments)
525         self.pg_start()
526
527         packets = self.dst_if.get_capture(len(self.pkt_infos))
528         self.verify_capture(packets)
529         self.src_if.assert_nothing_captured()
530
531         # run it all again to verify correctness
532         self.pg_enable_capture()
533         self.src_if.add_stream(fragments)
534         self.pg_start()
535
536         packets = self.dst_if.get_capture(len(self.pkt_infos))
537         self.verify_capture(packets)
538         self.src_if.assert_nothing_captured()
539
540     def test_random(self):
541         """ random order reassembly """
542
543         fragments = list(self.fragments_400)
544         shuffle(fragments)
545
546         self.pg_enable_capture()
547         self.src_if.add_stream(fragments)
548         self.pg_start()
549
550         packets = self.dst_if.get_capture(len(self.pkt_infos))
551         self.verify_capture(packets)
552         self.src_if.assert_nothing_captured()
553
554         # run it all again to verify correctness
555         self.pg_enable_capture()
556         self.src_if.add_stream(fragments)
557         self.pg_start()
558
559         packets = self.dst_if.get_capture(len(self.pkt_infos))
560         self.verify_capture(packets)
561         self.src_if.assert_nothing_captured()
562
563     def test_duplicates(self):
564         """ duplicate fragments """
565
566         fragments = [
567             x for (_, frags, _) in self.pkt_infos
568             for x in frags
569             for _ in range(0, min(2, len(frags)))
570         ]
571
572         self.pg_enable_capture()
573         self.src_if.add_stream(fragments)
574         self.pg_start()
575
576         packets = self.dst_if.get_capture(len(self.pkt_infos))
577         self.verify_capture(packets)
578         self.src_if.assert_nothing_captured()
579
580     def test_overlap1(self):
581         """ overlapping fragments case #1 """
582
583         fragments = []
584         for _, frags_400, frags_300 in self.pkt_infos:
585             if len(frags_300) == 1:
586                 fragments.extend(frags_400)
587             else:
588                 for i, j in zip(frags_300, frags_400):
589                     fragments.extend(i)
590                     fragments.extend(j)
591
592         dropped_packet_indexes = set(
593             index for (index, _, frags) in self.pkt_infos if len(frags) > 1
594         )
595
596         self.pg_enable_capture()
597         self.src_if.add_stream(fragments)
598         self.pg_start()
599
600         packets = self.dst_if.get_capture(
601             len(self.pkt_infos) - len(dropped_packet_indexes))
602         self.verify_capture(packets, dropped_packet_indexes)
603         self.src_if.assert_nothing_captured()
604
605     def test_overlap2(self):
606         """ overlapping fragments case #2 """
607
608         fragments = []
609         for _, frags_400, frags_300 in self.pkt_infos:
610             if len(frags_400) == 1:
611                 fragments.extend(frags_400)
612             else:
613                 # care must be taken here so that there are no fragments
614                 # received by vpp after reassembly is finished, otherwise
615                 # new reassemblies will be started and packet generator will
616                 # freak out when it detects unfreed buffers
617                 zipped = zip(frags_400, frags_300)
618                 for i, j in zipped[:-1]:
619                     fragments.extend(i)
620                     fragments.extend(j)
621                 fragments.append(zipped[-1][0])
622
623         dropped_packet_indexes = set(
624             index for (index, _, frags) in self.pkt_infos if len(frags) > 1
625         )
626
627         self.pg_enable_capture()
628         self.src_if.add_stream(fragments)
629         self.pg_start()
630
631         packets = self.dst_if.get_capture(
632             len(self.pkt_infos) - len(dropped_packet_indexes))
633         self.verify_capture(packets, dropped_packet_indexes)
634         self.src_if.assert_nothing_captured()
635
636     def test_timeout_inline(self):
637         """ timeout (inline) """
638
639         dropped_packet_indexes = set(
640             index for (index, frags, _) in self.pkt_infos if len(frags) > 1
641         )
642
643         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
644                                     expire_walk_interval_ms=10000, is_ip6=1)
645
646         self.pg_enable_capture()
647         self.src_if.add_stream(self.fragments_400)
648         self.pg_start()
649
650         packets = self.dst_if.get_capture(
651             len(self.pkt_infos) - len(dropped_packet_indexes))
652         self.verify_capture(packets, dropped_packet_indexes)
653         pkts = self.src_if.get_capture(
654             expected_count=len(dropped_packet_indexes))
655         for icmp in pkts:
656             self.assertIn(ICMPv6TimeExceeded, icmp)
657             self.assertIn(IPv6ExtHdrFragment, icmp)
658             self.assertIn(icmp[IPv6ExtHdrFragment].id, dropped_packet_indexes)
659             dropped_packet_indexes.remove(icmp[IPv6ExtHdrFragment].id)
660
661     def test_timeout_cleanup(self):
662         """ timeout (cleanup) """
663
664         # whole packets + fragmented packets sans last fragment
665         fragments = [
666             x for (_, frags_400, _) in self.pkt_infos
667             for x in frags_400[:-1 if len(frags_400) > 1 else None]
668         ]
669
670         # last fragments for fragmented packets
671         fragments2 = [frags_400[-1]
672                       for (_, frags_400, _) in self.pkt_infos
673                       if len(frags_400) > 1]
674
675         dropped_packet_indexes = set(
676             index for (index, frags_400, _) in self.pkt_infos
677             if len(frags_400) > 1)
678
679         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
680                                     expire_walk_interval_ms=50)
681
682         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
683                                     expire_walk_interval_ms=50, is_ip6=1)
684
685         self.pg_enable_capture()
686         self.src_if.add_stream(fragments)
687         self.pg_start()
688
689         self.sleep(.25, "wait before sending rest of fragments")
690
691         self.src_if.add_stream(fragments2)
692         self.pg_start()
693
694         packets = self.dst_if.get_capture(
695             len(self.pkt_infos) - len(dropped_packet_indexes))
696         self.verify_capture(packets, dropped_packet_indexes)
697         pkts = self.src_if.get_capture(
698             expected_count=len(dropped_packet_indexes))
699         for icmp in pkts:
700             self.assertIn(ICMPv6TimeExceeded, icmp)
701             self.assertIn(IPv6ExtHdrFragment, icmp)
702             self.assertIn(icmp[IPv6ExtHdrFragment].id, dropped_packet_indexes)
703             dropped_packet_indexes.remove(icmp[IPv6ExtHdrFragment].id)
704
705     def test_disabled(self):
706         """ reassembly disabled """
707
708         dropped_packet_indexes = set(
709             index for (index, frags_400, _) in self.pkt_infos
710             if len(frags_400) > 1)
711
712         self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0,
713                                     expire_walk_interval_ms=10000, is_ip6=1)
714
715         self.pg_enable_capture()
716         self.src_if.add_stream(self.fragments_400)
717         self.pg_start()
718
719         packets = self.dst_if.get_capture(
720             len(self.pkt_infos) - len(dropped_packet_indexes))
721         self.verify_capture(packets, dropped_packet_indexes)
722         self.src_if.assert_nothing_captured()
723
724     def test_missing_upper(self):
725         """ missing upper layer """
726         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
727              IPv6(src=self.src_if.remote_ip6,
728                   dst=self.src_if.local_ip6) /
729              UDP(sport=1234, dport=5678) /
730              Raw())
731         self.extend_packet(p, 1000, self.padding)
732         fragments = fragment_rfc8200(p, 1, 500)
733         bad_fragment = p.__class__(str(fragments[1]))
734         bad_fragment[IPv6ExtHdrFragment].nh = 59
735         bad_fragment[IPv6ExtHdrFragment].offset = 0
736         self.pg_enable_capture()
737         self.src_if.add_stream([bad_fragment])
738         self.pg_start()
739         pkts = self.src_if.get_capture(expected_count=1)
740         icmp = pkts[0]
741         self.assertIn(ICMPv6ParamProblem, icmp)
742         self.assert_equal(icmp[ICMPv6ParamProblem].code, 3, "ICMP code")
743
744     def test_invalid_frag_size(self):
745         """ fragment size not a multiple of 8 """
746         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
747              IPv6(src=self.src_if.remote_ip6,
748                   dst=self.src_if.local_ip6) /
749              UDP(sport=1234, dport=5678) /
750              Raw())
751         self.extend_packet(p, 1000, self.padding)
752         fragments = fragment_rfc8200(p, 1, 500)
753         bad_fragment = fragments[0]
754         self.extend_packet(bad_fragment, len(bad_fragment) + 5)
755         self.pg_enable_capture()
756         self.src_if.add_stream([bad_fragment])
757         self.pg_start()
758         pkts = self.src_if.get_capture(expected_count=1)
759         icmp = pkts[0]
760         self.assertIn(ICMPv6ParamProblem, icmp)
761         self.assert_equal(icmp[ICMPv6ParamProblem].code, 0, "ICMP code")
762
763     def test_invalid_packet_size(self):
764         """ total packet size > 65535 """
765         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
766              IPv6(src=self.src_if.remote_ip6,
767                   dst=self.src_if.local_ip6) /
768              UDP(sport=1234, dport=5678) /
769              Raw())
770         self.extend_packet(p, 1000, self.padding)
771         fragments = fragment_rfc8200(p, 1, 500)
772         bad_fragment = fragments[1]
773         bad_fragment[IPv6ExtHdrFragment].offset = 65500
774         self.pg_enable_capture()
775         self.src_if.add_stream([bad_fragment])
776         self.pg_start()
777         pkts = self.src_if.get_capture(expected_count=1)
778         icmp = pkts[0]
779         self.assertIn(ICMPv6ParamProblem, icmp)
780         self.assert_equal(icmp[ICMPv6ParamProblem].code, 0, "ICMP code")
781
782
783 class TestIPv4ReassemblyLocalNode(VppTestCase):
784     """ IPv4 Reassembly for packets coming to ip4-local node """
785
786     @classmethod
787     def setUpClass(cls):
788         super(TestIPv4ReassemblyLocalNode, cls).setUpClass()
789
790         cls.create_pg_interfaces([0])
791         cls.src_dst_if = cls.pg0
792
793         # setup all interfaces
794         for i in cls.pg_interfaces:
795             i.admin_up()
796             i.config_ip4()
797             i.resolve_arp()
798
799         cls.padding = " abcdefghijklmn"
800         cls.create_stream()
801         cls.create_fragments()
802
803     def setUp(self):
804         """ Test setup - force timeout on existing reassemblies """
805         super(TestIPv4ReassemblyLocalNode, self).setUp()
806         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
807                                     expire_walk_interval_ms=10)
808         self.sleep(.25)
809         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
810                                     expire_walk_interval_ms=10000)
811
812     def tearDown(self):
813         super(TestIPv4ReassemblyLocalNode, self).tearDown()
814         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
815
816     @classmethod
817     def create_stream(cls, packet_count=test_packet_count):
818         """Create input packet stream for defined interface.
819
820         :param list packet_sizes: Required packet sizes.
821         """
822         for i in range(0, packet_count):
823             info = cls.create_packet_info(cls.src_dst_if, cls.src_dst_if)
824             payload = cls.info_to_payload(info)
825             p = (Ether(dst=cls.src_dst_if.local_mac,
826                        src=cls.src_dst_if.remote_mac) /
827                  IP(id=info.index, src=cls.src_dst_if.remote_ip4,
828                     dst=cls.src_dst_if.local_ip4) /
829                  ICMP(type='echo-request', id=1234) /
830                  Raw(payload))
831             cls.extend_packet(p, 1518, cls.padding)
832             info.data = p
833
834     @classmethod
835     def create_fragments(cls):
836         infos = cls._packet_infos
837         cls.pkt_infos = []
838         for index, info in six.iteritems(infos):
839             p = info.data
840             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
841             fragments_300 = fragment_rfc791(p, 300)
842             cls.pkt_infos.append((index, fragments_300))
843         cls.fragments_300 = [x for (_, frags) in cls.pkt_infos for x in frags]
844         cls.logger.debug("Fragmented %s packets into %s 300-byte fragments" %
845                          (len(infos), len(cls.fragments_300)))
846
847     def verify_capture(self, capture):
848         """Verify captured packet stream.
849
850         :param list capture: Captured packet stream.
851         """
852         info = None
853         seen = set()
854         for packet in capture:
855             try:
856                 self.logger.debug(ppp("Got packet:", packet))
857                 ip = packet[IP]
858                 icmp = packet[ICMP]
859                 payload_info = self.payload_to_info(str(packet[Raw]))
860                 packet_index = payload_info.index
861                 if packet_index in seen:
862                     raise Exception(ppp("Duplicate packet received", packet))
863                 seen.add(packet_index)
864                 self.assertEqual(payload_info.dst, self.src_dst_if.sw_if_index)
865                 info = self._packet_infos[packet_index]
866                 self.assertIsNotNone(info)
867                 self.assertEqual(packet_index, info.index)
868                 saved_packet = info.data
869                 self.assertEqual(ip.src, saved_packet[IP].dst)
870                 self.assertEqual(ip.dst, saved_packet[IP].src)
871                 self.assertEqual(icmp.type, 0)  # echo reply
872                 self.assertEqual(icmp.id, saved_packet[ICMP].id)
873                 self.assertEqual(icmp.payload, saved_packet[ICMP].payload)
874             except Exception:
875                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
876                 raise
877         for index in self._packet_infos:
878             self.assertIn(index, seen,
879                           "Packet with packet_index %d not received" % index)
880
881     def test_reassembly(self):
882         """ basic reassembly """
883
884         self.pg_enable_capture()
885         self.src_dst_if.add_stream(self.fragments_300)
886         self.pg_start()
887
888         packets = self.src_dst_if.get_capture(len(self.pkt_infos))
889         self.verify_capture(packets)
890
891         # run it all again to verify correctness
892         self.pg_enable_capture()
893         self.src_dst_if.add_stream(self.fragments_300)
894         self.pg_start()
895
896         packets = self.src_dst_if.get_capture(len(self.pkt_infos))
897         self.verify_capture(packets)
898
899
900 class TestFIFReassembly(VppTestCase):
901     """ Fragments in fragments reassembly """
902
903     @classmethod
904     def setUpClass(cls):
905         super(TestFIFReassembly, cls).setUpClass()
906
907         cls.create_pg_interfaces([0, 1])
908         cls.src_if = cls.pg0
909         cls.dst_if = cls.pg1
910         for i in cls.pg_interfaces:
911             i.admin_up()
912             i.config_ip4()
913             i.resolve_arp()
914             i.config_ip6()
915             i.resolve_ndp()
916
917         cls.packet_sizes = [64, 512, 1518, 9018]
918         cls.padding = " abcdefghijklmn"
919
920     def setUp(self):
921         """ Test setup - force timeout on existing reassemblies """
922         super(TestFIFReassembly, self).setUp()
923         self.vapi.ip_reassembly_enable_disable(
924             sw_if_index=self.src_if.sw_if_index, enable_ip4=True,
925             enable_ip6=True)
926         self.vapi.ip_reassembly_enable_disable(
927             sw_if_index=self.dst_if.sw_if_index, enable_ip4=True,
928             enable_ip6=True)
929         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
930                                     expire_walk_interval_ms=10)
931         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
932                                     expire_walk_interval_ms=10, is_ip6=1)
933         self.sleep(.25)
934         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
935                                     expire_walk_interval_ms=10000)
936         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
937                                     expire_walk_interval_ms=10000, is_ip6=1)
938
939     def tearDown(self):
940         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
941         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
942         super(TestFIFReassembly, self).tearDown()
943
944     def verify_capture(self, capture, ip_class, dropped_packet_indexes=[]):
945         """Verify captured packet stream.
946
947         :param list capture: Captured packet stream.
948         """
949         info = None
950         seen = set()
951         for packet in capture:
952             try:
953                 self.logger.debug(ppp("Got packet:", packet))
954                 ip = packet[ip_class]
955                 udp = packet[UDP]
956                 payload_info = self.payload_to_info(str(packet[Raw]))
957                 packet_index = payload_info.index
958                 self.assertTrue(
959                     packet_index not in dropped_packet_indexes,
960                     ppp("Packet received, but should be dropped:", packet))
961                 if packet_index in seen:
962                     raise Exception(ppp("Duplicate packet received", packet))
963                 seen.add(packet_index)
964                 self.assertEqual(payload_info.dst, self.dst_if.sw_if_index)
965                 info = self._packet_infos[packet_index]
966                 self.assertTrue(info is not None)
967                 self.assertEqual(packet_index, info.index)
968                 saved_packet = info.data
969                 self.assertEqual(ip.src, saved_packet[ip_class].src)
970                 self.assertEqual(ip.dst, saved_packet[ip_class].dst)
971                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
972             except Exception:
973                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
974                 raise
975         for index in self._packet_infos:
976             self.assertTrue(index in seen or index in dropped_packet_indexes,
977                             "Packet with packet_index %d not received" % index)
978
979     def test_fif4(self):
980         """ Fragments in fragments (4o4) """
981
982         # TODO this should be ideally in setUpClass, but then we hit a bug
983         # with VppIpRoute incorrectly reporting it's present when it's not
984         # so we need to manually remove the vpp config, thus we cannot have
985         # it shared for multiple test cases
986         self.tun_ip4 = "1.1.1.2"
987
988         self.gre4 = VppGreInterface(self, self.src_if.local_ip4, self.tun_ip4)
989         self.gre4.add_vpp_config()
990         self.gre4.admin_up()
991         self.gre4.config_ip4()
992
993         self.vapi.ip_reassembly_enable_disable(
994             sw_if_index=self.gre4.sw_if_index, enable_ip4=True)
995
996         self.route4 = VppIpRoute(self, self.tun_ip4, 32,
997                                  [VppRoutePath(self.src_if.remote_ip4,
998                                                self.src_if.sw_if_index)])
999         self.route4.add_vpp_config()
1000
1001         self.reset_packet_infos()
1002         for i in range(test_packet_count):
1003             info = self.create_packet_info(self.src_if, self.dst_if)
1004             payload = self.info_to_payload(info)
1005             # Ethernet header here is only for size calculation, thus it
1006             # doesn't matter how it's initialized. This is to ensure that
1007             # reassembled packet is not > 9000 bytes, so that it's not dropped
1008             p = (Ether() /
1009                  IP(id=i, src=self.src_if.remote_ip4,
1010                     dst=self.dst_if.remote_ip4) /
1011                  UDP(sport=1234, dport=5678) /
1012                  Raw(payload))
1013             size = self.packet_sizes[(i // 2) % len(self.packet_sizes)]
1014             self.extend_packet(p, size, self.padding)
1015             info.data = p[IP]  # use only IP part, without ethernet header
1016
1017         fragments = [x for _, p in six.iteritems(self._packet_infos)
1018                      for x in fragment_rfc791(p.data, 400)]
1019
1020         encapped_fragments = \
1021             [Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
1022              IP(src=self.tun_ip4, dst=self.src_if.local_ip4) /
1023                 GRE() /
1024                 p
1025                 for p in fragments]
1026
1027         fragmented_encapped_fragments = \
1028             [x for p in encapped_fragments
1029              for x in fragment_rfc791(p, 200)]
1030
1031         self.src_if.add_stream(fragmented_encapped_fragments)
1032
1033         self.pg_enable_capture(self.pg_interfaces)
1034         self.pg_start()
1035
1036         self.src_if.assert_nothing_captured()
1037         packets = self.dst_if.get_capture(len(self._packet_infos))
1038         self.verify_capture(packets, IP)
1039
1040         # TODO remove gre vpp config by hand until VppIpRoute gets fixed
1041         # so that it's query_vpp_config() works as it should
1042         self.gre4.remove_vpp_config()
1043         self.logger.debug(self.vapi.ppcli("show interface"))
1044
1045     def test_fif6(self):
1046         """ Fragments in fragments (6o6) """
1047         # TODO this should be ideally in setUpClass, but then we hit a bug
1048         # with VppIpRoute incorrectly reporting it's present when it's not
1049         # so we need to manually remove the vpp config, thus we cannot have
1050         # it shared for multiple test cases
1051         self.tun_ip6 = "1002::1"
1052
1053         self.gre6 = VppGre6Interface(self, self.src_if.local_ip6, self.tun_ip6)
1054         self.gre6.add_vpp_config()
1055         self.gre6.admin_up()
1056         self.gre6.config_ip6()
1057
1058         self.vapi.ip_reassembly_enable_disable(
1059             sw_if_index=self.gre6.sw_if_index, enable_ip6=True)
1060
1061         self.route6 = VppIpRoute(self, self.tun_ip6, 128,
1062                                  [VppRoutePath(self.src_if.remote_ip6,
1063                                                self.src_if.sw_if_index,
1064                                                proto=DpoProto.DPO_PROTO_IP6)],
1065                                  is_ip6=1)
1066         self.route6.add_vpp_config()
1067
1068         self.reset_packet_infos()
1069         for i in range(test_packet_count):
1070             info = self.create_packet_info(self.src_if, self.dst_if)
1071             payload = self.info_to_payload(info)
1072             # Ethernet header here is only for size calculation, thus it
1073             # doesn't matter how it's initialized. This is to ensure that
1074             # reassembled packet is not > 9000 bytes, so that it's not dropped
1075             p = (Ether() /
1076                  IPv6(src=self.src_if.remote_ip6, dst=self.dst_if.remote_ip6) /
1077                  UDP(sport=1234, dport=5678) /
1078                  Raw(payload))
1079             size = self.packet_sizes[(i // 2) % len(self.packet_sizes)]
1080             self.extend_packet(p, size, self.padding)
1081             info.data = p[IPv6]  # use only IPv6 part, without ethernet header
1082
1083         fragments = [x for _, i in six.iteritems(self._packet_infos)
1084                      for x in fragment_rfc8200(
1085                          i.data, i.index, 400)]
1086
1087         encapped_fragments = \
1088             [Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
1089              IPv6(src=self.tun_ip6, dst=self.src_if.local_ip6) /
1090                 GRE() /
1091                 p
1092                 for p in fragments]
1093
1094         fragmented_encapped_fragments = \
1095             [x for p in encapped_fragments for x in (
1096                 fragment_rfc8200(
1097                     p,
1098                     2 * len(self._packet_infos) + p[IPv6ExtHdrFragment].id,
1099                     200)
1100                 if IPv6ExtHdrFragment in p else [p]
1101             )
1102             ]
1103
1104         self.src_if.add_stream(fragmented_encapped_fragments)
1105
1106         self.pg_enable_capture(self.pg_interfaces)
1107         self.pg_start()
1108
1109         self.src_if.assert_nothing_captured()
1110         packets = self.dst_if.get_capture(len(self._packet_infos))
1111         self.verify_capture(packets, IPv6)
1112
1113         # TODO remove gre vpp config by hand until VppIpRoute gets fixed
1114         # so that it's query_vpp_config() works as it should
1115         self.gre6.remove_vpp_config()
1116
1117
1118 if __name__ == '__main__':
1119     unittest.main(testRunner=VppTestRunner)