IGMP improvements
[vpp.git] / test / framework.py
index 423feaf..4ecb66f 100644 (file)
@@ -17,7 +17,7 @@ from inspect import getdoc, isclass
 from traceback import format_exception
 from logging import FileHandler, DEBUG, Formatter
 from scapy.packet import Raw
-from hook import StepHook, PollHook
+from hook import StepHook, PollHook, VppDiedError
 from vpp_pg_interface import VppPGInterface
 from vpp_sub_interface import VppSubInterface
 from vpp_lo_interface import VppLoInterface
@@ -25,6 +25,10 @@ from vpp_papi_provider import VppPapiProvider
 from log import RED, GREEN, YELLOW, double_line_delim, single_line_delim, \
     getLogger, colorize
 from vpp_object import VppObjectRegistry
+from util import ppp
+from scapy.layers.inet import IPerror, TCPerror, UDPerror, ICMPerror
+from scapy.layers.inet6 import ICMPv6DestUnreach, ICMPv6EchoRequest
+from scapy.layers.inet6 import ICMPv6EchoReply
 if os.name == 'posix' and sys.version_info[0] < 3:
     # using subprocess32 is recommended by python official documentation
     # @ https://docs.python.org/2/library/subprocess.html
@@ -32,6 +36,7 @@ if os.name == 'posix' and sys.version_info[0] < 3:
 else:
     import subprocess
 
+
 debug_framework = False
 if os.getenv('TEST_DEBUG', "0") == "1":
     debug_framework = True
@@ -120,21 +125,13 @@ def pump_output(testclass):
 
 
 def running_extended_tests():
-    try:
-        s = os.getenv("EXTENDED_TESTS")
-        return True if s.lower() in ("y", "yes", "1") else False
-    except:
-        return False
-    return False
+    s = os.getenv("EXTENDED_TESTS", "n")
+    return True if s.lower() in ("y", "yes", "1") else False
 
 
 def running_on_centos():
-    try:
-        os_id = os.getenv("OS_ID")
-        return True if "centos" in os_id.lower() else False
-    except:
-        return False
-    return False
+    os_id = os.getenv("OS_ID", "")
+    return True if "centos" in os_id.lower() else False
 
 
 class KeepAliveReporter(object):
@@ -217,21 +214,11 @@ class VppTestCase(unittest.TestCase):
     @classmethod
     def setUpConstants(cls):
         """ Set-up the test case class based on environment variables """
-        try:
-            s = os.getenv("STEP")
-            cls.step = True if s.lower() in ("y", "yes", "1") else False
-        except:
-            cls.step = False
-        try:
-            d = os.getenv("DEBUG")
-        except:
-            d = None
-        try:
-            c = os.getenv("CACHE_OUTPUT", "1")
-            cls.cache_vpp_output = \
-                False if c.lower() in ("n", "no", "0") else True
-        except:
-            cls.cache_vpp_output = True
+        s = os.getenv("STEP", "n")
+        cls.step = True if s.lower() in ("y", "yes", "1") else False
+        d = os.getenv("DEBUG", None)
+        c = os.getenv("CACHE_OUTPUT", "1")
+        cls.cache_vpp_output = False if c.lower() in ("n", "no", "0") else True
         cls.set_debug_flags(d)
         cls.vpp_bin = os.getenv('VPP_TEST_BIN', "vpp")
         cls.plugin_path = os.getenv('VPP_TEST_PLUGIN_PATH')
@@ -249,12 +236,9 @@ class VppTestCase(unittest.TestCase):
         if cls.step or cls.debug_gdb or cls.debug_gdbserver:
             debug_cli = "cli-listen localhost:5002"
         coredump_size = None
