Fix encode/decode for []bool
[govpp.git] / binapigen / gen_encoding.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "fmt"
19         "strconv"
20
21         "github.com/sirupsen/logrus"
22 )
23
24 func genMessageSize(g *GenFile, name string, fields []*Field) {
25         g.P("func (m *", name, ") Size() (size int) {")
26         g.P("if m == nil { return 0 }")
27
28         sizeBaseType := func(typ, name string, length int, sizefrom string) {
29                 switch typ {
30                 case STRING:
31                         if length > 0 {
32                                 g.P("size += ", length, " // ", name)
33                         } else {
34                                 g.P("size += 4 + len(", name, ")", " // ", name)
35                         }
36                 default:
37                         var size = BaseTypeSizes[typ]
38                         if sizefrom != "" {
39                                 g.P("size += ", size, " * len(", name, ")", " // ", name)
40                         } else {
41                                 if length > 0 {
42                                         g.P("size += ", size, " * ", length, " // ", name)
43                                 } else {
44                                         g.P("size += ", size, " // ", name)
45                                 }
46                         }
47                 }
48         }
49
50         lvl := 0
51         var sizeFields func(fields []*Field, parentName string)
52         sizeFields = func(fields []*Field, parentName string) {
53                 lvl++
54                 defer func() { lvl-- }()
55
56                 getFieldName := func(name string) string {
57                         return fmt.Sprintf("%s.%s", parentName, name)
58                 }
59
60                 for _, field := range fields {
61                         name := getFieldName(field.GoName)
62
63                         var sizeFromName string
64                         if field.FieldSizeFrom != nil {
65                                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
66                         }
67
68                         if _, ok := BaseTypesGo[field.Type]; ok {
69                                 sizeBaseType(field.Type, name, field.Length, sizeFromName)
70                                 continue
71                         }
72
73                         if field.Array {
74                                 index := fmt.Sprintf("j%d", lvl)
75                                 if field.Length > 0 {
76                                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
77                                 } else if field.FieldSizeFrom != nil {
78                                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
79                                 }
80                                 if field.Length == 0 || field.SizeFrom != "" {
81                                         char := fmt.Sprintf("s%d", lvl)
82                                         g.P("var ", char, " ", fieldGoType(g, field))
83                                         g.P("_ = ", char)
84                                         g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
85                                         name = char
86                                 } else {
87                                         name = fmt.Sprintf("%s[%s]", name, index)
88                                 }
89                         }
90
91                         switch {
92                         case field.TypeEnum != nil:
93                                 enum := field.TypeEnum
94                                 if _, ok := BaseTypesGo[enum.Type]; ok {
95                                         sizeBaseType(enum.Type, name, 0, "")
96                                 } else {
97                                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
98                                 }
99                         case field.TypeAlias != nil:
100                                 alias := field.TypeAlias
101                                 if typ := alias.TypeStruct; typ != nil {
102                                         sizeFields(typ.Fields, name)
103                                 } else {
104                                         sizeBaseType(alias.Type, name, alias.Length, "")
105                                 }
106                         case field.TypeStruct != nil:
107                                 typ := field.TypeStruct
108                                 sizeFields(typ.Fields, name)
109                         case field.TypeUnion != nil:
110                                 union := field.TypeUnion
111                                 maxSize := getUnionSize(union)
112                                 sizeBaseType("u8", name, maxSize, "")
113                         default:
114                                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
115                         }
116
117                         if field.Array {
118                                 g.P("}")
119                         }
120                 }
121         }
122         sizeFields(fields, "m")
123
124         g.P("return size")
125         g.P("}")
126 }
127
128 func genMessageMarshal(g *GenFile, name string, fields []*Field) {
129         g.P("func (m *", name, ") Marshal(b []byte) ([]byte, error) {")
130         g.P("if b == nil {")
131         g.P("b = make([]byte, m.Size())")
132         g.P("}")
133         g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
134
135         encodeFields(g, fields, "m", 0)
136
137         g.P("return buf.Bytes(), nil")
138         g.P("}")
139 }
140
141 func encodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
142         getFieldName := func(name string) string {
143                 return fmt.Sprintf("%s.%s", parentName, name)
144         }
145
146         for _, field := range fields {
147                 name := getFieldName(field.GoName)
148
149                 encodeField(g, field, name, getFieldName, lvl)
150         }
151 }
152
153 func encodeField(g *GenFile, field *Field, name string, getFieldName func(name string) string, lvl int) {
154         if f := field.FieldSizeOf; f != nil {
155                 if _, ok := BaseTypesGo[field.Type]; ok {
156                         val := fmt.Sprintf("len(%s)", getFieldName(f.GoName))
157                         encodeBaseType(g, field.Type, "int", val, 0, "", false)
158                         return
159                 } else {
160                         panic(fmt.Sprintf("failed to encode base type of sizefrom field: %s (%s)", field.Name, field.Type))
161                 }
162         }
163         var sizeFromName string
164         if field.FieldSizeFrom != nil {
165                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
166         }
167
168         if _, ok := BaseTypesGo[field.Type]; ok {
169                 encodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
170                 return
171         }
172
173         if field.Array {
174                 index := fmt.Sprintf("j%d", lvl)
175                 if field.Length > 0 {
176                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
177                 } else if field.SizeFrom != "" {
178                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
179                 }
180                 if field.Length == 0 || field.SizeFrom != "" {
181                         char := fmt.Sprintf("v%d", lvl)
182                         g.P("var ", char, " ", fieldGoType(g, field), "// ", field.GoName)
183                         g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
184                         name = char
185                 } else {
186                         name = fmt.Sprintf("%s[%s]", name, index)
187                 }
188         }
189
190         switch {
191         case field.TypeEnum != nil:
192                 encodeBaseType(g, field.TypeEnum.Type, fieldGoType(g, field), name, 0, "", false)
193         case field.TypeAlias != nil:
194                 alias := field.TypeAlias
195                 if typ := alias.TypeStruct; typ != nil {
196                         encodeFields(g, typ.Fields, name, lvl+1)
197                 } else {
198                         if alias.Length > 0 {
199                                 encodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
200                         } else {
201                                 encodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
202                         }
203                 }
204         case field.TypeStruct != nil:
205                 encodeFields(g, field.TypeStruct.Fields, name, lvl+1)
206         case field.TypeUnion != nil:
207                 maxSize := getUnionSize(field.TypeUnion)
208                 g.P("buf.EncodeBytes(", name, ".", fieldUnionData, "[:], ", maxSize, ")")
209         default:
210                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
211         }
212
213         if field.Array {
214                 g.P("}")
215         }
216 }
217
218 func encodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
219         isArray := length > 0 || sizefrom != ""
220         if isArray {
221                 switch typ {
222                 case U8:
223                         if alloc {
224                                 g.P("buf.EncodeBytes(", name, ", ", length, ")")
225                         } else {
226                                 g.P("buf.EncodeBytes(", name, "[:], ", length, ")")
227                         }
228                         return
229                 case I8, I16, U16, I32, U32, I64, U64, F64, BOOL:
230                         gotype := BaseTypesGo[typ]
231                         if length != 0 {
232                                 g.P("for i := 0; i < ", length, "; i++ {")
233                         } else if sizefrom != "" {
234                                 g.P("for i := 0; i < len(", name, "); i++ {")
235                         }
236                         if alloc {
237                                 g.P("var x ", gotype)
238                                 g.P("if i < len(", name, ") { x = ", gotype, "(", name, "[i]) }")
239                                 name = "x"
240                         }
241                 }
242         }
243         conv := func(s string) string {
244                 if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
245                         return fmt.Sprintf("%s(%s)", gotype, s)
246                 }
247                 return s
248         }
249         switch typ {
250         case I8, I16, I32, I64:
251                 typsize := BaseTypeSizes[typ]
252                 g.P("buf.EncodeInt", typsize*8, "(", conv(name), ")")
253         case U8, U16, U32, U64:
254                 typsize := BaseTypeSizes[typ]
255                 g.P("buf.EncodeUint", typsize*8, "(", conv(name), ")")
256         case F64:
257                 g.P("buf.EncodeFloat64(", conv(name), ")")
258         case BOOL:
259                 g.P("buf.EncodeBool(", name, ")")
260         case STRING:
261                 g.P("buf.EncodeString(", name, ", ", length, ")")
262         default:
263                 logrus.Panicf("// ??? %s %s\n", name, typ)
264         }
265         if isArray {
266                 switch typ {
267                 case I8, U8, I16, U16, I32, U32, I64, U64, F64, BOOL:
268                         g.P("}")
269                 }
270         }
271 }
272
273 func genMessageUnmarshal(g *GenFile, name string, fields []*Field) {
274         g.P("func (m *", name, ") Unmarshal(b []byte) error {")
275
276         if len(fields) > 0 {
277                 g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
278                 decodeFields(g, fields, "m", 0)
279         }
280
281         g.P("return nil")
282         g.P("}")
283 }
284
285 func decodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
286         getFieldName := func(name string) string {
287                 return fmt.Sprintf("%s.%s", parentName, name)
288         }
289
290         for _, field := range fields {
291                 name := getFieldName(field.GoName)
292
293                 decodeField(g, field, name, getFieldName, lvl)
294         }
295 }
296
297 func decodeField(g *GenFile, field *Field, name string, getFieldName func(string) string, lvl int) {
298         var sizeFromName string
299         if field.FieldSizeFrom != nil {
300                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
301         }
302
303         if _, ok := BaseTypesGo[field.Type]; ok {
304                 decodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
305                 return
306         }
307
308         if field.Array {
309                 index := fmt.Sprintf("j%d", lvl)
310                 if field.Length > 0 {
311                         g.P("for ", index, " := 0; ", index, " < ", field.Length, ";", index, "++ {")
312                 } else if field.SizeFrom != "" {
313                         g.P(name, " = make(", getFieldType(g, field), ", ", sizeFromName, ")")
314                         g.P("for ", index, " := 0; ", index, " < len(", name, ");", index, "++ {")
315                 }
316                 name = fmt.Sprintf("%s[%s]", name, index)
317         }
318
319         if enum := field.TypeEnum; enum != nil {
320                 if _, ok := BaseTypesGo[enum.Type]; ok {
321                         decodeBaseType(g, enum.Type, fieldGoType(g, field), name, 0, "", false)
322                 } else {
323                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
324                 }
325         } else if alias := field.TypeAlias; alias != nil {
326                 if typ := alias.TypeStruct; typ != nil {
327                         decodeFields(g, typ.Fields, name, lvl+1)
328                 } else {
329                         if alias.Length > 0 {
330                                 decodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
331                         } else {
332                                 decodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
333                         }
334                 }
335         } else if typ := field.TypeStruct; typ != nil {
336                 decodeFields(g, typ.Fields, name, lvl+1)
337         } else if union := field.TypeUnion; union != nil {
338                 maxSize := getUnionSize(union)
339                 g.P("copy(", name, ".", fieldUnionData, "[:], buf.DecodeBytes(", maxSize, "))")
340         } else {
341                 logrus.Panicf("\t// ??? %s (%v)\n", field.GoName, field.Type)
342         }
343
344         if field.Array {
345                 g.P("}")
346         }
347 }
348
349 func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
350         isArray := length > 0 || sizefrom != ""
351         if isArray {
352                 var size string
353                 switch {
354                 case length > 0:
355                         size = strconv.Itoa(length)
356                 case sizefrom != "":
357                         size = sizefrom
358                 }
359                 switch typ {
360                 case U8:
361                         if alloc {
362                                 g.P(name, " = make([]byte, ", size, ")")
363                                 g.P("copy(", name, ", buf.DecodeBytes(len(", name, ")))")
364                         } else {
365                                 g.P("copy(", name, "[:], buf.DecodeBytes(", size, "))")
366                         }
367                         return
368                 case I8, I16, U16, I32, U32, I64, U64, F64, BOOL:
369                         if alloc {
370                                 g.P(name, " = make([]", orig, ", ", size, ")")
371                         }
372                         g.P("for i := 0; i < len(", name, "); i++ {")
373                         name = fmt.Sprintf("%s[i]", name)
374                 }
375         }
376         conv := func(s string) string {
377                 if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
378                         return fmt.Sprintf("%s(%s)", orig, s)
379                 }
380                 return s
381         }
382         switch typ {
383         case I8, I16, I32, I64:
384                 typsize := BaseTypeSizes[typ]
385                 g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeInt%d()", typsize*8)))
386         case U8, U16, U32, U64:
387                 typsize := BaseTypeSizes[typ]
388                 g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeUint%d()", typsize*8)))
389         case F64:
390                 g.P(name, " = ", conv("buf.DecodeFloat64()"))
391         case BOOL:
392                 g.P(name, " = buf.DecodeBool()")
393         case STRING:
394                 g.P(name, " = buf.DecodeString(", length, ")")
395         default:
396                 logrus.Panicf("\t// ??? %s %s\n", name, typ)
397         }
398         if isArray {
399                 switch typ {
400                 case I8, U8, I16, U16, I32, U32, I64, U64, F64, BOOL:
401                         g.P("}")
402                 }
403         }
404 }