API: Add support for type aliases
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 from enum import IntEnum
19 import logging
20 from .vpp_format import VPPFormat
21
22 #
23 # Set log-level in application by doing e.g.:
24 # logger = logging.getLogger('vpp_serializer')
25 # logger.setLevel(logging.DEBUG)
26 #
27 logger = logging.getLogger(__name__)
28
29
30 class BaseTypes(object):
31     def __init__(self, type, elements=0):
32         base_types = {'u8': '>B',
33                       'u16': '>H',
34                       'u32': '>I',
35                       'i32': '>i',
36                       'u64': '>Q',
37                       'f64': '>d',
38                       'bool': '>?',
39                       'header': '>HI'}
40
41         if elements > 0 and type == 'u8':
42             self.packer = struct.Struct('>%ss' % elements)
43         else:
44             self.packer = struct.Struct(base_types[type])
45         self.size = self.packer.size
46         logger.debug('Adding {} with format: {}'
47                      .format(type, base_types[type]))
48
49     def pack(self, data, kwargs=None):
50         if not data:  # Default to zero if not specified
51             data = 0
52         return self.packer.pack(data)
53
54     def unpack(self, data, offset, result=None):
55         return self.packer.unpack_from(data, offset)[0], self.packer.size
56
57
58 types = {}
59 types['u8'] = BaseTypes('u8')
60 types['u16'] = BaseTypes('u16')
61 types['u32'] = BaseTypes('u32')
62 types['i32'] = BaseTypes('i32')
63 types['u64'] = BaseTypes('u64')
64 types['f64'] = BaseTypes('f64')
65 types['bool'] = BaseTypes('bool')
66
67
68 def vpp_get_type(name):
69     try:
70         return types[name]
71     except KeyError:
72         return None
73
74
75 class FixedList_u8(object):
76     def __init__(self, name, field_type, num):
77         self.name = name
78         self.num = num
79         self.packer = BaseTypes(field_type, num)
80         self.size = self.packer.size
81
82     def pack(self, list, kwargs = None):
83         """Packs a fixed length bytestring. Left-pads with zeros
84         if input data is too short."""
85         if not list:
86             return b'\x00' * self.size
87         if len(list) > self.num:
88             raise ValueError('Fixed list length error for "{}", got: {}'
89                              ' expected: {}'
90                              .format(self.name, len(list), self.num))
91         return self.packer.pack(list)
92
93     def unpack(self, data, offset=0, result=None):
94         if len(data[offset:]) < self.num:
95             raise ValueError('Invalid array length for "{}" got {}'
96                              ' expected {}'
97                              .format(self.name, len(data[offset:]), self.num))
98         return self.packer.unpack(data, offset)
99
100
101 class FixedList(object):
102     def __init__(self, name, field_type, num):
103         self.num = num
104         self.packer = types[field_type]
105         self.size = self.packer.size * num
106
107     def pack(self, list, kwargs):
108         if len(list) != self.num:
109             raise ValueError('Fixed list length error, got: {} expected: {}'
110                              .format(len(list), self.num))
111         b = bytes()
112         for e in list:
113             b += self.packer.pack(e)
114         return b
115
116     def unpack(self, data, offset=0, result=None):
117         # Return a list of arguments
118         result = []
119         total = 0
120         for e in range(self.num):
121             x, size = self.packer.unpack(data, offset)
122             result.append(x)
123             offset += size
124             total += size
125         return result, total
126
127
128 class VLAList(object):
129     def __init__(self, name, field_type, len_field_name, index):
130         self.name = name
131         self.index = index
132         self.packer = types[field_type]
133         self.size = self.packer.size
134         self.length_field = len_field_name
135
136     def pack(self, list, kwargs=None):
137         if not list:
138             return b""
139         if len(list) != kwargs[self.length_field]:
140             raise ValueError('Variable length error, got: {} expected: {}'
141                              .format(len(list), kwargs[self.length_field]))
142         b = bytes()
143
144         # u8 array
145         if self.packer.size == 1:
146             return bytearray(list)
147
148         for e in list:
149             b += self.packer.pack(e)
150         return b
151
152     def unpack(self, data, offset=0, result=None):
153         # Return a list of arguments
154         total = 0
155
156         # u8 array
157         if self.packer.size == 1:
158             if result[self.index] == 0:
159                 return b'', 0
160             p = BaseTypes('u8', result[self.index])
161             return p.unpack(data, offset)
162
163         r = []
164         for e in range(result[self.index]):
165             x, size = self.packer.unpack(data, offset)
166             r.append(x)
167             offset += size
168             total += size
169         return r, total
170
171
172 class VLAList_legacy():
173     def __init__(self, name, field_type):
174         self.packer = types[field_type]
175         self.size = self.packer.size
176
177     def pack(self, list, kwargs=None):
178         if self.packer.size == 1:
179             return bytes(list)
180
181         b = bytes()
182         for e in list:
183             b += self.packer.pack(e)
184         return b
185
186     def unpack(self, data, offset=0, result=None):
187         total = 0
188         # Return a list of arguments
189         if (len(data) - offset) % self.packer.size:
190             raise ValueError('Legacy Variable Length Array length mismatch.')
191         elements = int((len(data) - offset) / self.packer.size)
192         r = []
193         for e in range(elements):
194             x, size = self.packer.unpack(data, offset)
195             r.append(x)
196             offset += self.packer.size
197             total += size
198         return r, total
199
200
201 class VPPEnumType(object):
202     def __init__(self, name, msgdef):
203         self.size = types['u32'].size
204         e_hash = {}
205         for f in msgdef:
206             if type(f) is dict and 'enumtype' in f:
207                 if f['enumtype'] != 'u32':
208                     raise NotImplementedError
209                 continue
210             ename, evalue = f
211             e_hash[ename] = evalue
212         self.enum = IntEnum(name, e_hash)
213         types[name] = self
214         logger.debug('Adding enum {}'.format(name))
215
216     def __getattr__(self, name):
217         return self.enum[name]
218
219     def __nonzero__(self):
220         return True
221
222     def pack(self, data, kwargs=None):
223         return types['u32'].pack(data)
224
225     def unpack(self, data, offset=0, result=None):
226         x, size = types['u32'].unpack(data, offset)
227         return self.enum(x), size
228
229
230 class VPPUnionType(object):
231     def __init__(self, name, msgdef):
232         self.name = name
233         self.size = 0
234         self.maxindex = 0
235         fields = []
236         self.packers = collections.OrderedDict()
237         for i, f in enumerate(msgdef):
238             if type(f) is dict and 'crc' in f:
239                 self.crc = f['crc']
240                 continue
241             f_type, f_name = f
242             if f_type not in types:
243                 logger.debug('Unknown union type {}'.format(f_type))
244                 raise ValueError('Unknown message type {}'.format(f_type))
245             fields.append(f_name)
246             size = types[f_type].size
247             self.packers[f_name] = types[f_type]
248             if size > self.size:
249                 self.size = size
250                 self.maxindex = i
251
252         types[name] = self
253         self.tuple = collections.namedtuple(name, fields, rename=True)
254         logger.debug('Adding union {}'.format(name))
255
256     # Union of variable length?
257     def pack(self, data, kwargs=None):
258         if not data:
259             return b'\x00' * self.size
260
261         for k, v in data.items():
262             logger.debug("Key: {} Value: {}".format(k, v))
263             b = self.packers[k].pack(v, kwargs)
264             break
265         r = bytearray(self.size)
266         r[:len(b)] = b
267         return r
268
269     def unpack(self, data, offset=0, result=None):
270         r = []
271         maxsize = 0
272         for k, p in self.packers.items():
273             x, size = p.unpack(data, offset)
274             if size > maxsize:
275                 maxsize = size
276             r.append(x)
277         return self.tuple._make(r), maxsize
278
279
280 def VPPTypeAlias(name, msgdef):
281     t = vpp_get_type(msgdef['type'])
282     if not t:
283         raise ValueError()
284     if 'length' in msgdef:
285         if msgdef['length'] == 0:
286             raise ValueError()
287         types[name] = FixedList(name, msgdef['type'], msgdef['length'])
288     else:
289         types[name] = t
290
291
292 class VPPType(object):
293     # Set everything up to be able to pack / unpack
294     def __init__(self, name, msgdef):
295         self.name = name
296         self.msgdef = msgdef
297         self.packers = []
298         self.fields = []
299         self.fieldtypes = []
300         self.field_by_name = {}
301         size = 0
302         for i, f in enumerate(msgdef):
303             if type(f) is dict and 'crc' in f:
304                 self.crc = f['crc']
305                 continue
306             f_type, f_name = f[:2]
307             self.fields.append(f_name)
308             self.field_by_name[f_name] = None
309             self.fieldtypes.append(f_type)
310             if f_type not in types:
311                 logger.debug('Unknown type {}'.format(f_type))
312                 raise ValueError('Unknown message type {}'.format(f_type))
313             if len(f) == 3:  # list
314                 list_elements = f[2]
315                 if list_elements == 0:
316                     p = VLAList_legacy(f_name, f_type)
317                     self.packers.append(p)
318                 elif f_type == 'u8':
319                     p = FixedList_u8(f_name, f_type, list_elements)
320                     self.packers.append(p)
321                     size += p.size
322                 else:
323                     p = FixedList(f_name, f_type, list_elements)
324                     self.packers.append(p)
325                     size += p.size
326             elif len(f) == 4:  # Variable length list
327                     # Find index of length field
328                     length_index = self.fields.index(f[3])
329                     p = VLAList(f_name, f_type, f[3], length_index)
330                     self.packers.append(p)
331             else:
332                 self.packers.append(types[f_type])
333                 size += types[f_type].size
334
335         self.size = size
336         self.tuple = collections.namedtuple(name, self.fields, rename=True)
337         types[name] = self
338         logger.debug('Adding type {}'.format(name))
339
340     def pack(self, data, kwargs=None):
341         if not kwargs:
342             kwargs = data
343         b = bytes()
344         for i, a in enumerate(self.fields):
345
346             # Try one of the format functions
347             if data and type(data) is not dict and a not in data:
348                 raise ValueError("Invalid argument: {} expected {}.{}".
349                                  format(data, self.name, a))
350
351             # Defaulting to zero.
352             if not data or a not in data:  # Default to 0
353                 arg = None
354                 kwarg = None  # No default for VLA
355             else:
356                 arg = data[a]
357                 kwarg = kwargs[a] if a in kwargs else None
358
359             if isinstance(self.packers[i], VPPType):
360                 try:
361                     b += self.packers[i].pack(arg, kwarg)
362                 except ValueError:
363                     # Invalid argument, can we convert it?
364                     arg = VPPFormat.format(self.packers[i].name, data[a])
365                     data[a] = arg
366                     kwarg = arg
367                     b += self.packers[i].pack(arg, kwarg)
368             else:
369                 b += self.packers[i].pack(arg, kwargs)
370
371         return b
372
373     def unpack(self, data, offset=0, result=None):
374         # Return a list of arguments
375         result = []
376         total = 0
377         for p in self.packers:
378             x, size = p.unpack(data, offset, result)
379             if type(x) is tuple and len(x) == 1:
380                 x = x[0]
381             result.append(x)
382             offset += size
383             total += size
384         t = self.tuple._make(result)
385         return t, total
386
387
388 class VPPMessage(VPPType):
389     pass