stats: fix race conditions in vpp-api stats client
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_stats.py
1 #!/usr/bin/env python3
2 #
3 # Copyright (c) 2021 Cisco and/or its affiliates.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at:
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16
17 '''
18 This module implement Python access to the VPP statistics segment. It
19 accesses the data structures directly in shared memory.
20 VPP uses optimistic locking, so data structures may change underneath
21 us while we are reading. Data is copied out and it's important to
22 spend as little time as possible "holding the lock".
23
24 Counters are stored in VPP as a two dimensional array.
25 Index by thread and index (typically sw_if_index).
26 Simple counters count only packets, Combined counters count packets
27 and octets.
28
29 Counters can be accessed in either dimension.
30 stat['/if/rx'] - returns 2D lists
31 stat['/if/rx'][0] - returns counters for all interfaces for thread 0
32 stat['/if/rx'][0][1] - returns counter for interface 1 on thread 0
33 stat['/if/rx'][0][1]['packets'] - returns the packet counter
34                                   for interface 1 on thread 0
35 stat['/if/rx'][:, 1] - returns the counters for interface 1 on all threads
36 stat['/if/rx'][:, 1].packets() - returns the packet counters for
37                                  interface 1 on all threads
38 stat['/if/rx'][:, 1].sum_packets() - returns the sum of packet counters for
39                                      interface 1 on all threads
40 stat['/if/rx-miss'][:, 1].sum() - returns the sum of packet counters for
41                                   interface 1 on all threads for simple counters
42 '''
43
44 import os
45 import socket
46 import array
47 import mmap
48 from struct import Struct
49 import time
50 import unittest
51 import re
52
53 def recv_fd(sock):
54     '''Get file descriptor for memory map'''
55     fds = array.array("i")   # Array of ints
56     _, ancdata, _, _ = sock.recvmsg(0, socket.CMSG_LEN(4))
57     for cmsg_level, cmsg_type, cmsg_data in ancdata:
58         if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS:
59             fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
60     return list(fds)[0]
61
62 VEC_LEN_FMT = Struct('I')
63 def get_vec_len(stats, vector_offset):
64     '''Equivalent to VPP vec_len()'''
65     return VEC_LEN_FMT.unpack_from(stats.statseg, vector_offset - 8)[0]
66
67 def get_string(stats, ptr):
68     '''Get a string from a VPP vector'''
69     namevector = ptr - stats.base
70     namevectorlen = get_vec_len(stats, namevector)
71     if namevector + namevectorlen >= stats.size:
72         raise IOError('String overruns stats segment')
73     return stats.statseg[namevector:namevector+namevectorlen-1].decode('ascii')
74
75
76 class StatsVector:
77     '''A class representing a VPP vector'''
78
79     def __init__(self, stats, ptr, fmt):
80         self.vec_start = ptr - stats.base
81         self.vec_len = get_vec_len(stats, ptr - stats.base)
82         self.struct = Struct(fmt)
83         self.fmtlen = len(fmt)
84         self.elementsize = self.struct.size
85         self.statseg = stats.statseg
86         self.stats = stats
87
88         if self.vec_start + self.vec_len * self.elementsize >= stats.size:
89             raise IOError('Vector overruns stats segment')
90
91     def __iter__(self):
92         with self.stats.lock:
93             return self.struct.iter_unpack(self.statseg[self.vec_start:self.vec_start +
94                                                         self.elementsize*self.vec_len])
95
96     def __getitem__(self, index):
97         if index > self.vec_len:
98             raise IOError('Index beyond end of vector')
99         with self.stats.lock:
100             if self.fmtlen == 1:
101                 return self.struct.unpack_from(self.statseg, self.vec_start +
102                                                (index * self.elementsize))[0]
103             return self.struct.unpack_from(self.statseg, self.vec_start +
104                                            (index * self.elementsize))
105
106 class VPPStats():
107     '''Main class implementing Python access to the VPP statistics segment'''
108     # pylint: disable=too-many-instance-attributes
109     shared_headerfmt = Struct('QPQQPP')
110     default_socketname = '/run/vpp/stats.sock'
111
112     def __init__(self, socketname=default_socketname, timeout=10):
113         self.socketname = socketname
114         self.timeout = timeout
115         self.directory = {}
116         self.lock = StatsLock(self)
117         self.connected = False
118         self.size = 0
119         self.last_epoch = 0
120         self.error_vectors = 0
121         self.statseg = 0
122
123     def connect(self):
124         '''Connect to stats segment'''
125         if self.connected:
126             return
127         sock = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
128         sock.connect(self.socketname)
129
130         mfd = recv_fd(sock)
131         sock.close()
132
133         stat_result = os.fstat(mfd)
134         self.statseg = mmap.mmap(mfd, stat_result.st_size, mmap.PROT_READ, mmap.MAP_SHARED)
135         os.close(mfd)
136
137         self.size = stat_result.st_size
138         if self.version != 2:
139             raise Exception('Incompatbile stat segment version {}'
140                             .format(self.version))
141
142         self.refresh()
143         self.connected = True
144
145     def disconnect(self):
146         '''Disconnect from stats segment'''
147         if self.connected:
148             self.statseg.close()
149             self.connected = False
150
151     @property
152     def version(self):
153         '''Get version of stats segment'''
154         return self.shared_headerfmt.unpack_from(self.statseg)[0]
155
156     @property
157     def base(self):
158         '''Get base pointer of stats segment'''
159         return self.shared_headerfmt.unpack_from(self.statseg)[1]
160
161     @property
162     def epoch(self):
163         '''Get current epoch value from stats segment'''
164         return self.shared_headerfmt.unpack_from(self.statseg)[2]
165
166     @property
167     def in_progress(self):
168         '''Get value of in_progress from stats segment'''
169         return self.shared_headerfmt.unpack_from(self.statseg)[3]
170
171     @property
172     def directory_vector(self):
173         '''Get pointer of directory vector'''
174         return self.shared_headerfmt.unpack_from(self.statseg)[4]
175
176     @property
177     def error_vector(self):
178         '''Get pointer of error vector'''
179         return self.shared_headerfmt.unpack_from(self.statseg)[5]
180
181     elementfmt = 'IQ128s'
182
183     def refresh(self, blocking=True):
184         '''Refresh directory vector cache (epoch changed)'''
185         directory = {}
186         directory_by_idx = {}
187         while True:
188             try:
189                 with self.lock:
190                     self.last_epoch = self.epoch
191                     for i, direntry in enumerate(StatsVector(self, self.directory_vector, self.elementfmt)):
192                         path_raw = direntry[2].find(b'\x00')
193                         path = direntry[2][:path_raw].decode('ascii')
194                         directory[path] = StatsEntry(direntry[0], direntry[1])
195                         directory_by_idx[i] = path
196                     self.directory = directory
197                     self.directory_by_idx = directory_by_idx
198
199                     # Cache the error index vectors
200                     self.error_vectors = []
201                     for threads in StatsVector(self, self.error_vector, 'P'):
202                         self.error_vectors.append(StatsVector(self, threads[0], 'Q'))
203                 # Return statement must be outside the lock block to be sure
204                 # lock.release is executed
205                 return
206             except IOError:
207                 if not blocking:
208                     raise
209
210     def __getitem__(self, item, blocking=True):
211         if not self.connected:
212             self.connect()
213         while True:
214             try:
215                 if self.last_epoch != self.epoch:
216                     self.refresh(blocking)
217                 with self.lock:
218                     result = self.directory[item].get_counter(self)
219                 # Return statement must be outside the lock block to be sure
220                 # lock.release is executed
221                 return result
222             except IOError:
223                 if not blocking:
224                     raise
225
226     def __iter__(self):
227         return iter(self.directory.items())
228
229     def set_errors(self, blocking=True):
230         '''Return dictionary of error counters > 0'''
231         if not self.connected:
232             self.connect()
233
234         errors = {k:v for k, v in self.directory.items() if k.startswith("/err/")}
235         result = {}
236         while True:
237             try:
238                 if self.last_epoch != self.epoch:
239                     self.refresh(blocking)
240                 with self.lock:
241                     for k, entry in errors.items():
242                         total = 0
243                         i = entry.value
244                         for per_thread in self.error_vectors:
245                             total += per_thread[i]
246                         if total:
247                             result[k] = total
248                 return result
249             except IOError:
250                 if not blocking:
251                     raise
252
253     def set_errors_str(self, blocking=True):
254         '''Return all errors counters > 0 pretty printed'''
255         error_string = ['ERRORS:']
256         error_counters = self.set_errors(blocking)
257         for k in sorted(error_counters):
258             error_string.append('{:<60}{:>10}'.format(k, error_counters[k]))
259         return '%s\n' % '\n'.join(error_string)
260
261     def get_counter(self, name, blocking=True):
262         '''Alternative call to __getitem__'''
263         return self.__getitem__(name, blocking)
264
265     def get_err_counter(self, name, blocking=True):
266         '''Return a single value (sum of all threads)'''
267         if not self.connected:
268             self.connect()
269         if name.startswith("/err/"):
270             while True:
271                 try:
272                     if self.last_epoch != self.epoch:
273                         self.refresh(blocking)
274                     with self.lock:
275                         result =  sum(self.directory[name].get_counter(self))
276                     # Return statement must be outside the lock block to be sure
277                     # lock.release is executed
278                     return result
279                 except IOError:
280                     if not blocking:
281                         raise
282
283     def ls(self, patterns):
284         '''Returns list of counters matching pattern'''
285         # pylint: disable=invalid-name
286         if not self.connected:
287             self.connect()
288         if not isinstance(patterns, list):
289             patterns = [patterns]
290         regex = [re.compile(i) for i in patterns]
291         return [k for k, v in self.directory.items()
292                 if any(re.match(pattern, k) for pattern in regex)]
293
294     def dump(self, counters, blocking=True):
295         '''Given a list of counters return a dictionary of results'''
296         if not self.connected:
297             self.connect()
298         result = {}
299         for cnt in counters:
300             result[cnt] = self.__getitem__(cnt,blocking)
301         return result
302
303 class StatsLock():
304     '''Stat segment optimistic locking'''
305
306     def __init__(self, stats):
307         self.stats = stats
308         self.epoch = 0
309
310     def __enter__(self):
311         acquired = self.acquire(blocking=True)
312         assert acquired, "Lock wasn't acquired, but blocking=True"
313         return self
314
315     def __exit__(self, exc_type=None, exc_value=None, traceback=None):
316         self.release()
317
318     def acquire(self, blocking=True, timeout=-1):
319         '''Acquire the lock. Await in progress to go false. Record epoch.'''
320         self.epoch = self.stats.epoch
321         if timeout > 0:
322             start = time.monotonic()
323         while self.stats.in_progress:
324             if not blocking:
325                 time.sleep(0.01)
326                 if timeout > 0:
327                     if start + time.monotonic() > timeout:
328                         return False
329         return True
330
331     def release(self):
332         '''Check if data read while locked is valid'''
333         if self.stats.in_progress or self.stats.epoch != self.epoch:
334             raise IOError('Optimistic lock failed, retry')
335
336     def locked(self):
337         '''Not used'''
338
339
340 class StatsCombinedList(list):
341     '''Column slicing for Combined counters list'''
342
343     def __getitem__(self, item):
344         '''Supports partial numpy style 2d support. Slice by column [:,1]'''
345         if isinstance(item, int):
346             return list.__getitem__(self, item)
347         return CombinedList([row[item[1]] for row in self])
348
349 class CombinedList(list):
350     '''Combined Counters 2-dimensional by thread by index of packets/octets'''
351
352     def packets(self):
353         '''Return column (2nd dimension). Packets for all threads'''
354         return [pair[0] for pair in self]
355
356     def octets(self):
357         '''Return column (2nd dimension). Octets for all threads'''
358         return [pair[1] for pair in self]
359
360     def sum_packets(self):
361         '''Return column (2nd dimension). Sum of all packets for all threads'''
362         return sum(self.packets())
363
364     def sum_octets(self):
365         '''Return column (2nd dimension). Sum of all octets for all threads'''
366         return sum(self.octets())
367
368 class StatsTuple(tuple):
369     '''A Combined vector tuple (packets, octets)'''
370     def __init__(self, data):
371         self.dictionary = {'packets': data[0], 'bytes': data[1]}
372         super().__init__()
373
374     def __repr__(self):
375         return dict.__repr__(self.dictionary)
376
377     def __getitem__(self, item):
378         if isinstance(item, int):
379             return tuple.__getitem__(self, item)
380         if item == 'packets':
381             return tuple.__getitem__(self, 0)
382         return tuple.__getitem__(self, 1)
383
384 class StatsSimpleList(list):
385     '''Simple Counters 2-dimensional by thread by index of packets'''
386
387     def __getitem__(self, item):
388         '''Supports partial numpy style 2d support. Slice by column [:,1]'''
389         if isinstance(item, int):
390             return list.__getitem__(self, item)
391         return SimpleList([row[item[1]] for row in self])
392
393 class SimpleList(list):
394     '''Simple counter'''
395
396     def sum(self):
397         '''Sum the vector'''
398         return sum(self)
399
400 class StatsEntry():
401     '''An individual stats entry'''
402     # pylint: disable=unused-argument,no-self-use
403
404     def __init__(self, stattype, statvalue):
405         self.type = stattype
406         self.value = statvalue
407
408         if stattype == 1:
409             self.function = self.scalar
410         elif stattype == 2:
411             self.function = self.simple
412         elif stattype == 3:
413             self.function = self.combined
414         elif stattype == 4:
415             self.function = self.error
416         elif stattype == 5:
417             self.function = self.name
418         elif stattype == 7:
419             self.function = self.symlink
420         else:
421             self.function = self.illegal
422
423     def illegal(self, stats):
424         '''Invalid or unknown counter type'''
425         return None
426
427     def scalar(self, stats):
428         '''Scalar counter'''
429         return self.value
430
431     def simple(self, stats):
432         '''Simple counter'''
433         counter = StatsSimpleList()
434         for threads in StatsVector(stats, self.value, 'P'):
435             clist = [v[0] for v in StatsVector(stats, threads[0], 'Q')]
436             counter.append(clist)
437         return counter
438
439     def combined(self, stats):
440         '''Combined counter'''
441         counter = StatsCombinedList()
442         for threads in StatsVector(stats, self.value, 'P'):
443             clist = [StatsTuple(cnt) for cnt in StatsVector(stats, threads[0], 'QQ')]
444             counter.append(clist)
445         return counter
446
447     def error(self, stats):
448         '''Error counter'''
449         counter = SimpleList()
450         for clist in stats.error_vectors:
451             counter.append(clist[self.value])
452         return counter
453
454     def name(self, stats):
455         '''Name counter'''
456         counter = []
457         for name in StatsVector(stats, self.value, 'P'):
458             if name[0]:
459                 counter.append(get_string(stats, name[0]))
460         return counter
461
462     SYMLINK_FMT1 = Struct('II')
463     SYMLINK_FMT2 = Struct('Q')
464     def symlink(self, stats):
465         '''Symlink counter'''
466         b = self.SYMLINK_FMT2.pack(self.value)
467         index1, index2 = self.SYMLINK_FMT1.unpack(b)
468         name = stats.directory_by_idx[index1]
469         return stats[name][:,index2]
470
471     def get_counter(self, stats):
472         '''Return a list of counters'''
473         if stats:
474             return self.function(stats)
475
476 class TestStats(unittest.TestCase):
477     '''Basic statseg tests'''
478
479     def setUp(self):
480         '''Connect to statseg'''
481         self.stat = VPPStats()
482         self.stat.connect()
483         self.profile = cProfile.Profile()
484         self.profile.enable()
485
486     def tearDown(self):
487         '''Disconnect from statseg'''
488         self.stat.disconnect()
489         profile = Stats(self.profile)
490         profile.strip_dirs()
491         profile.sort_stats('cumtime')
492         profile.print_stats()
493         print("\n--->>>")
494
495     def test_counters(self):
496         '''Test access to statseg'''
497
498         print('/err/abf-input-ip4/missed', self.stat['/err/abf-input-ip4/missed'])
499         print('/sys/heartbeat', self.stat['/sys/heartbeat'])
500         print('/if/names', self.stat['/if/names'])
501         print('/if/rx-miss', self.stat['/if/rx-miss'])
502         print('/if/rx-miss', self.stat['/if/rx-miss'][1])
503         print('/nat44-ed/out2in/slowpath/drops', self.stat['/nat44-ed/out2in/slowpath/drops'])
504         print('Set Errors', self.stat.set_errors())
505         with self.assertRaises(KeyError):
506             print('NO SUCH COUNTER', self.stat['foobar'])
507         print('/if/rx', self.stat.get_counter('/if/rx'))
508         print('/err/ethernet-input/no error',
509               self.stat.get_err_counter('/err/ethernet-input/no error'))
510
511     def test_column(self):
512         '''Test column slicing'''
513
514         print('/if/rx-miss', self.stat['/if/rx-miss'])
515         print('/if/rx', self.stat['/if/rx'])  # All interfaces for thread #1
516         print('/if/rx thread #1', self.stat['/if/rx'][0])  # All interfaces for thread #1
517         print('/if/rx thread #1, interface #1',
518               self.stat['/if/rx'][0][1])  # All interfaces for thread #1
519         print('/if/rx if_index #1', self.stat['/if/rx'][:, 1])
520         print('/if/rx if_index #1 packets', self.stat['/if/rx'][:, 1].packets())
521         print('/if/rx if_index #1 packets', self.stat['/if/rx'][:, 1].sum_packets())
522         print('/if/rx if_index #1 packets', self.stat['/if/rx'][:, 1].octets())
523         print('/if/rx-miss', self.stat['/if/rx-miss'])
524         print('/if/rx-miss if_index #1 packets', self.stat['/if/rx-miss'][:, 1].sum())
525         print('/if/rx if_index #1 packets', self.stat['/if/rx'][0][1]['packets'])
526
527     def test_error(self):
528         '''Test the error vector'''
529
530         print('/err/ethernet-input', self.stat['/err/ethernet-input/no error'])
531         print('/err/nat44-ei-ha/pkts-processed', self.stat['/err/nat44-ei-ha/pkts-processed'])
532         print('/err/ethernet-input', self.stat.get_err_counter('/err/ethernet-input/no error'))
533         print('/err/ethernet-input', self.stat['/err/ethernet-input/no error'].sum())
534
535     def test_nat44(self):
536         '''Test the nat counters'''
537
538         print('/nat44-ei/ha/del-event-recv', self.stat['/nat44-ei/ha/del-event-recv'])
539         print('/err/nat44-ei-ha/pkts-processed', self.stat['/err/nat44-ei-ha/pkts-processed'].sum())
540
541     def test_legacy(self):
542         '''Legacy interface'''
543         directory = self.stat.ls(["^/if", "/err/ip4-input", "/sys/node/ip4-input"])
544         data = self.stat.dump(directory)
545         print(data)
546         print('Looking up sys node')
547         directory = self.stat.ls(["^/sys/node"])
548         print('Dumping sys node')
549         data = self.stat.dump(directory)
550         print(data)
551         directory = self.stat.ls(["^/foobar"])
552         data = self.stat.dump(directory)
553         print(data)
554
555     def test_sys_nodes(self):
556         '''Test /sys/nodes'''
557         counters = self.stat.ls('^/sys/node')
558         print('COUNTERS:', counters)
559         print('/sys/node', self.stat.dump(counters))
560         print('/net/route/to', self.stat['/net/route/to'])
561
562     def test_symlink(self):
563         '''Symbolic links'''
564         print('/interface/local0/rx', self.stat['/interfaces/local0/rx'])
565         print('/sys/nodes/unix-epoll-input', self.stat['/nodes/unix-epoll-input/calls'])
566
567 if __name__ == '__main__':
568     import cProfile
569     from pstats import Stats
570
571     unittest.main()