Python API: Add enum and union support.
[vpp.git] / src / vpp-api / python / vpp_papi / vpp_serializer.py
1 #
2 # Copyright (c) 2018 Cisco and/or its affiliates.
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 #
15
16 import struct
17 import collections
18 from enum import IntEnum
19 import logging
20
21 #
22 # Set log-level in application by doing e.g.:
23 # logger = logging.getLogger('vpp_serializer')
24 # logger.setLevel(logging.DEBUG)
25 #
26 logger = logging.getLogger(__name__)
27 FORMAT = "[%(filename)s:%(lineno)s - %(funcName)s() ] %(message)s"
28 logging.basicConfig(format=FORMAT)
29
30
31 class BaseTypes():
32     def __init__(self, type, elements=0):
33         base_types = {'u8': '>B',
34                       'u16': '>H',
35                       'u32': '>I',
36                       'i32': '>i',
37                       'u64': '>Q',
38                       'f64': '>d',
39                       'header': '>HI'}
40
41         if elements > 0 and type == 'u8':
42             self.packer = struct.Struct('>%ss' % elements)
43         else:
44             self.packer = struct.Struct(base_types[type])
45         self.size = self.packer.size
46         logger.debug('Adding {} with format: {}'
47                      .format(type, base_types[type]))
48
49     def pack(self, data, kwargs=None):
50         logger.debug("Data: {} Format: {}".format(data, self.packer.format))
51         return self.packer.pack(data)
52
53     def unpack(self, data, offset, result=None):
54         logger.debug("@ {} Format: {}".format(offset, self.packer.format))
55         return self.packer.unpack_from(data, offset)[0]
56
57
58 types = {}
59 types['u8'] = BaseTypes('u8')
60 types['u16'] = BaseTypes('u16')
61 types['u32'] = BaseTypes('u32')
62 types['i32'] = BaseTypes('i32')
63 types['u64'] = BaseTypes('u64')
64 types['f64'] = BaseTypes('f64')
65
66
67 class FixedList_u8():
68     def __init__(self, name, field_type, num):
69         self.name = name
70         self.num = num
71         self.packer = BaseTypes(field_type, num)
72         self.size = self.packer.size
73
74     def pack(self, list, kwargs):
75         logger.debug("Data: {}".format(list))
76
77         if len(list) > self.num:
78             raise ValueError('Fixed list length error for "{}", got: {}'
79                              ' expected: {}'
80                              .format(self.name, len(list), self.num))
81         return self.packer.pack(list)
82
83     def unpack(self, data, offset=0, result=None):
84         if len(data[offset:]) < self.num:
85             raise ValueError('Invalid array length for "{}" got {}'
86                              ' expected {}'
87                              .format(self.name, len(data), self.num))
88         return self.packer.unpack(data, offset)
89
90
91 class FixedList():
92     def __init__(self, name, field_type, num):
93         self.num = num
94         self.packer = types[field_type]
95         self.size = self.packer.size * num
96
97     def pack(self, list, kwargs):
98         logger.debug("Data: {}".format(list))
99
100         if len(list) != self.num:
101             raise ValueError('Fixed list length error, got: {} expected: {}'
102                              .format(len(list), self.num))
103         b = bytes()
104         for e in list:
105             b += self.packer.pack(e)
106         return b
107
108     def unpack(self, data, offset=0, result=None):
109         # Return a list of arguments
110         result = []
111         for e in range(self.num):
112             x = self.packer.unpack(data, offset)
113             result.append(x)
114             offset += self.packer.size
115         return result
116
117
118 class VLAList():
119     def __init__(self, name, field_type, len_field_name, index):
120         self.index = index
121         self.packer = types[field_type]
122         self.size = self.packer.size
123         self.length_field = len_field_name
124
125     def pack(self, list, kwargs=None):
126         logger.debug("Data: {}".format(list))
127         if len(list) != kwargs[self.length_field]:
128             raise ValueError('Variable length error, got: {} expected: {}'
129                              .format(len(list), kwargs[self.length_field]))
130         b = bytes()
131
132         # u8 array
133         if self.packer.size == 1:
134             p = BaseTypes('u8', len(list))
135             return p.pack(list)
136
137         for e in list:
138             b += self.packer.pack(e)
139         return b
140
141     def unpack(self, data, offset=0, result=None):
142         logger.debug("Data: {} @ {} Result: {}"
143                      .format(list, offset, result[self.index]))
144         # Return a list of arguments
145
146         # u8 array
147         if self.packer.size == 1:
148             if result[self.index] == 0:
149                 return b''
150             p = BaseTypes('u8', result[self.index])
151             r = p.unpack(data, offset)
152             return r
153
154         r = []
155         for e in range(result[self.index]):
156             x = self.packer.unpack(data, offset)
157             r.append(x)
158             offset += self.packer.size
159         return r
160
161
162 class VLAList_legacy():
163     def __init__(self, name, field_type):
164         self.packer = types[field_type]
165         self.size = self.packer.size
166
167     def pack(self, list, kwargs=None):
168         logger.debug("Data: {}".format(list))
169         b = bytes()
170         for e in list:
171             b += self.packer.pack(e)
172         return b
173
174     def unpack(self, data, offset=0, result=None):
175         # Return a list of arguments
176         if (len(data) - offset) % self.packer.size:
177             raise ValueError('Legacy Variable Length Array length mismatch.')
178         elements = int((len(data) - offset) / self.packer.size)
179         r = []
180         logger.debug("Legacy VLA: {} elements of size {}"
181                      .format(elements, self.packer.size))
182         for e in range(elements):
183             x = self.packer.unpack(data, offset)
184             r.append(x)
185             offset += self.packer.size
186         return r
187
188
189 class VPPEnumType():
190     def __init__(self, name, msgdef):
191         self.size = types['u32'].size
192         e_hash = {}
193         for f in msgdef:
194             if type(f) is dict and 'enumtype' in f:
195                 if f['enumtype'] != 'u32':
196                     raise NotImplementedError
197                 continue
198             ename, evalue = f
199             e_hash[ename] = evalue
200         self.enum = IntEnum(name, e_hash)
201         types[name] = self
202         logger.debug('Adding enum {}'.format(name))
203
204     def __getattr__(self, name):
205         return self.enum[name]
206
207     def pack(self, data, kwargs=None):
208         logger.debug("Data: {}".format(data))
209         return types['u32'].pack(data, kwargs)
210
211     def unpack(self, data, offset=0, result=None):
212         x = types['u32'].unpack(data, offset)
213         return self.enum(x)
214
215
216 class VPPUnionType():
217     def __init__(self, name, msgdef):
218         self.name = name
219         self.size = 0
220         self.maxindex = 0
221         fields = []
222         self.packers = collections.OrderedDict()
223         for i, f in enumerate(msgdef):
224             if type(f) is dict and 'crc' in f:
225                 self.crc = f['crc']
226                 continue
227             f_type, f_name = f
228             if f_type not in types:
229                 logger.debug('Unknown union type {}'.format(f_type))
230                 raise ValueError('Unknown message type {}'.format(f_type))
231             fields.append(f_name)
232             size = types[f_type].size
233             self.packers[f_name] = types[f_type]
234             if size > self.size:
235                 self.size = size
236                 self.maxindex = i
237
238         types[name] = self
239         self.tuple = collections.namedtuple(name, fields, rename=True)
240         logger.debug('Adding union {}'.format(name))
241
242     def pack(self, data, kwargs=None):
243         logger.debug("Data: {}".format(data))
244         for k, v in data.items():
245             logger.debug("Key: {} Value: {}".format(k, v))
246             b = self.packers[k].pack(v, kwargs)
247             offset = self.size - self.packers[k].size
248             break
249         r = bytearray(self.size)
250         r[offset:] = b
251         return r
252
253     def unpack(self, data, offset=0, result=None):
254         r = []
255         for k, p in self.packers.items():
256             union_offset = self.size - p.size
257             r.append(p.unpack(data, offset + union_offset))
258         return self.tuple._make(r)
259
260
261 class VPPType():
262     # Set everything up to be able to pack / unpack
263     def __init__(self, name, msgdef):
264         self.name = name
265         self.msgdef = msgdef
266         self.packers = []
267         self.fields = []
268         self.fieldtypes = []
269         self.field_by_name = {}
270         size = 0
271         for i, f in enumerate(msgdef):
272             if type(f) is dict and 'crc' in f:
273                 self.crc = f['crc']
274                 continue
275             f_type, f_name = f[:2]
276             self.fields.append(f_name)
277             self.field_by_name[f_name] = None
278             self.fieldtypes.append(f_type)
279             if f_type not in types:
280                 logger.debug('Unknown type {}'.format(f_type))
281                 raise ValueError('Unknown message type {}'.format(f_type))
282             if len(f) == 3:  # list
283                 list_elements = f[2]
284                 if list_elements == 0:
285                     p = VLAList_legacy(f_name, f_type)
286                     self.packers.append(p)
287                 elif f_type == 'u8':
288                     p = FixedList_u8(f_name, f_type, list_elements)
289                     self.packers.append(p)
290                     size += p.size
291                 else:
292                     p = FixedList(f_name, f_type, list_elements)
293                     self.packers.append(p)
294                     size += p.size
295             elif len(f) == 4:  # Variable length list
296                     # Find index of length field
297                     length_index = self.fields.index(f[3])
298                     p = VLAList(f_name, f_type, f[3], length_index)
299                     self.packers.append(p)
300             else:
301                 self.packers.append(types[f_type])
302                 size += types[f_type].size
303
304         self.size = size
305         self.tuple = collections.namedtuple(name, self.fields, rename=True)
306         types[name] = self
307         logger.debug('Adding type {}'.format(name))
308
309     def pack(self, data, kwargs=None):
310         if not kwargs:
311             kwargs = data
312         logger.debug("Data: {}".format(data))
313         b = bytes()
314         for i, a in enumerate(self.fields):
315             if a not in data:
316                 logger.debug("Argument {} not given, defaulting to 0"
317                              .format(a))
318                 b += b'\x00' * self.packers[i].size
319                 continue
320             b += self.packers[i].pack(data[a], kwargs)
321         return b
322
323     def unpack(self, data, offset=0, result=None):
324         # Return a list of arguments
325         result = []
326         for p in self.packers:
327             x = p.unpack(data, offset, result)
328             if type(x) is tuple and len(x) == 1:
329                 x = x[0]
330             result.append(x)
331             offset += p.size
332         return self.tuple._make(result)