-        try:
-            size = os.getenv("COREDUMP_SIZE")
-            if size is not None:
-                coredump_size = "coredump-size %s" % size
-        except:
-            pass
+        size = os.getenv("COREDUMP_SIZE")
+        if size is not None:
+            coredump_size = "coredump-size %s" % size
         if coredump_size is None:
             coredump_size = "coredump-size unlimited"
         cls.vpp_cmdline = [cls.vpp_bin, "unix",
@@ -368,7 +352,7 @@ class VppTestCase(unittest.TestCase):
             cls.sleep(0.1, "after vpp startup, before initial poll")
             try:
                 hook.poll_vpp()
-            except:
+            except VppDiedError:
                 cls.vpp_startup_failed = True
                 cls.logger.critical(
                     "VPP died shortly after startup, check the"
@@ -376,19 +360,22 @@ class VppTestCase(unittest.TestCase):
                 raise
             try:
                 cls.vapi.connect()
-            except:
+            except Exception:
+                try:
+                    cls.vapi.disconnect()
+                except Exception:
+                    pass
                 if cls.debug_gdbserver:
                     print(colorize("You're running VPP inside gdbserver but "
                                    "VPP-API connection failed, did you forget "
                                    "to 'continue' VPP from within gdb?", RED))
                 raise
-        except:
-            t, v, tb = sys.exc_info()
+        except Exception:
             try:
                 cls.quit()
-            except:
+            except Exception:
                 pass
-            raise (t, v, tb)
+            raise
 
     @classmethod
     def quit(cls):
@@ -472,6 +459,7 @@ class VppTestCase(unittest.TestCase):
             self.logger.info(self.vapi.ppcli("show hardware"))
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show run"))
+            self.logger.info(self.vapi.ppcli("show log"))
             self.registry.remove_vpp_config(self.logger)
             # Save/Dump VPP api trace log
             api_trace = "vpp_api_trace.%s.log" % self._testMethodName
@@ -570,18 +558,16 @@ class VppTestCase(unittest.TestCase):
         return result
 
     @classmethod
-    def create_loopback_interfaces(cls, interfaces):
+    def create_loopback_interfaces(cls, count):
         """
         Create loopback interfaces.
 
-        :param interfaces: iterable indexes of the interfaces.
+        :param count: number of interfaces created.
         :returns: List of created interfaces.
         """
-        result = []
-        for i in interfaces:
-            intf = VppLoInterface(cls, i)
+        result = [VppLoInterface(cls) for i in range(count)]
+        for intf in result:
             setattr(cls, intf.name, intf)
-            result.append(intf)
         cls.lo_interfaces = result
         return result
 
@@ -728,7 +714,7 @@ class VppTestCase(unittest.TestCase):
             msg = msg % (getdoc(name_or_class).strip(),
                          real_value, str(name_or_class(real_value)),
                          expected_value, str(name_or_class(expected_value)))
-        except:
+        except Exception:
             msg = "Invalid %s: %s does not match expected value %s" % (
                 name_or_class, real_value, expected_value)
 
@@ -746,6 +732,95 @@ class VppTestCase(unittest.TestCase):
                 name, real_value, expected_min, expected_max)
         self.assertTrue(expected_min <= real_value <= expected_max, msg)
 
+    def assert_packet_checksums_valid(self, packet,
+                                      ignore_zero_udp_checksums=True):
+        received = packet.__class__(str(packet))
+        self.logger.debug(
+            ppp("Verifying packet checksums for packet:", received))
+        udp_layers = ['UDP', 'UDPerror']
+        checksum_fields = ['cksum', 'chksum']
+        checksums = []
+        counter = 0
+        temp = received.__class__(str(received))
+        while True:
+            layer = temp.getlayer(counter)
+            if layer:
+                for cf in checksum_fields:
+                    if hasattr(layer, cf):
+                        if ignore_zero_udp_checksums and \
+                                0 == getattr(layer, cf) and \
+                                layer.name in udp_layers:
+                            continue
+                        delattr(layer, cf)
+                        checksums.append((counter, cf))
+            else:
+                break
+            counter = counter + 1
+        if 0 == len(checksums):
+            return
+        temp = temp.__class__(str(temp))
+        for layer, cf in checksums:
+            calc_sum = getattr(temp[layer], cf)
+            self.assert_equal(
+                getattr(received[layer], cf), calc_sum,
+                "packet checksum on layer #%d: %s" % (layer, temp[layer].name))
+            self.logger.debug(
+                "Checksum field `%s` on `%s` layer has correct value `%s`" %
+                (cf, temp[layer].name, calc_sum))
+
+    def assert_checksum_valid(self, received_packet, layer,
+                              field_name='chksum',
+                              ignore_zero_checksum=False):
+        """ Check checksum of received packet on given layer """
+        received_packet_checksum = getattr(received_packet[layer], field_name)
+        if ignore_zero_checksum and 0 == received_packet_checksum:
+            return
+        recalculated = received_packet.__class__(str(received_packet))
+        delattr(recalculated[layer], field_name)
+        recalculated = recalculated.__class__(str(recalculated))
+        self.assert_equal(received_packet_checksum,
+                          getattr(recalculated[layer], field_name),
+                          "packet checksum on layer: %s" % layer)
+
+    def assert_ip_checksum_valid(self, received_packet,
+                                 ignore_zero_checksum=False):
+        self.assert_checksum_valid(received_packet, 'IP',
+                                   ignore_zero_checksum=ignore_zero_checksum)
+
+    def assert_tcp_checksum_valid(self, received_packet,
+                                  ignore_zero_checksum=False):
+        self.assert_checksum_valid(received_packet, 'TCP',
+                                   ignore_zero_checksum=ignore_zero_checksum)
+
+    def assert_udp_checksum_valid(self, received_packet,
+                                  ignore_zero_checksum=True):
+        self.assert_checksum_valid(received_packet, 'UDP',
+                                   ignore_zero_checksum=ignore_zero_checksum)
+
+    def assert_embedded_icmp_checksum_valid(self, received_packet):
+        if received_packet.haslayer(IPerror):
+            self.assert_checksum_valid(received_packet, 'IPerror')
+        if received_packet.haslayer(TCPerror):
+            self.assert_checksum_valid(received_packet, 'TCPerror')
+        if received_packet.haslayer(UDPerror):
+            self.assert_checksum_valid(received_packet, 'UDPerror',
+                                       ignore_zero_checksum=True)
+        if received_packet.haslayer(ICMPerror):
+            self.assert_checksum_valid(received_packet, 'ICMPerror')
+
+    def assert_icmp_checksum_valid(self, received_packet):
+        self.assert_checksum_valid(received_packet, 'ICMP')
+        self.assert_embedded_icmp_checksum_valid(received_packet)
+
+    def assert_icmpv6_checksum_valid(self, pkt):
+        if pkt.haslayer(ICMPv6DestUnreach):
+            self.assert_checksum_valid(pkt, 'ICMPv6DestUnreach', 'cksum')
+            self.assert_embedded_icmp_checksum_valid(pkt)
+        if pkt.haslayer(ICMPv6EchoRequest):
+            self.assert_checksum_valid(pkt, 'ICMPv6EchoRequest', 'cksum')
+        if pkt.haslayer(ICMPv6EchoReply):
+            self.assert_checksum_valid(pkt, 'ICMPv6EchoReply', 'cksum')
+
     @classmethod
     def sleep(cls, timeout, remark=None):
         if hasattr(cls, 'logger'):
@@ -762,12 +837,13 @@ class VppTestCase(unittest.TestCase):
                 "Finished sleep (%s) - slept %ss (wanted %ss)" % (
                     remark, after - before, timeout))
 
