vppapigen: allow for enum size other than u32
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 import sys
19
20 if sys.version_info <= (3, 4):
21     from aenum import IntEnum
22 else:
23     from enum import IntEnum
24
25 if sys.version_info <= (3, 6):
26     from aenum import IntFlag
27 else:
28     from enum import IntFlag
29
30 import logging
31 from . import vpp_format
32 import ipaddress
33
34 import socket
35
36 #
37 # Set log-level in application by doing e.g.:
38 # logger = logging.getLogger('vpp_serializer')
39 # logger.setLevel(logging.DEBUG)
40 #
41 logger = logging.getLogger(__name__)
42
43 if sys.version[0] == '2':
44     def check(d): type(d) is dict
45 else:
46     def check(d): type(d) is dict or type(d) is bytes
47
48
49 def conversion_required(data, field_type):
50     if check(data):
51         return False
52     try:
53         if type(data).__name__ in vpp_format.conversion_table[field_type]:
54             return True
55     except KeyError:
56         return False
57
58
59 def conversion_packer(data, field_type):
60     t = type(data).__name__
61     return types[field_type].pack(vpp_format.
62                                   conversion_table[field_type][t](data))
63
64
65 def conversion_unpacker(data, field_type):
66     if field_type not in vpp_format.conversion_unpacker_table:
67         return data
68     return vpp_format.conversion_unpacker_table[field_type](data)
69
70
71 class BaseTypes(object):
72     def __init__(self, type, elements=0):
73         base_types = {'u8': '>B',
74                       'string': '>s',
75                       'u16': '>H',
76                       'u32': '>I',
77                       'i32': '>i',
78                       'u64': '>Q',
79                       'f64': '>d',
80                       'bool': '>?',
81                       'header': '>HI'}
82
83         if elements > 0 and (type == 'u8' or type == 'string'):
84             self.packer = struct.Struct('>%ss' % elements)
85         else:
86             self.packer = struct.Struct(base_types[type])
87         self.size = self.packer.size
88
89     def pack(self, data, kwargs=None):
90         if not data:  # Default to zero if not specified
91             data = 0
92         return self.packer.pack(data)
93
94     def unpack(self, data, offset, result=None, ntc=False):
95         return self.packer.unpack_from(data, offset)[0], self.packer.size
96
97
98 class String(object):
99     def __init__(self):
100         self.name = 'string'
101         self.size = 1
102         self.length_field_packer = BaseTypes('u32')
103
104     def pack(self, list, kwargs=None):
105         if not list:
106             return self.length_field_packer.pack(0) + b""
107         return self.length_field_packer.pack(len(list)) + list.encode('utf8')
108
109     def unpack(self, data, offset=0, result=None, ntc=False):
110         length, length_field_size = self.length_field_packer.unpack(data,
111                                                                     offset)
112         if length == 0:
113             return b'', 0
114         p = BaseTypes('u8', length)
115         x, size = p.unpack(data, offset + length_field_size)
116         x2 = x.split(b'\0', 1)[0]
117         return (x2.decode('utf8'), size + length_field_size)
118
119
120 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
121          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
122          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
123          'bool': BaseTypes('bool'), 'string': String()}
124
125
126 def vpp_get_type(name):
127     try:
128         return types[name]
129     except KeyError:
130         return None
131
132
133 class VPPSerializerValueError(ValueError):
134     pass
135
136
137 class FixedList_u8(object):
138     def __init__(self, name, field_type, num):
139         self.name = name
140         self.num = num
141         self.packer = BaseTypes(field_type, num)
142         self.size = self.packer.size
143         self.field_type = field_type
144
145     def pack(self, data, kwargs=None):
146         """Packs a fixed length bytestring. Left-pads with zeros
147         if input data is too short."""
148         if not data:
149             return b'\x00' * self.size
150
151         if len(data) > self.num:
152             raise VPPSerializerValueError(
153                 'Fixed list length error for "{}", got: {}'
154                 ' expected: {}'
155                 .format(self.name, len(data), self.num))
156
157         return self.packer.pack(data)
158
159     def unpack(self, data, offset=0, result=None, ntc=False):
160         if len(data[offset:]) < self.num:
161             raise VPPSerializerValueError(
162                 'Invalid array length for "{}" got {}'
163                 ' expected {}'
164                 .format(self.name, len(data[offset:]), self.num))
165         if self.field_type == 'string':
166             s = self.packer.unpack(data, offset)
167             s2 = s[0].split(b'\0', 1)[0]
168             return (s2.decode('utf-8'), self.num)
169         return self.packer.unpack(data, offset)
170
171
172 class FixedList(object):
173     def __init__(self, name, field_type, num):
174         self.num = num
175         self.packer = types[field_type]
176         self.size = self.packer.size * num
177         self.name = name
178         self.field_type = field_type
179
180     def pack(self, list, kwargs):
181         if len(list) != self.num:
182             raise VPPSerializerValueError(
183                 'Fixed list length error, got: {} expected: {}'
184                 .format(len(list), self.num))
185         b = bytes()
186         for e in list:
187             b += self.packer.pack(e)
188         return b
189
190     def unpack(self, data, offset=0, result=None, ntc=False):
191         # Return a list of arguments
192         result = []
193         total = 0
194         for e in range(self.num):
195             x, size = self.packer.unpack(data, offset, ntc=ntc)
196             result.append(x)
197             offset += size
198             total += size
199         return result, total
200
201
202 class VLAList(object):
203     def __init__(self, name, field_type, len_field_name, index):
204         self.name = name
205         self.field_type = field_type
206         self.index = index
207         self.packer = types[field_type]
208         self.size = self.packer.size
209         self.length_field = len_field_name
210
211     def pack(self, list, kwargs=None):
212         if not list:
213             return b""
214         if len(list) != kwargs[self.length_field]:
215             raise VPPSerializerValueError(
216                 'Variable length error, got: {} expected: {}'
217                 .format(len(list), kwargs[self.length_field]))
218         b = bytes()
219
220         # u8 array
221
222         if self.packer.size == 1:
223             return bytearray(list)
224
225         for e in list:
226             b += self.packer.pack(e)
227         return b
228
229     def unpack(self, data, offset=0, result=None, ntc=False):
230         # Return a list of arguments
231         total = 0
232
233         # u8 array
234         if self.packer.size == 1:
235             if result[self.index] == 0:
236                 return b'', 0
237             p = BaseTypes('u8', result[self.index])
238             return p.unpack(data, offset, ntc=ntc)
239
240         r = []
241         for e in range(result[self.index]):
242             x, size = self.packer.unpack(data, offset, ntc=ntc)
243             r.append(x)
244             offset += size
245             total += size
246         return r, total
247
248
249 class VLAList_legacy():
250     def __init__(self, name, field_type):
251         self.packer = types[field_type]
252         self.size = self.packer.size
253
254     def pack(self, list, kwargs=None):
255         if self.packer.size == 1:
256             return bytes(list)
257
258         b = bytes()
259         for e in list:
260             b += self.packer.pack(e)
261         return b
262
263     def unpack(self, data, offset=0, result=None, ntc=False):
264         total = 0
265         # Return a list of arguments
266         if (len(data) - offset) % self.packer.size:
267             raise VPPSerializerValueError(
268                 'Legacy Variable Length Array length mismatch.')
269         elements = int((len(data) - offset) / self.packer.size)
270         r = []
271         for e in range(elements):
272             x, size = self.packer.unpack(data, offset, ntc=ntc)
273             r.append(x)
274             offset += self.packer.size
275             total += size
276         return r, total
277
278
279 class VPPEnumType(object):
280     def __init__(self, name, msgdef):
281         self.size = types['u32'].size
282         self.enumtype = 'u32'
283         e_hash = {}
284         for f in msgdef:
285             if type(f) is dict and 'enumtype' in f:
286                 if f['enumtype'] != 'u32':
287                     self.size = types[f['enumtype']].size
288                     self.enumtype = f['enumtype']
289                 continue
290             ename, evalue = f
291             e_hash[ename] = evalue
292         self.enum = IntFlag(name, e_hash)
293         types[name] = self
294
295     def __getattr__(self, name):
296         return self.enum[name]
297
298     def __nonzero__(self):
299         return True
300
301     def pack(self, data, kwargs=None):
302         return types[self.enumtype].pack(data)
303
304     def unpack(self, data, offset=0, result=None, ntc=False):
305         x, size = types[self.enumtype].unpack(data, offset)
306         return self.enum(x), size
307
308
309 class VPPUnionType(object):
310     def __init__(self, name, msgdef):
311         self.name = name
312         self.size = 0
313         self.maxindex = 0
314         fields = []
315         self.packers = collections.OrderedDict()
316         for i, f in enumerate(msgdef):
317             if type(f) is dict and 'crc' in f:
318                 self.crc = f['crc']
319                 continue
320             f_type, f_name = f
321             if f_type not in types:
322                 logger.debug('Unknown union type {}'.format(f_type))
323                 raise VPPSerializerValueError(
324                     'Unknown message type {}'.format(f_type))
325             fields.append(f_name)
326             size = types[f_type].size
327             self.packers[f_name] = types[f_type]
328             if size > self.size:
329                 self.size = size
330                 self.maxindex = i
331
332         types[name] = self
333         self.tuple = collections.namedtuple(name, fields, rename=True)
334
335     # Union of variable length?
336     def pack(self, data, kwargs=None):
337         if not data:
338             return b'\x00' * self.size
339
340         for k, v in data.items():
341             logger.debug("Key: {} Value: {}".format(k, v))
342             b = self.packers[k].pack(v, kwargs)
343             break
344         r = bytearray(self.size)
345         r[:len(b)] = b
346         return r
347
348     def unpack(self, data, offset=0, result=None, ntc=False):
349         r = []
350         maxsize = 0
351         for k, p in self.packers.items():
352             x, size = p.unpack(data, offset, ntc=ntc)
353             if size > maxsize:
354                 maxsize = size
355             r.append(x)
356         return self.tuple._make(r), maxsize
357
358
359 class VPPTypeAlias(object):
360     def __init__(self, name, msgdef):
361         self.name = name
362         t = vpp_get_type(msgdef['type'])
363         if not t:
364             raise ValueError()
365         if 'length' in msgdef:
366             if msgdef['length'] == 0:
367                 raise ValueError()
368             if msgdef['type'] == 'u8':
369                 self.packer = FixedList_u8(name, msgdef['type'],
370                                            msgdef['length'])
371                 self.size = self.packer.size
372             else:
373                 self.packer = FixedList(name, msgdef['type'], msgdef['length'])
374         else:
375             self.packer = t
376             self.size = t.size
377
378         types[name] = self
379
380     def pack(self, data, kwargs=None):
381         if data and conversion_required(data, self.name):
382             try:
383                 return conversion_packer(data, self.name)
384             # Python 2 and 3 raises different exceptions from inet_pton
385             except(OSError, socket.error, TypeError):
386                 pass
387
388         return self.packer.pack(data, kwargs)
389
390     def unpack(self, data, offset=0, result=None, ntc=False):
391         t, size = self.packer.unpack(data, offset, result, ntc=ntc)
392         if not ntc:
393             return conversion_unpacker(t, self.name), size
394         return t, size
395
396
397 class VPPType(object):
398     # Set everything up to be able to pack / unpack
399     def __init__(self, name, msgdef):
400         self.name = name
401         self.msgdef = msgdef
402         self.packers = []
403         self.fields = []
404         self.fieldtypes = []
405         self.field_by_name = {}
406         size = 0
407         for i, f in enumerate(msgdef):
408             if type(f) is dict and 'crc' in f:
409                 self.crc = f['crc']
410                 continue
411             f_type, f_name = f[:2]
412             self.fields.append(f_name)
413             self.field_by_name[f_name] = None
414             self.fieldtypes.append(f_type)
415             if f_type not in types:
416                 logger.debug('Unknown type {}'.format(f_type))
417                 raise VPPSerializerValueError(
418                     'Unknown message type {}'.format(f_type))
419             if len(f) == 3:  # list
420                 list_elements = f[2]
421                 if list_elements == 0:
422                     p = VLAList_legacy(f_name, f_type)
423                     self.packers.append(p)
424                 elif f_type == 'u8' or f_type == 'string':
425                     p = FixedList_u8(f_name, f_type, list_elements)
426                     self.packers.append(p)
427                     size += p.size
428                 else:
429                     p = FixedList(f_name, f_type, list_elements)
430                     self.packers.append(p)
431                     size += p.size
432             elif len(f) == 4:  # Variable length list
433                 length_index = self.fields.index(f[3])
434                 p = VLAList(f_name, f_type, f[3], length_index)
435                 self.packers.append(p)
436             else:
437                 self.packers.append(types[f_type])
438                 size += types[f_type].size
439
440         self.size = size
441         self.tuple = collections.namedtuple(name, self.fields, rename=True)
442         types[name] = self
443
444     def pack(self, data, kwargs=None):
445         if not kwargs:
446             kwargs = data
447         b = bytes()
448
449         # Try one of the format functions
450         if data and conversion_required(data, self.name):
451             return conversion_packer(data, self.name)
452
453         for i, a in enumerate(self.fields):
454             if data and type(data) is not dict and a not in data:
455                 raise VPPSerializerValueError(
456                     "Invalid argument: {} expected {}.{}".
457                     format(data, self.name, a))
458
459             # Defaulting to zero.
460             if not data or a not in data:  # Default to 0
461                 arg = None
462                 kwarg = None  # No default for VLA
463             else:
464                 arg = data[a]
465                 kwarg = kwargs[a] if a in kwargs else None
466             if isinstance(self.packers[i], VPPType):
467                 b += self.packers[i].pack(arg, kwarg)
468             else:
469                 b += self.packers[i].pack(arg, kwargs)
470
471         return b
472
473     def unpack(self, data, offset=0, result=None, ntc=False):
474         # Return a list of arguments
475         result = []
476         total = 0
477         for p in self.packers:
478             x, size = p.unpack(data, offset, result, ntc)
479             if type(x) is tuple and len(x) == 1:
480                 x = x[0]
481             result.append(x)
482             offset += size
483             total += size
484         t = self.tuple._make(result)
485         if not ntc:
486             t = conversion_unpacker(t, self.name)
487         return t, total
488
489
490 class VPPMessage(VPPType):
491     pass