code review - cont. 86/5286/1
authorimarom <[email protected]>
Sun, 29 Jan 2017 15:49:53 +0000 (17:49 +0200)
committerimarom <[email protected]>
Sun, 29 Jan 2017 16:06:26 +0000 (18:06 +0200)
Signed-off-by: imarom <[email protected]>
scripts/automation/trex_control_plane/stl/console/trex_capture.py
scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_client.py
scripts/automation/trex_control_plane/stl/trex_stl_lib/trex_stl_types.py
scripts/automation/trex_control_plane/stl/trex_stl_lib/utils/text_opts.py

index aac685a..2132458 100644 (file)
@@ -4,53 +4,44 @@ import threading
 import tempfile
 import select
 
+# defines a generic monitor writer
 class CaptureMonitorWriter(object):
-    def init (self, start_ts):
-        raise NotImplementedError
 
     def deinit(self):
-        raise NotImplementedError
+        # by default - nothing to deinit
+        pass
         
     def handle_pkts (self, pkts):
         raise NotImplementedError
         
     def periodic_check (self):
-        raise NotImplementedError
-        
+        # by default - nothing to check
+        pass
         
+   
+# a stdout monitor
 class CaptureMonitorWriterStdout(CaptureMonitorWriter):
-    def __init__ (self, logger, is_brief):
+    def __init__ (self, logger, is_brief, start_ts):
         self.logger      = logger
         self.is_brief    = is_brief
-
+        self.start_ts    = start_ts
+        
+        # unicode arrows
         self.RX_ARROW = u'\u25c0\u2500\u2500'
         self.TX_ARROW = u'\u25b6\u2500\u2500'
-
-    def init (self, start_ts):
-        self.start_ts = start_ts
         
+        # decode issues with Python 2
+        if sys.version_info < (3,0):
+            self.RX_ARROW = self.RX_ARROW.encode('utf-8')
+            self.TX_ARROW = self.TX_ARROW.encode('utf-8')
+
+
         self.logger.pre_cmd("Starting stdout capture monitor - verbose: '{0}'".format('low' if self.is_brief else 'high'))
         self.logger.post_cmd(RC_OK)
         
         self.logger.log(format_text("\n*** use 'capture monitor stop' to abort capturing... ***\n", 'bold')) 
         
-        
-    def deinit (self):
-        pass
-        
-       
-    def periodic_check (self):
-        return RC_OK()
-        
-    def handle_pkts (self, pkts):
-        byte_count = 0
-        
-        for pkt in pkts:
-            byte_count += self.__handle_pkt(pkt)
-        
-        self.logger.prompt_redraw()
-        return RC_OK(byte_count)
-        
         
     def get_scapy_name (self, pkt_scapy):
         layer = pkt_scapy
@@ -62,9 +53,9 @@ class CaptureMonitorWriterStdout(CaptureMonitorWriter):
                 
     def format_origin (self, origin):
         if origin == 'RX':
-            return u'{0} {1}'.format(self.RX_ARROW, 'RX')
+            return '{0} {1}'.format(self.RX_ARROW, 'RX')
         elif origin == 'TX':
-            return u'{0} {1}'.format(self.TX_ARROW, 'TX')
+            return '{0} {1}'.format(self.TX_ARROW, 'TX')
         else:
             return '{0}'.format(origin)
             
@@ -73,10 +64,9 @@ class CaptureMonitorWriterStdout(CaptureMonitorWriter):
         pkt_bin = base64.b64decode(pkt['binary'])
 
         pkt_scapy = Ether(pkt_bin)
-        self.logger.log(format_text(u'\n\n#{} Port: {} {}\n'.format(pkt['index'], pkt['port'], self.format_origin(pkt['origin'])), 'bold', ''))
+        self.logger.log(format_text('\n\n#{} Port: {} {}\n'.format(pkt['index'], pkt['port'], self.format_origin(pkt['origin'])), 'bold', ''))
         self.logger.log(format_text('    Type: {}, Size: {} B, TS: {:.2f} [sec]\n'.format(self.get_scapy_name(pkt_scapy), len(pkt_bin), pkt['ts'] - self.start_ts), 'bold'))
 
-        
         if self.is_brief:
             self.logger.log('    {0}'.format(pkt_scapy.command()))
         else:
@@ -85,16 +75,29 @@ class CaptureMonitorWriterStdout(CaptureMonitorWriter):
 
         return len(pkt_bin)
 