-    def send_and_assert_no_replies(self, intf, pkts, remark=""):
+    def send_and_assert_no_replies(self, intf, pkts, remark="", timeout=None):
         self.vapi.cli("clear trace")
         intf.add_stream(pkts)
         self.pg_enable_capture(self.pg_interfaces)
         self.pg_start()
-        timeout = 1
+        if not timeout:
+            timeout = 1
         for i in self.pg_interfaces:
             i.get_capture(0, timeout=timeout)
             i.assert_nothing_captured(remark=remark)
@@ -882,7 +958,25 @@ class VppTestResult(unittest.TestResult):
         if hasattr(self, 'test_framework_failed_pipe'):
             pipe = self.test_framework_failed_pipe
             if pipe:
-                pipe.send(test.__class__)
+                if test.__class__.__name__ == "_ErrorHolder":
+                    x = str(test)
+                    if x.startswith("setUpClass"):
+                        # x looks like setUpClass (test_function.test_class)
+                        cls = x.split(".")[1].split(")")[0]
+                        for t in self.test_suite:
+                            if t.__class__.__name__ == cls:
+                                pipe.send(t.__class__)
+                                break
+                        else:
+                            raise Exception("Can't find class name `%s' "
+                                            "(from ErrorHolder) in test suite "
+                                            "`%s'" % (cls, self.test_suite))
+                    else:
+                        raise Exception("FIXME: unexpected special case - "
+                                        "ErrorHolder description is `%s'" %
+                                        str(test))
+                else:
+                    pipe.send(test.__class__)
 
     def addFailure(self, test, err):
         """
@@ -1047,10 +1141,7 @@ class VppTestRunner(unittest.TextTestRunner):
     test_option = "TEST"
 
     def parse_test_option(self):
-        try:
-            f = os.getenv(self.test_option)
-        except:
-            f = None
+        f = os.getenv(self.test_option, None)
         filter_file_name = None
         filter_class_name = None
         filter_func_name = None
@@ -1120,6 +1211,8 @@ class VppTestRunner(unittest.TextTestRunner):
             filtered.countTestCases(), test.countTestCases()))
         if not running_extended_tests():
             print("Not running extended tests (some tests will be skipped)")
+        # super-ugly hack #2
+        VppTestResult.test_suite = filtered
         return super(VppTestRunner, self).run(filtered)