Python API: Re-adding rudimentary variable length array pack support.
[vpp.git] / vppapigen / pyvppapigen.py
1 #!/usr/bin/env python
2 #
3 # Copyright (c) 2016 Cisco and/or its affiliates.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at:
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 #
16
17 from __future__ import print_function
18 import argparse, sys, os, importlib, pprint
19
20 parser = argparse.ArgumentParser(description='VPP Python API generator')
21 parser.add_argument('-i', '--input', action="store", dest="inputfile", type=argparse.FileType('r'))
22 parser.add_argument('-c', '--cfile', action="store")
23 args = parser.parse_args()
24
25 #
26 # Read API definitions file into vppapidefs
27 #
28 exec(args.inputfile.read())
29
30 # https://docs.python.org/3/library/struct.html
31 format_struct = {'u8': 'B',
32                  'u16' : 'H',
33                  'u32' : 'I',
34                  'i32' : 'i',
35                  'u64' : 'Q',
36                  'f64' : 'd',
37                  'vl_api_ip4_fib_counter_t' : 'IBQQ',
38                  'vl_api_ip6_fib_counter_t' : 'QQBQQ',
39                  };
40 #
41 # NB: If new types are introduced in vpe.api, these must be updated.
42 #
43 type_size = {'u8':   1,
44              'u16' : 2,
45              'u32' : 4,
46              'i32' : 4,
47              'u64' : 8,
48              'f64' : 8,
49              'vl_api_ip4_fib_counter_t' : 21,
50              'vl_api_ip6_fib_counter_t' : 33,
51 };
52
53 def eprint(*args, **kwargs):
54     print(*args, file=sys.stderr, **kwargs)
55
56 def get_args(t):
57     argslist = []
58     for i in t:
59         if i[1][0] == '_':
60             argslist.append(i[1][1:])
61         else:
62             argslist.append(i[1])
63
64     return argslist
65
66 def get_pack(f):
67     zeroarray = False
68     bytecount = 0
69     pack = ''
70     elements = 1
71     if len(f) is 3 or len(f) is 4:
72         size = type_size[f[0]]
73         bytecount += size * int(f[2])
74         # Check if we have a zero length array
75         if f[2] == '0':
76             # If len 3 zero array
77             elements = 0;
78             pack += format_struct[f[0]]
79             bytecount = size
80         elif size == 1:
81             n = f[2] * size
82             pack += str(n) + 's'
83         else:
84             pack += format_struct[f[0]] * int(f[2])
85             elements = int(f[2])
86     else:
87         bytecount += type_size[f[0]]
88         pack += format_struct[f[0]]
89     return (pack, elements, bytecount)
90
91
92 '''
93 def get_reply_func(f):
94     if f['name']+'_reply' in func_name:
95         return func_name[f['name']+'_reply']
96     if f['name'].find('_dump') > 0:
97         r = f['name'].replace('_dump','_details')
98         if r in func_name:
99             return func_name[r]
100     return None
101 '''
102
103 def footer_print():
104     print('''
105 def msg_id_base_set(b):
106     global base
107     base = b
108
109 import os
110 name = os.path.splitext(os.path.basename(__file__))[0]
111     ''')
112     print(u"plugin_register(name, api_func_table, api_name_to_id,", vl_api_version, ", msg_id_base_set)")
113
114 def api_table_print(name, i):
115     msg_id_in = 'VL_API_' + name.upper()
116     fstr = name + '_decode'
117     print('api_func_table.append(' + fstr + ')')
118     print('api_name_to_id["' + msg_id_in + '"] =', i)
119     print('')
120
121 def encode_print(name, id, t):
122     total = 0
123     args = get_args(t)
124     pack = '>'
125     for i, f in enumerate(t):
126         p, elements, size = get_pack(f)
127         pack += p
128         total += size
129
130     if name.find('_dump') > 0:
131         multipart = True
132     else:
133         multipart = False
134
135     if len(args) < 4:
136         print(u"def", name + "(async = False):")
137     else:
138         print(u"def", name + "(" + ', '.join(args[3:]) + ", async = False):")
139     print(u"    global base")
140     print(u"    context = get_context(base + " + id + ")")
141
142     print('''
143     results_prepare(context)
144     waiting_for_reply_set()
145     ''')
146     if multipart == True:
147         print(u"    results_more_set(context)")
148
149     pack = '>'
150     start = 0
151     end = 0
152     offset = 0
153     t = list(t)
154     i = 0
155
156     while t:
157         t, i, pack, offset, array = get_normal_pack(t, i, pack, offset)
158         if array:
159             print(u"    vpp_api.write(pack('" + pack + "', base + " +
160                   id + ", 0, context, " + ', '.join(args[3:-1]) + ") + "
161                   + args[-1] + ")")
162         else:
163             print(u"    vpp_api.write(pack('" + pack + "', base + " + id +
164                   ", 0, context, " + ', '.join(args[3:]) + "))")
165
166     if multipart == True:
167         print(
168             u"    vpp_api.write(pack('>HII', VL_API_CONTROL_PING, 0, context))")
169
170     print('''
171     if not async:
172         results_event_wait(context, 5)
173         return results_get(context)
174     return context
175     ''')
176
177 def get_normal_pack(t, i, pack, offset):
178     while t:
179         f = t.pop(0)
180         i += 1
181         if len(f) >= 3:
182             return t, i, pack, offset, f
183         p, elements, size = get_pack(f)
184         pack += p
185         offset += size
186     return t, i, pack, offset, None
187
188 def decode_print(name, t):
189     #
190     # Generate code for each element
191     #
192     print(u'def ' + name + u'_decode(msg):')
193     total = 0
194     args = get_args(t)
195     print(u"    n = namedtuple('" + name + "', '" + ', '.join(args) + "')")
196     print(u"    res = []")
197
198     pack = '>'
199     start = 0
200     end = 0
201     offset = 0
202     t = list(t)
203     i = 0
204     while t:
205         t, i, pack, offset, array = get_normal_pack(t, i, pack, offset)
206         if array:
207             p, elements, size = get_pack(array)
208
209             # Byte string
210             if elements > 0 and type_size[array[0]] == 1:
211                 pack += p
212                 offset += size * elements
213                 continue
214
215             # Dump current pack string
216             if pack != '>':
217                 print(u"    tr = unpack_from('" + pack + "', msg[" + str(start) + ":])")
218                 print(u"    res.extend(list(tr))")
219                 start += offset
220             pack = '>'
221
222             if elements == 0:
223                 # This has to be the last element
224                 if len(array) == 3:
225                     print(u"    res.append(msg[" + str(offset) + ":])")
226                     if len(t) > 0:
227                         eprint('WARNING: Variable length array must be last element in message', name, array)
228
229                     continue
230                 if size == 1 or len(p) == 1:
231                     # Do it as a bytestring.
232                     if p == 'B':
233                         p = 's'
234                     # XXX: Assume that length parameter is the previous field. Add validation.
235                     print(u"    c = res[" + str(i - 2) + "]")
236                     print(u"    tr = unpack_from('>' + str(c) + '" + p + "', msg[" + str(start) + ":])")
237                     print(u"    res.append(tr)")
238                     continue
239                 print(u"    tr2 = []")
240                 print(u"    offset = " + str(total))
241                 print(u"    for j in range(res[" + str(i - 2) + "]):")
242                 print(u"        tr2.append(unpack_from('>" + p + "', msg[" + str(start) + ":], offset))")
243                 print(u"        offset += " + str(size))
244                 print(u"    res.append(tr2)")
245                 continue
246
247             # Missing something!!
248             print(u"    tr = unpack_from('>" + p + "', msg[" + str(start) + ":])")
249             start += size
250
251             print(u"    res.append(tr)")
252
253     if pack != '>':
254         print(u"    tr = unpack_from('" + pack + "', msg[" + str(start) + ":])")
255         print(u"    res.extend(list(tr))")
256     print(u"    return n._make(res)")
257     print('')
258
259 #
260 # Generate the main Python file
261 #
262 def main():
263     print('''
264 #
265 # AUTO-GENERATED FILE. PLEASE DO NOT EDIT.
266 #
267 from vpp_api_base import *
268 from struct import *
269 from collections import namedtuple
270 import vpp_api
271 api_func_table = []
272 api_name_to_id = {}
273     ''')
274
275     for i, a in enumerate(vppapidef):
276         name = a[0]
277         encode_print(name, str(i), a[1:])
278         decode_print(name, a[1:])
279         api_table_print(name, i)
280     footer_print()
281
282 if __name__ == "__main__":
283     main()