-#
+
+    def handle_pkts (self, pkts):
+        try:
+            byte_count = 0
+            for pkt in pkts:
+                byte_count += self.__handle_pkt(pkt)
+
+            return byte_count
+
+        finally:
+            # make sure to restore the logger
+            self.logger.prompt_redraw()
+
+
+# a pipe based monitor
 class CaptureMonitorWriterPipe(CaptureMonitorWriter):
-    def __init__ (self, logger):
+    def __init__ (self, logger, start_ts):
+        
         self.logger    = logger
-        self.fifo_name = None
         self.fifo      = None
-        self.start_ts  = None
-        
-    def init (self, start_ts):
         self.start_ts  = start_ts
+        
+        # generate a temp fifo pipe
         self.fifo_name = tempfile.mktemp()
         
         try:
@@ -105,27 +108,35 @@ class CaptureMonitorWriterPipe(CaptureMonitorWriter):
             self.logger.log(format_text("*** Please run 'wireshark -k -i {0}' ***".format(self.fifo_name), 'bold'))
             
             self.logger.pre_cmd("Waiting for Wireshark pipe connection")
+            
+            # blocks until pipe is connected
             self.fifo = os.open(self.fifo_name, os.O_WRONLY)
             self.logger.post_cmd(RC_OK())
             
             self.logger.log(format_text('\n*** Capture monitoring started ***\n', 'bold'))
             
+            # open for write using a PCAP writer
             self.writer = RawPcapWriter(self.fifo_name, linktype = 1, sync = True)
             self.writer._write_header(None)
             
             # register a poller
             self.poll = select.poll()
             self.poll.register(self.fifo, select.EPOLLERR)
-            
+        
+            self.is_init = True
+                
         except KeyboardInterrupt as e:
+            self.deinit()
             self.logger.post_cmd(RC_ERR(""))
             raise STLError("*** pipe monitor aborted...cleaning up")
 
         except OSError as e:
+            self.deinit()
             self.logger.post_cmd(RC_ERR(""))
             raise STLError("failed to create pipe {0}\n{1}".format(self.fifo_name, str(e)))
         
         
+        
     def deinit (self):
         try:
             if self.fifo:
@@ -138,141 +149,109 @@ class CaptureMonitorWriterPipe(CaptureMonitorWriter):
                 
         except OSError:
             pass
+            
 
        
     def periodic_check (self):
-        return self.check_pipe()
+        self.check_pipe()
         
         
     def check_pipe (self):
         if self.poll.poll(0):
-            return RC_ERR('*** pipe has been disconnected - aborting monitoring ***')
-            
-        return RC_OK()
+            raise STLError('pipe has been disconnected')
         
         
     def handle_pkts (self, pkts):
-        rc = self.check_pipe()
-        if not rc:
-            return rc
+        # first check the pipe is alive
+        self.check_pipe()
+
+        return self.handle_pkts_internal(pkts)
+            
+        
+    def handle_pkts_internal (self, pkts):
         
         byte_count = 0
         
         for pkt in pkts:
             pkt_bin = base64.b64decode(pkt['binary'])
-            ts      = pkt['ts']
-            sec     = int(ts)
-            usec    = int( (ts - sec) * 1e6 )
-                
+            ts_sec, ts_usec = sec_split_usec(pkt['ts'] - self.start_ts)
+            
             try:
-                self.writer._write_packet(pkt_bin, sec = sec, usec = usec)
-            except IOError:
-                return RC_ERR("*** failed to write packet to pipe ***")
-             
+                self.writer._write_packet(pkt_bin, sec = ts_sec, usec = ts_usec)
+            except Exception as e:
+                raise STLError('fail to write packets to pipe: {}'.format(str(e)))
+                
             byte_count += len(pkt_bin)
                
-        return RC_OK(byte_count)
+        return byte_count
         
         
+# capture monitor - a live capture
 class CaptureMonitor(object):
-    def __init__ (self, client, cmd_lock):
+    def __init__ (self, client, cmd_lock, tx_port_list, rx_port_list, rate_pps, mon_type):
         self.client      = client
-        self.cmd_lock    = cmd_lock
-        self.active      = False
-        self.capture_id  = None
         self.logger      = client.logger
-        self.writer      = None
-        
-    def is_active (self):
-        return self.active
-        
+        self.cmd_lock    = cmd_lock
 
-    def get_capture_id (self):
-        return self.capture_id
+        self.t           = None
+        self.writer      = None
+        self.capture_id  = None
         
