make test: improve documentation and PEP8 compliance
[vpp.git] / test / framework.py
index aa4f2fd..b2c6b9e 100644 (file)
@@ -5,16 +5,18 @@ import unittest
 import tempfile
 import time
 import resource
-from time import sleep
-from Queue import Queue
+from collections import deque
 from threading import Thread
 from inspect import getdoc
 from hook import StepHook, PollHook
 from vpp_pg_interface import VppPGInterface
+from vpp_sub_interface import VppSubInterface
 from vpp_lo_interface import VppLoInterface
 from vpp_papi_provider import VppPapiProvider
 from scapy.packet import Raw
+from logging import FileHandler, DEBUG
 from log import *
+from vpp_object import VppObjectRegistry
 
 """
   Test framework module.
@@ -40,10 +42,17 @@ class _PacketInfo(object):
     #: Store the copy of the former packet.
     data = None
 
+    def __eq__(self, other):
+        index = self.index == other.index
+        src = self.src == other.src
+        dst = self.dst == other.dst
+        data = self.data == other.data
+        return index and src and dst and data
 
-def pump_output(out, queue):
+
+def pump_output(out, deque):
     for line in iter(out.readline, b''):
-        queue.put(line)
+        deque.append(line)
 
 
 class VppTestCase(unittest.TestCase):
@@ -56,9 +65,13 @@ class VppTestCase(unittest.TestCase):
         """List of packet infos"""
         return self._packet_infos
 
-    @packet_infos.setter
-    def packet_infos(self, value):
-        self._packet_infos = value
+    @classmethod
+    def get_packet_count_for_if_idx(cls, dst_if_index):
+        """Get the number of packet info for specified destination if index"""
+        if dst_if_index in cls._packet_count_for_dst_if_idx:
+            return cls._packet_count_for_dst_if_idx[dst_if_index]
+        else:
+            return 0
 
     @classmethod
     def instance(cls):
@@ -104,7 +117,8 @@ class VppTestCase(unittest.TestCase):
         debug_cli = ""
         if cls.step or cls.debug_gdb or cls.debug_gdbserver:
             debug_cli = "cli-listen localhost:5002"
-        cls.vpp_cmdline = [cls.vpp_bin, "unix", "{", "nodaemon", debug_cli, "}",
+        cls.vpp_cmdline = [cls.vpp_bin,
+                           "unix", "{", "nodaemon", debug_cli, "}",
                            "api-segment", "{", "prefix", cls.shm_prefix, "}"]
         if cls.plugin_path is not None:
             cls.vpp_cmdline.extend(["plugin_path", cls.plugin_path])
@@ -169,15 +183,20 @@ class VppTestCase(unittest.TestCase):
         cls.logger = getLogger(cls.__name__)
         cls.tempdir = tempfile.mkdtemp(
             prefix='vpp-unittest-' + cls.__name__ + '-')
+        file_handler = FileHandler("%s/log.txt" % cls.tempdir)
+        file_handler.setLevel(DEBUG)
+        cls.logger.addHandler(file_handler)
         cls.shm_prefix = cls.tempdir.split("/")[-1]
         os.chdir(cls.tempdir)
         cls.logger.info("Temporary dir is %s, shm prefix is %s",
                         cls.tempdir, cls.shm_prefix)
         cls.setUpConstants()
-        cls.pg_streams = []
-        cls.packet_infos = {}
+        cls.reset_packet_infos()
+        cls._captures = []
+        cls._zombie_captures = []
         cls.verbose = 0
         cls.vpp_dead = False
+        cls.registry = VppObjectRegistry()
         print(double_line_delim)
         print(colorize(getdoc(cls).splitlines()[0], YELLOW))
         print(double_line_delim)
@@ -185,13 +204,13 @@ class VppTestCase(unittest.TestCase):
         # doesn't get called and we might end with a zombie vpp
         try:
             cls.run_vpp()
-            cls.vpp_stdout_queue = Queue()
+            cls.vpp_stdout_deque = deque()
             cls.vpp_stdout_reader_thread = Thread(target=pump_output, args=(
-                cls.vpp.stdout, cls.vpp_stdout_queue))
+                cls.vpp.stdout, cls.vpp_stdout_deque))
             cls.vpp_stdout_reader_thread.start()
-            cls.vpp_stderr_queue = Queue()
+            cls.vpp_stderr_deque = deque()
             cls.vpp_stderr_reader_thread = Thread(target=pump_output, args=(
-                cls.vpp.stderr, cls.vpp_stderr_queue))
+                cls.vpp.stderr, cls.vpp_stderr_deque))
             cls.vpp_stderr_reader_thread.start()
             cls.vapi = VppPapiProvider(cls.shm_prefix, cls.shm_prefix, cls)
             if cls.step:
@@ -228,8 +247,8 @@ class VppTestCase(unittest.TestCase):
                 print(double_line_delim)
                 print("VPP or GDB server is still running")
                 print(single_line_delim)
-                raw_input("When done debugging, press ENTER to kill the process"
-                          " and finish running the testcase...")
+                raw_input("When done debugging, press ENTER to kill the "
+                          "process and finish running the testcase...")
 
         if hasattr(cls, 'vpp'):
             if hasattr(cls, 'vapi'):
@@ -239,27 +258,26 @@ class VppTestCase(unittest.TestCase):
                 cls.vpp.terminate()
             del cls.vpp
 
-        if hasattr(cls, 'vpp_stdout_queue'):
+        if hasattr(cls, 'vpp_stdout_deque'):
             cls.logger.info(single_line_delim)
             cls.logger.info('VPP output to stdout while running %s:',
                             cls.__name__)
             cls.logger.info(single_line_delim)
             f = open(cls.tempdir + '/vpp_stdout.txt', 'w')
-            while not cls.vpp_stdout_queue.empty():
-                line = cls.vpp_stdout_queue.get_nowait()
-                f.write(line)
-                cls.logger.info('VPP stdout: %s' % line.rstrip('\n'))
+            vpp_output = "".join(cls.vpp_stdout_deque)
+            f.write(vpp_output)
+            cls.logger.info('\n%s', vpp_output)
+            cls.logger.info(single_line_delim)
 
-        if hasattr(cls, 'vpp_stderr_queue'):
+        if hasattr(cls, 'vpp_stderr_deque'):
             cls.logger.info(single_line_delim)
             cls.logger.info('VPP output to stderr while running %s:',
                             cls.__name__)
             cls.logger.info(single_line_delim)
             f = open(cls.tempdir + '/vpp_stderr.txt', 'w')
-            while not cls.vpp_stderr_queue.empty():
-                line = cls.vpp_stderr_queue.get_nowait()
-                f.write(line)
-                cls.logger.info('VPP stderr: %s' % line.rstrip('\n'))
+            vpp_output = "".join(cls.vpp_stderr_deque)
+            f.write(vpp_output)
+            cls.logger.info('\n%s', vpp_output)
             cls.logger.info(single_line_delim)
 
     @classmethod
@@ -275,9 +293,21 @@ class VppTestCase(unittest.TestCase):
             self.logger.info(self.vapi.ppcli("show hardware"))
             self.logger.info(self.vapi.ppcli("show error"))
             self.logger.info(self.vapi.ppcli("show run"))
+            self.registry.remove_vpp_config(self.logger)
 
     def setUp(self):
         """ Clear trace before running each test"""
+        if self.vpp_dead:
+            raise Exception("VPP is dead when setting up the test")
+        time.sleep(.1)
+        self.vpp_stdout_deque.append(
+            "--- test setUp() for %s.%s(%s) starts here ---\n" %
+            (self.__class__.__name__, self._testMethodName,
+             self._testMethodDoc))
+        self.vpp_stderr_deque.append(
+            "--- test setUp() for %s.%s(%s) starts here ---\n" %
+            (self.__class__.__name__, self._testMethodName,
+             self._testMethodDoc))
         self.vapi.cli("clear trace")
         # store the test instance inside the test class - so that objects
         # holding the class can access instance methods (like assertEqual)
@@ -294,25 +324,45 @@ class VppTestCase(unittest.TestCase):
         for i in interfaces:
             i.enable_capture()
 
+    @classmethod
+    def register_capture(cls, cap_name):
+        """ Register a capture in the testclass """
+        # add to the list of captures with current timestamp
+        cls._captures.append((time.time(), cap_name))
+        # filter out from zombies
+        cls._zombie_captures = [(stamp, name)
+                                for (stamp, name) in cls._zombie_captures
+                                if name != cap_name]
+
     @classmethod
     def pg_start(cls):
-        """
-        Enable the packet-generator and send all prepared packet streams
-        Remove the packet streams afterwards
-        """
+        """ Remove any zombie captures and enable the packet generator """
+        # how long before capture is allowed to be deleted - otherwise vpp
+        # crashes - 100ms seems enough (this shouldn't be needed at all)
+        capture_ttl = 0.1
+        now = time.time()
+        for stamp, cap_name in cls._zombie_captures:
+            wait = stamp + capture_ttl - now
+            if wait > 0:
+                cls.logger.debug("Waiting for %ss before deleting capture %s",
+                                 wait, cap_name)
+                time.sleep(wait)
+                now = time.time()
+            cls.logger.debug("Removing zombie capture %s" % cap_name)
+            cls.vapi.cli('packet-generator delete %s' % cap_name)
+
         cls.vapi.cli("trace add pg-input 50")  # 50 is maximum
         cls.vapi.cli('packet-generator enable')
-        sleep(1)  # give VPP some time to process the packets
-        for stream in cls.pg_streams:
-            cls.vapi.cli('packet-generator delete %s' % stream)
-        cls.pg_streams = []
+        cls._zombie_captures = cls._captures
+        cls._captures = []
 
     @classmethod
     def create_pg_interfaces(cls, interfaces):
         """
