Recognize stat_dir_type_empty
[govpp.git] / binapigen / gen_encoding.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "fmt"
19         "strconv"
20
21         "github.com/sirupsen/logrus"
22 )
23
24 func init() {
25         //RegisterPlugin("encoding", GenerateEncoding)
26 }
27
28 func genMessageSize(g *GenFile, name string, fields []*Field) {
29         g.P("func (m *", name, ") Size() (size int) {")
30         g.P("if m == nil { return 0 }")
31
32         sizeBaseType := func(typ, name string, length int, sizefrom string) {
33                 switch typ {
34                 case STRING:
35                         if length > 0 {
36                                 g.P("size += ", length, " // ", name)
37                         } else {
38                                 g.P("size += 4 + len(", name, ")", " // ", name)
39                         }
40                 default:
41                         var size = BaseTypeSizes[typ]
42                         if sizefrom != "" {
43                                 g.P("size += ", size, " * len(", name, ")", " // ", name)
44                         } else {
45                                 if length > 0 {
46                                         g.P("size += ", size, " * ", length, " // ", name)
47                                 } else {
48                                         g.P("size += ", size, " // ", name)
49                                 }
50                         }
51                 }
52         }
53
54         lvl := 0
55         var sizeFields func(fields []*Field, parentName string)
56         sizeFields = func(fields []*Field, parentName string) {
57                 lvl++
58                 defer func() { lvl-- }()
59
60                 getFieldName := func(name string) string {
61                         return fmt.Sprintf("%s.%s", parentName, name)
62                 }
63
64                 for _, field := range fields {
65                         name := getFieldName(field.GoName)
66
67                         var sizeFromName string
68                         if field.FieldSizeFrom != nil {
69                                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
70                         }
71
72                         if _, ok := BaseTypesGo[field.Type]; ok {
73                                 sizeBaseType(field.Type, name, field.Length, sizeFromName)
74                                 continue
75                         }
76
77                         if field.Array {
78                                 index := fmt.Sprintf("j%d", lvl)
79                                 if field.Length > 0 {
80                                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
81                                 } else if field.FieldSizeFrom != nil {
82                                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
83                                 }
84                                 if field.Length == 0 || field.SizeFrom != "" {
85                                         char := fmt.Sprintf("s%d", lvl)
86                                         g.P("var ", char, " ", fieldGoType(g, field))
87                                         g.P("_ = ", char)
88                                         g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
89                                         name = char
90                                 } else {
91                                         name = fmt.Sprintf("%s[%s]", name, index)
92                                 }
93                         }
94
95                         switch {
96                         case field.TypeEnum != nil:
97                                 enum := field.TypeEnum
98                                 if _, ok := BaseTypesGo[enum.Type]; ok {
99                                         sizeBaseType(enum.Type, name, 0, "")
100                                 } else {
101                                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
102                                 }
103                         case field.TypeAlias != nil:
104                                 alias := field.TypeAlias
105                                 if typ := alias.TypeStruct; typ != nil {
106                                         sizeFields(typ.Fields, name)
107                                 } else {
108                                         sizeBaseType(alias.Type, name, alias.Length, "")
109                                 }
110                         case field.TypeStruct != nil:
111                                 typ := field.TypeStruct
112                                 sizeFields(typ.Fields, name)
113                         case field.TypeUnion != nil:
114                                 union := field.TypeUnion
115                                 maxSize := getUnionSize(union)
116                                 sizeBaseType("u8", name, maxSize, "")
117                         default:
118                                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
119                         }
120
121                         if field.Array {
122                                 g.P("}")
123                         }
124                 }
125         }
126         sizeFields(fields, "m")
127
128         g.P("return size")
129         g.P("}")
130 }
131
132 func genMessageMarshal(g *GenFile, name string, fields []*Field) {
133         g.P("func (m *", name, ") Marshal(b []byte) ([]byte, error) {")
134         g.P("if b == nil {")
135         g.P("b = make([]byte, m.Size())")
136         g.P("}")
137         g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
138
139         encodeFields(g, fields, "m", 0)
140
141         g.P("return buf.Bytes(), nil")
142         g.P("}")
143 }
144
145 func encodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
146         getFieldName := func(name string) string {
147                 return fmt.Sprintf("%s.%s", parentName, name)
148         }
149
150         for _, field := range fields {
151                 name := getFieldName(field.GoName)
152
153                 encodeField(g, field, name, getFieldName, lvl)
154         }
155 }
156
157 func encodeField(g *GenFile, field *Field, name string, getFieldName func(name string) string, lvl int) {
158         if f := field.FieldSizeOf; f != nil {
159                 if _, ok := BaseTypesGo[field.Type]; ok {
160                         val := fmt.Sprintf("len(%s)", getFieldName(f.GoName))
161                         encodeBaseType(g, field.Type, "int", val, 0, "", false)
162                         return
163                 } else {
164                         panic(fmt.Sprintf("failed to encode base type of sizefrom field: %s (%s)", field.Name, field.Type))
165                 }
166         }
167         var sizeFromName string
168         if field.FieldSizeFrom != nil {
169                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
170         }
171
172         if _, ok := BaseTypesGo[field.Type]; ok {
173                 encodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
174                 return
175         }
176
177         if field.Array {
178                 index := fmt.Sprintf("j%d", lvl)
179                 if field.Length > 0 {
180                         g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
181                 } else if field.SizeFrom != "" {
182                         g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
183                 }
184                 if field.Length == 0 || field.SizeFrom != "" {
185                         char := fmt.Sprintf("v%d", lvl)
186                         g.P("var ", char, " ", fieldGoType(g, field), "// ", field.GoName)
187                         g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
188                         name = char
189                 } else {
190                         name = fmt.Sprintf("%s[%s]", name, index)
191                 }
192         }
193
194         switch {
195         case field.TypeEnum != nil:
196                 encodeBaseType(g, field.TypeEnum.Type, fieldGoType(g, field), name, 0, "", false)
197         case field.TypeAlias != nil:
198                 alias := field.TypeAlias
199                 if typ := alias.TypeStruct; typ != nil {
200                         encodeFields(g, typ.Fields, name, lvl+1)
201                 } else {
202                         if alias.Length > 0 {
203                                 encodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
204                         } else {
205                                 encodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
206                         }
207                 }
208         case field.TypeStruct != nil:
209                 encodeFields(g, field.TypeStruct.Fields, name, lvl+1)
210         case field.TypeUnion != nil:
211                 maxSize := getUnionSize(field.TypeUnion)
212                 g.P("buf.EncodeBytes(", name, ".", fieldUnionData, "[:], ", maxSize, ")")
213         default:
214                 logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
215         }
216
217         if field.Array {
218                 g.P("}")
219         }
220 }
221
222 func encodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
223         isArray := length > 0 || sizefrom != ""
224         if isArray {
225                 switch typ {
226                 case U8:
227                         if alloc {
228                                 g.P("buf.EncodeBytes(", name, ", ", length, ")")
229                         } else {
230                                 g.P("buf.EncodeBytes(", name, "[:], ", length, ")")
231                         }
232                         return
233                 case I8, I16, U16, I32, U32, I64, U64, F64:
234                         gotype := BaseTypesGo[typ]
235                         if length != 0 {
236                                 g.P("for i := 0; i < ", length, "; i++ {")
237                         } else if sizefrom != "" {
238                                 g.P("for i := 0; i < len(", name, "); i++ {")
239                         }
240                         if alloc {
241                                 g.P("var x ", gotype)
242                                 g.P("if i < len(", name, ") { x = ", gotype, "(", name, "[i]) }")
243                                 name = "x"
244                         }
245                 }
246         }
247         conv := func(s string) string {
248                 if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
249                         return fmt.Sprintf("%s(%s)", gotype, s)
250                 }
251                 return s
252         }
253         switch typ {
254         case I8, I16, I32, I64:
255                 typsize := BaseTypeSizes[typ]
256                 g.P("buf.EncodeInt", typsize*8, "(", conv(name), ")")
257         case U8, U16, U32, U64:
258                 typsize := BaseTypeSizes[typ]
259                 g.P("buf.EncodeUint", typsize*8, "(", conv(name), ")")
260         case F64:
261                 g.P("buf.EncodeFloat64(", conv(name), ")")
262         case BOOL:
263                 g.P("buf.EncodeBool(", name, ")")
264         case STRING:
265                 g.P("buf.EncodeString(", name, ", ", length, ")")
266         default:
267                 logrus.Panicf("// ??? %s %s\n", name, typ)
268         }
269         if isArray {
270                 switch typ {
271                 case I8, U8, I16, U16, I32, U32, I64, U64, F64:
272                         g.P("}")
273                 }
274         }
275 }
276
277 func genMessageUnmarshal(g *GenFile, name string, fields []*Field) {
278         g.P("func (m *", name, ") Unmarshal(b []byte) error {")
279
280         if len(fields) > 0 {
281                 g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
282                 decodeFields(g, fields, "m", 0)
283         }
284
285         g.P("return nil")
286         g.P("}")
287 }
288
289 func decodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
290         getFieldName := func(name string) string {
291                 return fmt.Sprintf("%s.%s", parentName, name)
292         }
293
294         for _, field := range fields {
295                 name := getFieldName(field.GoName)
296
297                 decodeField(g, field, name, getFieldName, lvl)
298         }
299 }
300
301 func decodeField(g *GenFile, field *Field, name string, getFieldName func(string) string, lvl int) {
302         var sizeFromName string
303         if field.FieldSizeFrom != nil {
304                 sizeFromName = getFieldName(field.FieldSizeFrom.GoName)
305         }
306
307         if _, ok := BaseTypesGo[field.Type]; ok {
308                 decodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
309                 return
310         }
311
312         if field.Array {
313                 index := fmt.Sprintf("j%d", lvl)
314                 if field.Length > 0 {
315                         g.P("for ", index, " := 0; ", index, " < ", field.Length, ";", index, "++ {")
316                 } else if field.SizeFrom != "" {
317                         g.P(name, " = make(", getFieldType(g, field), ", ", sizeFromName, ")")
318                         g.P("for ", index, " := 0; ", index, " < len(", name, ");", index, "++ {")
319                 }
320                 name = fmt.Sprintf("%s[%s]", name, index)
321         }
322
323         if enum := field.TypeEnum; enum != nil {
324                 if _, ok := BaseTypesGo[enum.Type]; ok {
325                         decodeBaseType(g, enum.Type, fieldGoType(g, field), name, 0, "", false)
326                 } else {
327                         logrus.Panicf("\t// ??? ENUM %s %s\n", name, enum.Type)
328                 }
329         } else if alias := field.TypeAlias; alias != nil {
330                 if typ := alias.TypeStruct; typ != nil {
331                         decodeFields(g, typ.Fields, name, lvl+1)
332                 } else {
333                         if alias.Length > 0 {
334                                 decodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
335                         } else {
336                                 decodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
337                         }
338                 }
339         } else if typ := field.TypeStruct; typ != nil {
340                 decodeFields(g, typ.Fields, name, lvl+1)
341         } else if union := field.TypeUnion; union != nil {
342                 maxSize := getUnionSize(union)
343                 g.P("copy(", name, ".", fieldUnionData, "[:], buf.DecodeBytes(", maxSize, "))")
344         } else {
345                 logrus.Panicf("\t// ??? %s (%v)\n", field.GoName, field.Type)
346         }
347
348         if field.Array {
349                 g.P("}")
350         }
351 }
352
353 func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
354         isArray := length > 0 || sizefrom != ""
355         if isArray {
356                 var size string
357                 switch {
358                 case length > 0:
359                         size = strconv.Itoa(length)
360                 case sizefrom != "":
361                         size = sizefrom
362                 }
363                 switch typ {
364                 case U8:
365                         if alloc {
366                                 g.P(name, " = make([]byte, ", size, ")")
367                                 g.P("copy(", name, ", buf.DecodeBytes(len(", name, ")))")
368                         } else {
369                                 g.P("copy(", name, "[:], buf.DecodeBytes(", size, "))")
370                         }
371                         return
372                 case I8, I16, U16, I32, U32, I64, U64, F64:
373                         if alloc {
374                                 g.P(name, " = make([]", orig, ", ", size, ")")
375                         }
376                         g.P("for i := 0; i < len(", name, "); i++ {")
377                         name = fmt.Sprintf("%s[i]", name)
378                 }
379         }
380         conv := func(s string) string {
381                 if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
382                         return fmt.Sprintf("%s(%s)", orig, s)
383                 }
384                 return s
385         }
386         switch typ {
387         case I8, I16, I32, I64:
388                 typsize := BaseTypeSizes[typ]
389                 g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeInt%d()", typsize*8)))
390         case U8, U16, U32, U64:
391                 typsize := BaseTypeSizes[typ]
392                 g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeUint%d()", typsize*8)))
393         case F64:
394                 g.P(name, " = ", conv("buf.DecodeFloat64()"))
395         case BOOL:
396                 g.P(name, " = buf.DecodeBool()")
397         case STRING:
398                 g.P(name, " = buf.DecodeString(", length, ")")
399         default:
400                 logrus.Panicf("\t// ??? %s %s\n", name, typ)
401         }
402         if isArray {
403                 switch typ {
404                 case I8, U8, I16, U16, I32, U32, I64, U64, F64:
405                         g.P("}")
406                 }
407         }
408 }