+        self.tx_port_list = tx_port_list
+        self.rx_port_list = rx_port_list
+        self.rate_pps     = rate_pps
+        self.mon_type     = mon_type
         
-    def start (self, tx_port_list, rx_port_list, rate_pps, mon_type):
+        # try to launch
         try:
-            self.start_internal(tx_port_list, rx_port_list, rate_pps, mon_type)
+            self.__start()
         except Exception as e:
             self.__stop()
-            raise e
+            raise
             
-    def start_internal (self,  tx_port_list, rx_port_list, rate_pps, mon_type):
-        # stop any previous monitors
-        if self.active:
-            self.stop()
-        
-        self.tx_port_list = tx_port_list
-        self.rx_port_list = rx_port_list
-
-        if mon_type == 'compact':
-            self.writer = CaptureMonitorWriterStdout(self.logger, is_brief = True)
-        elif mon_type == 'verbose':
-            self.writer = CaptureMonitorWriterStdout(self.logger, is_brief = False)
-        elif mon_type == 'pipe':
-            self.writer = CaptureMonitorWriterPipe(self.logger)
-        else:
-            raise STLError('unknown writer type')
             
+    def __start (self):
         
+        # create a capture on the server
         with self.logger.supress():
-            data = self.client.start_capture(tx_port_list, rx_port_list, limit = rate_pps, mode = 'cyclic')
-        
+            data = self.client.start_capture(self.tx_port_list, self.rx_port_list, limit = self.rate_pps, mode = 'cyclic')
+
         self.capture_id = data['id']
         self.start_ts   = data['ts']
-
-        self.writer.init(self.start_ts)
-
         
-        self.t = threading.Thread(target = self.__thread_cb)
-        self.t.setDaemon(True)
-        
-        try:
-            self.active = True
-            self.t.start()
-        except Exception as e:
-            self.active = False
-            self.stop()
-            raise e
-        
-    # entry point stop 
-    def stop (self):
 
-        if self.active:
-            self.stop_logged()
+        # create a writer
+        if self.mon_type == 'compact':
+            self.writer = CaptureMonitorWriterStdout(self.logger, True, self.start_ts)
+        elif self.mon_type == 'verbose':
+            self.writer = CaptureMonitorWriterStdout(self.logger, False, self.start_ts)
+        elif self.mon_type == 'pipe':
+            self.writer = CaptureMonitorWriterPipe(self.logger, self.start_ts)
         else:
-            self.__stop()
-        
-    # wraps stop with a logging
-    def stop_logged (self):
-        self.logger.pre_cmd("Stopping capture monitor")
+            raise STLError('Internal error: unknown writer type')
         
-        try:
-            self.__stop()
-        except Exception as e:
-            self.logger.post_cmd(RC_ERR(""))
-            raise e
-        
-        self.logger.post_cmd(RC_OK())
+        # start the fetching thread
+        self.t = threading.Thread(target = self.__thread_cb)
+        self.t.setDaemon(True)
+        self.active = True
+        self.t.start()
+  
             
     # internal stop
     def __stop (self):
 
-        # shutdown thread
-        if self.active:
+        # stop the thread
+        if self.t and self.t.is_alive():
             self.active = False
             self.t.join()
+            self.t = None
             
         # deinit the writer
-        if self.writer is not None:
+        if self.writer:
             self.writer.deinit()
             self.writer = None
             
-        # cleanup capture ID if possible
-        if self.capture_id is None:
-            return
-
+        # take the capture ID
         capture_id = self.capture_id
         self.capture_id = None
         
@@ -280,31 +259,48 @@ class CaptureMonitor(object):
         if not self.client.is_connected():
             return
             
-        try:
-            captures = [x['id'] for x in self.client.get_capture_status()]
-            if capture_id not in captures:
-                return
-                
-            with self.logger.supress():
-                self.client.stop_capture(capture_id)
+        # make sure the capture is active on the server
+        captures = [x['id'] for x in self.client.get_capture_status()]
+        if capture_id not in captures:
+            return
             
-        except STLError as e:
-            self.logger.post_cmd(RC_ERR(""))
-            raise e
+        # remove the capture                
+        with self.logger.supress():
+            self.client.stop_capture(capture_id)
             
-                
+           
+    # user call for stop (adds log)
+    def stop (self):
+        self.logger.pre_cmd("Stopping capture monitor")
+        
+        try:
+            self.__stop()
+        except Exception as e:
+            self.logger.post_cmd(RC_ERR(""))
+            raise
+        
+        self.logger.post_cmd(RC_OK())
+
+
     def get_mon_row (self):
