papi: support default for type alias decaying to basetype
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15 import collections
16 import logging
17 import socket
18 import struct
19 import sys
20
21 if sys.version_info <= (3, 4):
22     from aenum import IntEnum  # noqa: F401
23 else:
24     from enum import IntEnum  # noqa: F401
25
26 if sys.version_info <= (3, 6):
27     from aenum import IntFlag  # noqa: F401
28 else:
29
30     from enum import IntFlag  # noqa: F401
31
32 from . import vpp_format  # noqa: E402
33
34 #
35 # Set log-level in application by doing e.g.:
36 # logger = logging.getLogger('vpp_serializer')
37 # logger.setLevel(logging.DEBUG)
38 #
39 logger = logging.getLogger(__name__)
40
41 if sys.version[0] == '2':
42     def check(d): type(d) is dict
43 else:
44     def check(d): type(d) is dict or type(d) is bytes
45
46
47 def conversion_required(data, field_type):
48     if check(data):
49         return False
50     try:
51         if type(data).__name__ in vpp_format.conversion_table[field_type]:
52             return True
53     except KeyError:
54         return False
55
56
57 def conversion_packer(data, field_type):
58     t = type(data).__name__
59     return types[field_type].pack(vpp_format.
60                                   conversion_table[field_type][t](data))
61
62
63 def conversion_unpacker(data, field_type):
64     if field_type not in vpp_format.conversion_unpacker_table:
65         return data
66     return vpp_format.conversion_unpacker_table[field_type](data)
67
68
69 # TODO: post 20.01, remove inherit from object.
70 class Packer(object):
71     options = {}
72
73     def pack(self, data, kwargs):
74         raise NotImplementedError
75
76     def unpack(self, data, offset, result=None, ntc=False):
77         raise NotImplementedError
78
79     # override as appropriate in subclasses
80     def _get_packer_with_options(self, f_type, options):
81         return types[f_type]
82
83     def get_packer_with_options(self, f_type, options):
84         if options is not None:
85             try:
86                 return self._get_packer_with_options(f_type, options)
87             except IndexError:
88                 raise VPPSerializerValueError(
89                     "Options not supported for {}{} ({})".
90                         format(f_type, types[f_type].__class__,
91                                options))
92
93
94 class BaseTypes(Packer):
95     def __init__(self, type, elements=0, options=None):
96         base_types = {'u8': '>B',
97                       'i8': '>b',
98                       'string': '>s',
99                       'u16': '>H',
100                       'i16': '>h',
101                       'u32': '>I',
102                       'i32': '>i',
103                       'u64': '>Q',
104                       'i64': '>q',
105                       'f64': '=d',
106                       'bool': '>?',
107                       'header': '>HI'}
108
109         if elements > 0 and (type == 'u8' or type == 'string'):
110             self.packer = struct.Struct('>%ss' % elements)
111         else:
112             self.packer = struct.Struct(base_types[type])
113         self.size = self.packer.size
114         self.options = options
115
116     def pack(self, data, kwargs=None):
117         if data is None:  # Default to zero if not specified
118             if self.options and 'default' in self.options:
119                 data = self.options['default']
120             else:
121                 data = 0
122         return self.packer.pack(data)
123
124     def unpack(self, data, offset, result=None, ntc=False):
125         return self.packer.unpack_from(data, offset)[0], self.packer.size
126
127     def _get_packer_with_options(self, f_type, options):
128         c = types[f_type].__class__
129         return c(f_type, options=options)
130
131
132 class String(Packer):
133     def __init__(self, name, num, options):
134         self.name = name
135         self.num = num
136         self.size = 1
137         self.length_field_packer = BaseTypes('u32')
138         self.limit = options['limit'] if 'limit' in options else num
139         self.fixed = True if num else False
140         if self.fixed and not self.limit:
141             raise VPPSerializerValueError(
142                 "Invalid argument length for: {}, {} maximum {}".
143                 format(list, len(list), self.limit))
144
145     def pack(self, list, kwargs=None):
146         if not list:
147             if self.fixed:
148                 return b"\x00" * self.limit
149             return self.length_field_packer.pack(0) + b""
150         if self.limit and len(list) > self.limit - 1:
151             raise VPPSerializerValueError(
152                 "Invalid argument length for: {}, {} maximum {}".
153                 format(list, len(list), self.limit - 1))
154         if self.fixed:
155             return list.encode('ascii').ljust(self.limit, b'\x00')
156         return self.length_field_packer.pack(len(list)) + list.encode('ascii')
157
158     def unpack(self, data, offset=0, result=None, ntc=False):
159         if self.fixed:
160             p = BaseTypes('u8', self.num)
161             s = p.unpack(data, offset)
162             s2 = s[0].split(b'\0', 1)[0]
163             return (s2.decode('ascii'), self.num)
164
165         length, length_field_size = self.length_field_packer.unpack(data,
166                                                                     offset)
167         if length == 0:
168             return '', 0
169         p = BaseTypes('u8', length)
170         x, size = p.unpack(data, offset + length_field_size)
171         #x2 = x.split(b'\0', 1)[0]
172         return (x.decode('ascii', errors='replace'), size + length_field_size)
173
174
175 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
176          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
177          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
178          'bool': BaseTypes('bool'), 'string': String}
179
180 class_types = {}
181
182
183 def vpp_get_type(name):
184     try:
185         return types[name]
186     except KeyError:
187         return None
188
189
190 class VPPSerializerValueError(ValueError):
191     pass
192
193
194 class FixedList_u8(Packer):
195     def __init__(self, name, field_type, num):
196         self.name = name
197         self.num = num
198         self.packer = BaseTypes(field_type, num)
199         self.size = self.packer.size
200         self.field_type = field_type
201
202     def pack(self, data, kwargs=None):
203         """Packs a fixed length bytestring. Left-pads with zeros
204         if input data is too short."""
205         if not data:
206             return b'\x00' * self.size
207
208         if len(data) > self.num:
209             raise VPPSerializerValueError(
210                 'Fixed list length error for "{}", got: {}'
211                 ' expected: {}'
212                 .format(self.name, len(data), self.num))
213
214         try:
215             return self.packer.pack(data)
216         except struct.error:
217             raise VPPSerializerValueError(
218                 'Packing failed for "{}" {}'
219                 .format(self.name, kwargs))
220
221     def unpack(self, data, offset=0, result=None, ntc=False):
222         if len(data[offset:]) < self.num:
223             raise VPPSerializerValueError(
224                 'Invalid array length for "{}" got {}'
225                 ' expected {}'
226                 .format(self.name, len(data[offset:]), self.num))
227         return self.packer.unpack(data, offset)
228
229
230 class FixedList(Packer):
231     def __init__(self, name, field_type, num):
232         self.num = num
233         self.packer = types[field_type]
234         self.size = self.packer.size * num
235         self.name = name
236         self.field_type = field_type
237
238     def pack(self, list, kwargs):
239         if len(list) != self.num:
240             raise VPPSerializerValueError(
241                 'Fixed list length error, got: {} expected: {}'
242                 .format(len(list), self.num))
243         b = bytes()
244         for e in list:
245             b += self.packer.pack(e)
246         return b
247
248     def unpack(self, data, offset=0, result=None, ntc=False):
249         # Return a list of arguments
250         result = []
251         total = 0
252         for e in range(self.num):
253             x, size = self.packer.unpack(data, offset, ntc=ntc)
254             result.append(x)
255             offset += size
256             total += size
257         return result, total
258
259
260 class VLAList(Packer):
261     def __init__(self, name, field_type, len_field_name, index):
262         self.name = name
263         self.field_type = field_type
264         self.index = index
265         self.packer = types[field_type]
266         self.size = self.packer.size
267         self.length_field = len_field_name
268
269     def pack(self, lst, kwargs=None):
270         if not lst:
271             return b""
272         if len(lst) != kwargs[self.length_field]:
273             raise VPPSerializerValueError(
274                 'Variable length error, got: {} expected: {}'
275                 .format(len(lst), kwargs[self.length_field]))
276
277         # u8 array
278         if self.packer.size == 1:
279             if isinstance(lst, list):
280                 return b''.join(lst)
281             return bytes(lst)
282
283         b = bytes()
284         for e in lst:
285             b += self.packer.pack(e)
286         return b
287
288     def unpack(self, data, offset=0, result=None, ntc=False):
289         # Return a list of arguments
290         total = 0
291
292         # u8 array
293         if self.packer.size == 1:
294             if result[self.index] == 0:
295                 return b'', 0
296             p = BaseTypes('u8', result[self.index])
297             return p.unpack(data, offset, ntc=ntc)
298
299         r = []
300         for e in range(result[self.index]):
301             x, size = self.packer.unpack(data, offset, ntc=ntc)
302             r.append(x)
303             offset += size
304             total += size
305         return r, total
306
307
308 class VLAList_legacy(Packer):
309     def __init__(self, name, field_type):
310         self.packer = types[field_type]
311         self.size = self.packer.size
312
313     def pack(self, list, kwargs=None):
314         if self.packer.size == 1:
315             return bytes(list)
316
317         b = bytes()
318         for e in list:
319             b += self.packer.pack(e)
320         return b
321
322     def unpack(self, data, offset=0, result=None, ntc=False):
323         total = 0
324         # Return a list of arguments
325         if (len(data) - offset) % self.packer.size:
326             raise VPPSerializerValueError(
327                 'Legacy Variable Length Array length mismatch.')
328         elements = int((len(data) - offset) / self.packer.size)
329         r = []
330         for e in range(elements):
331             x, size = self.packer.unpack(data, offset, ntc=ntc)
332             r.append(x)
333             offset += self.packer.size
334             total += size
335         return r, total
336
337
338 class VPPEnumType(Packer):
339     def __init__(self, name, msgdef, options=None):
340         self.size = types['u32'].size
341         self.name = name
342         self.enumtype = 'u32'
343         self.msgdef = msgdef
344         e_hash = {}
345         for f in msgdef:
346             if type(f) is dict and 'enumtype' in f:
347                 if f['enumtype'] != 'u32':
348                     self.size = types[f['enumtype']].size
349                     self.enumtype = f['enumtype']
350                 continue
351             ename, evalue = f
352             e_hash[ename] = evalue
353         self.enum = IntFlag(name, e_hash)
354         types[name] = self
355         class_types[name] = VPPEnumType
356         self.options = options
357
358     def __getattr__(self, name):
359         return self.enum[name]
360
361     def __bool__(self):
362         return True
363
364     # TODO: Remove post 20.01.
365     if sys.version[0] == '2':
366         __nonzero__ = __bool__
367
368     def pack(self, data, kwargs=None):
369         if data is None:  # Default to zero if not specified
370             if self.options and 'default' in self.options:
371                 data = self.options['default']
372             else:
373                 data = 0
374
375         return types[self.enumtype].pack(data)
376
377     def unpack(self, data, offset=0, result=None, ntc=False):
378         x, size = types[self.enumtype].unpack(data, offset)
379         return self.enum(x), size
380
381     def _get_packer_with_options(self, f_type, options):
382         c = types[f_type].__class__
383         return c(f_type, types[f_type].msgdef, options=options)
384
385
386 class VPPUnionType(Packer):
387     def __init__(self, name, msgdef):
388         self.name = name
389         self.size = 0
390         self.maxindex = 0
391         fields = []
392         self.packers = collections.OrderedDict()
393         for i, f in enumerate(msgdef):
394             if type(f) is dict and 'crc' in f:
395                 self.crc = f['crc']
396                 continue
397             f_type, f_name = f
398             if f_type not in types:
399                 logger.debug('Unknown union type {}'.format(f_type))
400                 raise VPPSerializerValueError(
401                     'Unknown message type {}'.format(f_type))
402             fields.append(f_name)
403             size = types[f_type].size
404             self.packers[f_name] = types[f_type]
405             if size > self.size:
406                 self.size = size
407                 self.maxindex = i
408
409         types[name] = self
410         self.tuple = collections.namedtuple(name, fields, rename=True)
411
412     # Union of variable length?
413     def pack(self, data, kwargs=None):
414         if not data:
415             return b'\x00' * self.size
416
417         for k, v in data.items():
418             logger.debug("Key: {} Value: {}".format(k, v))
419             b = self.packers[k].pack(v, kwargs)
420             break
421         r = bytearray(self.size)
422         r[:len(b)] = b
423         return r
424
425     def unpack(self, data, offset=0, result=None, ntc=False):
426         r = []
427         maxsize = 0
428         for k, p in self.packers.items():
429             x, size = p.unpack(data, offset, ntc=ntc)
430             if size > maxsize:
431                 maxsize = size
432             r.append(x)
433         return self.tuple._make(r), maxsize
434
435
436 class VPPTypeAlias(Packer):
437     def __init__(self, name, msgdef, options=None):
438         self.name = name
439         self.msgdef = msgdef
440         t = vpp_get_type(msgdef['type'])
441         if not t:
442             raise ValueError('No such type: {}'.format(msgdef['type']))
443         if 'length' in msgdef:
444             if msgdef['length'] == 0:
445                 raise ValueError()
446             if msgdef['type'] == 'u8':
447                 self.packer = FixedList_u8(name, msgdef['type'],
448                                            msgdef['length'])
449                 self.size = self.packer.size
450             else:
451                 self.packer = FixedList(name, msgdef['type'], msgdef['length'])
452         else:
453             self.packer = t
454             self.size = t.size
455
456         types[name] = self
457         self.toplevelconversion = False
458         self.options = options
459
460     def pack(self, data, kwargs=None):
461         if data and conversion_required(data, self.name):
462             try:
463                 return conversion_packer(data, self.name)
464             # Python 2 and 3 raises different exceptions from inet_pton
465             except(OSError, socket.error, TypeError):
466                 pass
467         if data is None:  # Default to zero if not specified
468             if self.options and 'default' in self.options:
469                 data = self.options['default']
470             else:
471                 data = 0
472
473         return self.packer.pack(data, kwargs)
474
475     def _get_packer_with_options(self, f_type, options):
476         c = types[f_type].__class__
477         return c(f_type, types[f_type].msgdef, options=options)
478
479     def unpack(self, data, offset=0, result=None, ntc=False):
480         if ntc is False and self.name in vpp_format.conversion_unpacker_table:
481             # Disable type conversion for dependent types
482             ntc = True
483             self.toplevelconversion = True
484         t, size = self.packer.unpack(data, offset, result, ntc=ntc)
485         if self.toplevelconversion:
486             self.toplevelconversion = False
487             return conversion_unpacker(t, self.name), size
488         return t, size
489
490
491 class VPPType(Packer):
492     # Set everything up to be able to pack / unpack
493     def __init__(self, name, msgdef):
494         self.name = name
495         self.msgdef = msgdef
496         self.packers = []
497         self.fields = []
498         self.fieldtypes = []
499         self.field_by_name = {}
500         size = 0
501         for i, f in enumerate(msgdef):
502             if type(f) is dict and 'crc' in f:
503                 self.crc = f['crc']
504                 continue
505             f_type, f_name = f[:2]
506             self.fields.append(f_name)
507             self.field_by_name[f_name] = None
508             self.fieldtypes.append(f_type)
509             if f_type not in types:
510                 logger.debug('Unknown type {}'.format(f_type))
511                 raise VPPSerializerValueError(
512                     'Unknown message type {}'.format(f_type))
513
514             fieldlen = len(f)
515             options = [x for x in f if type(x) is dict]
516             if len(options):
517                 self.options = options[0]
518                 fieldlen -= 1
519             else:
520                 self.options = {}
521             if fieldlen == 3:  # list
522                 list_elements = f[2]
523                 if list_elements == 0:
524                     if f_type == 'string':
525                         p = String(f_name, 0, self.options)
526                     else:
527                         p = VLAList_legacy(f_name, f_type)
528                     self.packers.append(p)
529                 elif f_type == 'u8':
530                     p = FixedList_u8(f_name, f_type, list_elements)
531                     self.packers.append(p)
532                     size += p.size
533                 elif f_type == 'string':
534                     p = String(f_name, list_elements, self.options)
535                     self.packers.append(p)
536                     size += p.size
537                 else:
538                     p = FixedList(f_name, f_type, list_elements)
539                     self.packers.append(p)
540                     size += p.size
541             elif fieldlen == 4:  # Variable length list
542                 length_index = self.fields.index(f[3])
543                 p = VLAList(f_name, f_type, f[3], length_index)
544                 self.packers.append(p)
545             else:
546                 # default support for types that decay to basetype
547                 if 'default' in self.options:
548                     p = self.get_packer_with_options(f_type, self.options)
549                 else:
550                     p = types[f_type]
551
552                 self.packers.append(p)
553                 size += p.size
554
555         self.size = size
556         self.tuple = collections.namedtuple(name, self.fields, rename=True)
557         types[name] = self
558         self.toplevelconversion = False
559
560     def pack(self, data, kwargs=None):
561         if not kwargs:
562             kwargs = data
563         b = bytes()
564
565         # Try one of the format functions
566         if data and conversion_required(data, self.name):
567             return conversion_packer(data, self.name)
568
569         for i, a in enumerate(self.fields):
570             if data and type(data) is not dict and a not in data:
571                 raise VPPSerializerValueError(
572                     "Invalid argument: {} expected {}.{}".
573                     format(data, self.name, a))
574
575             # Defaulting to zero.
576             if not data or a not in data:  # Default to 0
577                 arg = None
578                 kwarg = None  # No default for VLA
579             else:
580                 arg = data[a]
581                 kwarg = kwargs[a] if a in kwargs else None
582             if isinstance(self.packers[i], VPPType):
583                 b += self.packers[i].pack(arg, kwarg)
584             else:
585                 b += self.packers[i].pack(arg, kwargs)
586
587         return b
588
589     def unpack(self, data, offset=0, result=None, ntc=False):
590         # Return a list of arguments
591         result = []
592         total = 0
593         if ntc is False and self.name in vpp_format.conversion_unpacker_table:
594             # Disable type conversion for dependent types
595             ntc = True
596             self.toplevelconversion = True
597
598         for p in self.packers:
599             x, size = p.unpack(data, offset, result, ntc)
600             if type(x) is tuple and len(x) == 1:
601                 x = x[0]
602             result.append(x)
603             offset += size
604             total += size
605         t = self.tuple._make(result)
606
607         if self.toplevelconversion:
608             self.toplevelconversion = False
609             t = conversion_unpacker(t, self.name)
610         return t, total
611
612
613 class VPPMessage(VPPType):
614     pass