-        Create packet-generator interfaces
+        Create packet-generator interfaces.
 
-        :param interfaces: iterable indexes of the interfaces
+        :param interfaces: iterable indexes of the interfaces.
+        :returns: List of created interfaces.
 
         """
         result = []
@@ -326,10 +376,10 @@ class VppTestCase(unittest.TestCase):
     @classmethod
     def create_loopback_interfaces(cls, interfaces):
         """
-        Create loopback interfaces
-
-        :param interfaces: iterable indexes of the interfaces
+        Create loopback interfaces.
 
+        :param interfaces: iterable indexes of the interfaces.
+        :returns: List of created interfaces.
         """
         result = []
         for i in interfaces:
@@ -354,31 +404,37 @@ class VppTestCase(unittest.TestCase):
         if extend > 0:
             packet[Raw].load += ' ' * extend
 
-    def add_packet_info_to_list(self, info):
-        """
-        Add packet info to the testcase's packet info list
-
-        :param info: packet info
-
-        """
-        info.index = len(self.packet_infos)
-        self.packet_infos[info.index] = info
+    @classmethod
+    def reset_packet_infos(cls):
+        """ Reset the list of packet info objects and packet counts to zero """
+        cls._packet_infos = {}
+        cls._packet_count_for_dst_if_idx = {}
 
-    def create_packet_info(self, src_pg_index, dst_pg_index):
+    @classmethod
+    def create_packet_info(cls, src_if, dst_if):
         """
         Create packet info object containing the source and destination indexes
         and add it to the testcase's packet info list
 