-        if not self.is_active():
-            return None
             
         return [self.capture_id,
+                format_text('ACTIVE' if self.t.is_alive() else 'DEAD', 'bold'),
                 self.pkt_count,
                 format_num(self.byte_count, suffix = 'B'),
                 ', '.join([str(x) for x in self.tx_port_list] if self.tx_port_list else '-'),
                 ', '.join([str(x) for x in self.rx_port_list] if self.rx_port_list else '-')
                 ]
         
-    
+
+    def is_active (self):
+        return self.active
+
+
+    def get_capture_id (self):
+        return self.capture_id
+        
+
     # sleeps with high freq checks for active
     def __sleep (self):
         for _ in range(5):
@@ -331,13 +327,18 @@ class CaptureMonitor(object):
     
     def __thread_cb (self):
         try:
-            rc = self.__thread_main_loop()
-        finally:
-            pass
+            self.__thread_main_loop()
+        
+        # common errors
+        except STLError as e:
+            self.logger.log(format_text("\n\nMonitor has encountered the following error: '{}'\n".format(e.brief()), 'bold'))
+            self.logger.log(format_text("\n*** monitor is inactive - please restart the monitor ***\n", 'bold'))
+            self.logger.prompt_redraw()
             
-        if not rc:
-            self.logger.log(str(rc))
-            self.logger.log(format_text('\n*** monitor is inactive - please restart the monitor ***\n', 'bold'))
+        # unexpected errors
+        except Exception as e:
+            self.logger.log("\n\n*** A fatal internal error has occurred: '{}'\n".format(str(e)))
+            self.logger.log(format_text("\n*** monitor is inactive - please restart the monitor ***\n", 'bold'))
             self.logger.prompt_redraw()
             
             
@@ -347,54 +348,50 @@ class CaptureMonitor(object):
         
         while self.active:
             
-            # sleep
+            # sleep - if interrupt by graceful shutdown - go out
             if not self.__sleep():
-                break
+                return
             
             # check that the writer is ok
-            rc = self.writer.periodic_check()
-            if not rc:
-                return rc
+            self.writer.periodic_check()
 
-            # try to lock
+            # try to lock - if interrupt by graceful shutdown - go out
             if not self.__lock():
-                break
+                return
                 
             try:
                 if not self.client.is_connected():
-                    return RC_ERR('*** client has been disconnected, aborting monitoring ***')
+                    raise STLError('client has been disconnected')
+                    
                 rc = self.client._transmit("capture", params = {'command': 'fetch', 'capture_id': self.capture_id, 'pkt_limit': 10})
                 if not rc:
-                    return rc
+                    raise STLError(rc)
                     
             finally:
                 self.__unlock()
                 
 
+            # no packets - skip
             pkts = rc.data()['pkts']
             if not pkts:
                 continue
             
-            rc = self.writer.handle_pkts(pkts)
-            if not rc:
-                return rc
+            byte_count = self.writer.handle_pkts(pkts)
             
             self.pkt_count  += len(pkts)
-            self.byte_count += rc.data()
-                
-        # graceful shutdown
-        return RC_OK()
-        
-     
+            self.byte_count += byte_count
+
+
+
 
 # main class
 class CaptureManager(object):
     def __init__ (self, client, cmd_lock):
         self.c          = client
         self.cmd_lock   = cmd_lock
-        self.monitor    = CaptureMonitor(client, cmd_lock)
         self.logger     = client.logger
-
+        self.monitor    = None
+        
         # install parsers
         
         self.parser = parsing_opts.gen_parser(self, "capture", self.parse_line_internal.__doc__)
@@ -445,7 +442,9 @@ class CaptureManager(object):
         
         
     def stop (self):
-        self.monitor.stop()
+        if self.monitor:
+            self.monitor.stop()
+            self.monitor = None
 
         
     # main entry point for parsing commands from console
@@ -453,7 +452,7 @@ class CaptureManager(object):
         try:
             self.parse_line_internal(line)
         except STLError as e:
-            self.logger.log("\nAction has failed with the following error:\n" + format_text(e.brief() + "\n", 'bold'))
+            self.logger.log("\nAction has failed with the following error:\n\n" + format_text(e.brief() + "\n", 'bold'))
             return RC_ERR(e.brief())
 
 
