Improve binapi generator
[govpp.git] / binapigen / gen_encoding.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "fmt"
19         "strconv"
20         "strings"
21
22         "github.com/sirupsen/logrus"
23 )
24
25 func init() {
26         //RegisterPlugin("encoding", GenerateEncoding)
27 }
28
29 func generateMessageSize(g *GenFile, name string, fields []*Field) {
30         g.P("func (m *", name, ") Size() int {")
31         g.P("if m == nil { return 0 }")
32         g.P("var size int")
33
34         sizeBaseType := func(typ, name string, length int, sizefrom string) {
35                 switch typ {
36                 case STRING:
37                         if length > 0 {
38                                 g.P("size += ", length, " // ", name)
39                         } else {
40                                 g.P("size += 4 + len(", name, ")", " // ", name)
41                         }
42                 default:
43                         var size = BaseTypeSizes[typ]
44                         if sizefrom != "" {
45                                 g.P("size += ", size, " * len(", name, ")", " // ", name)
46                         } else {
47                                 if length > 0 {
48                                         g.P("size += ", size, " * ", length, " // ", name)
49                                 } else {
50                                         g.P("size += ", size, " // ", name)
51                                 }
52                         }
53                 }
54         }
55
56         lvl := 0
57         var sizeFields func(fields []*Field, parentName string)
58         sizeFields = func(fields []*Field, parentName string) {
59                 lvl++
60                 defer func() { lvl-- }()
61
62                 getFieldName := func(name string) string {
63                         return fmt.Sprintf("%s.%s", parentName, name)
64                 }
65
66                 for _, field := range fields {
67                         name := getFieldName(field.GoName)
68
69                         var sizeFromName string
70                         if field.FieldSizeFrom != nil {
71                                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
72                         }
73
74                         if _, ok := BaseTypesGo[field.Type]; ok {
75                                 sizeBaseType(field.Type, name, field.Length, sizeFromName)
76                                 continue
77                         }
78
79                         if field.Array {
80                                 char := fmt.Sprintf("s%d", lvl)
81                                 index := fmt.Sprintf("j%d", lvl)
82                                 if field.Length > 0 {
83                                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
84                                 } else if field.FieldSizeFrom != nil {
85                                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
86                                 }
87                                 g.P("var ", char, " ", fieldGoType(g, field))
88                                 g.P("_ = ", char)
89                                 g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
90                                 name = char
91                         }
92
93                         switch {
94                         case field.TypeEnum != nil:
95                                 enum := field.TypeEnum
96                                 if _, ok := BaseTypesGo[enum.Type]; ok {
97                                         sizeBaseType(enum.Type, name, 0, "")
98                                 } else {
99                                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
100                                 }
101                         case field.TypeAlias != nil:
102                                 alias := field.TypeAlias
103                                 if typ := alias.TypeStruct; typ != nil {
104                                         sizeFields(typ.Fields, name)
105                                 } else {
106                                         sizeBaseType(alias.Type, name, alias.Length, "")
107                                 }
108                         case field.TypeStruct != nil:
109                                 typ := field.TypeStruct
110                                 sizeFields(typ.Fields, name)
111                         case field.TypeUnion != nil:
112                                 union := field.TypeUnion
113                                 maxSize := getUnionSize(union)
114                                 sizeBaseType("u8", name, maxSize, "")
115                         default:
116                                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
117                         }
118
119                         if field.Array {
120                                 g.P("}")
121                         }
122                 }
123         }
124         sizeFields(fields, "m")
125
126         g.P("return size")
127         g.P("}")
128 }
129
130 func encodeBaseType(g *GenFile, typ, name string, length int, sizefrom string) {
131         isArray := length > 0 || sizefrom != ""
132         if isArray {
133                 switch typ {
134                 case U8:
135                         g.P("buf.EncodeBytes(", name, "[:], ", length, ")")
136                         return
137                 case I8, I16, U16, I32, U32, I64, U64, F64:
138                         gotype := BaseTypesGo[typ]
139                         if length != 0 {
140                                 g.P("for i := 0; i < ", length, "; i++ {")
141                         } else if sizefrom != "" {
142                                 g.P("for i := 0; i < len(", name, "); i++ {")
143                         }
144                         g.P("var x ", gotype)
145                         g.P("if i < len(", name, ") { x = ", gotype, "(", name, "[i]) }")
146                         name = "x"
147                 }
148         }
149         switch typ {
150         case I8, U8, I16, U16, I32, U32, I64, U64:
151                 typsize := BaseTypeSizes[typ]
152                 g.P("buf.EncodeUint", typsize*8, "(uint", typsize*8, "(", name, "))")
153         case F64:
154                 g.P("buf.EncodeFloat64(float64(", name, "))")
155         case BOOL:
156                 g.P("buf.EncodeBool(", name, ")")
157         case STRING:
158                 g.P("buf.EncodeString(", name, ", ", length, ")")
159         default:
160                 logrus.Panicf("// ??? %s %s\n", name, typ)
161         }
162         if isArray {
163                 switch typ {
164                 case I8, U8, I16, U16, I32, U32, I64, U64, F64:
165                         g.P("}")
166                 }
167         }
168 }
169
170 func encodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
171         getFieldName := func(name string) string {
172                 return fmt.Sprintf("%s.%s", parentName, name)
173         }
174
175         for _, field := range fields {
176                 name := getFieldName(field.GoName)
177
178                 encodeField(g, field, name, getFieldName, lvl)
179         }
180 }
181
182 func encodeField(g *GenFile, field *Field, name string, getFieldName func(name string) string, lvl int) {
183         if f := field.FieldSizeOf; f != nil {
184                 if _, ok := BaseTypesGo[field.Type]; ok {
185                         encodeBaseType(g, field.Type, fmt.Sprintf("len(%s)", getFieldName(f.GoName)), field.Length, "")
186                         return
187                 } else {
188                         panic(fmt.Sprintf("failed to encode base type of sizefrom field: %s (%s)", field.Name, field.Type))
189                 }
190         }
191         var sizeFromName string
192         if field.FieldSizeFrom != nil {
193                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
194         }
195
196         if _, ok := BaseTypesGo[field.Type]; ok {
197                 encodeBaseType(g, field.Type, name, field.Length, sizeFromName)
198                 return
199         }
200
201         if field.Array {
202                 char := fmt.Sprintf("v%d", lvl)
203                 index := fmt.Sprintf("j%d", lvl)
204                 if field.Length > 0 {
205                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
206                 } else if field.SizeFrom != "" {
207                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
208                 }
209                 g.P("var ", char, " ", fieldGoType(g, field))
210                 g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
211                 name = char
212         }
213
214         switch {
215         case field.TypeEnum != nil:
216                 encodeBaseType(g, field.TypeEnum.Type, name, 0, "")
217         case field.TypeAlias != nil:
218                 alias := field.TypeAlias
219                 if typ := alias.TypeStruct; typ != nil {
220                         encodeFields(g, typ.Fields, name, lvl+1)
221                 } else {
222                         encodeBaseType(g, alias.Type, name, alias.Length, "")
223                 }
224         case field.TypeStruct != nil:
225                 encodeFields(g, field.TypeStruct.Fields, name, lvl+1)
226         case field.TypeUnion != nil:
227                 g.P("buf.EncodeBytes(", name, ".", fieldUnionData, "[:], 0)")
228         default:
229                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
230         }
231
232         if field.Array {
233                 g.P("}")
234         }
235 }
236
237 func generateMessageMarshal(g *GenFile, name string, fields []*Field) {
238         g.P("func (m *", name, ") Marshal(b []byte) ([]byte, error) {")
239         g.P("var buf *", govppCodecPkg.Ident("Buffer"))
240         g.P("if b == nil {")
241         g.P("buf = ", govppCodecPkg.Ident("NewBuffer"), "(make([]byte, m.Size()))")
242         g.P("} else {")
243         g.P("buf = ", govppCodecPkg.Ident("NewBuffer"), "(b)")
244         g.P("}")
245
246         encodeFields(g, fields, "m", 0)
247
248         g.P("return buf.Bytes(), nil")
249         g.P("}")
250 }
251
252 func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
253         isArray := length > 0 || sizefrom != ""
254         if isArray {
255                 switch typ {
256                 case U8:
257                         g.P("copy(", name, "[:], buf.DecodeBytes(", length, "))")
258                         return
259                 case I8, I16, U16, I32, U32, I64, U64, F64:
260                         if alloc {
261                                 var size string
262                                 switch {
263                                 case length > 0:
264                                         size = strconv.Itoa(length)
265                                 case sizefrom != "":
266                                         size = sizefrom
267                                 }
268                                 if size != "" {
269                                         g.P(name, " = make([]", orig, ", ", size, ")")
270                                 }
271                         }
272                         g.P("for i := 0; i < len(", name, "); i++ {")
273                         name = fmt.Sprintf("%s[i]", name)
274                 }
275         }
276         switch typ {
277         case I8, U8, I16, U16, I32, U32, I64, U64:
278                 typsize := BaseTypeSizes[typ]
279                 if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig || strings.HasPrefix(orig, "i") {
280                         g.P(name, " = ", orig, "(buf.DecodeUint", typsize*8, "())")
281                 } else {
282                         g.P(name, " = buf.DecodeUint", typsize*8, "()")
283                 }
284         case F64:
285                 g.P(name, " = ", orig, "(buf.DecodeFloat64())")
286         case BOOL:
287                 g.P(name, " = buf.DecodeBool()")
288         case STRING:
289                 g.P(name, " = buf.DecodeString(", length, ")")
290         default:
291                 logrus.Panicf("\t// ??? %s %s\n", name, typ)
292         }
293         if isArray {
294                 switch typ {
295                 case I8, U8, I16, U16, I32, U32, I64, U64, F64:
296                         g.P("}")
297                 }
298         }
299 }
300
301 func generateMessageUnmarshal(g *GenFile, name string, fields []*Field) {
302         g.P("func (m *", name, ") Unmarshal(b []byte) error {")
303
304         if len(fields) > 0 {
305                 g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
306                 decodeFields(g, fields, "m", 0)
307         }
308
309         g.P("return nil")
310         g.P("}")
311 }
312
313 func decodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
314         getFieldName := func(name string) string {
315                 return fmt.Sprintf("%s.%s", parentName, name)
316         }
317
318         for _, field := range fields {
319                 name := getFieldName(field.GoName)
320
321                 decodeField(g, field, name, getFieldName, lvl)
322         }
323 }
324
325 func decodeField(g *GenFile, field *Field, name string, getFieldName func(string) string, lvl int) {
326         var sizeFromName string
327         if field.FieldSizeFrom != nil {
328                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
329         }
330
331         if _, ok := BaseTypesGo[field.Type]; ok {
332                 decodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
333                 return
334         }
335
336         if field.Array {
337                 index := fmt.Sprintf("j%d", lvl)
338                 if field.Length > 0 {
339                         g.P("for ", index, " := 0; ", index, " < ", field.Length, ";", index, "++ {")
340                 } else if field.SizeFrom != "" {
341                         g.P(name, " = make(", getFieldType(g, field), ", int(", sizeFromName, "))")
342                         g.P("for ", index, " := 0; ", index, " < len(", name, ");", index, "++ {")
343                 }
344                 name = fmt.Sprintf("%s[%s]", name, index)
345         }
346
347         if enum := field.TypeEnum; enum != nil {
348                 if _, ok := BaseTypesGo[enum.Type]; ok {
349                         decodeBaseType(g, enum.Type, fieldGoType(g, field), name, 0, "", false)
350                 } else {
351                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
352                 }
353         } else if alias := field.TypeAlias; alias != nil {
354                 if typ := alias.TypeStruct; typ != nil {
355                         decodeFields(g, typ.Fields, name, lvl+1)
356                 } else {
357                         if alias.Length > 0 {
358                                 decodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
359                         } else {
360                                 decodeBaseType(g, alias.Type, fieldGoType(g, field), name, alias.Length, "", false)
361                         }
362                 }
363         } else if typ := field.TypeStruct; typ != nil {
364                 decodeFields(g, typ.Fields, name, lvl+1)
365         } else if union := field.TypeUnion; union != nil {
366                 maxSize := getUnionSize(union)
367                 g.P("copy(", name, ".", fieldUnionData, "[:], buf.DecodeBytes(", maxSize, "))")
368         } else {
369                 logrus.Panicf("\t// ??? %s (%v)\n", field.GoName, field.Type)
370         }
371
372         if field.Array {
373                 g.P("}")
374         }
375 }