API: Use string type instead of u8.
[vpp.git] / test / test_reassembly.py
1 #!/usr/bin/env python
2
3 import six
4 import unittest
5 from random import shuffle
6
7 from framework import VppTestCase, VppTestRunner, is_skip_aarch64_set,\
8     is_platform_aarch64
9
10 from scapy.packet import Raw
11 from scapy.layers.l2 import Ether, GRE
12 from scapy.layers.inet import IP, UDP, ICMP
13 from util import ppp, fragment_rfc791, fragment_rfc8200
14 from scapy.layers.inet6 import IPv6, IPv6ExtHdrFragment, ICMPv6ParamProblem,\
15     ICMPv6TimeExceeded
16 from vpp_gre_interface import VppGreInterface, VppGre6Interface
17 from vpp_ip import DpoProto
18 from vpp_ip_route import VppIpRoute, VppRoutePath
19
20 test_packet_count = 257
21
22
23 class TestIPv4Reassembly(VppTestCase):
24     """ IPv4 Reassembly """
25
26     @classmethod
27     def setUpClass(cls):
28         super(TestIPv4Reassembly, cls).setUpClass()
29
30         cls.create_pg_interfaces([0, 1])
31         cls.src_if = cls.pg0
32         cls.dst_if = cls.pg1
33
34         # setup all interfaces
35         for i in cls.pg_interfaces:
36             i.admin_up()
37             i.config_ip4()
38             i.resolve_arp()
39
40         # packet sizes
41         cls.packet_sizes = [64, 512, 1518, 9018]
42         cls.padding = " abcdefghijklmn"
43         cls.create_stream(cls.packet_sizes)
44         cls.create_fragments()
45
46     def setUp(self):
47         """ Test setup - force timeout on existing reassemblies """
48         super(TestIPv4Reassembly, self).setUp()
49         self.vapi.ip_reassembly_enable_disable(
50             sw_if_index=self.src_if.sw_if_index, enable_ip4=True)
51         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
52                                     expire_walk_interval_ms=10)
53         self.sleep(.25)
54         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
55                                     expire_walk_interval_ms=10000)
56
57     def tearDown(self):
58         super(TestIPv4Reassembly, self).tearDown()
59         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
60
61     @classmethod
62     def create_stream(cls, packet_sizes, packet_count=test_packet_count):
63         """Create input packet stream for defined interface.
64
65         :param list packet_sizes: Required packet sizes.
66         """
67         for i in range(0, packet_count):
68             info = cls.create_packet_info(cls.src_if, cls.src_if)
69             payload = cls.info_to_payload(info)
70             p = (Ether(dst=cls.src_if.local_mac, src=cls.src_if.remote_mac) /
71                  IP(id=info.index, src=cls.src_if.remote_ip4,
72                     dst=cls.dst_if.remote_ip4) /
73                  UDP(sport=1234, dport=5678) /
74                  Raw(payload))
75             size = packet_sizes[(i // 2) % len(packet_sizes)]
76             cls.extend_packet(p, size, cls.padding)
77             info.data = p
78
79     @classmethod
80     def create_fragments(cls):
81         infos = cls._packet_infos
82         cls.pkt_infos = []
83         for index, info in six.iteritems(infos):
84             p = info.data
85             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
86             fragments_400 = fragment_rfc791(p, 400)
87             fragments_300 = fragment_rfc791(p, 300)
88             fragments_200 = [
89                 x for f in fragments_400 for x in fragment_rfc791(f, 200)]
90             cls.pkt_infos.append(
91                 (index, fragments_400, fragments_300, fragments_200))
92         cls.fragments_400 = [
93             x for (_, frags, _, _) in cls.pkt_infos for x in frags]
94         cls.fragments_300 = [
95             x for (_, _, frags, _) in cls.pkt_infos for x in frags]
96         cls.fragments_200 = [
97             x for (_, _, _, frags) in cls.pkt_infos for x in frags]
98         cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, "
99                          "%s 300-byte fragments and %s 200-byte fragments" %
100                          (len(infos), len(cls.fragments_400),
101                              len(cls.fragments_300), len(cls.fragments_200)))
102
103     def verify_capture(self, capture, dropped_packet_indexes=[]):
104         """Verify captured packet stream.
105
106         :param list capture: Captured packet stream.
107         """
108         info = None
109         seen = set()
110         for packet in capture:
111             try:
112                 self.logger.debug(ppp("Got packet:", packet))
113                 ip = packet[IP]
114                 udp = packet[UDP]
115                 payload_info = self.payload_to_info(str(packet[Raw]))
116                 packet_index = payload_info.index
117                 self.assertTrue(
118                     packet_index not in dropped_packet_indexes,
119                     ppp("Packet received, but should be dropped:", packet))
120                 if packet_index in seen:
121                     raise Exception(ppp("Duplicate packet received", packet))
122                 seen.add(packet_index)
123                 self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
124                 info = self._packet_infos[packet_index]
125                 self.assertTrue(info is not None)
126                 self.assertEqual(packet_index, info.index)
127                 saved_packet = info.data
128                 self.assertEqual(ip.src, saved_packet[IP].src)
129                 self.assertEqual(ip.dst, saved_packet[IP].dst)
130                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
131             except Exception:
132                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
133                 raise
134         for index in self._packet_infos:
135             self.assertTrue(index in seen or index in dropped_packet_indexes,
136                             "Packet with packet_index %d not received" % index)
137
138     def test_reassembly(self):
139         """ basic reassembly """
140
141         self.pg_enable_capture()
142         self.src_if.add_stream(self.fragments_200)
143         self.pg_start()
144
145         packets = self.dst_if.get_capture(len(self.pkt_infos))
146         self.verify_capture(packets)
147         self.src_if.assert_nothing_captured()
148
149         # run it all again to verify correctness
150         self.pg_enable_capture()
151         self.src_if.add_stream(self.fragments_200)
152         self.pg_start()
153
154         packets = self.dst_if.get_capture(len(self.pkt_infos))
155         self.verify_capture(packets)
156         self.src_if.assert_nothing_captured()
157
158     def test_reversed(self):
159         """ reverse order reassembly """
160
161         fragments = list(self.fragments_200)
162         fragments.reverse()
163
164         self.pg_enable_capture()
165         self.src_if.add_stream(fragments)
166         self.pg_start()
167
168         packets = self.dst_if.get_capture(len(self.packet_infos))
169         self.verify_capture(packets)
170         self.src_if.assert_nothing_captured()
171
172         # run it all again to verify correctness
173         self.pg_enable_capture()
174         self.src_if.add_stream(fragments)
175         self.pg_start()
176
177         packets = self.dst_if.get_capture(len(self.packet_infos))
178         self.verify_capture(packets)
179         self.src_if.assert_nothing_captured()
180
181     @unittest.skipIf(is_skip_aarch64_set() and is_platform_aarch64(),
182                      "test doesn't work on aarch64")
183     def test_random(self):
184         """ random order reassembly """
185
186         fragments = list(self.fragments_200)
187         shuffle(fragments)
188
189         self.pg_enable_capture()
190         self.src_if.add_stream(fragments)
191         self.pg_start()
192
193         packets = self.dst_if.get_capture(len(self.packet_infos))
194         self.verify_capture(packets)
195         self.src_if.assert_nothing_captured()
196
197         # run it all again to verify correctness
198         self.pg_enable_capture()
199         self.src_if.add_stream(fragments)
200         self.pg_start()
201
202         packets = self.dst_if.get_capture(len(self.packet_infos))
203         self.verify_capture(packets)
204         self.src_if.assert_nothing_captured()
205
206     def test_duplicates(self):
207         """ duplicate fragments """
208
209         fragments = [
210             x for (_, frags, _, _) in self.pkt_infos
211             for x in frags
212             for _ in range(0, min(2, len(frags)))
213         ]
214
215         self.pg_enable_capture()
216         self.src_if.add_stream(fragments)
217         self.pg_start()
218
219         packets = self.dst_if.get_capture(len(self.pkt_infos))
220         self.verify_capture(packets)
221         self.src_if.assert_nothing_captured()
222
223     def test_overlap1(self):
224         """ overlapping fragments case #1 """
225
226         fragments = []
227         for _, _, frags_300, frags_200 in self.pkt_infos:
228             if len(frags_300) == 1:
229                 fragments.extend(frags_300)
230             else:
231                 for i, j in zip(frags_200, frags_300):
232                     fragments.extend(i)
233                     fragments.extend(j)
234
235         self.pg_enable_capture()
236         self.src_if.add_stream(fragments)
237         self.pg_start()
238
239         packets = self.dst_if.get_capture(len(self.pkt_infos))
240         self.verify_capture(packets)
241         self.src_if.assert_nothing_captured()
242
243         # run it all to verify correctness
244         self.pg_enable_capture()
245         self.src_if.add_stream(fragments)
246         self.pg_start()
247
248         packets = self.dst_if.get_capture(len(self.pkt_infos))
249         self.verify_capture(packets)
250         self.src_if.assert_nothing_captured()
251
252     def test_overlap2(self):
253         """ overlapping fragments case #2 """
254
255         fragments = []
256         for _, _, frags_300, frags_200 in self.pkt_infos:
257             if len(frags_300) == 1:
258                 fragments.extend(frags_300)
259             else:
260                 # care must be taken here so that there are no fragments
261                 # received by vpp after reassembly is finished, otherwise
262                 # new reassemblies will be started and packet generator will
263                 # freak out when it detects unfreed buffers
264                 zipped = zip(frags_300, frags_200)
265                 for i, j in zipped[:-1]:
266                     fragments.extend(i)
267                     fragments.extend(j)
268                 fragments.append(zipped[-1][0])
269
270         self.pg_enable_capture()
271         self.src_if.add_stream(fragments)
272         self.pg_start()
273
274         packets = self.dst_if.get_capture(len(self.pkt_infos))
275         self.verify_capture(packets)
276         self.src_if.assert_nothing_captured()
277
278         # run it all to verify correctness
279         self.pg_enable_capture()
280         self.src_if.add_stream(fragments)
281         self.pg_start()
282
283         packets = self.dst_if.get_capture(len(self.pkt_infos))
284         self.verify_capture(packets)
285         self.src_if.assert_nothing_captured()
286
287     def test_timeout_inline(self):
288         """ timeout (inline) """
289
290         dropped_packet_indexes = set(
291             index for (index, frags, _, _) in self.pkt_infos if len(frags) > 1
292         )
293
294         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
295                                     expire_walk_interval_ms=10000)
296
297         self.pg_enable_capture()
298         self.src_if.add_stream(self.fragments_400)
299         self.pg_start()
300
301         packets = self.dst_if.get_capture(
302             len(self.pkt_infos) - len(dropped_packet_indexes))
303         self.verify_capture(packets, dropped_packet_indexes)
304         self.src_if.assert_nothing_captured()
305
306     def test_timeout_cleanup(self):
307         """ timeout (cleanup) """
308
309         # whole packets + fragmented packets sans last fragment
310         fragments = [
311             x for (_, frags_400, _, _) in self.pkt_infos
312             for x in frags_400[:-1 if len(frags_400) > 1 else None]
313         ]
314
315         # last fragments for fragmented packets
316         fragments2 = [frags_400[-1]
317                       for (_, frags_400, _, _) in self.pkt_infos
318                       if len(frags_400) > 1]
319
320         dropped_packet_indexes = set(
321             index for (index, frags_400, _, _) in self.pkt_infos
322             if len(frags_400) > 1)
323
324         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
325                                     expire_walk_interval_ms=50)
326
327         self.pg_enable_capture()
328         self.src_if.add_stream(fragments)
329         self.pg_start()
330
331         self.sleep(.25, "wait before sending rest of fragments")
332
333         self.src_if.add_stream(fragments2)
334         self.pg_start()
335
336         packets = self.dst_if.get_capture(
337             len(self.pkt_infos) - len(dropped_packet_indexes))
338         self.verify_capture(packets, dropped_packet_indexes)
339         self.src_if.assert_nothing_captured()
340
341     def test_disabled(self):
342         """ reassembly disabled """
343
344         dropped_packet_indexes = set(
345             index for (index, frags_400, _, _) in self.pkt_infos
346             if len(frags_400) > 1)
347
348         self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0,
349                                     expire_walk_interval_ms=10000)
350
351         self.pg_enable_capture()
352         self.src_if.add_stream(self.fragments_400)
353         self.pg_start()
354
355         packets = self.dst_if.get_capture(
356             len(self.pkt_infos) - len(dropped_packet_indexes))
357         self.verify_capture(packets, dropped_packet_indexes)
358         self.src_if.assert_nothing_captured()
359
360
361 class TestIPv6Reassembly(VppTestCase):
362     """ IPv6 Reassembly """
363
364     @classmethod
365     def setUpClass(cls):
366         super(TestIPv6Reassembly, cls).setUpClass()
367
368         cls.create_pg_interfaces([0, 1])
369         cls.src_if = cls.pg0
370         cls.dst_if = cls.pg1
371
372         # setup all interfaces
373         for i in cls.pg_interfaces:
374             i.admin_up()
375             i.config_ip6()
376             i.resolve_ndp()
377
378         # packet sizes
379         cls.packet_sizes = [64, 512, 1518, 9018]
380         cls.padding = " abcdefghijklmn"
381         cls.create_stream(cls.packet_sizes)
382         cls.create_fragments()
383
384     def setUp(self):
385         """ Test setup - force timeout on existing reassemblies """
386         super(TestIPv6Reassembly, self).setUp()
387         self.vapi.ip_reassembly_enable_disable(
388             sw_if_index=self.src_if.sw_if_index, enable_ip6=True)
389         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
390                                     expire_walk_interval_ms=10, is_ip6=1)
391         self.sleep(.25)
392         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
393                                     expire_walk_interval_ms=10000, is_ip6=1)
394         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
395
396     def tearDown(self):
397         super(TestIPv6Reassembly, self).tearDown()
398         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
399
400     @classmethod
401     def create_stream(cls, packet_sizes, packet_count=test_packet_count):
402         """Create input packet stream for defined interface.
403
404         :param list packet_sizes: Required packet sizes.
405         """
406         for i in range(0, packet_count):
407             info = cls.create_packet_info(cls.src_if, cls.src_if)
408             payload = cls.info_to_payload(info)
409             p = (Ether(dst=cls.src_if.local_mac, src=cls.src_if.remote_mac) /
410                  IPv6(src=cls.src_if.remote_ip6,
411                       dst=cls.dst_if.remote_ip6) /
412                  UDP(sport=1234, dport=5678) /
413                  Raw(payload))
414             size = packet_sizes[(i // 2) % len(packet_sizes)]
415             cls.extend_packet(p, size, cls.padding)
416             info.data = p
417
418     @classmethod
419     def create_fragments(cls):
420         infos = cls._packet_infos
421         cls.pkt_infos = []
422         for index, info in six.iteritems(infos):
423             p = info.data
424             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
425             fragments_400 = fragment_rfc8200(p, info.index, 400)
426             fragments_300 = fragment_rfc8200(p, info.index, 300)
427             cls.pkt_infos.append((index, fragments_400, fragments_300))
428         cls.fragments_400 = [
429             x for _, frags, _ in cls.pkt_infos for x in frags]
430         cls.fragments_300 = [
431             x for _, _, frags in cls.pkt_infos for x in frags]
432         cls.logger.debug("Fragmented %s packets into %s 400-byte fragments, "
433                          "and %s 300-byte fragments" %
434                          (len(infos), len(cls.fragments_400),
435                              len(cls.fragments_300)))
436
437     def verify_capture(self, capture, dropped_packet_indexes=[]):
438         """Verify captured packet strea .
439
440         :param list capture: Captured packet stream.
441         """
442         info = None
443         seen = set()
444         for packet in capture:
445             try:
446                 self.logger.debug(ppp("Got packet:", packet))
447                 ip = packet[IPv6]
448                 udp = packet[UDP]
449                 payload_info = self.payload_to_info(str(packet[Raw]))
450                 packet_index = payload_info.index
451                 self.assertTrue(
452                     packet_index not in dropped_packet_indexes,
453                     ppp("Packet received, but should be dropped:", packet))
454                 if packet_index in seen:
455                     raise Exception(ppp("Duplicate packet received", packet))
456                 seen.add(packet_index)
457                 self.assertEqual(payload_info.dst, self.src_if.sw_if_index)
458                 info = self._packet_infos[packet_index]
459                 self.assertTrue(info is not None)
460                 self.assertEqual(packet_index, info.index)
461                 saved_packet = info.data
462                 self.assertEqual(ip.src, saved_packet[IPv6].src)
463                 self.assertEqual(ip.dst, saved_packet[IPv6].dst)
464                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
465             except Exception:
466                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
467                 raise
468         for index in self._packet_infos:
469             self.assertTrue(index in seen or index in dropped_packet_indexes,
470                             "Packet with packet_index %d not received" % index)
471
472     def test_reassembly(self):
473         """ basic reassembly """
474
475         self.pg_enable_capture()
476         self.src_if.add_stream(self.fragments_400)
477         self.pg_start()
478
479         packets = self.dst_if.get_capture(len(self.pkt_infos))
480         self.verify_capture(packets)
481         self.src_if.assert_nothing_captured()
482
483         # run it all again to verify correctness
484         self.pg_enable_capture()
485         self.src_if.add_stream(self.fragments_400)
486         self.pg_start()
487
488         packets = self.dst_if.get_capture(len(self.pkt_infos))
489         self.verify_capture(packets)
490         self.src_if.assert_nothing_captured()
491
492     def test_reversed(self):
493         """ reverse order reassembly """
494
495         fragments = list(self.fragments_400)
496         fragments.reverse()
497
498         self.pg_enable_capture()
499         self.src_if.add_stream(fragments)
500         self.pg_start()
501
502         packets = self.dst_if.get_capture(len(self.pkt_infos))
503         self.verify_capture(packets)
504         self.src_if.assert_nothing_captured()
505
506         # run it all again to verify correctness
507         self.pg_enable_capture()
508         self.src_if.add_stream(fragments)
509         self.pg_start()
510
511         packets = self.dst_if.get_capture(len(self.pkt_infos))
512         self.verify_capture(packets)
513         self.src_if.assert_nothing_captured()
514
515     def test_random(self):
516         """ random order reassembly """
517
518         fragments = list(self.fragments_400)
519         shuffle(fragments)
520
521         self.pg_enable_capture()
522         self.src_if.add_stream(fragments)
523         self.pg_start()
524
525         packets = self.dst_if.get_capture(len(self.pkt_infos))
526         self.verify_capture(packets)
527         self.src_if.assert_nothing_captured()
528
529         # run it all again to verify correctness
530         self.pg_enable_capture()
531         self.src_if.add_stream(fragments)
532         self.pg_start()
533
534         packets = self.dst_if.get_capture(len(self.pkt_infos))
535         self.verify_capture(packets)
536         self.src_if.assert_nothing_captured()
537
538     def test_duplicates(self):
539         """ duplicate fragments """
540
541         fragments = [
542             x for (_, frags, _) in self.pkt_infos
543             for x in frags
544             for _ in range(0, min(2, len(frags)))
545         ]
546
547         self.pg_enable_capture()
548         self.src_if.add_stream(fragments)
549         self.pg_start()
550
551         packets = self.dst_if.get_capture(len(self.pkt_infos))
552         self.verify_capture(packets)
553         self.src_if.assert_nothing_captured()
554
555     def test_overlap1(self):
556         """ overlapping fragments case #1 """
557
558         fragments = []
559         for _, frags_400, frags_300 in self.pkt_infos:
560             if len(frags_300) == 1:
561                 fragments.extend(frags_400)
562             else:
563                 for i, j in zip(frags_300, frags_400):
564                     fragments.extend(i)
565                     fragments.extend(j)
566
567         dropped_packet_indexes = set(
568             index for (index, _, frags) in self.pkt_infos if len(frags) > 1
569         )
570
571         self.pg_enable_capture()
572         self.src_if.add_stream(fragments)
573         self.pg_start()
574
575         packets = self.dst_if.get_capture(
576             len(self.pkt_infos) - len(dropped_packet_indexes))
577         self.verify_capture(packets, dropped_packet_indexes)
578         self.src_if.assert_nothing_captured()
579
580     def test_overlap2(self):
581         """ overlapping fragments case #2 """
582
583         fragments = []
584         for _, frags_400, frags_300 in self.pkt_infos:
585             if len(frags_400) == 1:
586                 fragments.extend(frags_400)
587             else:
588                 # care must be taken here so that there are no fragments
589                 # received by vpp after reassembly is finished, otherwise
590                 # new reassemblies will be started and packet generator will
591                 # freak out when it detects unfreed buffers
592                 zipped = zip(frags_400, frags_300)
593                 for i, j in zipped[:-1]:
594                     fragments.extend(i)
595                     fragments.extend(j)
596                 fragments.append(zipped[-1][0])
597
598         dropped_packet_indexes = set(
599             index for (index, _, frags) in self.pkt_infos if len(frags) > 1
600         )
601
602         self.pg_enable_capture()
603         self.src_if.add_stream(fragments)
604         self.pg_start()
605
606         packets = self.dst_if.get_capture(
607             len(self.pkt_infos) - len(dropped_packet_indexes))
608         self.verify_capture(packets, dropped_packet_indexes)
609         self.src_if.assert_nothing_captured()
610
611     def test_timeout_inline(self):
612         """ timeout (inline) """
613
614         dropped_packet_indexes = set(
615             index for (index, frags, _) in self.pkt_infos if len(frags) > 1
616         )
617
618         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
619                                     expire_walk_interval_ms=10000, is_ip6=1)
620
621         self.pg_enable_capture()
622         self.src_if.add_stream(self.fragments_400)
623         self.pg_start()
624
625         packets = self.dst_if.get_capture(
626             len(self.pkt_infos) - len(dropped_packet_indexes))
627         self.verify_capture(packets, dropped_packet_indexes)
628         pkts = self.src_if.get_capture(
629             expected_count=len(dropped_packet_indexes))
630         for icmp in pkts:
631             self.assertIn(ICMPv6TimeExceeded, icmp)
632             self.assertIn(IPv6ExtHdrFragment, icmp)
633             self.assertIn(icmp[IPv6ExtHdrFragment].id, dropped_packet_indexes)
634             dropped_packet_indexes.remove(icmp[IPv6ExtHdrFragment].id)
635
636     def test_timeout_cleanup(self):
637         """ timeout (cleanup) """
638
639         # whole packets + fragmented packets sans last fragment
640         fragments = [
641             x for (_, frags_400, _) in self.pkt_infos
642             for x in frags_400[:-1 if len(frags_400) > 1 else None]
643         ]
644
645         # last fragments for fragmented packets
646         fragments2 = [frags_400[-1]
647                       for (_, frags_400, _) in self.pkt_infos
648                       if len(frags_400) > 1]
649
650         dropped_packet_indexes = set(
651             index for (index, frags_400, _) in self.pkt_infos
652             if len(frags_400) > 1)
653
654         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
655                                     expire_walk_interval_ms=50)
656
657         self.vapi.ip_reassembly_set(timeout_ms=100, max_reassemblies=1000,
658                                     expire_walk_interval_ms=50, is_ip6=1)
659
660         self.pg_enable_capture()
661         self.src_if.add_stream(fragments)
662         self.pg_start()
663
664         self.sleep(.25, "wait before sending rest of fragments")
665
666         self.src_if.add_stream(fragments2)
667         self.pg_start()
668
669         packets = self.dst_if.get_capture(
670             len(self.pkt_infos) - len(dropped_packet_indexes))
671         self.verify_capture(packets, dropped_packet_indexes)
672         pkts = self.src_if.get_capture(
673             expected_count=len(dropped_packet_indexes))
674         for icmp in pkts:
675             self.assertIn(ICMPv6TimeExceeded, icmp)
676             self.assertIn(IPv6ExtHdrFragment, icmp)
677             self.assertIn(icmp[IPv6ExtHdrFragment].id, dropped_packet_indexes)
678             dropped_packet_indexes.remove(icmp[IPv6ExtHdrFragment].id)
679
680     def test_disabled(self):
681         """ reassembly disabled """
682
683         dropped_packet_indexes = set(
684             index for (index, frags_400, _) in self.pkt_infos
685             if len(frags_400) > 1)
686
687         self.vapi.ip_reassembly_set(timeout_ms=1000, max_reassemblies=0,
688                                     expire_walk_interval_ms=10000, is_ip6=1)
689
690         self.pg_enable_capture()
691         self.src_if.add_stream(self.fragments_400)
692         self.pg_start()
693
694         packets = self.dst_if.get_capture(
695             len(self.pkt_infos) - len(dropped_packet_indexes))
696         self.verify_capture(packets, dropped_packet_indexes)
697         self.src_if.assert_nothing_captured()
698
699     def test_missing_upper(self):
700         """ missing upper layer """
701         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
702              IPv6(src=self.src_if.remote_ip6,
703                   dst=self.src_if.local_ip6) /
704              UDP(sport=1234, dport=5678) /
705              Raw())
706         self.extend_packet(p, 1000, self.padding)
707         fragments = fragment_rfc8200(p, 1, 500)
708         bad_fragment = p.__class__(str(fragments[1]))
709         bad_fragment[IPv6ExtHdrFragment].nh = 59
710         bad_fragment[IPv6ExtHdrFragment].offset = 0
711         self.pg_enable_capture()
712         self.src_if.add_stream([bad_fragment])
713         self.pg_start()
714         pkts = self.src_if.get_capture(expected_count=1)
715         icmp = pkts[0]
716         self.assertIn(ICMPv6ParamProblem, icmp)
717         self.assert_equal(icmp[ICMPv6ParamProblem].code, 3, "ICMP code")
718
719     def test_invalid_frag_size(self):
720         """ fragment size not a multiple of 8 """
721         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
722              IPv6(src=self.src_if.remote_ip6,
723                   dst=self.src_if.local_ip6) /
724              UDP(sport=1234, dport=5678) /
725              Raw())
726         self.extend_packet(p, 1000, self.padding)
727         fragments = fragment_rfc8200(p, 1, 500)
728         bad_fragment = fragments[0]
729         self.extend_packet(bad_fragment, len(bad_fragment) + 5)
730         self.pg_enable_capture()
731         self.src_if.add_stream([bad_fragment])
732         self.pg_start()
733         pkts = self.src_if.get_capture(expected_count=1)
734         icmp = pkts[0]
735         self.assertIn(ICMPv6ParamProblem, icmp)
736         self.assert_equal(icmp[ICMPv6ParamProblem].code, 0, "ICMP code")
737
738     def test_invalid_packet_size(self):
739         """ total packet size > 65535 """
740         p = (Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
741              IPv6(src=self.src_if.remote_ip6,
742                   dst=self.src_if.local_ip6) /
743              UDP(sport=1234, dport=5678) /
744              Raw())
745         self.extend_packet(p, 1000, self.padding)
746         fragments = fragment_rfc8200(p, 1, 500)
747         bad_fragment = fragments[1]
748         bad_fragment[IPv6ExtHdrFragment].offset = 65500
749         self.pg_enable_capture()
750         self.src_if.add_stream([bad_fragment])
751         self.pg_start()
752         pkts = self.src_if.get_capture(expected_count=1)
753         icmp = pkts[0]
754         self.assertIn(ICMPv6ParamProblem, icmp)
755         self.assert_equal(icmp[ICMPv6ParamProblem].code, 0, "ICMP code")
756
757
758 class TestIPv4ReassemblyLocalNode(VppTestCase):
759     """ IPv4 Reassembly for packets coming to ip4-local node """
760
761     @classmethod
762     def setUpClass(cls):
763         super(TestIPv4ReassemblyLocalNode, cls).setUpClass()
764
765         cls.create_pg_interfaces([0])
766         cls.src_dst_if = cls.pg0
767
768         # setup all interfaces
769         for i in cls.pg_interfaces:
770             i.admin_up()
771             i.config_ip4()
772             i.resolve_arp()
773
774         cls.padding = " abcdefghijklmn"
775         cls.create_stream()
776         cls.create_fragments()
777
778     def setUp(self):
779         """ Test setup - force timeout on existing reassemblies """
780         super(TestIPv4ReassemblyLocalNode, self).setUp()
781         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
782                                     expire_walk_interval_ms=10)
783         self.sleep(.25)
784         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
785                                     expire_walk_interval_ms=10000)
786
787     def tearDown(self):
788         super(TestIPv4ReassemblyLocalNode, self).tearDown()
789         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
790
791     @classmethod
792     def create_stream(cls, packet_count=test_packet_count):
793         """Create input packet stream for defined interface.
794
795         :param list packet_sizes: Required packet sizes.
796         """
797         for i in range(0, packet_count):
798             info = cls.create_packet_info(cls.src_dst_if, cls.src_dst_if)
799             payload = cls.info_to_payload(info)
800             p = (Ether(dst=cls.src_dst_if.local_mac,
801                        src=cls.src_dst_if.remote_mac) /
802                  IP(id=info.index, src=cls.src_dst_if.remote_ip4,
803                     dst=cls.src_dst_if.local_ip4) /
804                  ICMP(type='echo-request', id=1234) /
805                  Raw(payload))
806             cls.extend_packet(p, 1518, cls.padding)
807             info.data = p
808
809     @classmethod
810     def create_fragments(cls):
811         infos = cls._packet_infos
812         cls.pkt_infos = []
813         for index, info in six.iteritems(infos):
814             p = info.data
815             # cls.logger.debug(ppp("Packet:", p.__class__(str(p))))
816             fragments_300 = fragment_rfc791(p, 300)
817             cls.pkt_infos.append((index, fragments_300))
818         cls.fragments_300 = [x for (_, frags) in cls.pkt_infos for x in frags]
819         cls.logger.debug("Fragmented %s packets into %s 300-byte fragments" %
820                          (len(infos), len(cls.fragments_300)))
821
822     def verify_capture(self, capture):
823         """Verify captured packet stream.
824
825         :param list capture: Captured packet stream.
826         """
827         info = None
828         seen = set()
829         for packet in capture:
830             try:
831                 self.logger.debug(ppp("Got packet:", packet))
832                 ip = packet[IP]
833                 icmp = packet[ICMP]
834                 payload_info = self.payload_to_info(str(packet[Raw]))
835                 packet_index = payload_info.index
836                 if packet_index in seen:
837                     raise Exception(ppp("Duplicate packet received", packet))
838                 seen.add(packet_index)
839                 self.assertEqual(payload_info.dst, self.src_dst_if.sw_if_index)
840                 info = self._packet_infos[packet_index]
841                 self.assertTrue(info is not None)
842                 self.assertEqual(packet_index, info.index)
843                 saved_packet = info.data
844                 self.assertEqual(ip.src, saved_packet[IP].dst)
845                 self.assertEqual(ip.dst, saved_packet[IP].src)
846                 self.assertEqual(icmp.type, 0)  # echo reply
847                 self.assertEqual(icmp.id, saved_packet[ICMP].id)
848                 self.assertEqual(icmp.payload, saved_packet[ICMP].payload)
849             except Exception:
850                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
851                 raise
852         for index in self._packet_infos:
853             self.assertTrue(index in seen or index in dropped_packet_indexes,
854                             "Packet with packet_index %d not received" % index)
855
856     def test_reassembly(self):
857         """ basic reassembly """
858
859         self.pg_enable_capture()
860         self.src_dst_if.add_stream(self.fragments_300)
861         self.pg_start()
862
863         packets = self.src_dst_if.get_capture(len(self.pkt_infos))
864         self.verify_capture(packets)
865
866         # run it all again to verify correctness
867         self.pg_enable_capture()
868         self.src_dst_if.add_stream(self.fragments_300)
869         self.pg_start()
870
871         packets = self.src_dst_if.get_capture(len(self.pkt_infos))
872         self.verify_capture(packets)
873
874
875 class TestFIFReassembly(VppTestCase):
876     """ Fragments in fragments reassembly """
877
878     @classmethod
879     def setUpClass(cls):
880         super(TestFIFReassembly, cls).setUpClass()
881
882         cls.create_pg_interfaces([0, 1])
883         cls.src_if = cls.pg0
884         cls.dst_if = cls.pg1
885         for i in cls.pg_interfaces:
886             i.admin_up()
887             i.config_ip4()
888             i.resolve_arp()
889             i.config_ip6()
890             i.resolve_ndp()
891
892         cls.packet_sizes = [64, 512, 1518, 9018]
893         cls.padding = " abcdefghijklmn"
894
895     def setUp(self):
896         """ Test setup - force timeout on existing reassemblies """
897         super(TestFIFReassembly, self).setUp()
898         self.vapi.ip_reassembly_enable_disable(
899             sw_if_index=self.src_if.sw_if_index, enable_ip4=True,
900             enable_ip6=True)
901         self.vapi.ip_reassembly_enable_disable(
902             sw_if_index=self.dst_if.sw_if_index, enable_ip4=True,
903             enable_ip6=True)
904         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
905                                     expire_walk_interval_ms=10)
906         self.vapi.ip_reassembly_set(timeout_ms=0, max_reassemblies=1000,
907                                     expire_walk_interval_ms=10, is_ip6=1)
908         self.sleep(.25)
909         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
910                                     expire_walk_interval_ms=10000)
911         self.vapi.ip_reassembly_set(timeout_ms=1000000, max_reassemblies=1000,
912                                     expire_walk_interval_ms=10000, is_ip6=1)
913
914     def tearDown(self):
915         self.logger.debug(self.vapi.ppcli("show ip4-reassembly details"))
916         self.logger.debug(self.vapi.ppcli("show ip6-reassembly details"))
917         super(TestFIFReassembly, self).tearDown()
918
919     def verify_capture(self, capture, ip_class, dropped_packet_indexes=[]):
920         """Verify captured packet stream.
921
922         :param list capture: Captured packet stream.
923         """
924         info = None
925         seen = set()
926         for packet in capture:
927             try:
928                 self.logger.debug(ppp("Got packet:", packet))
929                 ip = packet[ip_class]
930                 udp = packet[UDP]
931                 payload_info = self.payload_to_info(str(packet[Raw]))
932                 packet_index = payload_info.index
933                 self.assertTrue(
934                     packet_index not in dropped_packet_indexes,
935                     ppp("Packet received, but should be dropped:", packet))
936                 if packet_index in seen:
937                     raise Exception(ppp("Duplicate packet received", packet))
938                 seen.add(packet_index)
939                 self.assertEqual(payload_info.dst, self.dst_if.sw_if_index)
940                 info = self._packet_infos[packet_index]
941                 self.assertTrue(info is not None)
942                 self.assertEqual(packet_index, info.index)
943                 saved_packet = info.data
944                 self.assertEqual(ip.src, saved_packet[ip_class].src)
945                 self.assertEqual(ip.dst, saved_packet[ip_class].dst)
946                 self.assertEqual(udp.payload, saved_packet[UDP].payload)
947             except Exception:
948                 self.logger.error(ppp("Unexpected or invalid packet:", packet))
949                 raise
950         for index in self._packet_infos:
951             self.assertTrue(index in seen or index in dropped_packet_indexes,
952                             "Packet with packet_index %d not received" % index)
953
954     def test_fif4(self):
955         """ Fragments in fragments (4o4) """
956
957         # TODO this should be ideally in setUpClass, but then we hit a bug
958         # with VppIpRoute incorrectly reporting it's present when it's not
959         # so we need to manually remove the vpp config, thus we cannot have
960         # it shared for multiple test cases
961         self.tun_ip4 = "1.1.1.2"
962
963         self.gre4 = VppGreInterface(self, self.src_if.local_ip4, self.tun_ip4)
964         self.gre4.add_vpp_config()
965         self.gre4.admin_up()
966         self.gre4.config_ip4()
967
968         self.vapi.ip_reassembly_enable_disable(
969             sw_if_index=self.gre4.sw_if_index, enable_ip4=True)
970
971         self.route4 = VppIpRoute(self, self.tun_ip4, 32,
972                                  [VppRoutePath(self.src_if.remote_ip4,
973                                                self.src_if.sw_if_index)])
974         self.route4.add_vpp_config()
975
976         self.reset_packet_infos()
977         for i in range(test_packet_count):
978             info = self.create_packet_info(self.src_if, self.dst_if)
979             payload = self.info_to_payload(info)
980             # Ethernet header here is only for size calculation, thus it
981             # doesn't matter how it's initialized. This is to ensure that
982             # reassembled packet is not > 9000 bytes, so that it's not dropped
983             p = (Ether() /
984                  IP(id=i, src=self.src_if.remote_ip4,
985                     dst=self.dst_if.remote_ip4) /
986                  UDP(sport=1234, dport=5678) /
987                  Raw(payload))
988             size = self.packet_sizes[(i // 2) % len(self.packet_sizes)]
989             self.extend_packet(p, size, self.padding)
990             info.data = p[IP]  # use only IP part, without ethernet header
991
992         fragments = [x for _, p in six.iteritems(self._packet_infos)
993                      for x in fragment_rfc791(p.data, 400)]
994
995         encapped_fragments = \
996             [Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
997              IP(src=self.tun_ip4, dst=self.src_if.local_ip4) /
998                 GRE() /
999                 p
1000                 for p in fragments]
1001
1002         fragmented_encapped_fragments = \
1003             [x for p in encapped_fragments
1004              for x in fragment_rfc791(p, 200)]
1005
1006         self.src_if.add_stream(fragmented_encapped_fragments)
1007
1008         self.pg_enable_capture(self.pg_interfaces)
1009         self.pg_start()
1010
1011         self.src_if.assert_nothing_captured()
1012         packets = self.dst_if.get_capture(len(self._packet_infos))
1013         self.verify_capture(packets, IP)
1014
1015         # TODO remove gre vpp config by hand until VppIpRoute gets fixed
1016         # so that it's query_vpp_config() works as it should
1017         self.gre4.remove_vpp_config()
1018         self.logger.debug(self.vapi.ppcli("show interface"))
1019
1020     def test_fif6(self):
1021         """ Fragments in fragments (6o6) """
1022         # TODO this should be ideally in setUpClass, but then we hit a bug
1023         # with VppIpRoute incorrectly reporting it's present when it's not
1024         # so we need to manually remove the vpp config, thus we cannot have
1025         # it shared for multiple test cases
1026         self.tun_ip6 = "1002::1"
1027
1028         self.gre6 = VppGre6Interface(self, self.src_if.local_ip6, self.tun_ip6)
1029         self.gre6.add_vpp_config()
1030         self.gre6.admin_up()
1031         self.gre6.config_ip6()
1032
1033         self.vapi.ip_reassembly_enable_disable(
1034             sw_if_index=self.gre6.sw_if_index, enable_ip6=True)
1035
1036         self.route6 = VppIpRoute(self, self.tun_ip6, 128,
1037                                  [VppRoutePath(self.src_if.remote_ip6,
1038                                                self.src_if.sw_if_index,
1039                                                proto=DpoProto.DPO_PROTO_IP6)],
1040                                  is_ip6=1)
1041         self.route6.add_vpp_config()
1042
1043         self.reset_packet_infos()
1044         for i in range(test_packet_count):
1045             info = self.create_packet_info(self.src_if, self.dst_if)
1046             payload = self.info_to_payload(info)
1047             # Ethernet header here is only for size calculation, thus it
1048             # doesn't matter how it's initialized. This is to ensure that
1049             # reassembled packet is not > 9000 bytes, so that it's not dropped
1050             p = (Ether() /
1051                  IPv6(src=self.src_if.remote_ip6, dst=self.dst_if.remote_ip6) /
1052                  UDP(sport=1234, dport=5678) /
1053                  Raw(payload))
1054             size = self.packet_sizes[(i // 2) % len(self.packet_sizes)]
1055             self.extend_packet(p, size, self.padding)
1056             info.data = p[IPv6]  # use only IPv6 part, without ethernet header
1057
1058         fragments = [x for _, i in six.iteritems(self._packet_infos)
1059                      for x in fragment_rfc8200(
1060                          i.data, i.index, 400)]
1061
1062         encapped_fragments = \
1063             [Ether(dst=self.src_if.local_mac, src=self.src_if.remote_mac) /
1064              IPv6(src=self.tun_ip6, dst=self.src_if.local_ip6) /
1065                 GRE() /
1066                 p
1067                 for p in fragments]
1068
1069         fragmented_encapped_fragments = \
1070             [x for p in encapped_fragments for x in (
1071                 fragment_rfc8200(
1072                     p,
1073                     2 * len(self._packet_infos) + p[IPv6ExtHdrFragment].id,
1074                     200)
1075                 if IPv6ExtHdrFragment in p else [p]
1076             )
1077             ]
1078
1079         self.src_if.add_stream(fragmented_encapped_fragments)
1080
1081         self.pg_enable_capture(self.pg_interfaces)
1082         self.pg_start()
1083
1084         self.src_if.assert_nothing_captured()
1085         packets = self.dst_if.get_capture(len(self._packet_infos))
1086         self.verify_capture(packets, IPv6)
1087
1088         # TODO remove gre vpp config by hand until VppIpRoute gets fixed
1089         # so that it's query_vpp_config() works as it should
1090         self.gre6.remove_vpp_config()
1091
1092
1093 if __name__ == '__main__':
1094     unittest.main(testRunner=VppTestRunner)