make test: split into basic and extended tests
[vpp.git] / test / framework.py
1 #!/usr/bin/env python
2
3 from __future__ import print_function
4 import gc
5 import sys
6 import os
7 import select
8 import unittest
9 import tempfile
10 import time
11 import resource
12 from collections import deque
13 from threading import Thread, Event
14 from inspect import getdoc
15 from traceback import format_exception
16 from logging import FileHandler, DEBUG, Formatter
17 from scapy.packet import Raw
18 from hook import StepHook, PollHook
19 from vpp_pg_interface import VppPGInterface
20 from vpp_sub_interface import VppSubInterface
21 from vpp_lo_interface import VppLoInterface
22 from vpp_papi_provider import VppPapiProvider
23 from log import *
24 from vpp_object import VppObjectRegistry
25 if os.name == 'posix' and sys.version_info[0] < 3:
26     # using subprocess32 is recommended by python official documentation
27     # @ https://docs.python.org/2/library/subprocess.html
28     import subprocess32 as subprocess
29 else:
30     import subprocess
31
32 """
33   Test framework module.
34
35   The module provides a set of tools for constructing and running tests and
36   representing the results.
37 """
38
39
40 class _PacketInfo(object):
41     """Private class to create packet info object.
42
43     Help process information about the next packet.
44     Set variables to default values.
45     """
46     #: Store the index of the packet.
47     index = -1
48     #: Store the index of the source packet generator interface of the packet.
49     src = -1
50     #: Store the index of the destination packet generator interface
51     #: of the packet.
52     dst = -1
53     #: Store the copy of the former packet.
54     data = None
55
56     def __eq__(self, other):
57         index = self.index == other.index
58         src = self.src == other.src
59         dst = self.dst == other.dst
60         data = self.data == other.data
61         return index and src and dst and data
62
63
64 def pump_output(testclass):
65     """ pump output from vpp stdout/stderr to proper queues """
66     while not testclass.pump_thread_stop_flag.wait(0):
67         readable = select.select([testclass.vpp.stdout.fileno(),
68                                   testclass.vpp.stderr.fileno(),
69                                   testclass.pump_thread_wakeup_pipe[0]],
70                                  [], [])[0]
71         if testclass.vpp.stdout.fileno() in readable:
72             read = os.read(testclass.vpp.stdout.fileno(), 1024)
73             testclass.vpp_stdout_deque.append(read)
74         if testclass.vpp.stderr.fileno() in readable:
75             read = os.read(testclass.vpp.stderr.fileno(), 1024)
76             testclass.vpp_stderr_deque.append(read)
77         # ignoring the dummy pipe here intentionally - the flag will take care
78         # of properly terminating the loop
79
80
81 def running_extended_tests():
82     try:
83         s = os.getenv("EXTENDED_TESTS")
84         return True if s.lower() in ("y", "yes", "1") else False
85     except:
86         return False
87     return False
88
89
90 class VppTestCase(unittest.TestCase):
91     """This subclass is a base class for VPP test cases that are implemented as
92     classes. It provides methods to create and run test case.
93     """
94
95     @property
96     def packet_infos(self):
97         """List of packet infos"""
98         return self._packet_infos
99
100     @classmethod
101     def get_packet_count_for_if_idx(cls, dst_if_index):
102         """Get the number of packet info for specified destination if index"""
103         if dst_if_index in cls._packet_count_for_dst_if_idx:
104             return cls._packet_count_for_dst_if_idx[dst_if_index]
105         else:
106             return 0
107
108     @classmethod
109     def instance(cls):
110         """Return the instance of this testcase"""
111         return cls.test_instance
112
113     @classmethod
114     def set_debug_flags(cls, d):
115         cls.debug_core = False
116         cls.debug_gdb = False
117         cls.debug_gdbserver = False
118         if d is None:
119             return
120         dl = d.lower()
121         if dl == "core":
122             cls.debug_core = True
123         elif dl == "gdb":
124             cls.debug_gdb = True
125         elif dl == "gdbserver":
126             cls.debug_gdbserver = True
127         else:
128             raise Exception("Unrecognized DEBUG option: '%s'" % d)
129
130     @classmethod
131     def setUpConstants(cls):
132         """ Set-up the test case class based on environment variables """
133         try:
134             s = os.getenv("STEP")
135             cls.step = True if s.lower() in ("y", "yes", "1") else False
136         except:
137             cls.step = False
138         try:
139             d = os.getenv("DEBUG")
140         except:
141             d = None
142         cls.set_debug_flags(d)
143         cls.vpp_bin = os.getenv('VPP_TEST_BIN', "vpp")
144         cls.plugin_path = os.getenv('VPP_TEST_PLUGIN_PATH')
145         debug_cli = ""
146         if cls.step or cls.debug_gdb or cls.debug_gdbserver:
147             debug_cli = "cli-listen localhost:5002"
148         coredump_size = None
149         try:
150             size = os.getenv("COREDUMP_SIZE")
151             if size is not None:
152                 coredump_size = "coredump-size %s" % size
153         except:
154             pass
155         if coredump_size is None:
156             coredump_size = "coredump-size unlimited"
157         cls.vpp_cmdline = [cls.vpp_bin, "unix",
158                            "{", "nodaemon", debug_cli, coredump_size, "}",
159                            "api-trace", "{", "on", "}",
160                            "api-segment", "{", "prefix", cls.shm_prefix, "}"]
161         if cls.plugin_path is not None:
162             cls.vpp_cmdline.extend(["plugin_path", cls.plugin_path])
163         cls.logger.info("vpp_cmdline: %s" % cls.vpp_cmdline)
164
165     @classmethod
166     def wait_for_enter(cls):
167         if cls.debug_gdbserver:
168             print(double_line_delim)
169             print("Spawned GDB server with PID: %d" % cls.vpp.pid)
170         elif cls.debug_gdb:
171             print(double_line_delim)
172             print("Spawned VPP with PID: %d" % cls.vpp.pid)
173         else:
174             cls.logger.debug("Spawned VPP with PID: %d" % cls.vpp.pid)
175             return
176         print(single_line_delim)
177         print("You can debug the VPP using e.g.:")
178         if cls.debug_gdbserver:
179             print("gdb " + cls.vpp_bin + " -ex 'target remote localhost:7777'")
180             print("Now is the time to attach a gdb by running the above "
181                   "command, set up breakpoints etc. and then resume VPP from "
182                   "within gdb by issuing the 'continue' command")
183         elif cls.debug_gdb:
184             print("gdb " + cls.vpp_bin + " -ex 'attach %s'" % cls.vpp.pid)
185             print("Now is the time to attach a gdb by running the above "
186                   "command and set up breakpoints etc.")
187         print(single_line_delim)
188         raw_input("Press ENTER to continue running the testcase...")
189
190     @classmethod
191     def run_vpp(cls):
192         cmdline = cls.vpp_cmdline
193
194         if cls.debug_gdbserver:
195             gdbserver = '/usr/bin/gdbserver'
196             if not os.path.isfile(gdbserver) or \
197                     not os.access(gdbserver, os.X_OK):
198                 raise Exception("gdbserver binary '%s' does not exist or is "
199                                 "not executable" % gdbserver)
200
201             cmdline = [gdbserver, 'localhost:7777'] + cls.vpp_cmdline
202             cls.logger.info("Gdbserver cmdline is %s", " ".join(cmdline))
203
204         try:
205             cls.vpp = subprocess.Popen(cmdline,
206                                        stdout=subprocess.PIPE,
207                                        stderr=subprocess.PIPE,
208                                        bufsize=1)
209         except Exception as e:
210             cls.logger.critical("Couldn't start vpp: %s" % e)
211             raise
212
213         cls.wait_for_enter()
214
215     @classmethod
216     def setUpClass(cls):
217         """
218         Perform class setup before running the testcase
219         Remove shared memory files, start vpp and connect the vpp-api
220         """
221         gc.collect()  # run garbage collection first
222         cls.logger = getLogger(cls.__name__)
223         cls.tempdir = tempfile.mkdtemp(
224             prefix='vpp-unittest-' + cls.__name__ + '-')
225         file_handler = FileHandler("%s/log.txt" % cls.tempdir)
226         file_handler.setFormatter(
227             Formatter(fmt='%(asctime)s,%(msecs)03d %(message)s',
228                       datefmt="%H:%M:%S"))
229         file_handler.setLevel(DEBUG)
230         cls.logger.addHandler(file_handler)
231         cls.shm_prefix = cls.tempdir.split("/")[-1]
232         os.chdir(cls.tempdir)
233         cls.logger.info("Temporary dir is %s, shm prefix is %s",
234                         cls.tempdir, cls.shm_prefix)
235         cls.setUpConstants()
236         cls.reset_packet_infos()
237         cls._captures = []
238         cls._zombie_captures = []
239         cls.verbose = 0
240         cls.vpp_dead = False
241         cls.registry = VppObjectRegistry()
242         # need to catch exceptions here because if we raise, then the cleanup
243         # doesn't get called and we might end with a zombie vpp
244         try:
245             cls.run_vpp()
246             cls.vpp_stdout_deque = deque()
247             cls.vpp_stderr_deque = deque()
248             cls.pump_thread_stop_flag = Event()
249             cls.pump_thread_wakeup_pipe = os.pipe()
250             cls.pump_thread = Thread(target=pump_output, args=(cls,))
251             cls.pump_thread.daemon = True
252             cls.pump_thread.start()
253             cls.vapi = VppPapiProvider(cls.shm_prefix, cls.shm_prefix, cls)
254             if cls.step:
255                 hook = StepHook(cls)
256             else:
257                 hook = PollHook(cls)
258             cls.vapi.register_hook(hook)
259             cls.sleep(0.1, "after vpp startup, before initial poll")
260             hook.poll_vpp()
261             try:
262                 cls.vapi.connect()
263             except:
264                 if cls.debug_gdbserver:
265                     print(colorize("You're running VPP inside gdbserver but "
266                                    "VPP-API connection failed, did you forget "
267                                    "to 'continue' VPP from within gdb?", RED))
268                 raise
269         except:
270             t, v, tb = sys.exc_info()
271             try:
272                 cls.quit()
273             except:
274                 pass
275             raise t, v, tb
276
277     @classmethod
278     def quit(cls):
279         """
280         Disconnect vpp-api, kill vpp and cleanup shared memory files
281         """
282         if (cls.debug_gdbserver or cls.debug_gdb) and hasattr(cls, 'vpp'):
283             cls.vpp.poll()
284             if cls.vpp.returncode is None:
285                 print(double_line_delim)
286                 print("VPP or GDB server is still running")
287                 print(single_line_delim)
288                 raw_input("When done debugging, press ENTER to kill the "
289                           "process and finish running the testcase...")
290
291         os.write(cls.pump_thread_wakeup_pipe[1], 'ding dong wake up')
292         cls.pump_thread_stop_flag.set()
293         if hasattr(cls, 'pump_thread'):
294             cls.logger.debug("Waiting for pump thread to stop")
295             cls.pump_thread.join()
296         if hasattr(cls, 'vpp_stderr_reader_thread'):
297             cls.logger.debug("Waiting for stdderr pump to stop")
298             cls.vpp_stderr_reader_thread.join()
299
300         if hasattr(cls, 'vpp'):
301             if hasattr(cls, 'vapi'):
302                 cls.vapi.disconnect()
303                 del cls.vapi
304             cls.vpp.poll()
305             if cls.vpp.returncode is None:
306                 cls.logger.debug("Sending TERM to vpp")
307                 cls.vpp.terminate()
308                 cls.logger.debug("Waiting for vpp to die")
309                 cls.vpp.communicate()
310             del cls.vpp
311
312         if hasattr(cls, 'vpp_stdout_deque'):
313             cls.logger.info(single_line_delim)
314             cls.logger.info('VPP output to stdout while running %s:',
315                             cls.__name__)
316             cls.logger.info(single_line_delim)
317             f = open(cls.tempdir + '/vpp_stdout.txt', 'w')
318             vpp_output = "".join(cls.vpp_stdout_deque)
319             f.write(vpp_output)
320             cls.logger.info('\n%s', vpp_output)
321             cls.logger.info(single_line_delim)
322
323         if hasattr(cls, 'vpp_stderr_deque'):
324             cls.logger.info(single_line_delim)
325             cls.logger.info('VPP output to stderr while running %s:',
326                             cls.__name__)
327             cls.logger.info(single_line_delim)
328             f = open(cls.tempdir + '/vpp_stderr.txt', 'w')
329             vpp_output = "".join(cls.vpp_stderr_deque)
330             f.write(vpp_output)
331             cls.logger.info('\n%s', vpp_output)
332             cls.logger.info(single_line_delim)
333
334     @classmethod
335     def tearDownClass(cls):
336         """ Perform final cleanup after running all tests in this test-case """
337         cls.quit()
338
339     def tearDown(self):
340         """ Show various debug prints after each test """
341         self.logger.debug("--- tearDown() for %s.%s(%s) called ---" %
342                           (self.__class__.__name__, self._testMethodName,
343                            self._testMethodDoc))
344         if not self.vpp_dead:
345             self.logger.debug(self.vapi.cli("show trace"))
346             self.logger.info(self.vapi.ppcli("show int"))
347             self.logger.info(self.vapi.ppcli("show hardware"))
348             self.logger.info(self.vapi.ppcli("show error"))
349             self.logger.info(self.vapi.ppcli("show run"))
350             self.registry.remove_vpp_config(self.logger)
351             # Save/Dump VPP api trace log
352             api_trace = "vpp_api_trace.%s.log" % self._testMethodName
353             tmp_api_trace = "/tmp/%s" % api_trace
354             vpp_api_trace_log = "%s/%s" % (self.tempdir, api_trace)
355             self.logger.info(self.vapi.ppcli("api trace save %s" % api_trace))
356             self.logger.info("Moving %s to %s\n" % (tmp_api_trace,
357                                                     vpp_api_trace_log))
358             os.rename(tmp_api_trace, vpp_api_trace_log)
359             self.logger.info(self.vapi.ppcli("api trace dump %s" %
360                                              vpp_api_trace_log))
361         else:
362             self.registry.unregister_all(self.logger)
363
364     def setUp(self):
365         """ Clear trace before running each test"""
366         self.logger.debug("--- setUp() for %s.%s(%s) called ---" %
367                           (self.__class__.__name__, self._testMethodName,
368                            self._testMethodDoc))
369         if self.vpp_dead:
370             raise Exception("VPP is dead when setting up the test")
371         self.sleep(.1, "during setUp")
372         self.vpp_stdout_deque.append(
373             "--- test setUp() for %s.%s(%s) starts here ---\n" %
374             (self.__class__.__name__, self._testMethodName,
375              self._testMethodDoc))
376         self.vpp_stderr_deque.append(
377             "--- test setUp() for %s.%s(%s) starts here ---\n" %
378             (self.__class__.__name__, self._testMethodName,
379              self._testMethodDoc))
380         self.vapi.cli("clear trace")
381         # store the test instance inside the test class - so that objects
382         # holding the class can access instance methods (like assertEqual)
383         type(self).test_instance = self
384
385     @classmethod
386     def pg_enable_capture(cls, interfaces):
387         """
388         Enable capture on packet-generator interfaces
389
390         :param interfaces: iterable interface indexes
391
392         """
393         for i in interfaces:
394             i.enable_capture()
395
396     @classmethod
397     def register_capture(cls, cap_name):
398         """ Register a capture in the testclass """
399         # add to the list of captures with current timestamp
400         cls._captures.append((time.time(), cap_name))
401         # filter out from zombies
402         cls._zombie_captures = [(stamp, name)
403                                 for (stamp, name) in cls._zombie_captures
404                                 if name != cap_name]
405
406     @classmethod
407     def pg_start(cls):
408         """ Remove any zombie captures and enable the packet generator """
409         # how long before capture is allowed to be deleted - otherwise vpp
410         # crashes - 100ms seems enough (this shouldn't be needed at all)
411         capture_ttl = 0.1
412         now = time.time()
413         for stamp, cap_name in cls._zombie_captures:
414             wait = stamp + capture_ttl - now
415             if wait > 0:
416                 cls.sleep(wait, "before deleting capture %s" % cap_name)
417                 now = time.time()
418             cls.logger.debug("Removing zombie capture %s" % cap_name)
419             cls.vapi.cli('packet-generator delete %s' % cap_name)
420
421         cls.vapi.cli("trace add pg-input 50")  # 50 is maximum
422         cls.vapi.cli('packet-generator enable')
423         cls._zombie_captures = cls._captures
424         cls._captures = []
425
426     @classmethod
427     def create_pg_interfaces(cls, interfaces):
428         """
429         Create packet-generator interfaces.
430
431         :param interfaces: iterable indexes of the interfaces.
432         :returns: List of created interfaces.
433
434         """
435         result = []
436         for i in interfaces:
437             intf = VppPGInterface(cls, i)
438             setattr(cls, intf.name, intf)
439             result.append(intf)
440         cls.pg_interfaces = result
441         return result
442
443     @classmethod
444     def create_loopback_interfaces(cls, interfaces):
445         """
446         Create loopback interfaces.
447
448         :param interfaces: iterable indexes of the interfaces.
449         :returns: List of created interfaces.
450         """
451         result = []
452         for i in interfaces:
453             intf = VppLoInterface(cls, i)
454             setattr(cls, intf.name, intf)
455             result.append(intf)
456         cls.lo_interfaces = result
457         return result
458
459     @staticmethod
460     def extend_packet(packet, size):
461         """
462         Extend packet to given size by padding with spaces
463         NOTE: Currently works only when Raw layer is present.
464
465         :param packet: packet
466         :param size: target size
467
468         """
469         packet_len = len(packet) + 4
470         extend = size - packet_len
471         if extend > 0:
472             packet[Raw].load += ' ' * extend
473
474     @classmethod
475     def reset_packet_infos(cls):
476         """ Reset the list of packet info objects and packet counts to zero """
477         cls._packet_infos = {}
478         cls._packet_count_for_dst_if_idx = {}
479
480     @classmethod
481     def create_packet_info(cls, src_if, dst_if):
482         """
483         Create packet info object containing the source and destination indexes
484         and add it to the testcase's packet info list
485
486         :param VppInterface src_if: source interface
487         :param VppInterface dst_if: destination interface
488
489         :returns: _PacketInfo object
490
491         """
492         info = _PacketInfo()
493         info.index = len(cls._packet_infos)
494         info.src = src_if.sw_if_index
495         info.dst = dst_if.sw_if_index
496         if isinstance(dst_if, VppSubInterface):
497             dst_idx = dst_if.parent.sw_if_index
498         else:
499             dst_idx = dst_if.sw_if_index
500         if dst_idx in cls._packet_count_for_dst_if_idx:
501             cls._packet_count_for_dst_if_idx[dst_idx] += 1
502         else:
503             cls._packet_count_for_dst_if_idx[dst_idx] = 1
504         cls._packet_infos[info.index] = info
505         return info
506
507     @staticmethod
508     def info_to_payload(info):
509         """
510         Convert _PacketInfo object to packet payload
511
512         :param info: _PacketInfo object
513
514         :returns: string containing serialized data from packet info
515         """
516         return "%d %d %d" % (info.index, info.src, info.dst)
517
518     @staticmethod
519     def payload_to_info(payload):
520         """
521         Convert packet payload to _PacketInfo object
522
523         :param payload: packet payload
524
525         :returns: _PacketInfo object containing de-serialized data from payload
526
527         """
528         numbers = payload.split()
529         info = _PacketInfo()
530         info.index = int(numbers[0])
531         info.src = int(numbers[1])
532         info.dst = int(numbers[2])
533         return info
534
535     def get_next_packet_info(self, info):
536         """
537         Iterate over the packet info list stored in the testcase
538         Start iteration with first element if info is None
539         Continue based on index in info if info is specified
540
541         :param info: info or None
542         :returns: next info in list or None if no more infos
543         """
544         if info is None:
545             next_index = 0
546         else:
547             next_index = info.index + 1
548         if next_index == len(self._packet_infos):
549             return None
550         else:
551             return self._packet_infos[next_index]
552
553     def get_next_packet_info_for_interface(self, src_index, info):
554         """
555         Search the packet info list for the next packet info with same source
556         interface index
557
558         :param src_index: source interface index to search for
559         :param info: packet info - where to start the search
560         :returns: packet info or None
561
562         """
563         while True:
564             info = self.get_next_packet_info(info)
565             if info is None:
566                 return None
567             if info.src == src_index:
568                 return info
569
570     def get_next_packet_info_for_interface2(self, src_index, dst_index, info):
571         """
572         Search the packet info list for the next packet info with same source
573         and destination interface indexes
574
575         :param src_index: source interface index to search for
576         :param dst_index: destination interface index to search for
577         :param info: packet info - where to start the search
578         :returns: packet info or None
579
580         """
581         while True:
582             info = self.get_next_packet_info_for_interface(src_index, info)
583             if info is None:
584                 return None
585             if info.dst == dst_index:
586                 return info
587
588     def assert_equal(self, real_value, expected_value, name_or_class=None):
589         if name_or_class is None:
590             self.assertEqual(real_value, expected_value)
591             return
592         try:
593             msg = "Invalid %s: %d('%s') does not match expected value %d('%s')"
594             msg = msg % (getdoc(name_or_class).strip(),
595                          real_value, str(name_or_class(real_value)),
596                          expected_value, str(name_or_class(expected_value)))
597         except:
598             msg = "Invalid %s: %s does not match expected value %s" % (
599                 name_or_class, real_value, expected_value)
600
601         self.assertEqual(real_value, expected_value, msg)
602
603     def assert_in_range(self,
604                         real_value,
605                         expected_min,
606                         expected_max,
607                         name=None):
608         if name is None:
609             msg = None
610         else:
611             msg = "Invalid %s: %s out of range <%s,%s>" % (
612                 name, real_value, expected_min, expected_max)
613         self.assertTrue(expected_min <= real_value <= expected_max, msg)
614
615     @classmethod
616     def sleep(cls, timeout, remark=None):
617         if hasattr(cls, 'logger'):
618             cls.logger.debug("Sleeping for %ss (%s)" % (timeout, remark))
619         time.sleep(timeout)
620
621
622 class TestCasePrinter(object):
623     _shared_state = {}
624
625     def __init__(self):
626         self.__dict__ = self._shared_state
627         if not hasattr(self, "_test_case_set"):
628             self._test_case_set = set()
629
630     def print_test_case_heading_if_first_time(self, case):
631         if case.__class__ not in self._test_case_set:
632             print(double_line_delim)
633             print(colorize(getdoc(case.__class__).splitlines()[0], YELLOW))
634             print(double_line_delim)
635             self._test_case_set.add(case.__class__)
636
637
638 class VppTestResult(unittest.TestResult):
639     """
640     @property result_string
641      String variable to store the test case result string.
642     @property errors
643      List variable containing 2-tuples of TestCase instances and strings
644      holding formatted tracebacks. Each tuple represents a test which
645      raised an unexpected exception.
646     @property failures
647      List variable containing 2-tuples of TestCase instances and strings
648      holding formatted tracebacks. Each tuple represents a test where
649      a failure was explicitly signalled using the TestCase.assert*()
650      methods.
651     """
652
653     def __init__(self, stream, descriptions, verbosity):
654         """
655         :param stream File descriptor to store where to report test results.
656             Set to the standard error stream by default.
657         :param descriptions Boolean variable to store information if to use
658             test case descriptions.
659         :param verbosity Integer variable to store required verbosity level.
660         """
661         unittest.TestResult.__init__(self, stream, descriptions, verbosity)
662         self.stream = stream
663         self.descriptions = descriptions
664         self.verbosity = verbosity
665         self.result_string = None
666         self.printer = TestCasePrinter()
667
668     def addSuccess(self, test):
669         """
670         Record a test succeeded result
671
672         :param test:
673
674         """
675         if hasattr(test, 'logger'):
676             test.logger.debug("--- addSuccess() %s.%s(%s) called"
677                               % (test.__class__.__name__,
678                                  test._testMethodName,
679                                  test._testMethodDoc))
680         unittest.TestResult.addSuccess(self, test)
681         self.result_string = colorize("OK", GREEN)
682
683     def addSkip(self, test, reason):
684         """
685         Record a test skipped.
686
687         :param test:
688         :param reason:
689
690         """
691         if hasattr(test, 'logger'):
692             test.logger.debug("--- addSkip() %s.%s(%s) called, reason is %s"
693                               % (test.__class__.__name__,
694                                  test._testMethodName,
695                                  test._testMethodDoc,
696                                  reason))
697         unittest.TestResult.addSkip(self, test, reason)
698         self.result_string = colorize("SKIP", YELLOW)
699
700     def addFailure(self, test, err):
701         """
702         Record a test failed result
703
704         :param test:
705         :param err: error message
706
707         """
708         if hasattr(test, 'logger'):
709             test.logger.debug("--- addFailure() %s.%s(%s) called, err is %s"
710                               % (test.__class__.__name__,
711                                  test._testMethodName,
712                                  test._testMethodDoc, err))
713             test.logger.debug("formatted exception is:\n%s" %
714                               "".join(format_exception(*err)))
715         unittest.TestResult.addFailure(self, test, err)
716         if hasattr(test, 'tempdir'):
717             self.result_string = colorize("FAIL", RED) + \
718                 ' [ temp dir used by test case: ' + test.tempdir + ' ]'
719         else:
720             self.result_string = colorize("FAIL", RED) + ' [no temp dir]'
721
722     def addError(self, test, err):
723         """
724         Record a test error result
725
726         :param test:
727         :param err: error message
728
729         """
730         if hasattr(test, 'logger'):
731             test.logger.debug("--- addError() %s.%s(%s) called, err is %s"
732                               % (test.__class__.__name__,
733                                  test._testMethodName,
734                                  test._testMethodDoc, err))
735             test.logger.debug("formatted exception is:\n%s" %
736                               "".join(format_exception(*err)))
737         unittest.TestResult.addError(self, test, err)
738         if hasattr(test, 'tempdir'):
739             self.result_string = colorize("ERROR", RED) + \
740                 ' [ temp dir used by test case: ' + test.tempdir + ' ]'
741         else:
742             self.result_string = colorize("ERROR", RED) + ' [no temp dir]'
743
744     def getDescription(self, test):
745         """
746         Get test description
747
748         :param test:
749         :returns: test description
750
751         """
752         # TODO: if none print warning not raise exception
753         short_description = test.shortDescription()
754         if self.descriptions and short_description:
755             return short_description
756         else:
757             return str(test)
758
759     def startTest(self, test):
760         """
761         Start a test
762
763         :param test:
764
765         """
766         self.printer.print_test_case_heading_if_first_time(test)
767         unittest.TestResult.startTest(self, test)
768         if self.verbosity > 0:
769             self.stream.writeln(
770                 "Starting " + self.getDescription(test) + " ...")
771             self.stream.writeln(single_line_delim)
772
773     def stopTest(self, test):
774         """
775         Stop a test
776
777         :param test:
778
779         """
780         unittest.TestResult.stopTest(self, test)
781         if self.verbosity > 0:
782             self.stream.writeln(single_line_delim)
783             self.stream.writeln("%-73s%s" % (self.getDescription(test),
784                                              self.result_string))
785             self.stream.writeln(single_line_delim)
786         else:
787             self.stream.writeln("%-73s%s" % (self.getDescription(test),
788                                              self.result_string))
789
790     def printErrors(self):
791         """
792         Print errors from running the test case
793         """
794         self.stream.writeln()
795         self.printErrorList('ERROR', self.errors)
796         self.printErrorList('FAIL', self.failures)
797
798     def printErrorList(self, flavour, errors):
799         """
800         Print error list to the output stream together with error type
801         and test case description.
802
803         :param flavour: error type
804         :param errors: iterable errors
805
806         """
807         for test, err in errors:
808             self.stream.writeln(double_line_delim)
809             self.stream.writeln("%s: %s" %
810                                 (flavour, self.getDescription(test)))
811             self.stream.writeln(single_line_delim)
812             self.stream.writeln("%s" % err)
813
814
815 class VppTestRunner(unittest.TextTestRunner):
816     """
817     A basic test runner implementation which prints results to standard error.
818     """
819     @property
820     def resultclass(self):
821         """Class maintaining the results of the tests"""
822         return VppTestResult
823
824     def __init__(self, stream=sys.stderr, descriptions=True, verbosity=1,
825                  failfast=False, buffer=False, resultclass=None):
826         # ignore stream setting here, use hard-coded stdout to be in sync
827         # with prints from VppTestCase methods ...
828         super(VppTestRunner, self).__init__(sys.stdout, descriptions,
829                                             verbosity, failfast, buffer,
830                                             resultclass)
831
832     test_option = "TEST"
833
834     def parse_test_option(self):
835         try:
836             f = os.getenv(self.test_option)
837         except:
838             f = None
839         filter_file_name = None
840         filter_class_name = None
841         filter_func_name = None
842         if f:
843             if '.' in f:
844                 parts = f.split('.')
845                 if len(parts) > 3:
846                     raise Exception("Unrecognized %s option: %s" %
847                                     (self.test_option, f))
848                 if len(parts) > 2:
849                     if parts[2] not in ('*', ''):
850                         filter_func_name = parts[2]
851                 if parts[1] not in ('*', ''):
852                     filter_class_name = parts[1]
853                 if parts[0] not in ('*', ''):
854                     if parts[0].startswith('test_'):
855                         filter_file_name = parts[0]
856                     else:
857                         filter_file_name = 'test_%s' % parts[0]
858             else:
859                 if f.startswith('test_'):
860                     filter_file_name = f
861                 else:
862                     filter_file_name = 'test_%s' % f
863         return filter_file_name, filter_class_name, filter_func_name
864
865     def filter_tests(self, tests, filter_file, filter_class, filter_func):
866         result = unittest.suite.TestSuite()
867         for t in tests:
868             if isinstance(t, unittest.suite.TestSuite):
869                 # this is a bunch of tests, recursively filter...
870                 x = self.filter_tests(t, filter_file, filter_class,
871                                       filter_func)
872                 if x.countTestCases() > 0:
873                     result.addTest(x)
874             elif isinstance(t, unittest.TestCase):
875                 # this is a single test
876                 parts = t.id().split('.')
877                 # t.id() for common cases like this:
878                 # test_classifier.TestClassifier.test_acl_ip
879                 # apply filtering only if it is so
880                 if len(parts) == 3:
881                     if filter_file and filter_file != parts[0]:
882                         continue
883                     if filter_class and filter_class != parts[1]:
884                         continue
885                     if filter_func and filter_func != parts[2]:
886                         continue
887                 result.addTest(t)
888             else:
889                 # unexpected object, don't touch it
890                 result.addTest(t)
891         return result
892
893     def run(self, test):
894         """
895         Run the tests
896
897         :param test:
898
899         """
900         gc.disable()  # disable garbage collection, we'll do that manually
901         print("Running tests using custom test runner")  # debug message
902         filter_file, filter_class, filter_func = self.parse_test_option()
903         print("Active filters: file=%s, class=%s, function=%s" % (
904             filter_file, filter_class, filter_func))
905         filtered = self.filter_tests(test, filter_file, filter_class,
906                                      filter_func)
907         print("%s out of %s tests match specified filters" % (
908             filtered.countTestCases(), test.countTestCases()))
909         return super(VppTestRunner, self).run(filtered)