hs-test: clean up Makefile for compatibility with ci-management
[vpp.git] / test / framework.py
1 #!/usr/bin/env python3
2
3 from __future__ import print_function
4 import logging
5 import sys
6 import os
7 import select
8 import signal
9 import subprocess
10 import unittest
11 import re
12 import time
13 import faulthandler
14 import random
15 import copy
16 import platform
17 import shutil
18 from collections import deque
19 from threading import Thread, Event
20 from inspect import getdoc, isclass
21 from traceback import format_exception
22 from logging import FileHandler, DEBUG, Formatter
23 from enum import Enum
24 from abc import ABC, abstractmethod
25 from struct import pack, unpack
26
27 import scapy.compat
28 from scapy.packet import Raw, Packet
29 from vpp_pg_interface import VppPGInterface
30 from vpp_sub_interface import VppSubInterface
31 from vpp_lo_interface import VppLoInterface
32 from vpp_bvi_interface import VppBviInterface
33 from vpp_papi_provider import VppPapiProvider
34 from vpp_papi import VppEnum
35 import vpp_papi
36 from vpp_object import VppObjectRegistry
37 from util import ppp, is_core_present
38 from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror
39 from scapy.layers.inet6 import ICMPv6DestUnreach, ICMPv6EchoRequest
40 from scapy.layers.inet6 import ICMPv6EchoReply
41 from vpp_running import use_running
42 from asfframework import VppAsfTestCase
43
44
45 """
46   Packet Generator / Scapy Test framework module.
47
48   The module provides a set of tools for constructing and running tests and
49   representing the results.
50 """
51
52
53 class _PacketInfo(object):
54     """Private class to create packet info object.
55
56     Help process information about the next packet.
57     Set variables to default values.
58     """
59
60     #: Store the index of the packet.
61     index = -1
62     #: Store the index of the source packet generator interface of the packet.
63     src = -1
64     #: Store the index of the destination packet generator interface
65     #: of the packet.
66     dst = -1
67     #: Store expected ip version
68     ip = -1
69     #: Store expected upper protocol
70     proto = -1
71     #: Store the copy of the former packet.
72     data = None
73
74     def __eq__(self, other):
75         index = self.index == other.index
76         src = self.src == other.src
77         dst = self.dst == other.dst
78         data = self.data == other.data
79         return index and src and dst and data
80
81
82 @use_running
83 class VppTestCase(VppAsfTestCase):
84     """This subclass is a base class for VPP test cases that are implemented as
85     classes. It provides methods to create and run test case.
86     """
87
88     @property
89     def packet_infos(self):
90         """List of packet infos"""
91         return self._packet_infos
92
93     @classmethod
94     def get_packet_count_for_if_idx(cls, dst_if_index):
95         """Get the number of packet info for specified destination if index"""
96         if dst_if_index in cls._packet_count_for_dst_if_idx:
97             return cls._packet_count_for_dst_if_idx[dst_if_index]
98         else:
99             return 0
100
101     @classmethod
102     def setUpClass(cls):
103         super(VppTestCase, cls).setUpClass()
104         cls.reset_packet_infos()
105         cls._pcaps = []
106         cls._old_pcaps = []
107
108     @classmethod
109     def tearDownClass(cls):
110         cls.logger.debug("--- tearDownClass() for %s called ---" % cls.__name__)
111         cls.reset_packet_infos()
112         super(VppTestCase, cls).tearDownClass()
113
114     @classmethod
115     def pg_enable_capture(cls, interfaces=None):
116         """
117         Enable capture on packet-generator interfaces
118
119         :param interfaces: iterable interface indexes (if None,
120                            use self.pg_interfaces)
121
122         """
123         if interfaces is None:
124             interfaces = cls.pg_interfaces
125         for i in interfaces:
126             i.enable_capture()
127
128     @classmethod
129     def register_pcap(cls, intf, worker):
130         """Register a pcap in the testclass"""
131         # add to the list of captures with current timestamp
132         cls._pcaps.append((intf, worker))
133
134     @classmethod
135     def pg_start(cls, trace=True, traceFilter=False):
136         """Enable the PG, wait till it is done, then clean up"""
137         for intf, worker in cls._old_pcaps:
138             intf.remove_old_pcap_file(intf.get_in_path(worker))
139         cls._old_pcaps = []
140         if trace:
141             cls.vapi.cli("clear trace")
142             cls.vapi.cli("trace add pg-input 1000" + (" filter" if traceFilter else ""))
143         cls.vapi.cli("packet-generator enable")
144         # PG, when starts, runs to completion -
145         # so let's avoid a race condition,
146         # and wait a little till it's done.
147         # Then clean it up  - and then be gone.
148         deadline = time.time() + 300
149         while cls.vapi.cli("show packet-generator").find("Yes") != -1:
150             cls.sleep(0.01)  # yield
151             if time.time() > deadline:
152                 cls.logger.error("Timeout waiting for pg to stop")
153                 break
154         for intf, worker in cls._pcaps:
155             cls.vapi.cli("packet-generator delete %s" % intf.get_cap_name(worker))
156         cls._old_pcaps = cls._pcaps
157         cls._pcaps = []
158
159     @classmethod
160     def create_pg_interfaces_internal(cls, interfaces, gso=0, gso_size=0, mode=None):
161         """
162         Create packet-generator interfaces.
163
164         :param interfaces: iterable indexes of the interfaces.
165         :returns: List of created interfaces.
166
167         """
168         result = []
169         for i in interfaces:
170             intf = VppPGInterface(cls, i, gso, gso_size, mode)
171             setattr(cls, intf.name, intf)
172             result.append(intf)
173         cls.pg_interfaces = result
174         return result
175
176     @classmethod
177     def create_pg_ip4_interfaces(cls, interfaces, gso=0, gso_size=0):
178         if not hasattr(cls, "vpp"):
179             cls.pg_interfaces = []
180             return cls.pg_interfaces
181         pgmode = VppEnum.vl_api_pg_interface_mode_t
182         return cls.create_pg_interfaces_internal(
183             interfaces, gso, gso_size, pgmode.PG_API_MODE_IP4
184         )
185
186     @classmethod
187     def create_pg_ip6_interfaces(cls, interfaces, gso=0, gso_size=0):
188         if not hasattr(cls, "vpp"):
189             cls.pg_interfaces = []
190             return cls.pg_interfaces
191         pgmode = VppEnum.vl_api_pg_interface_mode_t
192         return cls.create_pg_interfaces_internal(
193             interfaces, gso, gso_size, pgmode.PG_API_MODE_IP6
194         )
195
196     @classmethod
197     def create_pg_interfaces(cls, interfaces, gso=0, gso_size=0):
198         if not hasattr(cls, "vpp"):
199             cls.pg_interfaces = []
200             return cls.pg_interfaces
201         pgmode = VppEnum.vl_api_pg_interface_mode_t
202         return cls.create_pg_interfaces_internal(
203             interfaces, gso, gso_size, pgmode.PG_API_MODE_ETHERNET
204         )
205
206     @classmethod
207     def create_pg_ethernet_interfaces(cls, interfaces, gso=0, gso_size=0):
208         if not hasattr(cls, "vpp"):
209             cls.pg_interfaces = []
210             return cls.pg_interfaces
211         pgmode = VppEnum.vl_api_pg_interface_mode_t
212         return cls.create_pg_interfaces_internal(
213             interfaces, gso, gso_size, pgmode.PG_API_MODE_ETHERNET
214         )
215
216     @classmethod
217     def create_loopback_interfaces(cls, count):
218         """
219         Create loopback interfaces.
220
221         :param count: number of interfaces created.
222         :returns: List of created interfaces.
223         """
224         if not hasattr(cls, "vpp"):
225             cls.lo_interfaces = []
226             return cls.lo_interfaces
227         result = [VppLoInterface(cls) for i in range(count)]
228         for intf in result:
229             setattr(cls, intf.name, intf)
230         cls.lo_interfaces = result
231         return result
232
233     @classmethod
234     def create_bvi_interfaces(cls, count):
235         """
236         Create BVI interfaces.
237
238         :param count: number of interfaces created.
239         :returns: List of created interfaces.
240         """
241         if not hasattr(cls, "vpp"):
242             cls.bvi_interfaces = []
243             return cls.bvi_interfaces
244         result = [VppBviInterface(cls) for i in range(count)]
245         for intf in result:
246             setattr(cls, intf.name, intf)
247         cls.bvi_interfaces = result
248         return result
249
250     @staticmethod
251     def extend_packet(packet, size, padding=" "):
252         """
253         Extend packet to given size by padding with spaces or custom padding
254         NOTE: Currently works only when Raw layer is present.
255
256         :param packet: packet
257         :param size: target size
258         :param padding: padding used to extend the payload
259
260         """
261         packet_len = len(packet) + 4
262         extend = size - packet_len
263         if extend > 0:
264             num = (extend // len(padding)) + 1
265             packet[Raw].load += (padding * num)[:extend].encode("ascii")
266
267     @classmethod
268     def reset_packet_infos(cls):
269         """Reset the list of packet info objects and packet counts to zero"""
270         cls._packet_infos = {}
271         cls._packet_count_for_dst_if_idx = {}
272
273     @classmethod
274     def create_packet_info(cls, src_if, dst_if):
275         """
276         Create packet info object containing the source and destination indexes
277         and add it to the testcase's packet info list
278
279         :param VppInterface src_if: source interface
280         :param VppInterface dst_if: destination interface
281
282         :returns: _PacketInfo object
283
284         """
285         info = _PacketInfo()
286         info.index = len(cls._packet_infos)
287         info.src = src_if.sw_if_index
288         info.dst = dst_if.sw_if_index
289         if isinstance(dst_if, VppSubInterface):
290             dst_idx = dst_if.parent.sw_if_index
291         else:
292             dst_idx = dst_if.sw_if_index
293         if dst_idx in cls._packet_count_for_dst_if_idx:
294             cls._packet_count_for_dst_if_idx[dst_idx] += 1
295         else:
296             cls._packet_count_for_dst_if_idx[dst_idx] = 1
297         cls._packet_infos[info.index] = info
298         return info
299
300     @staticmethod
301     def info_to_payload(info):
302         """
303         Convert _PacketInfo object to packet payload
304
305         :param info: _PacketInfo object
306
307         :returns: string containing serialized data from packet info
308         """
309
310         # retrieve payload, currently 18 bytes (4 x ints + 1 short)
311         return pack("iiiih", info.index, info.src, info.dst, info.ip, info.proto)
312
313     @staticmethod
314     def payload_to_info(payload, payload_field="load"):
315         """
316         Convert packet payload to _PacketInfo object
317
318         :param payload: packet payload
319         :type payload:  <class 'scapy.packet.Raw'>
320         :param payload_field: packet fieldname of payload "load" for
321                 <class 'scapy.packet.Raw'>
322         :type payload_field: str
323         :returns: _PacketInfo object containing de-serialized data from payload
324
325         """
326
327         # retrieve payload, currently 18 bytes (4 x ints + 1 short)
328         payload_b = getattr(payload, payload_field)[:18]
329
330         info = _PacketInfo()
331         info.index, info.src, info.dst, info.ip, info.proto = unpack("iiiih", payload_b)
332
333         # some SRv6 TCs depend on get an exception if bad values are detected
334         if info.index > 0x4000:
335             raise ValueError("Index value is invalid")
336
337         return info
338
339     def get_next_packet_info(self, info):
340         """
341         Iterate over the packet info list stored in the testcase
342         Start iteration with first element if info is None
343         Continue based on index in info if info is specified
344
345         :param info: info or None
346         :returns: next info in list or None if no more infos
347         """
348         if info is None:
349             next_index = 0
350         else:
351             next_index = info.index + 1
352         if next_index == len(self._packet_infos):
353             return None
354         else:
355             return self._packet_infos[next_index]
356
357     def get_next_packet_info_for_interface(self, src_index, info):
358         """
359         Search the packet info list for the next packet info with same source
360         interface index
361
362         :param src_index: source interface index to search for
363         :param info: packet info - where to start the search
364         :returns: packet info or None
365
366         """
367         while True:
368             info = self.get_next_packet_info(info)
369             if info is None:
370                 return None
371             if info.src == src_index:
372                 return info
373
374     def get_next_packet_info_for_interface2(self, src_index, dst_index, info):
375         """
376         Search the packet info list for the next packet info with same source
377         and destination interface indexes
378
379         :param src_index: source interface index to search for
380         :param dst_index: destination interface index to search for
381         :param info: packet info - where to start the search
382         :returns: packet info or None
383
384         """
385         while True:
386             info = self.get_next_packet_info_for_interface(src_index, info)
387             if info is None:
388                 return None
389             if info.dst == dst_index:
390                 return info
391
392     def assert_packet_checksums_valid(self, packet, ignore_zero_udp_checksums=True):
393         received = packet.__class__(scapy.compat.raw(packet))
394         udp_layers = ["UDP", "UDPerror"]
395         checksum_fields = ["cksum", "chksum"]
396         checksums = []
397         counter = 0
398         temp = received.__class__(scapy.compat.raw(received))
399         while True:
400             layer = temp.getlayer(counter)
401             if layer:
402                 layer = layer.copy()
403                 layer.remove_payload()
404                 for cf in checksum_fields:
405                     if hasattr(layer, cf):
406                         if (
407                             ignore_zero_udp_checksums
408                             and 0 == getattr(layer, cf)
409                             and layer.name in udp_layers
410                         ):
411                             continue
412                         delattr(temp.getlayer(counter), cf)
413                         checksums.append((counter, cf))
414             else:
415                 break
416             counter = counter + 1
417         if 0 == len(checksums):
418             return
419         temp = temp.__class__(scapy.compat.raw(temp))
420         for layer, cf in reversed(checksums):
421             calc_sum = getattr(temp[layer], cf)
422             self.assert_equal(
423                 getattr(received[layer], cf),
424                 calc_sum,
425                 "packet checksum on layer #%d: %s" % (layer, temp[layer].name),
426             )
427             self.logger.debug(
428                 "Checksum field `%s` on `%s` layer has correct value `%s`"
429                 % (cf, temp[layer].name, calc_sum)
430             )
431
432     def assert_checksum_valid(
433         self,
434         received_packet,
435         layer,
436         checksum_field_names=["chksum", "cksum"],
437         ignore_zero_checksum=False,
438     ):
439         """Check checksum of received packet on given layer"""
440         layer_copy = received_packet[layer].copy()
441         layer_copy.remove_payload()
442         field_name = None
443         for f in checksum_field_names:
444             if hasattr(layer_copy, f):
445                 field_name = f
446                 break
447         if field_name is None:
448             raise Exception(
449                 f"Layer `{layer}` has none of checksum fields: `{checksum_field_names}`."
450             )
451         received_packet_checksum = getattr(received_packet[layer], field_name)
452         if ignore_zero_checksum and 0 == received_packet_checksum:
453             return
454         recalculated = received_packet.__class__(scapy.compat.raw(received_packet))
455         delattr(recalculated[layer], field_name)
456         recalculated = recalculated.__class__(scapy.compat.raw(recalculated))
457         self.assert_equal(
458             received_packet_checksum,
459             getattr(recalculated[layer], field_name),
460             f"packet checksum (field: {field_name}) on layer: %s" % layer,
461         )
462
463     def assert_ip_checksum_valid(self, received_packet, ignore_zero_checksum=False):
464         self.assert_checksum_valid(
465             received_packet, "IP", ignore_zero_checksum=ignore_zero_checksum
466         )
467
468     def assert_tcp_checksum_valid(self, received_packet, ignore_zero_checksum=False):
469         self.assert_checksum_valid(
470             received_packet, "TCP", ignore_zero_checksum=ignore_zero_checksum
471         )
472
473     def assert_udp_checksum_valid(self, received_packet, ignore_zero_checksum=True):
474         self.assert_checksum_valid(
475             received_packet, "UDP", ignore_zero_checksum=ignore_zero_checksum
476         )
477
478     def assert_embedded_icmp_checksum_valid(self, received_packet):
479         if received_packet.haslayer(IPerror):
480             self.assert_checksum_valid(received_packet, "IPerror")
481         if received_packet.haslayer(TCPerror):
482             self.assert_checksum_valid(received_packet, "TCPerror")
483         if received_packet.haslayer(UDPerror):
484             self.assert_checksum_valid(
485                 received_packet, "UDPerror", ignore_zero_checksum=True
486             )
487         if received_packet.haslayer(ICMPerror):
488             self.assert_checksum_valid(received_packet, "ICMPerror")
489
490     def assert_icmp_checksum_valid(self, received_packet):
491         self.assert_checksum_valid(received_packet, "ICMP")
492         self.assert_embedded_icmp_checksum_valid(received_packet)
493
494     def assert_icmpv6_checksum_valid(self, pkt):
495         if pkt.haslayer(ICMPv6DestUnreach):
496             self.assert_checksum_valid(pkt, "ICMPv6DestUnreach")
497             self.assert_embedded_icmp_checksum_valid(pkt)
498         if pkt.haslayer(ICMPv6EchoRequest):
499             self.assert_checksum_valid(pkt, "ICMPv6EchoRequest")
500         if pkt.haslayer(ICMPv6EchoReply):
501             self.assert_checksum_valid(pkt, "ICMPv6EchoReply")
502
503     def assert_packet_counter_equal(self, counter, expected_value):
504         counter_value = self.get_counter(counter)
505         self.assert_equal(
506             counter_value, expected_value, "packet counter `%s'" % counter
507         )
508
509     def pg_send(self, intf, pkts, worker=None, trace=True):
510         intf.add_stream(pkts, worker=worker)
511         self.pg_enable_capture(self.pg_interfaces)
512         self.pg_start(trace=trace)
513
514     def send_and_assert_no_replies(
515         self, intf, pkts, remark="", timeout=None, stats_diff=None, trace=True, msg=None
516     ):
517         if stats_diff:
518             stats_snapshot = self.snapshot_stats(stats_diff)
519
520         self.pg_send(intf, pkts)
521
522         try:
523             if not timeout:
524                 timeout = 1
525             for i in self.pg_interfaces:
526                 i.assert_nothing_captured(timeout=timeout, remark=remark)
527                 timeout = 0.1
528         finally:
529             if trace:
530                 if msg:
531                     self.logger.debug(f"send_and_assert_no_replies: {msg}")
532                 self.logger.debug(self.vapi.cli("show trace"))
533
534         if stats_diff:
535             self.compare_stats_with_snapshot(stats_diff, stats_snapshot)
536
537     def send_and_expect(
538         self,
539         intf,
540         pkts,
541         output,
542         n_rx=None,
543         worker=None,
544         trace=True,
545         msg=None,
546         stats_diff=None,
547     ):
548         if stats_diff:
549             stats_snapshot = self.snapshot_stats(stats_diff)
550
551         if not n_rx:
552             n_rx = 1 if isinstance(pkts, Packet) else len(pkts)
553         self.pg_send(intf, pkts, worker=worker, trace=trace)
554         rx = output.get_capture(n_rx)
555         if trace:
556             if msg:
557                 self.logger.debug(f"send_and_expect: {msg}")
558             self.logger.debug(self.vapi.cli("show trace"))
559
560         if stats_diff:
561             self.compare_stats_with_snapshot(stats_diff, stats_snapshot)
562
563         return rx
564
565     def send_and_expect_load_balancing(
566         self, input, pkts, outputs, worker=None, trace=True
567     ):
568         self.pg_send(input, pkts, worker=worker, trace=trace)
569         rxs = []
570         for oo in outputs:
571             rx = oo._get_capture(1)
572             self.assertNotEqual(0, len(rx), f"0 != len(rx) ({len(rx)})")
573             rxs.append(rx)
574         if trace:
575             self.logger.debug(self.vapi.cli("show trace"))
576         return rxs
577
578     def send_and_expect_some(self, intf, pkts, output, worker=None, trace=True):
579         self.pg_send(intf, pkts, worker=worker, trace=trace)
580         rx = output._get_capture(1)
581         if trace:
582             self.logger.debug(self.vapi.cli("show trace"))
583         self.assertTrue(len(rx) > 0)
584         self.assertTrue(
585             len(rx) <= len(pkts), f"len(rx) ({len(rx)}) > len(pkts) ({len(pkts)})"
586         )
587         return rx
588
589     def send_and_expect_only(self, intf, pkts, output, timeout=None, stats_diff=None):
590         if stats_diff:
591             stats_snapshot = self.snapshot_stats(stats_diff)
592
593         self.pg_send(intf, pkts)
594         rx = output.get_capture(len(pkts))
595         outputs = [output]
596         if not timeout:
597             timeout = 1
598         for i in self.pg_interfaces:
599             if i not in outputs:
600                 i.assert_nothing_captured(timeout=timeout)
601                 timeout = 0.1
602
603         if stats_diff:
604             self.compare_stats_with_snapshot(stats_diff, stats_snapshot)
605
606         return rx
607
608
609 if __name__ == "__main__":
610     pass