@@ -497,7 +496,7 @@ class CaptureManager(object):
         captures = self.c.get_capture_status()
         ids = [c['id'] for c in captures]
         
-        if opts.capture_id == self.monitor.get_capture_id():
+        if self.monitor and (opts.capture_id == self.monitor.get_capture_id()):
             self.record_stop_parser.formatted_error("'{0}' is a monitor, please use 'capture monitor stop'".format(opts.capture_id))
             return
             
@@ -530,14 +529,24 @@ class CaptureManager(object):
             self.monitor_start_parser.formatted_error('please provide either --tx or --rx')
             return
         
-        self.monitor.stop()
-        self.monitor.start(opts.tx_port_list, opts.rx_port_list, 100, mon_type)
+        if self.monitor:
+            self.monitor.stop()
+            self.monitor = None
+            
+        self.monitor = CaptureMonitor(self.c, self.cmd_lock, opts.tx_port_list, opts.rx_port_list, 100, mon_type)
+        
     
     def parse_monitor_stop (self, opts):
-        self.monitor.stop()
+        if self.monitor:
+            self.monitor.stop()
+            self.monitor = None
+            
         
     def parse_clear (self, opts):
-        self.monitor.stop()
+        if self.monitor:
+            self.monitor.stop()
+            self.monitor = None
+            
         self.c.remove_all_captures()
         
         
@@ -552,13 +561,13 @@ class CaptureManager(object):
 
         # monitor
         mon_table = text_tables.TRexTextTable()
-        mon_table.set_cols_align(["c"] * 5)
-        mon_table.set_cols_width([15] * 5)
+        mon_table.set_cols_align(["c"] * 6)
+        mon_table.set_cols_width([15] * 6)
 
         for elem in data:
             id = elem['id']
 
-            if self.monitor.get_capture_id() == id:
+            if self.monitor and (self.monitor.get_capture_id() == id):
                 row = self.monitor.get_mon_row()
                 mon_table.add_rows([row], header=False)
 
@@ -573,7 +582,7 @@ class CaptureManager(object):
                 cap_table.add_rows([row], header=False)
 
         cap_table.header(['ID', 'Status', 'Packets', 'Bytes', 'TX Ports', 'RX Ports'])
-        mon_table.header(['ID', 'Packets Seen', 'Bytes Seen', 'TX Ports', 'RX Ports'])
+        mon_table.header(['ID', 'Status', 'Packets Seen', 'Bytes Seen', 'TX Ports', 'RX Ports'])
 
         if cap_table._rows:
             text_tables.print_table_with_header(cap_table, '\nActive Recorders')
index 571334e..c46a7d7 100755 (executable)
@@ -3296,7 +3296,7 @@ class STLClient(object):
             try:
                 rc = f(*args)
             except STLError as e:
-                client.logger.log("\nAction has failed with the following error:\n" + format_text(e.brief() + "\n", 'bold'))
+                client.logger.log("\nAction has failed with the following error:\n\n" + format_text(e.brief() + "\n", 'bold'))
                 return RC_ERR(e.brief())
 
             # if got true - print time
index 0230db2..7ac508a 100644 (file)
@@ -64,7 +64,7 @@ class RC():
                 err_count += 1
                 if len(err_list) < show_count:
                     err_list.append(format_text(x, 'bold'))
-            s = '\n'
+            s = ''
             if err_count > show_count:
                 s += format_text('Occurred %s errors, showing first %s:\n' % (err_count, show_count), 'bold')
             s += '\n'.join(err_list)
index 3ffd07e..477d81a 100644 (file)
@@ -133,16 +133,12 @@ def underline(text):
 
 # apply attribute on each non-empty line
 def text_attribute(text, attribute):
-    if isinstance(text, str):
-        return "{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'],
-                                           txt=text,
-                                           stop=TEXT_CODES[attribute]['end'])
-    elif isinstance(text, unicode):
-        return u"{start}{txt}{stop}".format(start=TEXT_CODES[attribute]['start'],
-                                            txt=text,
-                                            stop=TEXT_CODES[attribute]['end'])
-    else:
-        raise Exception("not a string")
+    return '\n'.join(['{start}{txt}{end}'.format(
+        start = TEXT_CODES[attribute]['start'],
+        txt = line,
+        end = TEXT_CODES[attribute]['end'])
+                      if line else '' for line in ('%s' % text).split('\n')])
+
 
 FUNC_DICT = {'blue': blue,
              'bold': bold,