Update japi to support type aliases
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 from enum import IntEnum
19 import logging
20 from .vpp_format import VPPFormat
21
22 #
23 # Set log-level in application by doing e.g.:
24 # logger = logging.getLogger('vpp_serializer')
25 # logger.setLevel(logging.DEBUG)
26 #
27 logger = logging.getLogger(__name__)
28
29
30 class BaseTypes(object):
31     def __init__(self, type, elements=0):
32         base_types = {'u8': '>B',
33                       'u16': '>H',
34                       'u32': '>I',
35                       'i32': '>i',
36                       'u64': '>Q',
37                       'f64': '>d',
38                       'bool': '>?',
39                       'header': '>HI'}
40
41         if elements > 0 and type == 'u8':
42             self.packer = struct.Struct('>%ss' % elements)
43         else:
44             self.packer = struct.Struct(base_types[type])
45         self.size = self.packer.size
46         logger.debug('Adding {} with format: {}'
47                      .format(type, base_types[type]))
48
49     def pack(self, data, kwargs=None):
50         if not data:  # Default to zero if not specified
51             data = 0
52         return self.packer.pack(data)
53
54     def unpack(self, data, offset, result=None):
55         return self.packer.unpack_from(data, offset)[0], self.packer.size
56
57
58 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
59          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
60          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
61          'bool': BaseTypes('bool')}
62
63
64 def vpp_get_type(name):
65     try:
66         return types[name]
67     except KeyError:
68         return None
69
70
71 class VPPSerializerValueError(ValueError):
72     pass
73
74
75 class FixedList_u8(object):
76     def __init__(self, name, field_type, num):
77         self.name = name
78         self.num = num
79         self.packer = BaseTypes(field_type, num)
80         self.size = self.packer.size
81
82     def pack(self, list, kwargs = None):
83         """Packs a fixed length bytestring. Left-pads with zeros
84         if input data is too short."""
85         if not list:
86             return b'\x00' * self.size
87         if len(list) > self.num:
88             raise VPPSerializerValueError(
89                 'Fixed list length error for "{}", got: {}'
90                 ' expected: {}'
91                 .format(self.name, len(list), self.num))
92         return self.packer.pack(list)
93
94     def unpack(self, data, offset=0, result=None):
95         if len(data[offset:]) < self.num:
96             raise VPPSerializerValueError(
97                 'Invalid array length for "{}" got {}'
98                 ' expected {}'
99                 .format(self.name, len(data[offset:]), self.num))
100         return self.packer.unpack(data, offset)
101
102
103 class FixedList(object):
104     def __init__(self, name, field_type, num):
105         self.num = num
106         self.packer = types[field_type]
107         self.size = self.packer.size * num
108
109     def pack(self, list, kwargs):
110         if len(list) != self.num:
111             raise VPPSerializerValueError(
112                 'Fixed list length error, got: {} expected: {}'
113                 .format(len(list), self.num))
114         b = bytes()
115         for e in list:
116             b += self.packer.pack(e)
117         return b
118
119     def unpack(self, data, offset=0, result=None):
120         # Return a list of arguments
121         result = []
122         total = 0
123         for e in range(self.num):
124             x, size = self.packer.unpack(data, offset)
125             result.append(x)
126             offset += size
127             total += size
128         return result, total
129
130
131 class VLAList(object):
132     def __init__(self, name, field_type, len_field_name, index):
133         self.name = name
134         self.index = index
135         self.packer = types[field_type]
136         self.size = self.packer.size
137         self.length_field = len_field_name
138
139     def pack(self, list, kwargs=None):
140         if not list:
141             return b""
142         if len(list) != kwargs[self.length_field]:
143             raise VPPSerializerValueError(
144                 'Variable length error, got: {} expected: {}'
145                 .format(len(list), kwargs[self.length_field]))
146         b = bytes()
147
148         # u8 array
149         if self.packer.size == 1:
150             return bytearray(list)
151
152         for e in list:
153             b += self.packer.pack(e)
154         return b
155
156     def unpack(self, data, offset=0, result=None):
157         # Return a list of arguments
158         total = 0
159
160         # u8 array
161         if self.packer.size == 1:
162             if result[self.index] == 0:
163                 return b'', 0
164             p = BaseTypes('u8', result[self.index])
165             return p.unpack(data, offset)
166
167         r = []
168         for e in range(result[self.index]):
169             x, size = self.packer.unpack(data, offset)
170             r.append(x)
171             offset += size
172             total += size
173         return r, total
174
175
176 class VLAList_legacy():
177     def __init__(self, name, field_type):
178         self.packer = types[field_type]
179         self.size = self.packer.size
180
181     def pack(self, list, kwargs=None):
182         if self.packer.size == 1:
183             return bytes(list)
184
185         b = bytes()
186         for e in list:
187             b += self.packer.pack(e)
188         return b
189
190     def unpack(self, data, offset=0, result=None):
191         total = 0
192         # Return a list of arguments
193         if (len(data) - offset) % self.packer.size:
194             raise VPPSerializerValueError(
195                 'Legacy Variable Length Array length mismatch.')
196         elements = int((len(data) - offset) / self.packer.size)
197         r = []
198         for e in range(elements):
199             x, size = self.packer.unpack(data, offset)
200             r.append(x)
201             offset += self.packer.size
202             total += size
203         return r, total
204
205
206 class VPPEnumType(object):
207     def __init__(self, name, msgdef):
208         self.size = types['u32'].size
209         e_hash = {}
210         for f in msgdef:
211             if type(f) is dict and 'enumtype' in f:
212                 if f['enumtype'] != 'u32':
213                     raise NotImplementedError
214                 continue
215             ename, evalue = f
216             e_hash[ename] = evalue
217         self.enum = IntEnum(name, e_hash)
218         types[name] = self
219         logger.debug('Adding enum {}'.format(name))
220
221     def __getattr__(self, name):
222         return self.enum[name]
223
224     def __nonzero__(self):
225         return True
226
227     def pack(self, data, kwargs=None):
228         return types['u32'].pack(data)
229
230     def unpack(self, data, offset=0, result=None):
231         x, size = types['u32'].unpack(data, offset)
232         return self.enum(x), size
233
234
235 class VPPUnionType(object):
236     def __init__(self, name, msgdef):
237         self.name = name
238         self.size = 0
239         self.maxindex = 0
240         fields = []
241         self.packers = collections.OrderedDict()
242         for i, f in enumerate(msgdef):
243             if type(f) is dict and 'crc' in f:
244                 self.crc = f['crc']
245                 continue
246             f_type, f_name = f
247             if f_type not in types:
248                 logger.debug('Unknown union type {}'.format(f_type))
249                 raise VPPSerializerValueError(
250                     'Unknown message type {}'.format(f_type))
251             fields.append(f_name)
252             size = types[f_type].size
253             self.packers[f_name] = types[f_type]
254             if size > self.size:
255                 self.size = size
256                 self.maxindex = i
257
258         types[name] = self
259         self.tuple = collections.namedtuple(name, fields, rename=True)
260         logger.debug('Adding union {}'.format(name))
261
262     # Union of variable length?
263     def pack(self, data, kwargs=None):
264         if not data:
265             return b'\x00' * self.size
266
267         for k, v in data.items():
268             logger.debug("Key: {} Value: {}".format(k, v))
269             b = self.packers[k].pack(v, kwargs)
270             break
271         r = bytearray(self.size)
272         r[:len(b)] = b
273         return r
274
275     def unpack(self, data, offset=0, result=None):
276         r = []
277         maxsize = 0
278         for k, p in self.packers.items():
279             x, size = p.unpack(data, offset)
280             if size > maxsize:
281                 maxsize = size
282             r.append(x)
283         return self.tuple._make(r), maxsize
284
285
286 def VPPTypeAlias(name, msgdef):
287     t = vpp_get_type(msgdef['type'])
288     if not t:
289         raise ValueError()
290     if 'length' in msgdef:
291         if msgdef['length'] == 0:
292             raise ValueError()
293         types[name] = FixedList(name, msgdef['type'], msgdef['length'])
294     else:
295         types[name] = t
296
297
298 class VPPType(object):
299     # Set everything up to be able to pack / unpack
300     def __init__(self, name, msgdef):
301         self.name = name
302         self.msgdef = msgdef
303         self.packers = []
304         self.fields = []
305         self.fieldtypes = []
306         self.field_by_name = {}
307         size = 0
308         for i, f in enumerate(msgdef):
309             if type(f) is dict and 'crc' in f:
310                 self.crc = f['crc']
311                 continue
312             f_type, f_name = f[:2]
313             self.fields.append(f_name)
314             self.field_by_name[f_name] = None
315             self.fieldtypes.append(f_type)
316             if f_type not in types:
317                 logger.debug('Unknown type {}'.format(f_type))
318                 raise VPPSerializerValueError(
319                     'Unknown message type {}'.format(f_type))
320             if len(f) == 3:  # list
321                 list_elements = f[2]
322                 if list_elements == 0:
323                     p = VLAList_legacy(f_name, f_type)
324                     self.packers.append(p)
325                 elif f_type == 'u8':
326                     p = FixedList_u8(f_name, f_type, list_elements)
327                     self.packers.append(p)
328                     size += p.size
329                 else:
330                     p = FixedList(f_name, f_type, list_elements)
331                     self.packers.append(p)
332                     size += p.size
333             elif len(f) == 4:  # Variable length list
334                     # Find index of length field
335                     length_index = self.fields.index(f[3])
336                     p = VLAList(f_name, f_type, f[3], length_index)
337                     self.packers.append(p)
338             else:
339                 self.packers.append(types[f_type])
340                 size += types[f_type].size
341
342         self.size = size
343         self.tuple = collections.namedtuple(name, self.fields, rename=True)
344         types[name] = self
345         logger.debug('Adding type {}'.format(name))
346
347     def pack(self, data, kwargs=None):
348         if not kwargs:
349             kwargs = data
350         b = bytes()
351         for i, a in enumerate(self.fields):
352
353             # Try one of the format functions
354             if data and type(data) is not dict and a not in data:
355                 raise VPPSerializerValueError(
356                     "Invalid argument: {} expected {}.{}".
357                     format(data, self.name, a))
358
359             # Defaulting to zero.
360             if not data or a not in data:  # Default to 0
361                 arg = None
362                 kwarg = None  # No default for VLA
363             else:
364                 arg = data[a]
365                 kwarg = kwargs[a] if a in kwargs else None
366
367             if isinstance(self.packers[i], VPPType):
368                 try:
369                     b += self.packers[i].pack(arg, kwarg)
370                 except ValueError:
371                     # Invalid argument, can we convert it?
372                     arg = VPPFormat.format(self.packers[i].name, data[a])
373                     data[a] = arg
374                     kwarg = arg
375                     b += self.packers[i].pack(arg, kwarg)
376             else:
377                 b += self.packers[i].pack(arg, kwargs)
378
379         return b
380
381     def unpack(self, data, offset=0, result=None):
382         # Return a list of arguments
383         result = []
384         total = 0
385         for p in self.packers:
386             x, size = p.unpack(data, offset, result)
387             if type(x) is tuple and len(x) == 1:
388                 x = x[0]
389             result.append(x)
390             offset += size
391             total += size
392         t = self.tuple._make(result)
393         return t, total
394
395
396 class VPPMessage(VPPType):
397     pass