test framework: add factory function and default parameters
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_papi.py
1 #!/usr/bin/env python
2 #
3 # Copyright (c) 2016 Cisco and/or its affiliates.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at:
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16
17 from __future__ import print_function
18 from __future__ import absolute_import
19 import sys
20 import os
21 import logging
22 import collections
23 import struct
24 import functools
25 import json
26 import threading
27 import fnmatch
28 import weakref
29 import atexit
30 from . vpp_serializer import VPPType, VPPEnumType, VPPUnionType, BaseTypes
31 from . vpp_serializer import VPPMessage, vpp_get_type, VPPTypeAlias
32 from . macaddress import MACAddress, mac_pton, mac_ntop
33
34 logger = logging.getLogger(__name__)
35
36 if sys.version[0] == '2':
37     import Queue as queue
38 else:
39     import queue as queue
40
41
42 def metaclass(metaclass):
43     @functools.wraps(metaclass)
44     def wrapper(cls):
45         return metaclass(cls.__name__, cls.__bases__, cls.__dict__.copy())
46
47     return wrapper
48
49
50 class VppEnumType(type):
51     def __getattr__(cls, name):
52         t = vpp_get_type(name)
53         return t.enum
54
55
56 @metaclass(VppEnumType)
57 class VppEnum(object):
58     pass
59
60
61 def vpp_atexit(vpp_weakref):
62     """Clean up VPP connection on shutdown."""
63     vpp_instance = vpp_weakref()
64     if vpp_instance and vpp_instance.transport.connected:
65         vpp_instance.logger.debug('Cleaning up VPP on exit')
66         vpp_instance.disconnect()
67
68
69 if sys.version[0] == '2':
70     def vpp_iterator(d):
71         return d.iteritems()
72 else:
73     def vpp_iterator(d):
74         return d.items()
75
76
77 def call_logger(msgdef, kwargs):
78     s = 'Calling {}('.format(msgdef.name)
79     for k, v in kwargs.items():
80         s += '{}:{} '.format(k, v)
81     s += ')'
82     return s
83
84
85 def return_logger(r):
86     s = 'Return from {}'.format(r)
87     return s
88
89
90 class VppApiDynamicMethodHolder(object):
91     pass
92
93
94 class FuncWrapper(object):
95     def __init__(self, func):
96         self._func = func
97         self.__name__ = func.__name__
98
99     def __call__(self, **kwargs):
100         return self._func(**kwargs)
101
102
103 class VPPApiError(Exception):
104     pass
105
106
107 class VPPNotImplementedError(NotImplementedError):
108     pass
109
110
111 class VPPIOError(IOError):
112     pass
113
114
115 class VPPRuntimeError(RuntimeError):
116     pass
117
118
119 class VPPValueError(ValueError):
120     pass
121
122
123 class VPP(object):
124     """VPP interface.
125
126     This class provides the APIs to VPP.  The APIs are loaded
127     from provided .api.json files and makes functions accordingly.
128     These functions are documented in the VPP .api files, as they
129     are dynamically created.
130
131     Additionally, VPP can send callback messages; this class
132     provides a means to register a callback function to receive
133     these messages in a background thread.
134     """
135     VPPApiError = VPPApiError
136     VPPRuntimeError = VPPRuntimeError
137     VPPValueError = VPPValueError
138     VPPNotImplementedError = VPPNotImplementedError
139     VPPIOError = VPPIOError
140
141     def process_json_file(self, apidef_file):
142         api = json.load(apidef_file)
143         types = {}
144         for t in api['enums']:
145             t[0] = 'vl_api_' + t[0] + '_t'
146             types[t[0]] = {'type': 'enum', 'data': t}
147         for t in api['unions']:
148             t[0] = 'vl_api_' + t[0] + '_t'
149             types[t[0]] = {'type': 'union', 'data': t}
150         for t in api['types']:
151             t[0] = 'vl_api_' + t[0] + '_t'
152             types[t[0]] = {'type': 'type', 'data': t}
153         for t, v in api['aliases'].items():
154             types['vl_api_' + t + '_t'] = {'type': 'alias', 'data': v}
155         self.services.update(api['services'])
156
157         i = 0
158         while True:
159             unresolved = {}
160             for k, v in types.items():
161                 t = v['data']
162                 if not vpp_get_type(k):
163                     if v['type'] == 'enum':
164                         try:
165                             VPPEnumType(t[0], t[1:])
166                         except ValueError:
167                             unresolved[k] = v
168                     elif v['type'] == 'union':
169                         try:
170                             VPPUnionType(t[0], t[1:])
171                         except ValueError:
172                             unresolved[k] = v
173                     elif v['type'] == 'type':
174                         try:
175                             VPPType(t[0], t[1:])
176                         except ValueError:
177                             unresolved[k] = v
178                     elif v['type'] == 'alias':
179                         try:
180                             VPPTypeAlias(k, t)
181                         except ValueError:
182                             unresolved[k] = v
183             if len(unresolved) == 0:
184                 break
185             if i > 3:
186                 raise VPPValueError('Unresolved type definitions {}'
187                                     .format(unresolved))
188             types = unresolved
189             i += 1
190
191         for m in api['messages']:
192             try:
193                 self.messages[m[0]] = VPPMessage(m[0], m[1:])
194             except VPPNotImplementedError:
195                 self.logger.error('Not implemented error for {}'.format(m[0]))
196
197     def __init__(self, apifiles=None, testmode=False, async_thread=True,
198                  logger=None, loglevel=None,
199                  read_timeout=5, use_socket=False,
200                  server_address='/run/vpp-api.sock'):
201         """Create a VPP API object.
202
203         apifiles is a list of files containing API
204         descriptions that will be loaded - methods will be
205         dynamically created reflecting these APIs.  If not
206         provided this will load the API files from VPP's
207         default install location.
208
209         logger, if supplied, is the logging logger object to log to.
210         loglevel, if supplied, is the log level this logger is set
211         to report at (from the loglevels in the logging module).
212         """
213         if logger is None:
214             logger = logging.getLogger(__name__)
215             if loglevel is not None:
216                 logger.setLevel(loglevel)
217         self.logger = logger
218
219         self.messages = {}
220         self.services = {}
221         self.id_names = []
222         self.id_msgdef = []
223         self.header = VPPType('header', [['u16', 'msgid'],
224                                          ['u32', 'client_index']])
225         self.apifiles = []
226         self.event_callback = None
227         self.message_queue = queue.Queue()
228         self.read_timeout = read_timeout
229         self.async_thread = async_thread
230
231         if use_socket:
232             from . vpp_transport_socket import VppTransport
233         else:
234             from . vpp_transport_shmem import VppTransport
235
236         if not apifiles:
237             # Pick up API definitions from default directory
238             try:
239                 apifiles = self.find_api_files()
240             except RuntimeError:
241                 # In test mode we don't care that we can't find the API files
242                 if testmode:
243                     apifiles = []
244                 else:
245                     raise VPPRuntimeError
246
247         for file in apifiles:
248             with open(file) as apidef_file:
249                 self.process_json_file(apidef_file)
250
251         self.apifiles = apifiles
252
253         # Basic sanity check
254         if len(self.messages) == 0 and not testmode:
255             raise VPPValueError(1, 'Missing JSON message definitions')
256
257         self.transport = VppTransport(self, read_timeout=read_timeout,
258                                       server_address=server_address)
259         # Make sure we allow VPP to clean up the message rings.
260         atexit.register(vpp_atexit, weakref.ref(self))
261
262     class ContextId(object):
263         """Thread-safe provider of unique context IDs."""
264         def __init__(self):
265             self.context = 0
266             self.lock = threading.Lock()
267
268         def __call__(self):
269             """Get a new unique (or, at least, not recently used) context."""
270             with self.lock:
271                 self.context += 1
272                 return self.context
273     get_context = ContextId()
274
275     def get_type(self, name):
276         return vpp_get_type(name)
277
278     @classmethod
279     def find_api_dir(cls):
280         """Attempt to find the best directory in which API definition
281         files may reside. If the value VPP_API_DIR exists in the environment
282         then it is first on the search list. If we're inside a recognized
283         location in a VPP source tree (src/scripts and src/vpp-api/python)
284         then entries from there to the likely locations in build-root are
285         added. Finally the location used by system packages is added.
286
287         :returns: A single directory name, or None if no such directory
288             could be found.
289         """
290         dirs = []
291
292         if 'VPP_API_DIR' in os.environ:
293             dirs.append(os.environ['VPP_API_DIR'])
294
295         # perhaps we're in the 'src/scripts' or 'src/vpp-api/python' dir;
296         # in which case, plot a course to likely places in the src tree
297         import __main__ as main
298         if hasattr(main, '__file__'):
299             # get the path of the calling script
300             localdir = os.path.dirname(os.path.realpath(main.__file__))
301         else:
302             # use cwd if there is no calling script
303             localdir = os.getcwd()
304         localdir_s = localdir.split(os.path.sep)
305
306         def dmatch(dir):
307             """Match dir against right-hand components of the script dir"""
308             d = dir.split('/')  # param 'dir' assumes a / separator
309             length = len(d)
310             return len(localdir_s) > length and localdir_s[-length:] == d
311
312         def sdir(srcdir, variant):
313             """Build a path from srcdir to the staged API files of
314             'variant'  (typically '' or '_debug')"""
315             # Since 'core' and 'plugin' files are staged
316             # in separate directories, we target the parent dir.
317             return os.path.sep.join((
318                 srcdir,
319                 'build-root',
320                 'install-vpp%s-native' % variant,
321                 'vpp',
322                 'share',
323                 'vpp',
324                 'api',
325             ))
326
327         srcdir = None
328         if dmatch('src/scripts'):
329             srcdir = os.path.sep.join(localdir_s[:-2])
330         elif dmatch('src/vpp-api/python'):
331             srcdir = os.path.sep.join(localdir_s[:-3])
332         elif dmatch('test'):
333             # we're apparently running tests
334             srcdir = os.path.sep.join(localdir_s[:-1])
335
336         if srcdir:
337             # we're in the source tree, try both the debug and release
338             # variants.
339             dirs.append(sdir(srcdir, '_debug'))
340             dirs.append(sdir(srcdir, ''))
341
342         # Test for staged copies of the scripts
343         # For these, since we explicitly know if we're running a debug versus
344         # release variant, target only the relevant directory
345         if dmatch('build-root/install-vpp_debug-native/vpp/bin'):
346             srcdir = os.path.sep.join(localdir_s[:-4])
347             dirs.append(sdir(srcdir, '_debug'))
348         if dmatch('build-root/install-vpp-native/vpp/bin'):
349             srcdir = os.path.sep.join(localdir_s[:-4])
350             dirs.append(sdir(srcdir, ''))
351
352         # finally, try the location system packages typically install into
353         dirs.append(os.path.sep.join(('', 'usr', 'share', 'vpp', 'api')))
354
355         # check the directories for existance; first one wins
356         for dir in dirs:
357             if os.path.isdir(dir):
358                 return dir
359
360         return None
361
362     @classmethod
363     def find_api_files(cls, api_dir=None, patterns='*'):
364         """Find API definition files from the given directory tree with the
365         given pattern. If no directory is given then find_api_dir() is used
366         to locate one. If no pattern is given then all definition files found
367         in the directory tree are used.
368
369         :param api_dir: A directory tree in which to locate API definition
370             files; subdirectories are descended into.
371             If this is None then find_api_dir() is called to discover it.
372         :param patterns: A list of patterns to use in each visited directory
373             when looking for files.
374             This can be a list/tuple object or a comma-separated string of
375             patterns. Each value in the list will have leading/trialing
376             whitespace stripped.
377             The pattern specifies the first part of the filename, '.api.json'
378             is appended.
379             The results are de-duplicated, thus overlapping patterns are fine.
380             If this is None it defaults to '*' meaning "all API files".
381         :returns: A list of file paths for the API files found.
382         """
383         if api_dir is None:
384             api_dir = cls.find_api_dir()
385             if api_dir is None:
386                 raise VPPApiError("api_dir cannot be located")
387
388         if isinstance(patterns, list) or isinstance(patterns, tuple):
389             patterns = [p.strip() + '.api.json' for p in patterns]
390         else:
391             patterns = [p.strip() + '.api.json' for p in patterns.split(",")]
392
393         api_files = []
394         for root, dirnames, files in os.walk(api_dir):
395             # iterate all given patterns and de-dup the result
396             files = set(sum([fnmatch.filter(files, p) for p in patterns], []))
397             for filename in files:
398                 api_files.append(os.path.join(root, filename))
399
400         return api_files
401
402     @property
403     def api(self):
404         if not hasattr(self, "_api"):
405             raise VPPApiError("Not connected, api definitions not available")
406         return self._api
407
408     def make_function(self, msg, i, multipart, do_async):
409         if (do_async):
410             def f(**kwargs):
411                 return self._call_vpp_async(i, msg, **kwargs)
412         else:
413             def f(**kwargs):
414                 return self._call_vpp(i, msg, multipart, **kwargs)
415
416         f.__name__ = str(msg.name)
417         f.__doc__ = ", ".join(["%s %s" %
418                                (msg.fieldtypes[j], k)
419                                for j, k in enumerate(msg.fields)])
420         f.msg = msg
421
422         return f
423
424     def _register_functions(self, do_async=False):
425         self.id_names = [None] * (self.vpp_dictionary_maxid + 1)
426         self.id_msgdef = [None] * (self.vpp_dictionary_maxid + 1)
427         self._api = VppApiDynamicMethodHolder()
428         for name, msg in vpp_iterator(self.messages):
429             n = name + '_' + msg.crc[2:]
430             i = self.transport.get_msg_index(n.encode())
431             if i > 0:
432                 self.id_msgdef[i] = msg
433                 self.id_names[i] = name
434
435                 # Create function for client side messages.
436                 if name in self.services:
437                     if 'stream' in self.services[name] and \
438                        self.services[name]['stream']:
439                         multipart = True
440                     else:
441                         multipart = False
442                     f = self.make_function(msg, i, multipart, do_async)
443                     setattr(self._api, name, FuncWrapper(f))
444             else:
445                 self.logger.debug(
446                     'No such message type or failed CRC checksum: %s', n)
447
448     def connect_internal(self, name, msg_handler, chroot_prefix, rx_qlen,
449                          do_async):
450         pfx = chroot_prefix.encode() if chroot_prefix else None
451
452         rv = self.transport.connect(name.encode(), pfx, msg_handler, rx_qlen)
453         if rv != 0:
454             raise VPPIOError(2, 'Connect failed')
455         self.vpp_dictionary_maxid = self.transport.msg_table_max_index()
456         self._register_functions(do_async=do_async)
457
458         # Initialise control ping
459         crc = self.messages['control_ping'].crc
460         self.control_ping_index = self.transport.get_msg_index(
461             ('control_ping' + '_' + crc[2:]).encode())
462         self.control_ping_msgdef = self.messages['control_ping']
463         if self.async_thread:
464             self.event_thread = threading.Thread(
465                 target=self.thread_msg_handler)
466             self.event_thread.daemon = True
467             self.event_thread.start()
468         return rv
469
470     def connect(self, name, chroot_prefix=None, do_async=False, rx_qlen=32):
471         """Attach to VPP.
472
473         name - the name of the client.
474         chroot_prefix - if VPP is chroot'ed, the prefix of the jail
475         do_async - if true, messages are sent without waiting for a reply
476         rx_qlen - the length of the VPP message receive queue between
477         client and server.
478         """
479         msg_handler = self.transport.get_callback(do_async)
480         return self.connect_internal(name, msg_handler, chroot_prefix, rx_qlen,
481                                      do_async)
482
483     def connect_sync(self, name, chroot_prefix=None, rx_qlen=32):
484         """Attach to VPP in synchronous mode. Application must poll for events.
485
486         name - the name of the client.
487         chroot_prefix - if VPP is chroot'ed, the prefix of the jail
488         rx_qlen - the length of the VPP message receive queue between
489         client and server.
490         """
491
492         return self.connect_internal(name, None, chroot_prefix, rx_qlen,
493                                      do_async=False)
494
495     def disconnect(self):
496         """Detach from VPP."""
497         rv = self.transport.disconnect()
498         self.message_queue.put("terminate event thread")
499         return rv
500
501     def msg_handler_sync(self, msg):
502         """Process an incoming message from VPP in sync mode.
503
504         The message may be a reply or it may be an async notification.
505         """
506         r = self.decode_incoming_msg(msg)
507         if r is None:
508             return
509
510         # If we have a context, then use the context to find any
511         # request waiting for a reply
512         context = 0
513         if hasattr(r, 'context') and r.context > 0:
514             context = r.context
515
516         if context == 0:
517             # No context -> async notification that we feed to the callback
518             self.message_queue.put_nowait(r)
519         else:
520             raise VPPIOError(2, 'RPC reply message received in event handler')
521
522     def has_context(self, msg):
523         if len(msg) < 10:
524             return False
525
526         header = VPPType('header_with_context', [['u16', 'msgid'],
527                                                  ['u32', 'client_index'],
528                                                  ['u32', 'context']])
529
530         (i, ci, context), size = header.unpack(msg, 0)
531         if self.id_names[i] == 'rx_thread_exit':
532             return
533
534         #
535         # Decode message and returns a tuple.
536         #
537         msgobj = self.id_msgdef[i]
538         if 'context' in msgobj.field_by_name and context >= 0:
539             return True
540         return False
541
542     def decode_incoming_msg(self, msg, no_type_conversion=False):
543         if not msg:
544             self.logger.warning('vpp_api.read failed')
545             return
546
547         (i, ci), size = self.header.unpack(msg, 0)
548         if self.id_names[i] == 'rx_thread_exit':
549             return
550
551         #
552         # Decode message and returns a tuple.
553         #
554         msgobj = self.id_msgdef[i]
555         if not msgobj:
556             raise VPPIOError(2, 'Reply message undefined')
557
558         r, size = msgobj.unpack(msg, ntc=no_type_conversion)
559         return r
560
561     def msg_handler_async(self, msg):
562         """Process a message from VPP in async mode.
563
564         In async mode, all messages are returned to the callback.
565         """
566         r = self.decode_incoming_msg(msg)
567         if r is None:
568             return
569
570         msgname = type(r).__name__
571
572         if self.event_callback:
573             self.event_callback(msgname, r)
574
575     def _control_ping(self, context):
576         """Send a ping command."""
577         self._call_vpp_async(self.control_ping_index,
578                              self.control_ping_msgdef,
579                              context=context)
580
581     def validate_args(self, msg, kwargs):
582         d = set(kwargs.keys()) - set(msg.field_by_name.keys())
583         if d:
584             raise VPPValueError('Invalid argument {} to {}'
585                                 .format(list(d), msg.name))
586
587     def _call_vpp(self, i, msgdef, multipart, **kwargs):
588         """Given a message, send the message and await a reply.
589
590         msgdef - the message packing definition
591         i - the message type index
592         multipart - True if the message returns multiple
593         messages in return.
594         context - context number - chosen at random if not
595         supplied.
596         The remainder of the kwargs are the arguments to the API call.
597
598         The return value is the message or message array containing
599         the response.  It will raise an IOError exception if there was
600         no response within the timeout window.
601         """
602
603         if 'context' not in kwargs:
604             context = self.get_context()
605             kwargs['context'] = context
606         else:
607             context = kwargs['context']
608         kwargs['_vl_msg_id'] = i
609
610         no_type_conversion = kwargs.pop('_no_type_conversion', False)
611
612         try:
613             if self.transport.socket_index:
614                 kwargs['client_index'] = self.transport.socket_index
615         except AttributeError:
616             pass
617         self.validate_args(msgdef, kwargs)
618
619         logging.debug(call_logger(msgdef, kwargs))
620
621         b = msgdef.pack(kwargs)
622         self.transport.suspend()
623
624         self.transport.write(b)
625
626         if multipart:
627             # Send a ping after the request - we use its response
628             # to detect that we have seen all results.
629             self._control_ping(context)
630
631         # Block until we get a reply.
632         rl = []
633         while (True):
634             msg = self.transport.read()
635             if not msg:
636                 raise VPPIOError(2, 'VPP API client: read failed')
637             r = self.decode_incoming_msg(msg, no_type_conversion)
638             msgname = type(r).__name__
639             if context not in r or r.context == 0 or context != r.context:
640                 # Message being queued
641                 self.message_queue.put_nowait(r)
642                 continue
643
644             if not multipart:
645                 rl = r
646                 break
647             if msgname == 'control_ping_reply':
648                 break
649
650             rl.append(r)
651
652         self.transport.resume()
653
654         logger.debug(return_logger(rl))
655         return rl
656
657     def _call_vpp_async(self, i, msg, **kwargs):
658         """Given a message, send the message and await a reply.
659
660         msgdef - the message packing definition
661         i - the message type index
662         context - context number - chosen at random if not
663         supplied.
664         The remainder of the kwargs are the arguments to the API call.
665         """
666         if 'context' not in kwargs:
667             context = self.get_context()
668             kwargs['context'] = context
669         else:
670             context = kwargs['context']
671         try:
672             if self.transport.socket_index:
673                 kwargs['client_index'] = self.transport.socket_index
674         except AttributeError:
675             kwargs['client_index'] = 0
676         kwargs['_vl_msg_id'] = i
677         b = msg.pack(kwargs)
678
679         self.transport.write(b)
680
681     def register_event_callback(self, callback):
682         """Register a callback for async messages.
683
684         This will be called for async notifications in sync mode,
685         and all messages in async mode.  In sync mode, replies to
686         requests will not come here.
687
688         callback is a fn(msg_type_name, msg_type) that will be
689         called when a message comes in.  While this function is
690         executing, note that (a) you are in a background thread and
691         may wish to use threading.Lock to protect your datastructures,
692         and (b) message processing from VPP will stop (so if you take
693         a long while about it you may provoke reply timeouts or cause
694         VPP to fill the RX buffer).  Passing None will disable the
695         callback.
696         """
697         self.event_callback = callback
698
699     def thread_msg_handler(self):
700         """Python thread calling the user registered message handler.
701
702         This is to emulate the old style event callback scheme. Modern
703         clients should provide their own thread to poll the event
704         queue.
705         """
706         while True:
707             r = self.message_queue.get()
708             if r == "terminate event thread":
709                 break
710             msgname = type(r).__name__
711             if self.event_callback:
712                 self.event_callback(msgname, r)
713
714
715 # vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4