misc: vpp_papi- add tests, clean up pep8
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_papi.py
1 #!/usr/bin/env python
2 #
3 # Copyright (c) 2016 Cisco and/or its affiliates.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at:
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16
17 from __future__ import print_function
18 from __future__ import absolute_import
19 import ctypes
20 import sys
21 import multiprocessing as mp
22 import os
23 import logging
24 import functools
25 import json
26 import threading
27 import fnmatch
28 import weakref
29 import atexit
30 from . vpp_serializer import VPPType, VPPEnumType, VPPUnionType
31 from . vpp_serializer import VPPMessage, vpp_get_type, VPPTypeAlias
32
33 logger = logging.getLogger(__name__)
34
35 if sys.version[0] == '2':
36     import Queue as queue
37 else:
38     import queue as queue
39
40 __all__ = ('FuncWrapper', 'VPP', 'VppApiDynamicMethodHolder',
41            'VppEnum', 'VppEnumType',
42            'VPPIOError', 'VPPRuntimeError', 'VPPValueError',
43            'VPPApiClient', )
44
45
46 def metaclass(metaclass):
47     @functools.wraps(metaclass)
48     def wrapper(cls):
49         return metaclass(cls.__name__, cls.__bases__, cls.__dict__.copy())
50
51     return wrapper
52
53
54 class VppEnumType(type):
55     def __getattr__(cls, name):
56         t = vpp_get_type(name)
57         return t.enum
58
59
60 @metaclass(VppEnumType)
61 class VppEnum(object):
62     pass
63
64
65 def vpp_atexit(vpp_weakref):
66     """Clean up VPP connection on shutdown."""
67     vpp_instance = vpp_weakref()
68     if vpp_instance and vpp_instance.transport.connected:
69         vpp_instance.logger.debug('Cleaning up VPP on exit')
70         vpp_instance.disconnect()
71
72
73 if sys.version[0] == '2':
74     def vpp_iterator(d):
75         return d.iteritems()
76 else:
77     def vpp_iterator(d):
78         return d.items()
79
80
81 def call_logger(msgdef, kwargs):
82     s = 'Calling {}('.format(msgdef.name)
83     for k, v in kwargs.items():
84         s += '{}:{} '.format(k, v)
85     s += ')'
86     return s
87
88
89 def return_logger(r):
90     s = 'Return from {}'.format(r)
91     return s
92
93
94 class VppApiDynamicMethodHolder(object):
95     pass
96
97
98 class FuncWrapper(object):
99     def __init__(self, func):
100         self._func = func
101         self.__name__ = func.__name__
102         self.__doc__ = func.__doc__
103
104     def __call__(self, **kwargs):
105         return self._func(**kwargs)
106
107
108 class VPPApiError(Exception):
109     pass
110
111
112 class VPPNotImplementedError(NotImplementedError):
113     pass
114
115
116 class VPPIOError(IOError):
117     pass
118
119
120 class VPPRuntimeError(RuntimeError):
121     pass
122
123
124 class VPPValueError(ValueError):
125     pass
126
127
128 class VPPApiClient(object):
129     """VPP interface.
130
131     This class provides the APIs to VPP.  The APIs are loaded
132     from provided .api.json files and makes functions accordingly.
133     These functions are documented in the VPP .api files, as they
134     are dynamically created.
135
136     Additionally, VPP can send callback messages; this class
137     provides a means to register a callback function to receive
138     these messages in a background thread.
139     """
140     apidir = None
141     VPPApiError = VPPApiError
142     VPPRuntimeError = VPPRuntimeError
143     VPPValueError = VPPValueError
144     VPPNotImplementedError = VPPNotImplementedError
145     VPPIOError = VPPIOError
146
147     def process_json_file(self, apidef_file):
148         api = json.load(apidef_file)
149         types = {}
150         for t in api['enums']:
151             t[0] = 'vl_api_' + t[0] + '_t'
152             types[t[0]] = {'type': 'enum', 'data': t}
153         for t in api['unions']:
154             t[0] = 'vl_api_' + t[0] + '_t'
155             types[t[0]] = {'type': 'union', 'data': t}
156         for t in api['types']:
157             t[0] = 'vl_api_' + t[0] + '_t'
158             types[t[0]] = {'type': 'type', 'data': t}
159         for t, v in api['aliases'].items():
160             types['vl_api_' + t + '_t'] = {'type': 'alias', 'data': v}
161         self.services.update(api['services'])
162
163         i = 0
164         while True:
165             unresolved = {}
166             for k, v in types.items():
167                 t = v['data']
168                 if not vpp_get_type(k):
169                     if v['type'] == 'enum':
170                         try:
171                             VPPEnumType(t[0], t[1:])
172                         except ValueError:
173                             unresolved[k] = v
174                     elif v['type'] == 'union':
175                         try:
176                             VPPUnionType(t[0], t[1:])
177                         except ValueError:
178                             unresolved[k] = v
179                     elif v['type'] == 'type':
180                         try:
181                             VPPType(t[0], t[1:])
182                         except ValueError:
183                             unresolved[k] = v
184                     elif v['type'] == 'alias':
185                         try:
186                             VPPTypeAlias(k, t)
187                         except ValueError:
188                             unresolved[k] = v
189             if len(unresolved) == 0:
190                 break
191             if i > 3:
192                 raise VPPValueError('Unresolved type definitions {}'
193                                     .format(unresolved))
194             types = unresolved
195             i += 1
196
197         for m in api['messages']:
198             try:
199                 self.messages[m[0]] = VPPMessage(m[0], m[1:])
200             except VPPNotImplementedError:
201                 self.logger.error('Not implemented error for {}'.format(m[0]))
202
203     def __init__(self, apifiles=None, testmode=False, async_thread=True,
204                  logger=None, loglevel=None,
205                  read_timeout=5, use_socket=False,
206                  server_address='/run/vpp-api.sock'):
207         """Create a VPP API object.
208
209         apifiles is a list of files containing API
210         descriptions that will be loaded - methods will be
211         dynamically created reflecting these APIs.  If not
212         provided this will load the API files from VPP's
213         default install location.
214
215         logger, if supplied, is the logging logger object to log to.
216         loglevel, if supplied, is the log level this logger is set
217         to report at (from the loglevels in the logging module).
218         """
219         if logger is None:
220             logger = logging.getLogger(__name__)
221             if loglevel is not None:
222                 logger.setLevel(loglevel)
223         self.logger = logger
224
225         self.messages = {}
226         self.services = {}
227         self.id_names = []
228         self.id_msgdef = []
229         self.header = VPPType('header', [['u16', 'msgid'],
230                                          ['u32', 'client_index']])
231         self.apifiles = []
232         self.event_callback = None
233         self.message_queue = queue.Queue()
234         self.read_timeout = read_timeout
235         self.async_thread = async_thread
236         self.event_thread = None
237         self.testmode = testmode
238         self.use_socket = use_socket
239         self.server_address = server_address
240         self._apifiles = apifiles
241
242         if use_socket:
243             from . vpp_transport_socket import VppTransport
244         else:
245             from . vpp_transport_shmem import VppTransport
246
247         if not apifiles:
248             # Pick up API definitions from default directory
249             try:
250                 apifiles = self.find_api_files()
251             except RuntimeError:
252                 # In test mode we don't care that we can't find the API files
253                 if testmode:
254                     apifiles = []
255                 else:
256                     raise VPPRuntimeError
257
258         for file in apifiles:
259             with open(file) as apidef_file:
260                 self.process_json_file(apidef_file)
261
262         self.apifiles = apifiles
263
264         # Basic sanity check
265         if len(self.messages) == 0 and not testmode:
266             raise VPPValueError(1, 'Missing JSON message definitions')
267
268         self.transport = VppTransport(self, read_timeout=read_timeout,
269                                       server_address=server_address)
270         # Make sure we allow VPP to clean up the message rings.
271         atexit.register(vpp_atexit, weakref.ref(self))
272
273     class ContextId(object):
274         """Multiprocessing-safe provider of unique context IDs."""
275         def __init__(self):
276             self.context = mp.Value(ctypes.c_uint, 0)
277             self.lock = mp.Lock()
278
279         def __call__(self):
280             """Get a new unique (or, at least, not recently used) context."""
281             with self.lock:
282                 self.context.value += 1
283                 return self.context.value
284     get_context = ContextId()
285
286     def get_type(self, name):
287         return vpp_get_type(name)
288
289     @classmethod
290     def find_api_dir(cls):
291         """Attempt to find the best directory in which API definition
292         files may reside. If the value VPP_API_DIR exists in the environment
293         then it is first on the search list. If we're inside a recognized
294         location in a VPP source tree (src/scripts and src/vpp-api/python)
295         then entries from there to the likely locations in build-root are
296         added. Finally the location used by system packages is added.
297
298         :returns: A single directory name, or None if no such directory
299             could be found.
300         """
301         dirs = [cls.apidir] if cls.apidir else []
302
303         # perhaps we're in the 'src/scripts' or 'src/vpp-api/python' dir;
304         # in which case, plot a course to likely places in the src tree
305         import __main__ as main
306         if hasattr(main, '__file__'):
307             # get the path of the calling script
308             localdir = os.path.dirname(os.path.realpath(main.__file__))
309         else:
310             # use cwd if there is no calling script
311             localdir = os.getcwd()
312         localdir_s = localdir.split(os.path.sep)
313
314         def dmatch(dir):
315             """Match dir against right-hand components of the script dir"""
316             d = dir.split('/')  # param 'dir' assumes a / separator
317             length = len(d)
318             return len(localdir_s) > length and localdir_s[-length:] == d
319
320         def sdir(srcdir, variant):
321             """Build a path from srcdir to the staged API files of
322             'variant'  (typically '' or '_debug')"""
323             # Since 'core' and 'plugin' files are staged
324             # in separate directories, we target the parent dir.
325             return os.path.sep.join((
326                 srcdir,
327                 'build-root',
328                 'install-vpp%s-native' % variant,
329                 'vpp',
330                 'share',
331                 'vpp',
332                 'api',
333             ))
334
335         srcdir = None
336         if dmatch('src/scripts'):
337             srcdir = os.path.sep.join(localdir_s[:-2])
338         elif dmatch('src/vpp-api/python'):
339             srcdir = os.path.sep.join(localdir_s[:-3])
340         elif dmatch('test'):
341             # we're apparently running tests
342             srcdir = os.path.sep.join(localdir_s[:-1])
343
344         if srcdir:
345             # we're in the source tree, try both the debug and release
346             # variants.
347             dirs.append(sdir(srcdir, '_debug'))
348             dirs.append(sdir(srcdir, ''))
349
350         # Test for staged copies of the scripts
351         # For these, since we explicitly know if we're running a debug versus
352         # release variant, target only the relevant directory
353         if dmatch('build-root/install-vpp_debug-native/vpp/bin'):
354             srcdir = os.path.sep.join(localdir_s[:-4])
355             dirs.append(sdir(srcdir, '_debug'))
356         if dmatch('build-root/install-vpp-native/vpp/bin'):
357             srcdir = os.path.sep.join(localdir_s[:-4])
358             dirs.append(sdir(srcdir, ''))
359
360         # finally, try the location system packages typically install into
361         dirs.append(os.path.sep.join(('', 'usr', 'share', 'vpp', 'api')))
362
363         # check the directories for existence; first one wins
364         for dir in dirs:
365             if os.path.isdir(dir):
366                 return dir
367
368         return None
369
370     @classmethod
371     def find_api_files(cls, api_dir=None, patterns='*'):
372         """Find API definition files from the given directory tree with the
373         given pattern. If no directory is given then find_api_dir() is used
374         to locate one. If no pattern is given then all definition files found
375         in the directory tree are used.
376
377         :param api_dir: A directory tree in which to locate API definition
378             files; subdirectories are descended into.
379             If this is None then find_api_dir() is called to discover it.
380         :param patterns: A list of patterns to use in each visited directory
381             when looking for files.
382             This can be a list/tuple object or a comma-separated string of
383             patterns. Each value in the list will have leading/trialing
384             whitespace stripped.
385             The pattern specifies the first part of the filename, '.api.json'
386             is appended.
387             The results are de-duplicated, thus overlapping patterns are fine.
388             If this is None it defaults to '*' meaning "all API files".
389         :returns: A list of file paths for the API files found.
390         """
391         if api_dir is None:
392             api_dir = cls.find_api_dir()
393             if api_dir is None:
394                 raise VPPApiError("api_dir cannot be located")
395
396         if isinstance(patterns, list) or isinstance(patterns, tuple):
397             patterns = [p.strip() + '.api.json' for p in patterns]
398         else:
399             patterns = [p.strip() + '.api.json' for p in patterns.split(",")]
400
401         api_files = []
402         for root, dirnames, files in os.walk(api_dir):
403             # iterate all given patterns and de-dup the result
404             files = set(sum([fnmatch.filter(files, p) for p in patterns], []))
405             for filename in files:
406                 api_files.append(os.path.join(root, filename))
407
408         return api_files
409
410     @property
411     def api(self):
412         if not hasattr(self, "_api"):
413             raise VPPApiError("Not connected, api definitions not available")
414         return self._api
415
416     def make_function(self, msg, i, multipart, do_async):
417         if (do_async):
418             def f(**kwargs):
419                 return self._call_vpp_async(i, msg, **kwargs)
420         else:
421             def f(**kwargs):
422                 return self._call_vpp(i, msg, multipart, **kwargs)
423
424         f.__name__ = str(msg.name)
425         f.__doc__ = ", ".join(["%s %s" %
426                                (msg.fieldtypes[j], k)
427                                for j, k in enumerate(msg.fields)])
428         f.msg = msg
429
430         return f
431
432     def _register_functions(self, do_async=False):
433         self.id_names = [None] * (self.vpp_dictionary_maxid + 1)
434         self.id_msgdef = [None] * (self.vpp_dictionary_maxid + 1)
435         self._api = VppApiDynamicMethodHolder()
436         for name, msg in vpp_iterator(self.messages):
437             n = name + '_' + msg.crc[2:]
438             i = self.transport.get_msg_index(n.encode('utf-8'))
439             if i > 0:
440                 self.id_msgdef[i] = msg
441                 self.id_names[i] = name
442
443                 # Create function for client side messages.
444                 if name in self.services:
445                     if 'stream' in self.services[name] and \
446                        self.services[name]['stream']:
447                         multipart = True
448                     else:
449                         multipart = False
450                     f = self.make_function(msg, i, multipart, do_async)
451                     setattr(self._api, name, FuncWrapper(f))
452             else:
453                 self.logger.debug(
454                     'No such message type or failed CRC checksum: %s', n)
455
456     def connect_internal(self, name, msg_handler, chroot_prefix, rx_qlen,
457                          do_async):
458         pfx = chroot_prefix.encode('utf-8') if chroot_prefix else None
459
460         rv = self.transport.connect(name.encode('utf-8'), pfx,
461                                     msg_handler, rx_qlen)
462         if rv != 0:
463             raise VPPIOError(2, 'Connect failed')
464         self.vpp_dictionary_maxid = self.transport.msg_table_max_index()
465         self._register_functions(do_async=do_async)
466
467         # Initialise control ping
468         crc = self.messages['control_ping'].crc
469         self.control_ping_index = self.transport.get_msg_index(
470             ('control_ping' + '_' + crc[2:]).encode('utf-8'))
471         self.control_ping_msgdef = self.messages['control_ping']
472         if self.async_thread:
473             self.event_thread = threading.Thread(
474                 target=self.thread_msg_handler)
475             self.event_thread.daemon = True
476             self.event_thread.start()
477         else:
478             self.event_thread = None
479         return rv
480
481     def connect(self, name, chroot_prefix=None, do_async=False, rx_qlen=32):
482         """Attach to VPP.
483
484         name - the name of the client.
485         chroot_prefix - if VPP is chroot'ed, the prefix of the jail
486         do_async - if true, messages are sent without waiting for a reply
487         rx_qlen - the length of the VPP message receive queue between
488         client and server.
489         """
490         msg_handler = self.transport.get_callback(do_async)
491         return self.connect_internal(name, msg_handler, chroot_prefix, rx_qlen,
492                                      do_async)
493
494     def connect_sync(self, name, chroot_prefix=None, rx_qlen=32):
495         """Attach to VPP in synchronous mode. Application must poll for events.
496
497         name - the name of the client.
498         chroot_prefix - if VPP is chroot'ed, the prefix of the jail
499         rx_qlen - the length of the VPP message receive queue between
500         client and server.
501         """
502
503         return self.connect_internal(name, None, chroot_prefix, rx_qlen,
504                                      do_async=False)
505
506     def disconnect(self):
507         """Detach from VPP."""
508         rv = self.transport.disconnect()
509         if self.event_thread is not None:
510             self.message_queue.put("terminate event thread")
511         return rv
512
513     def msg_handler_sync(self, msg):
514         """Process an incoming message from VPP in sync mode.
515
516         The message may be a reply or it may be an async notification.
517         """
518         r = self.decode_incoming_msg(msg)
519         if r is None:
520             return
521
522         # If we have a context, then use the context to find any
523         # request waiting for a reply
524         context = 0
525         if hasattr(r, 'context') and r.context > 0:
526             context = r.context
527
528         if context == 0:
529             # No context -> async notification that we feed to the callback
530             self.message_queue.put_nowait(r)
531         else:
532             raise VPPIOError(2, 'RPC reply message received in event handler')
533
534     def has_context(self, msg):
535         if len(msg) < 10:
536             return False
537
538         header = VPPType('header_with_context', [['u16', 'msgid'],
539                                                  ['u32', 'client_index'],
540                                                  ['u32', 'context']])
541
542         (i, ci, context), size = header.unpack(msg, 0)
543         if self.id_names[i] == 'rx_thread_exit':
544             return
545
546         #
547         # Decode message and returns a tuple.
548         #
549         msgobj = self.id_msgdef[i]
550         if 'context' in msgobj.field_by_name and context >= 0:
551             return True
552         return False
553
554     def decode_incoming_msg(self, msg, no_type_conversion=False):
555         if not msg:
556             self.logger.warning('vpp_api.read failed')
557             return
558
559         (i, ci), size = self.header.unpack(msg, 0)
560         if self.id_names[i] == 'rx_thread_exit':
561             return
562
563         #
564         # Decode message and returns a tuple.
565         #
566         msgobj = self.id_msgdef[i]
567         if not msgobj:
568             raise VPPIOError(2, 'Reply message undefined')
569
570         r, size = msgobj.unpack(msg, ntc=no_type_conversion)
571         return r
572
573     def msg_handler_async(self, msg):
574         """Process a message from VPP in async mode.
575
576         In async mode, all messages are returned to the callback.
577         """
578         r = self.decode_incoming_msg(msg)
579         if r is None:
580             return
581
582         msgname = type(r).__name__
583
584         if self.event_callback:
585             self.event_callback(msgname, r)
586
587     def _control_ping(self, context):
588         """Send a ping command."""
589         self._call_vpp_async(self.control_ping_index,
590                              self.control_ping_msgdef,
591                              context=context)
592
593     def validate_args(self, msg, kwargs):
594         d = set(kwargs.keys()) - set(msg.field_by_name.keys())
595         if d:
596             raise VPPValueError('Invalid argument {} to {}'
597                                 .format(list(d), msg.name))
598
599     def _call_vpp(self, i, msgdef, multipart, **kwargs):
600         """Given a message, send the message and await a reply.
601
602         msgdef - the message packing definition
603         i - the message type index
604         multipart - True if the message returns multiple
605         messages in return.
606         context - context number - chosen at random if not
607         supplied.
608         The remainder of the kwargs are the arguments to the API call.
609
610         The return value is the message or message array containing
611         the response.  It will raise an IOError exception if there was
612         no response within the timeout window.
613         """
614
615         if 'context' not in kwargs:
616             context = self.get_context()
617             kwargs['context'] = context
618         else:
619             context = kwargs['context']
620         kwargs['_vl_msg_id'] = i
621
622         no_type_conversion = kwargs.pop('_no_type_conversion', False)
623
624         try:
625             if self.transport.socket_index:
626                 kwargs['client_index'] = self.transport.socket_index
627         except AttributeError:
628             pass
629         self.validate_args(msgdef, kwargs)
630
631         logging.debug(call_logger(msgdef, kwargs))
632
633         b = msgdef.pack(kwargs)
634         self.transport.suspend()
635
636         self.transport.write(b)
637
638         if multipart:
639             # Send a ping after the request - we use its response
640             # to detect that we have seen all results.
641             self._control_ping(context)
642
643         # Block until we get a reply.
644         rl = []
645         while (True):
646             msg = self.transport.read()
647             if not msg:
648                 raise VPPIOError(2, 'VPP API client: read failed')
649             r = self.decode_incoming_msg(msg, no_type_conversion)
650             msgname = type(r).__name__
651             if context not in r or r.context == 0 or context != r.context:
652                 # Message being queued
653                 self.message_queue.put_nowait(r)
654                 continue
655
656             if not multipart:
657                 rl = r
658                 break
659             if msgname == 'control_ping_reply':
660                 break
661
662             rl.append(r)
663
664         self.transport.resume()
665
666         logger.debug(return_logger(rl))
667         return rl
668
669     def _call_vpp_async(self, i, msg, **kwargs):
670         """Given a message, send the message and await a reply.
671
672         msgdef - the message packing definition
673         i - the message type index
674         context - context number - chosen at random if not
675         supplied.
676         The remainder of the kwargs are the arguments to the API call.
677         """
678         if 'context' not in kwargs:
679             context = self.get_context()
680             kwargs['context'] = context
681         else:
682             context = kwargs['context']
683         try:
684             if self.transport.socket_index:
685                 kwargs['client_index'] = self.transport.socket_index
686         except AttributeError:
687             kwargs['client_index'] = 0
688         kwargs['_vl_msg_id'] = i
689         b = msg.pack(kwargs)
690
691         self.transport.write(b)
692
693     def register_event_callback(self, callback):
694         """Register a callback for async messages.
695
696         This will be called for async notifications in sync mode,
697         and all messages in async mode.  In sync mode, replies to
698         requests will not come here.
699
700         callback is a fn(msg_type_name, msg_type) that will be
701         called when a message comes in.  While this function is
702         executing, note that (a) you are in a background thread and
703         may wish to use threading.Lock to protect your datastructures,
704         and (b) message processing from VPP will stop (so if you take
705         a long while about it you may provoke reply timeouts or cause
706         VPP to fill the RX buffer).  Passing None will disable the
707         callback.
708         """
709         self.event_callback = callback
710
711     def thread_msg_handler(self):
712         """Python thread calling the user registered message handler.
713
714         This is to emulate the old style event callback scheme. Modern
715         clients should provide their own thread to poll the event
716         queue.
717         """
718         while True:
719             r = self.message_queue.get()
720             if r == "terminate event thread":
721                 break
722             msgname = type(r).__name__
723             if self.event_callback:
724                 self.event_callback(msgname, r)
725
726     def __repr__(self):
727         return "<VPPApiClient apifiles=%s, testmode=%s, async_thread=%s, " \
728                "logger=%s, read_timeout=%s, use_socket=%s, " \
729                "server_address='%s'>" % (
730                    self._apifiles, self.testmode, self.async_thread,
731                    self.logger, self.read_timeout, self.use_socket,
732                    self.server_address)
733
734
735 # Provide the old name for backward compatibility.
736 VPP = VPPApiClient
737
738 # vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4