-        :param src_pg_index: source packet-generator index
-        :param dst_pg_index: destination packet-generator index
+        :param VppInterface src_if: source interface
+        :param VppInterface dst_if: destination interface
 
         :returns: _PacketInfo object
 
         """
         info = _PacketInfo()
-        self.add_packet_info_to_list(info)
-        info.src = src_pg_index
-        info.dst = dst_pg_index
+        info.index = len(cls._packet_infos)
+        info.src = src_if.sw_if_index
+        info.dst = dst_if.sw_if_index
+        if isinstance(dst_if, VppSubInterface):
+            dst_idx = dst_if.parent.sw_if_index
+        else:
+            dst_idx = dst_if.sw_if_index
+        if dst_idx in cls._packet_count_for_dst_if_idx:
+            cls._packet_count_for_dst_if_idx[dst_idx] += 1
+        else:
+            cls._packet_count_for_dst_if_idx[dst_idx] = 1
+        cls._packet_infos[info.index] = info
         return info
 
     @staticmethod
@@ -422,10 +478,10 @@ class VppTestCase(unittest.TestCase):
             next_index = 0
         else:
             next_index = info.index + 1
-        if next_index == len(self.packet_infos):
+        if next_index == len(self._packet_infos):
             return None
         else:
-            return self.packet_infos[next_index]
+            return self._packet_infos[next_index]
 
     def get_next_packet_info_for_interface(self, src_index, info):
         """
@@ -508,10 +564,10 @@ class VppTestResult(unittest.TestResult):
 
     def __init__(self, stream, descriptions, verbosity):
         """
-        :param stream File descriptor to store where to report test results. Set
-            to the standard error stream by default.
-        :param descriptions Boolean variable to store information if to use test
-            case descriptions.
+        :param stream File descriptor to store where to report test results.
+            Set to the standard error stream by default.
+        :param descriptions Boolean variable to store information if to use
+            test case descriptions.
         :param verbosity Integer variable to store required verbosity level.
         """
         unittest.TestResult.__init__(self, stream, descriptions, verbosity)
@@ -609,12 +665,12 @@ class VppTestResult(unittest.TestResult):
         unittest.TestResult.stopTest(self, test)
         if self.verbosity > 0:
             self.stream.writeln(single_line_delim)
-            self.stream.writeln("%-60s%s" %
-                                (self.getDescription(test), self.result_string))
+            self.stream.writeln("%-60s%s" % (self.getDescription(test),
+                                             self.result_string))
             self.stream.writeln(single_line_delim)
         else:
-            self.stream.writeln("%-60s%s" %
-                                (self.getDescription(test), self.result_string))
+            self.stream.writeln("%-60s%s" % (self.getDescription(test),
+                                             self.result_string))
 
     def printErrors(self):
         """