5dce03b6188f770ac30339afbe4707ebf36557eb
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 from enum import IntEnum
19 import logging
20 from . import vpp_format
21 import ipaddress
22 import sys
23 import socket
24
25 #
26 # Set log-level in application by doing e.g.:
27 # logger = logging.getLogger('vpp_serializer')
28 # logger.setLevel(logging.DEBUG)
29 #
30 logger = logging.getLogger(__name__)
31
32 if sys.version[0] == '2':
33     check = lambda d: type(d) is dict
34 else:
35     check = lambda d: type(d) is dict or type(d) is bytes
36
37 def conversion_required(data, field_type):
38     if check(data):
39         return False
40     try:
41         if type(data).__name__ in vpp_format.conversion_table[field_type]:
42             return True
43     except KeyError:
44         return False
45
46
47 def conversion_packer(data, field_type):
48     t = type(data).__name__
49     return types[field_type].pack(vpp_format.
50                                   conversion_table[field_type][t](data))
51
52
53 def conversion_unpacker(data, field_type):
54     if field_type not in vpp_format.conversion_unpacker_table:
55         return data
56     return vpp_format.conversion_unpacker_table[field_type](data)
57
58
59 class BaseTypes(object):
60     def __init__(self, type, elements=0):
61         base_types = {'u8': '>B',
62                       'u16': '>H',
63                       'u32': '>I',
64                       'i32': '>i',
65                       'u64': '>Q',
66                       'f64': '>d',
67                       'bool': '>?',
68                       'header': '>HI'}
69
70         if elements > 0 and type == 'u8':
71             self.packer = struct.Struct('>%ss' % elements)
72         else:
73             self.packer = struct.Struct(base_types[type])
74         self.size = self.packer.size
75         logger.debug('Adding {} with format: {}'
76                      .format(type, base_types[type]))
77
78     def pack(self, data, kwargs=None):
79         if not data:  # Default to zero if not specified
80             data = 0
81         return self.packer.pack(data)
82
83     def unpack(self, data, offset, result=None, ntc=False):
84         return self.packer.unpack_from(data, offset)[0], self.packer.size
85
86
87 types = {'u8': BaseTypes('u8'), 'u16': BaseTypes('u16'),
88          'u32': BaseTypes('u32'), 'i32': BaseTypes('i32'),
89          'u64': BaseTypes('u64'), 'f64': BaseTypes('f64'),
90          'bool': BaseTypes('bool')}
91
92
93 def vpp_get_type(name):
94     try:
95         return types[name]
96     except KeyError:
97         return None
98
99
100 class VPPSerializerValueError(ValueError):
101     pass
102
103
104 class FixedList_u8(object):
105     def __init__(self, name, field_type, num):
106         self.name = name
107         self.num = num
108         self.packer = BaseTypes(field_type, num)
109         self.size = self.packer.size
110
111     def pack(self, data, kwargs=None):
112         """Packs a fixed length bytestring. Left-pads with zeros
113         if input data is too short."""
114         if not data:
115             return b'\x00' * self.size
116
117         if len(data) > self.num:
118             raise VPPSerializerValueError(
119                 'Fixed list length error for "{}", got: {}'
120                 ' expected: {}'
121                 .format(self.name, len(data), self.num))
122
123         return self.packer.pack(data)
124
125     def unpack(self, data, offset=0, result=None, ntc=False):
126         if len(data[offset:]) < self.num:
127             raise VPPSerializerValueError(
128                 'Invalid array length for "{}" got {}'
129                 ' expected {}'
130                 .format(self.name, len(data[offset:]), self.num))
131         return self.packer.unpack(data, offset)
132
133
134 class FixedList(object):
135     def __init__(self, name, field_type, num):
136         self.num = num
137         self.packer = types[field_type]
138         self.size = self.packer.size * num
139         self.name = name
140         self.field_type = field_type
141
142     def pack(self, list, kwargs):
143         if len(list) != self.num:
144             raise VPPSerializerValueError(
145                 'Fixed list length error, got: {} expected: {}'
146                 .format(len(list), self.num))
147         b = bytes()
148         for e in list:
149             b += self.packer.pack(e)
150         return b
151
152     def unpack(self, data, offset=0, result=None, ntc=False):
153         # Return a list of arguments
154         result = []
155         total = 0
156         for e in range(self.num):
157             x, size = self.packer.unpack(data, offset, ntc=ntc)
158             result.append(x)
159             offset += size
160             total += size
161         return result, total
162
163
164 class VLAList(object):
165     def __init__(self, name, field_type, len_field_name, index):
166         self.name = name
167         self.index = index
168         self.packer = types[field_type]
169         self.size = self.packer.size
170         self.length_field = len_field_name
171
172     def pack(self, list, kwargs=None):
173         if not list:
174             return b""
175         if len(list) != kwargs[self.length_field]:
176             raise VPPSerializerValueError(
177                 'Variable length error, got: {} expected: {}'
178                 .format(len(list), kwargs[self.length_field]))
179         b = bytes()
180
181         # u8 array
182         if self.packer.size == 1:
183             return bytearray(list)
184
185         for e in list:
186             b += self.packer.pack(e)
187         return b
188
189     def unpack(self, data, offset=0, result=None, ntc=False):
190         # Return a list of arguments
191         total = 0
192
193         # u8 array
194         if self.packer.size == 1:
195             if result[self.index] == 0:
196                 return b'', 0
197             p = BaseTypes('u8', result[self.index])
198             return p.unpack(data, offset, ntc=ntc)
199
200         r = []
201         for e in range(result[self.index]):
202             x, size = self.packer.unpack(data, offset, ntc=ntc)
203             r.append(x)
204             offset += size
205             total += size
206         return r, total
207
208
209 class VLAList_legacy():
210     def __init__(self, name, field_type):
211         self.packer = types[field_type]
212         self.size = self.packer.size
213
214     def pack(self, list, kwargs=None):
215         if self.packer.size == 1:
216             return bytes(list)
217
218         b = bytes()
219         for e in list:
220             b += self.packer.pack(e)
221         return b
222
223     def unpack(self, data, offset=0, result=None, ntc=False):
224         total = 0
225         # Return a list of arguments
226         if (len(data) - offset) % self.packer.size:
227             raise VPPSerializerValueError(
228                 'Legacy Variable Length Array length mismatch.')
229         elements = int((len(data) - offset) / self.packer.size)
230         r = []
231         for e in range(elements):
232             x, size = self.packer.unpack(data, offset, ntc=ntc)
233             r.append(x)
234             offset += self.packer.size
235             total += size
236         return r, total
237
238
239 class VPPEnumType(object):
240     def __init__(self, name, msgdef):
241         self.size = types['u32'].size
242         e_hash = {}
243         for f in msgdef:
244             if type(f) is dict and 'enumtype' in f:
245                 if f['enumtype'] != 'u32':
246                     raise NotImplementedError
247                 continue
248             ename, evalue = f
249             e_hash[ename] = evalue
250         self.enum = IntEnum(name, e_hash)
251         types[name] = self
252         logger.debug('Adding enum {}'.format(name))
253
254     def __getattr__(self, name):
255         return self.enum[name]
256
257     def __nonzero__(self):
258         return True
259
260     def pack(self, data, kwargs=None):
261         return types['u32'].pack(data)
262
263     def unpack(self, data, offset=0, result=None, ntc=False):
264         x, size = types['u32'].unpack(data, offset)
265         return self.enum(x), size
266
267
268 class VPPUnionType(object):
269     def __init__(self, name, msgdef):
270         self.name = name
271         self.size = 0
272         self.maxindex = 0
273         fields = []
274         self.packers = collections.OrderedDict()
275         for i, f in enumerate(msgdef):
276             if type(f) is dict and 'crc' in f:
277                 self.crc = f['crc']
278                 continue
279             f_type, f_name = f
280             if f_type not in types:
281                 logger.debug('Unknown union type {}'.format(f_type))
282                 raise VPPSerializerValueError(
283                     'Unknown message type {}'.format(f_type))
284             fields.append(f_name)
285             size = types[f_type].size
286             self.packers[f_name] = types[f_type]
287             if size > self.size:
288                 self.size = size
289                 self.maxindex = i
290
291         types[name] = self
292         self.tuple = collections.namedtuple(name, fields, rename=True)
293         logger.debug('Adding union {}'.format(name))
294
295     # Union of variable length?
296     def pack(self, data, kwargs=None):
297         if not data:
298             return b'\x00' * self.size
299
300         for k, v in data.items():
301             logger.debug("Key: {} Value: {}".format(k, v))
302             b = self.packers[k].pack(v, kwargs)
303             break
304         r = bytearray(self.size)
305         r[:len(b)] = b
306         return r
307
308     def unpack(self, data, offset=0, result=None, ntc=False):
309         r = []
310         maxsize = 0
311         for k, p in self.packers.items():
312             x, size = p.unpack(data, offset, ntc=ntc)
313             if size > maxsize:
314                 maxsize = size
315             r.append(x)
316         return self.tuple._make(r), maxsize
317
318
319 class VPPTypeAlias(object):
320     def __init__(self, name, msgdef):
321         self.name = name
322         t = vpp_get_type(msgdef['type'])
323         if not t:
324             raise ValueError()
325         if 'length' in msgdef:
326             if msgdef['length'] == 0:
327                 raise ValueError()
328             if msgdef['type'] == 'u8':
329                 self.packer = FixedList_u8(name, msgdef['type'],
330                                            msgdef['length'])
331                 self.size = self.packer.size
332             else:
333                 self.packer = FixedList(name, msgdef['type'], msgdef['length'])
334         else:
335             self.packer = t
336             self.size = t.size
337
338         types[name] = self
339
340     def pack(self, data, kwargs=None):
341         if data and conversion_required(data, self.name):
342             try:
343                 return conversion_packer(data, self.name)
344             # Python 2 and 3 raises different exceptions from inet_pton
345             except(OSError, socket.error, TypeError):
346                 pass
347
348         return self.packer.pack(data, kwargs)
349
350     def unpack(self, data, offset=0, result=None, ntc=False):
351         t, size = self.packer.unpack(data, offset, result, ntc=ntc)
352         if not ntc:
353             return conversion_unpacker(t, self.name), size
354         return t, size
355
356
357 class VPPType(object):
358     # Set everything up to be able to pack / unpack
359     def __init__(self, name, msgdef):
360         self.name = name
361         self.msgdef = msgdef
362         self.packers = []
363         self.fields = []
364         self.fieldtypes = []
365         self.field_by_name = {}
366         size = 0
367         for i, f in enumerate(msgdef):
368             if type(f) is dict and 'crc' in f:
369                 self.crc = f['crc']
370                 continue
371             f_type, f_name = f[:2]
372             self.fields.append(f_name)
373             self.field_by_name[f_name] = None
374             self.fieldtypes.append(f_type)
375             if f_type not in types:
376                 logger.debug('Unknown type {}'.format(f_type))
377                 raise VPPSerializerValueError(
378                     'Unknown message type {}'.format(f_type))
379             if len(f) == 3:  # list
380                 list_elements = f[2]
381                 if list_elements == 0:
382                     p = VLAList_legacy(f_name, f_type)
383                     self.packers.append(p)
384                 elif f_type == 'u8':
385                     p = FixedList_u8(f_name, f_type, list_elements)
386                     self.packers.append(p)
387                     size += p.size
388                 else:
389                     p = FixedList(f_name, f_type, list_elements)
390                     self.packers.append(p)
391                     size += p.size
392             elif len(f) == 4:  # Variable length list
393                     # Find index of length field
394                     length_index = self.fields.index(f[3])
395                     p = VLAList(f_name, f_type, f[3], length_index)
396                     self.packers.append(p)
397             else:
398                 self.packers.append(types[f_type])
399                 size += types[f_type].size
400
401         self.size = size
402         self.tuple = collections.namedtuple(name, self.fields, rename=True)
403         types[name] = self
404         logger.debug('Adding type {}'.format(name))
405
406     def pack(self, data, kwargs=None):
407         if not kwargs:
408             kwargs = data
409         b = bytes()
410
411         # Try one of the format functions
412         if data and conversion_required(data, self.name):
413             return conversion_packer(data, self.name)
414
415         for i, a in enumerate(self.fields):
416             if data and type(data) is not dict and a not in data:
417                 raise VPPSerializerValueError(
418                     "Invalid argument: {} expected {}.{}".
419                     format(data, self.name, a))
420
421             # Defaulting to zero.
422             if not data or a not in data:  # Default to 0
423                 arg = None
424                 kwarg = None  # No default for VLA
425             else:
426                 arg = data[a]
427                 kwarg = kwargs[a] if a in kwargs else None
428             if isinstance(self.packers[i], VPPType):
429                 b += self.packers[i].pack(arg, kwarg)
430             else:
431                 b += self.packers[i].pack(arg, kwargs)
432
433         return b
434
435     def unpack(self, data, offset=0, result=None, ntc=False):
436         # Return a list of arguments
437         result = []
438         total = 0
439         for p in self.packers:
440             x, size = p.unpack(data, offset, result, ntc)
441             if type(x) is tuple and len(x) == 1:
442                 x = x[0]
443             result.append(x)
444             offset += size
445             total += size
446         t = self.tuple._make(result)
447         if not ntc:
448             t = conversion_unpacker(t, self.name)
449         return t, total
450
451
452 class VPPMessage(VPPType):
453     pass