API: Change ip4_address and ip6_address to use type alias.
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 from enum import IntEnum
19 import logging
20 from .vpp_format import VPPFormat
21
22 #
23 # Set log-level in application by doing e.g.:
24 # logger = logging.getLogger('vpp_serializer')
25 # logger.setLevel(logging.DEBUG)
26 #
27 logger = logging.getLogger(__name__)
28
29
30 class BaseTypes(object):
31     def __init__(self, type, elements=0):
32         base_types = {'u8': '>B',
33                       'u16': '>H',
34                       'u32': '>I',
35                       'i32': '>i',
36                       'u64': '>Q',
37                       'f64': '>d',
38                       'bool': '>?',
39                       'header': '>HI'}
40
41         if elements > 0 and type == 'u8':
42             self.packer = struct.Struct('>%ss' % elements)
43         else:
44             self.packer = struct.Struct(base_types[type])
45         self.size = self.packer.size
46         logger.debug('Adding {} with format: {}'
47                      .format(type, base_types[type]))
48
49     def pack(self, data, kwargs=None):
50         if not data:  # Default to zero if not specified
51             data = 0
52         return self.packer.pack(data)
53
54     def unpack(self, data, offset, result=None):
55         return self.packer.unpack_from(data, offset)[0], self.packer.size
56
57
58 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
59          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
60          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
61          'bool': BaseTypes('bool')}
62
63
64 def vpp_get_type(name):
65     try:
66         return types[name]
67     except KeyError:
68         return None
69
70
71 class VPPSerializerValueError(ValueError):
72     pass
73
74
75 class FixedList_u8(object):
76     def __init__(self, name, field_type, num):
77         self.name = name
78         self.num = num
79         self.packer = BaseTypes(field_type, num)
80         self.size = self.packer.size
81
82     def pack(self, list, kwargs=None):
83         """Packs a fixed length bytestring. Left-pads with zeros
84         if input data is too short."""
85         if not list:
86             return b'\x00' * self.size
87         if len(list) > self.num:
88             raise VPPSerializerValueError(
89                 'Fixed list length error for "{}", got: {}'
90                 ' expected: {}'
91                 .format(self.name, len(list), self.num))
92         return self.packer.pack(list)
93
94     def unpack(self, data, offset=0, result=None):
95         if len(data[offset:]) < self.num:
96             raise VPPSerializerValueError(
97                 'Invalid array length for "{}" got {}'
98                 ' expected {}'
99                 .format(self.name, len(data[offset:]), self.num))
100         return self.packer.unpack(data, offset)
101
102
103 class FixedList(object):
104     def __init__(self, name, field_type, num):
105         self.num = num
106         self.packer = types[field_type]
107         self.size = self.packer.size * num
108
109     def pack(self, list, kwargs):
110         if len(list) != self.num:
111             raise VPPSerializerValueError(
112                 'Fixed list length error, got: {} expected: {}'
113                 .format(len(list), self.num))
114         b = bytes()
115         for e in list:
116             b += self.packer.pack(e)
117         return b
118
119     def unpack(self, data, offset=0, result=None):
120         # Return a list of arguments
121         result = []
122         total = 0
123         for e in range(self.num):
124             x, size = self.packer.unpack(data, offset)
125             result.append(x)
126             offset += size
127             total += size
128         return result, total
129
130
131 class VLAList(object):
132     def __init__(self, name, field_type, len_field_name, index):
133         self.name = name
134         self.index = index
135         self.packer = types[field_type]
136         self.size = self.packer.size
137         self.length_field = len_field_name
138
139     def pack(self, list, kwargs=None):
140         if not list:
141             return b""
142         if len(list) != kwargs[self.length_field]:
143             raise VPPSerializerValueError(
144                 'Variable length error, got: {} expected: {}'
145                 .format(len(list), kwargs[self.length_field]))
146         b = bytes()
147
148         # u8 array
149         if self.packer.size == 1:
150             return bytearray(list)
151
152         for e in list:
153             b += self.packer.pack(e)
154         return b
155
156     def unpack(self, data, offset=0, result=None):
157         # Return a list of arguments
158         total = 0
159
160         # u8 array
161         if self.packer.size == 1:
162             if result[self.index] == 0:
163                 return b'', 0
164             p = BaseTypes('u8', result[self.index])
165             return p.unpack(data, offset)
166
167         r = []
168         for e in range(result[self.index]):
169             x, size = self.packer.unpack(data, offset)
170             r.append(x)
171             offset += size
172             total += size
173         return r, total
174
175
176 class VLAList_legacy():
177     def __init__(self, name, field_type):
178         self.packer = types[field_type]
179         self.size = self.packer.size
180
181     def pack(self, list, kwargs=None):
182         if self.packer.size == 1:
183             return bytes(list)
184
185         b = bytes()
186         for e in list:
187             b += self.packer.pack(e)
188         return b
189
190     def unpack(self, data, offset=0, result=None):
191         total = 0
192         # Return a list of arguments
193         if (len(data) - offset) % self.packer.size:
194             raise VPPSerializerValueError(
195                 'Legacy Variable Length Array length mismatch.')
196         elements = int((len(data) - offset) / self.packer.size)
197         r = []
198         for e in range(elements):
199             x, size = self.packer.unpack(data, offset)
200             r.append(x)
201             offset += self.packer.size
202             total += size
203         return r, total
204
205
206 class VPPEnumType(object):
207     def __init__(self, name, msgdef):
208         self.size = types['u32'].size
209         e_hash = {}
210         for f in msgdef:
211             if type(f) is dict and 'enumtype' in f:
212                 if f['enumtype'] != 'u32':
213                     raise NotImplementedError
214                 continue
215             ename, evalue = f
216             e_hash[ename] = evalue
217         self.enum = IntEnum(name, e_hash)
218         types[name] = self
219         logger.debug('Adding enum {}'.format(name))
220
221     def __getattr__(self, name):
222         return self.enum[name]
223
224     def __nonzero__(self):
225         return True
226
227     def pack(self, data, kwargs=None):
228         return types['u32'].pack(data)
229
230     def unpack(self, data, offset=0, result=None):
231         x, size = types['u32'].unpack(data, offset)
232         return self.enum(x), size
233
234
235 class VPPUnionType(object):
236     def __init__(self, name, msgdef):
237         self.name = name
238         self.size = 0
239         self.maxindex = 0
240         fields = []
241         self.packers = collections.OrderedDict()
242         for i, f in enumerate(msgdef):
243             if type(f) is dict and 'crc' in f:
244                 self.crc = f['crc']
245                 continue
246             f_type, f_name = f
247             if f_type not in types:
248                 logger.debug('Unknown union type {}'.format(f_type))
249                 raise VPPSerializerValueError(
250                     'Unknown message type {}'.format(f_type))
251             fields.append(f_name)
252             size = types[f_type].size
253             self.packers[f_name] = types[f_type]
254             if size > self.size:
255                 self.size = size
256                 self.maxindex = i
257
258         types[name] = self
259         self.tuple = collections.namedtuple(name, fields, rename=True)
260         logger.debug('Adding union {}'.format(name))
261
262     # Union of variable length?
263     def pack(self, data, kwargs=None):
264         if not data:
265             return b'\x00' * self.size
266
267         for k, v in data.items():
268             logger.debug("Key: {} Value: {}".format(k, v))
269             b = self.packers[k].pack(v, kwargs)
270             break
271         r = bytearray(self.size)
272         r[:len(b)] = b
273         return r
274
275     def unpack(self, data, offset=0, result=None):
276         r = []
277         maxsize = 0
278         for k, p in self.packers.items():
279             x, size = p.unpack(data, offset)
280             if size > maxsize:
281                 maxsize = size
282             r.append(x)
283         return self.tuple._make(r), maxsize
284
285
286 def VPPTypeAlias(name, msgdef):
287     t = vpp_get_type(msgdef['type'])
288     if not t:
289         raise ValueError()
290     if 'length' in msgdef:
291         if msgdef['length'] == 0:
292             raise ValueError()
293         if msgdef['type'] == 'u8':
294             types[name] = FixedList_u8(name, msgdef['type'],
295                                        msgdef['length'])
296         else:
297             types[name] = FixedList(name, msgdef['type'], msgdef['length'])
298     else:
299         types[name] = t
300
301
302 class VPPType(object):
303     # Set everything up to be able to pack / unpack
304     def __init__(self, name, msgdef):
305         self.name = name
306         self.msgdef = msgdef
307         self.packers = []
308         self.fields = []
309         self.fieldtypes = []
310         self.field_by_name = {}
311         size = 0
312         for i, f in enumerate(msgdef):
313             if type(f) is dict and 'crc' in f:
314                 self.crc = f['crc']
315                 continue
316             f_type, f_name = f[:2]
317             self.fields.append(f_name)
318             self.field_by_name[f_name] = None
319             self.fieldtypes.append(f_type)
320             if f_type not in types:
321                 logger.debug('Unknown type {}'.format(f_type))
322                 raise VPPSerializerValueError(
323                     'Unknown message type {}'.format(f_type))
324             if len(f) == 3:  # list
325                 list_elements = f[2]
326                 if list_elements == 0:
327                     p = VLAList_legacy(f_name, f_type)
328                     self.packers.append(p)
329                 elif f_type == 'u8':
330                     p = FixedList_u8(f_name, f_type, list_elements)
331                     self.packers.append(p)
332                     size += p.size
333                 else:
334                     p = FixedList(f_name, f_type, list_elements)
335                     self.packers.append(p)
336                     size += p.size
337             elif len(f) == 4:  # Variable length list
338                     # Find index of length field
339                     length_index = self.fields.index(f[3])
340                     p = VLAList(f_name, f_type, f[3], length_index)
341                     self.packers.append(p)
342             else:
343                 self.packers.append(types[f_type])
344                 size += types[f_type].size
345
346         self.size = size
347         self.tuple = collections.namedtuple(name, self.fields, rename=True)
348         types[name] = self
349         logger.debug('Adding type {}'.format(name))
350
351     def pack(self, data, kwargs=None):
352         if not kwargs:
353             kwargs = data
354         b = bytes()
355         for i, a in enumerate(self.fields):
356
357             # Try one of the format functions
358             if data and type(data) is not dict and a not in data:
359                 raise VPPSerializerValueError(
360                     "Invalid argument: {} expected {}.{}".
361                     format(data, self.name, a))
362
363             # Defaulting to zero.
364             if not data or a not in data:  # Default to 0
365                 arg = None
366                 kwarg = None  # No default for VLA
367             else:
368                 arg = data[a]
369                 kwarg = kwargs[a] if a in kwargs else None
370
371             if isinstance(self.packers[i], VPPType):
372                 try:
373                     b += self.packers[i].pack(arg, kwarg)
374                 except ValueError:
375                     # Invalid argument, can we convert it?
376                     arg = VPPFormat.format(self.packers[i].name, data[a])
377                     data[a] = arg
378                     kwarg = arg
379                     b += self.packers[i].pack(arg, kwarg)
380             else:
381                 b += self.packers[i].pack(arg, kwargs)
382
383         return b
384
385     def unpack(self, data, offset=0, result=None):
386         # Return a list of arguments
387         result = []
388         total = 0
389         for p in self.packers:
390             x, size = p.unpack(data, offset, result)
391             if type(x) is tuple and len(x) == 1:
392                 x = x[0]
393             result.append(x)
394             offset += size
395             total += size
396         t = self.tuple._make(result)
397         return t, total
398
399
400 class VPPMessage(VPPType):
401     pass