papi: update VPPEnumType for python3
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15 import collections
16 import logging
17 import socket
18 import struct
19 import sys
20
21 if sys.version_info <= (3, 4):
22     from aenum import IntEnum  # noqa: F401
23 else:
24     from enum import IntEnum  # noqa: F401
25
26 if sys.version_info <= (3, 6):
27     from aenum import IntFlag  # noqa: F401
28 else:
29
30     from enum import IntFlag  # noqa: F401
31
32 from . import vpp_format  # noqa: E402
33
34 #
35 # Set log-level in application by doing e.g.:
36 # logger = logging.getLogger('vpp_serializer')
37 # logger.setLevel(logging.DEBUG)
38 #
39 logger = logging.getLogger(__name__)
40
41 if sys.version[0] == '2':
42     def check(d): type(d) is dict
43 else:
44     def check(d): type(d) is dict or type(d) is bytes
45
46
47 def conversion_required(data, field_type):
48     if check(data):
49         return False
50     try:
51         if type(data).__name__ in vpp_format.conversion_table[field_type]:
52             return True
53     except KeyError:
54         return False
55
56
57 def conversion_packer(data, field_type):
58     t = type(data).__name__
59     return types[field_type].pack(vpp_format.
60                                   conversion_table[field_type][t](data))
61
62
63 def conversion_unpacker(data, field_type):
64     if field_type not in vpp_format.conversion_unpacker_table:
65         return data
66     return vpp_format.conversion_unpacker_table[field_type](data)
67
68
69 class BaseTypes(object):
70     def __init__(self, type, elements=0, options=None):
71         base_types = {'u8': '>B',
72                       'string': '>s',
73                       'u16': '>H',
74                       'u32': '>I',
75                       'i32': '>i',
76                       'u64': '>Q',
77                       'f64': '=d',
78                       'bool': '>?',
79                       'header': '>HI'}
80
81         if elements > 0 and (type == 'u8' or type == 'string'):
82             self.packer = struct.Struct('>%ss' % elements)
83         else:
84             self.packer = struct.Struct(base_types[type])
85         self.size = self.packer.size
86         self.options = options
87
88     def __call__(self, args):
89         self.options = args
90         return self
91
92     def pack(self, data, kwargs=None):
93         if not data:  # Default to zero if not specified
94             if self.options and 'default' in self.options:
95                 data = self.options['default']
96             else:
97                 data = 0
98         return self.packer.pack(data)
99
100     def unpack(self, data, offset, result=None, ntc=False):
101         return self.packer.unpack_from(data, offset)[0], self.packer.size
102
103
104 class String(object):
105     def __init__(self, options):
106         self.name = 'string'
107         self.size = 1
108         self.length_field_packer = BaseTypes('u32')
109         self.limit = options['limit'] if 'limit' in options else None
110
111     def pack(self, list, kwargs=None):
112         if not list:
113             return self.length_field_packer.pack(0) + b""
114         if self.limit and len(list) > self.limit:
115             raise VPPSerializerValueError(
116                 "Invalid argument length for: {}, {} maximum {}".
117                 format(list, len(list), self.limit))
118
119         return self.length_field_packer.pack(len(list)) + list.encode('utf8')
120
121     def unpack(self, data, offset=0, result=None, ntc=False):
122         length, length_field_size = self.length_field_packer.unpack(data,
123                                                                     offset)
124         if length == 0:
125             return b'', 0
126         p = BaseTypes('u8', length)
127         x, size = p.unpack(data, offset + length_field_size)
128         x2 = x.split(b'\0', 1)[0]
129         return (x2.decode('utf8'), size + length_field_size)
130
131
132 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
133          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
134          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
135          'bool': BaseTypes('bool'), 'string': String}
136
137
138 def vpp_get_type(name):
139     try:
140         return types[name]
141     except KeyError:
142         return None
143
144
145 class VPPSerializerValueError(ValueError):
146     pass
147
148
149 class FixedList_u8(object):
150     def __init__(self, name, field_type, num):
151         self.name = name
152         self.num = num
153         self.packer = BaseTypes(field_type, num)
154         self.size = self.packer.size
155         self.field_type = field_type
156
157     def __call__(self, args):
158         self.options = args
159         return self
160
161     def pack(self, data, kwargs=None):
162         """Packs a fixed length bytestring. Left-pads with zeros
163         if input data is too short."""
164         if not data:
165             return b'\x00' * self.size
166
167         if len(data) > self.num:
168             raise VPPSerializerValueError(
169                 'Fixed list length error for "{}", got: {}'
170                 ' expected: {}'
171                 .format(self.name, len(data), self.num))
172
173         return self.packer.pack(data)
174
175     def unpack(self, data, offset=0, result=None, ntc=False):
176         if len(data[offset:]) < self.num:
177             raise VPPSerializerValueError(
178                 'Invalid array length for "{}" got {}'
179                 ' expected {}'
180                 .format(self.name, len(data[offset:]), self.num))
181         if self.field_type == 'string':
182             s = self.packer.unpack(data, offset)
183             s2 = s[0].split(b'\0', 1)[0]
184             return (s2.decode('utf-8'), self.num)
185         return self.packer.unpack(data, offset)
186
187
188 class FixedList(object):
189     def __init__(self, name, field_type, num):
190         self.num = num
191         self.packer = types[field_type]
192         self.size = self.packer.size * num
193         self.name = name
194         self.field_type = field_type
195
196     def __call__(self, args):
197         self.options = args
198         return self
199
200     def pack(self, list, kwargs):
201         if len(list) != self.num:
202             raise VPPSerializerValueError(
203                 'Fixed list length error, got: {} expected: {}'
204                 .format(len(list), self.num))
205         b = bytes()
206         for e in list:
207             b += self.packer.pack(e)
208         return b
209
210     def unpack(self, data, offset=0, result=None, ntc=False):
211         # Return a list of arguments
212         result = []
213         total = 0
214         for e in range(self.num):
215             x, size = self.packer.unpack(data, offset, ntc=ntc)
216             result.append(x)
217             offset += size
218             total += size
219         return result, total
220
221
222 class VLAList(object):
223     def __init__(self, name, field_type, len_field_name, index):
224         self.name = name
225         self.field_type = field_type
226         self.index = index
227         self.packer = types[field_type]
228         self.size = self.packer.size
229         self.length_field = len_field_name
230
231     def __call__(self, args):
232         self.options = args
233         return self
234
235     def pack(self, list, kwargs=None):
236         if not list:
237             return b""
238         if len(list) != kwargs[self.length_field]:
239             raise VPPSerializerValueError(
240                 'Variable length error, got: {} expected: {}'
241                 .format(len(list), kwargs[self.length_field]))
242         b = bytes()
243
244         # u8 array
245
246         if self.packer.size == 1:
247             return bytearray(list)
248
249         for e in list:
250             b += self.packer.pack(e)
251         return b
252
253     def unpack(self, data, offset=0, result=None, ntc=False):
254         # Return a list of arguments
255         total = 0
256
257         # u8 array
258         if self.packer.size == 1:
259             if result[self.index] == 0:
260                 return b'', 0
261             p = BaseTypes('u8', result[self.index])
262             return p.unpack(data, offset, ntc=ntc)
263
264         r = []
265         for e in range(result[self.index]):
266             x, size = self.packer.unpack(data, offset, ntc=ntc)
267             r.append(x)
268             offset += size
269             total += size
270         return r, total
271
272
273 class VLAList_legacy():
274     def __init__(self, name, field_type):
275         self.packer = types[field_type]
276         self.size = self.packer.size
277
278     def __call__(self, args):
279         self.options = args
280         return self
281
282     def pack(self, list, kwargs=None):
283         if self.packer.size == 1:
284             return bytes(list)
285
286         b = bytes()
287         for e in list:
288             b += self.packer.pack(e)
289         return b
290
291     def unpack(self, data, offset=0, result=None, ntc=False):
292         total = 0
293         # Return a list of arguments
294         if (len(data) - offset) % self.packer.size:
295             raise VPPSerializerValueError(
296                 'Legacy Variable Length Array length mismatch.')
297         elements = int((len(data) - offset) / self.packer.size)
298         r = []
299         for e in range(elements):
300             x, size = self.packer.unpack(data, offset, ntc=ntc)
301             r.append(x)
302             offset += self.packer.size
303             total += size
304         return r, total
305
306
307 class VPPEnumType(object):
308     def __init__(self, name, msgdef):
309         self.size = types['u32'].size
310         self.enumtype = 'u32'
311         e_hash = {}
312         for f in msgdef:
313             if type(f) is dict and 'enumtype' in f:
314                 if f['enumtype'] != 'u32':
315                     self.size = types[f['enumtype']].size
316                     self.enumtype = f['enumtype']
317                 continue
318             ename, evalue = f
319             e_hash[ename] = evalue
320         self.enum = IntFlag(name, e_hash)
321         types[name] = self
322
323     def __call__(self, args):
324         self.options = args
325         return self
326
327     def __getattr__(self, name):
328         return self.enum[name]
329
330     def __bool__(self):
331         return True
332
333     if sys.version[0] == '2':
334         __nonzero__ = __bool__
335
336     def pack(self, data, kwargs=None):
337         return types[self.enumtype].pack(data)
338
339     def unpack(self, data, offset=0, result=None, ntc=False):
340         x, size = types[self.enumtype].unpack(data, offset)
341         return self.enum(x), size
342
343
344 class VPPUnionType(object):
345     def __init__(self, name, msgdef):
346         self.name = name
347         self.size = 0
348         self.maxindex = 0
349         fields = []
350         self.packers = collections.OrderedDict()
351         for i, f in enumerate(msgdef):
352             if type(f) is dict and 'crc' in f:
353                 self.crc = f['crc']
354                 continue
355             f_type, f_name = f
356             if f_type not in types:
357                 logger.debug('Unknown union type {}'.format(f_type))
358                 raise VPPSerializerValueError(
359                     'Unknown message type {}'.format(f_type))
360             fields.append(f_name)
361             size = types[f_type].size
362             self.packers[f_name] = types[f_type]
363             if size > self.size:
364                 self.size = size
365                 self.maxindex = i
366
367         types[name] = self
368         self.tuple = collections.namedtuple(name, fields, rename=True)
369
370     def __call__(self, args):
371         self.options = args
372         return self
373
374     # Union of variable length?
375     def pack(self, data, kwargs=None):
376         if not data:
377             return b'\x00' * self.size
378
379         for k, v in data.items():
380             logger.debug("Key: {} Value: {}".format(k, v))
381             b = self.packers[k].pack(v, kwargs)
382             break
383         r = bytearray(self.size)
384         r[:len(b)] = b
385         return r
386
387     def unpack(self, data, offset=0, result=None, ntc=False):
388         r = []
389         maxsize = 0
390         for k, p in self.packers.items():
391             x, size = p.unpack(data, offset, ntc=ntc)
392             if size > maxsize:
393                 maxsize = size
394             r.append(x)
395         return self.tuple._make(r), maxsize
396
397
398 class VPPTypeAlias(object):
399     def __init__(self, name, msgdef):
400         self.name = name
401         t = vpp_get_type(msgdef['type'])
402         if not t:
403             raise ValueError()
404         if 'length' in msgdef:
405             if msgdef['length'] == 0:
406                 raise ValueError()
407             if msgdef['type'] == 'u8':
408                 self.packer = FixedList_u8(name, msgdef['type'],
409                                            msgdef['length'])
410                 self.size = self.packer.size
411             else:
412                 self.packer = FixedList(name, msgdef['type'], msgdef['length'])
413         else:
414             self.packer = t
415             self.size = t.size
416
417         types[name] = self
418
419     def __call__(self, args):
420         self.options = args
421         return self
422
423     def pack(self, data, kwargs=None):
424         if data and conversion_required(data, self.name):
425             try:
426                 return conversion_packer(data, self.name)
427             # Python 2 and 3 raises different exceptions from inet_pton
428             except(OSError, socket.error, TypeError):
429                 pass
430
431         return self.packer.pack(data, kwargs)
432
433     def unpack(self, data, offset=0, result=None, ntc=False):
434         t, size = self.packer.unpack(data, offset, result, ntc=ntc)
435         if not ntc:
436             return conversion_unpacker(t, self.name), size
437         return t, size
438
439
440 class VPPType(object):
441     # Set everything up to be able to pack / unpack
442     def __init__(self, name, msgdef):
443         self.name = name
444         self.msgdef = msgdef
445         self.packers = []
446         self.fields = []
447         self.fieldtypes = []
448         self.field_by_name = {}
449         size = 0
450         for i, f in enumerate(msgdef):
451             if type(f) is dict and 'crc' in f:
452                 self.crc = f['crc']
453                 continue
454             f_type, f_name = f[:2]
455             self.fields.append(f_name)
456             self.field_by_name[f_name] = None
457             self.fieldtypes.append(f_type)
458             if f_type not in types:
459                 logger.debug('Unknown type {}'.format(f_type))
460                 raise VPPSerializerValueError(
461                     'Unknown message type {}'.format(f_type))
462
463             fieldlen = len(f)
464             options = [x for x in f if type(x) is dict]
465             if len(options):
466                 self.options = options[0]
467                 fieldlen -= 1
468             else:
469                 self.options = {}
470             if fieldlen == 3:  # list
471                 list_elements = f[2]
472                 if list_elements == 0:
473                     p = VLAList_legacy(f_name, f_type)
474                     self.packers.append(p)
475                 elif f_type == 'u8' or f_type == 'string':
476                     p = FixedList_u8(f_name, f_type, list_elements)
477                     self.packers.append(p)
478                     size += p.size
479                 else:
480                     p = FixedList(f_name, f_type, list_elements)
481                     self.packers.append(p)
482                     size += p.size
483             elif fieldlen == 4:  # Variable length list
484                 length_index = self.fields.index(f[3])
485                 p = VLAList(f_name, f_type, f[3], length_index)
486                 self.packers.append(p)
487             else:
488                 p = types[f_type](self.options)
489                 self.packers.append(p)
490                 size += p.size
491
492         self.size = size
493         self.tuple = collections.namedtuple(name, self.fields, rename=True)
494         types[name] = self
495
496     def __call__(self, args):
497         self.options = args
498         return self
499
500     def pack(self, data, kwargs=None):
501         if not kwargs:
502             kwargs = data
503         b = bytes()
504
505         # Try one of the format functions
506         if data and conversion_required(data, self.name):
507             return conversion_packer(data, self.name)
508
509         for i, a in enumerate(self.fields):
510             if data and type(data) is not dict and a not in data:
511                 raise VPPSerializerValueError(
512                     "Invalid argument: {} expected {}.{}".
513                     format(data, self.name, a))
514
515             # Defaulting to zero.
516             if not data or a not in data:  # Default to 0
517                 arg = None
518                 kwarg = None  # No default for VLA
519             else:
520                 arg = data[a]
521                 kwarg = kwargs[a] if a in kwargs else None
522             if isinstance(self.packers[i], VPPType):
523                 b += self.packers[i].pack(arg, kwarg)
524             else:
525                 b += self.packers[i].pack(arg, kwargs)
526
527         return b
528
529     def unpack(self, data, offset=0, result=None, ntc=False):
530         # Return a list of arguments
531         result = []
532         total = 0
533         for p in self.packers:
534             x, size = p.unpack(data, offset, result, ntc)
535             if type(x) is tuple and len(x) == 1:
536                 x = x[0]
537             result.append(x)
538             offset += size
539             total += size
540         t = self.tuple._make(result)
541         if not ntc:
542             t = conversion_unpacker(t, self.name)
543         return t, total
544
545
546 class VPPMessage(VPPType):
547     pass