Fix binapigen decoding and minor improvements
[govpp.git] / binapigen / gen_encoding.go
index 1cd3eb3..d946771 100644 (file)
@@ -17,7 +17,6 @@ package binapigen
 import (
        "fmt"
        "strconv"
-       "strings"
 
        "github.com/sirupsen/logrus"
 )
@@ -26,10 +25,9 @@ func init() {
        //RegisterPlugin("encoding", GenerateEncoding)
 }
 
-func generateMessageSize(g *GenFile, name string, fields []*Field) {
-       g.P("func (m *", name, ") Size() int {")
+func genMessageSize(g *GenFile, name string, fields []*Field) {
+       g.P("func (m *", name, ") Size() (size int) {")
        g.P("if m == nil { return 0 }")
-       g.P("var size int")
 
        sizeBaseType := func(typ, name string, length int, sizefrom string) {
                switch typ {
@@ -77,17 +75,21 @@ func generateMessageSize(g *GenFile, name string, fields []*Field) {
                        }
 
                        if field.Array {
-                               char := fmt.Sprintf("s%d", lvl)
                                index := fmt.Sprintf("j%d", lvl)
                                if field.Length > 0 {
                                        g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
                                } else if field.FieldSizeFrom != nil {
                                        g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
                                }
-                               g.P("var ", char, " ", fieldGoType(g, field))
-                               g.P("_ = ", char)
-                               g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
-                               name = char
+                               if field.Length == 0 || field.SizeFrom != "" {
+                                       char := fmt.Sprintf("s%d", lvl)
+                                       g.P("var ", char, " ", fieldGoType(g, field))
+                                       g.P("_ = ", char)
+                                       g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
+                                       name = char
+                               } else {
+                                       name = fmt.Sprintf("%s[%s]", name, index)
+                               }
                        }
 
                        switch {
@@ -127,44 +129,17 @@ func generateMessageSize(g *GenFile, name string, fields []*Field) {
        g.P("}")
 }
 
-func encodeBaseType(g *GenFile, typ, name string, length int, sizefrom string) {
-       isArray := length > 0 || sizefrom != ""
-       if isArray {
-               switch typ {
-               case U8:
-                       g.P("buf.EncodeBytes(", name, "[:], ", length, ")")
-                       return
-               case I8, I16, U16, I32, U32, I64, U64, F64:
-                       gotype := BaseTypesGo[typ]
-                       if length != 0 {
-                               g.P("for i := 0; i < ", length, "; i++ {")
-                       } else if sizefrom != "" {
-                               g.P("for i := 0; i < len(", name, "); i++ {")
-                       }
-                       g.P("var x ", gotype)
-                       g.P("if i < len(", name, ") { x = ", gotype, "(", name, "[i]) }")
-                       name = "x"
-               }
-       }
-       switch typ {
-       case I8, U8, I16, U16, I32, U32, I64, U64:
-               typsize := BaseTypeSizes[typ]
-               g.P("buf.EncodeUint", typsize*8, "(uint", typsize*8, "(", name, "))")
-       case F64:
-               g.P("buf.EncodeFloat64(float64(", name, "))")
-       case BOOL:
-               g.P("buf.EncodeBool(", name, ")")
-       case STRING:
-               g.P("buf.EncodeString(", name, ", ", length, ")")
-       default:
-               logrus.Panicf("// ??? %s %s\n", name, typ)
-       }
-       if isArray {
-               switch typ {
-               case I8, U8, I16, U16, I32, U32, I64, U64, F64:
-                       g.P("}")
-               }
-       }
+func genMessageMarshal(g *GenFile, name string, fields []*Field) {
+       g.P("func (m *", name, ") Marshal(b []byte) ([]byte, error) {")
+       g.P("if b == nil {")
+       g.P("b = make([]byte, m.Size())")
+       g.P("}")
+       g.P("buf := ", govppCodecPkg.Ident("NewBuffer"), "(b)")
+
+       encodeFields(g, fields, "m", 0)
+
+       g.P("return buf.Bytes(), nil")
+       g.P("}")
 }
 
 func encodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
@@ -182,7 +157,8 @@ func encodeFields(g *GenFile, fields []*Field, parentName string, lvl int) {
 func encodeField(g *GenFile, field *Field, name string, getFieldName func(name string) string, lvl int) {
        if f := field.FieldSizeOf; f != nil {
                if _, ok := BaseTypesGo[field.Type]; ok {
-                       encodeBaseType(g, field.Type, fmt.Sprintf("len(%s)", getFieldName(f.GoName)), field.Length, "")
+                       val := fmt.Sprintf("len(%s)", getFieldName(f.GoName))
+                       encodeBaseType(g, field.Type, "int", val, 0, "", false)
                        return
                } else {
                        panic(fmt.Sprintf("failed to encode base type of sizefrom field: %s (%s)", field.Name, field.Type))
@@ -194,37 +170,46 @@ func encodeField(g *GenFile, field *Field, name string, getFieldName func(name s
        }
 
        if _, ok := BaseTypesGo[field.Type]; ok {
-               encodeBaseType(g, field.Type, name, field.Length, sizeFromName)
+               encodeBaseType(g, field.Type, fieldGoType(g, field), name, field.Length, sizeFromName, true)
                return
        }
 
        if field.Array {
-               char := fmt.Sprintf("v%d", lvl)
                index := fmt.Sprintf("j%d", lvl)
                if field.Length > 0 {
                        g.P("for ", index, " := 0; ", index, " < ", field.Length, "; ", index, "++ {")
                } else if field.SizeFrom != "" {
                        g.P("for ", index, " := 0; ", index, " < len(", name, "); ", index, "++ {")
                }
-               g.P("var ", char, " ", fieldGoType(g, field))
-               g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
-               name = char
+               if field.Length == 0 || field.SizeFrom != "" {
+                       char := fmt.Sprintf("v%d", lvl)
+                       g.P("var ", char, " ", fieldGoType(g, field), "// ", field.GoName)
+                       g.P("if ", index, " < len(", name, ") { ", char, " = ", name, "[", index, "] }")
+                       name = char
+               } else {
+                       name = fmt.Sprintf("%s[%s]", name, index)
+               }
        }
 
        switch {
        case field.TypeEnum != nil:
-               encodeBaseType(g, field.TypeEnum.Type, name, 0, "")
+               encodeBaseType(g, field.TypeEnum.Type, fieldGoType(g, field), name, 0, "", false)
        case field.TypeAlias != nil:
                alias := field.TypeAlias
                if typ := alias.TypeStruct; typ != nil {
                        encodeFields(g, typ.Fields, name, lvl+1)
                } else {
-                       encodeBaseType(g, alias.Type, name, alias.Length, "")
+                       if alias.Length > 0 {
+                               encodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
+                       } else {
+                               encodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
+                       }
                }
        case field.TypeStruct != nil:
                encodeFields(g, field.TypeStruct.Fields, name, lvl+1)
        case field.TypeUnion != nil:
-               g.P("buf.EncodeBytes(", name, ".", fieldUnionData, "[:], 0)")
+               maxSize := getUnionSize(field.TypeUnion)
+               g.P("buf.EncodeBytes(", name, ".", fieldUnionData, "[:], ", maxSize, ")")
        default:
                logrus.Panicf("\t// ??? buf[pos] = %s (%s)\n", name, field.Type)
        }
@@ -234,61 +219,52 @@ func encodeField(g *GenFile, field *Field, name string, getFieldName func(name s
        }
 }
 
-func generateMessageMarshal(g *GenFile, name string, fields []*Field) {
-       g.P("func (m *", name, ") Marshal(b []byte) ([]byte, error) {")
-       g.P("var buf *", govppCodecPkg.Ident("Buffer"))
-       g.P("if b == nil {")
-       g.P("buf = ", govppCodecPkg.Ident("NewBuffer"), "(make([]byte, m.Size()))")
-       g.P("} else {")
-       g.P("buf = ", govppCodecPkg.Ident("NewBuffer"), "(b)")
-       g.P("}")
-
-       encodeFields(g, fields, "m", 0)
-
-       g.P("return buf.Bytes(), nil")
-       g.P("}")
-}
-
-func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
+func encodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
        isArray := length > 0 || sizefrom != ""
        if isArray {
                switch typ {
                case U8:
-                       g.P("copy(", name, "[:], buf.DecodeBytes(", length, "))")
+                       if alloc {
+                               g.P("buf.EncodeBytes(", name, ", ", length, ")")
+                       } else {
+                               g.P("buf.EncodeBytes(", name, "[:], ", length, ")")
+                       }
                        return
                case I8, I16, U16, I32, U32, I64, U64, F64:
+                       gotype := BaseTypesGo[typ]
+                       if length != 0 {
+                               g.P("for i := 0; i < ", length, "; i++ {")
+                       } else if sizefrom != "" {
+                               g.P("for i := 0; i < len(", name, "); i++ {")
+                       }
                        if alloc {
-                               var size string
-                               switch {
-                               case length > 0:
-                                       size = strconv.Itoa(length)
-                               case sizefrom != "":
-                                       size = sizefrom
-                               }
-                               if size != "" {
-                                       g.P(name, " = make([]", orig, ", ", size, ")")
-                               }
+                               g.P("var x ", gotype)
+                               g.P("if i < len(", name, ") { x = ", gotype, "(", name, "[i]) }")
+                               name = "x"
                        }
-                       g.P("for i := 0; i < len(", name, "); i++ {")
-                       name = fmt.Sprintf("%s[i]", name)
                }
        }
+       conv := func(s string) string {
+               if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
+                       return fmt.Sprintf("%s(%s)", gotype, s)
+               }
+               return s
+       }
        switch typ {
-       case I8, U8, I16, U16, I32, U32, I64, U64:
+       case I8, I16, I32, I64:
                typsize := BaseTypeSizes[typ]
-               if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig || strings.HasPrefix(orig, "i") {
-                       g.P(name, " = ", orig, "(buf.DecodeUint", typsize*8, "())")
-               } else {
-                       g.P(name, " = buf.DecodeUint", typsize*8, "()")
-               }
+               g.P("buf.EncodeInt", typsize*8, "(", conv(name), ")")
+       case U8, U16, U32, U64:
+               typsize := BaseTypeSizes[typ]
+               g.P("buf.EncodeUint", typsize*8, "(", conv(name), ")")
        case F64:
-               g.P(name, " = ", orig, "(buf.DecodeFloat64())")
+               g.P("buf.EncodeFloat64(", conv(name), ")")
        case BOOL:
-               g.P(name, " = buf.DecodeBool()")
+               g.P("buf.EncodeBool(", name, ")")
        case STRING:
-               g.P(name, " = buf.DecodeString(", length, ")")
+               g.P("buf.EncodeString(", name, ", ", length, ")")
        default:
-               logrus.Panicf("\t// ??? %s %s\n", name, typ)
+               logrus.Panicf("// ??? %s %s\n", name, typ)
        }
        if isArray {
                switch typ {
@@ -298,7 +274,7 @@ func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom str
        }
 }
 
-func generateMessageUnmarshal(g *GenFile, name string, fields []*Field) {
+func genMessageUnmarshal(g *GenFile, name string, fields []*Field) {
        g.P("func (m *", name, ") Unmarshal(b []byte) error {")
 
        if len(fields) > 0 {
@@ -338,7 +314,7 @@ func decodeField(g *GenFile, field *Field, name string, getFieldName func(string
                if field.Length > 0 {
                        g.P("for ", index, " := 0; ", index, " < ", field.Length, ";", index, "++ {")
                } else if field.SizeFrom != "" {
-                       g.P(name, " = make(", getFieldType(g, field), ", int(", sizeFromName, "))")
+                       g.P(name, " = make(", getFieldType(g, field), ", ", sizeFromName, ")")
                        g.P("for ", index, " := 0; ", index, " < len(", name, ");", index, "++ {")
                }
                name = fmt.Sprintf("%s[%s]", name, index)
@@ -357,7 +333,7 @@ func decodeField(g *GenFile, field *Field, name string, getFieldName func(string
                        if alias.Length > 0 {
                                decodeBaseType(g, alias.Type, BaseTypesGo[alias.Type], name, alias.Length, "", false)
                        } else {
-                               decodeBaseType(g, alias.Type, fieldGoType(g, field), name, alias.Length, "", false)
+                               decodeBaseType(g, alias.Type, fieldGoType(g, field), name, 0, "", false)
                        }
                }
        } else if typ := field.TypeStruct; typ != nil {
@@ -373,3 +349,60 @@ func decodeField(g *GenFile, field *Field, name string, getFieldName func(string
                g.P("}")
        }
 }
+
+func decodeBaseType(g *GenFile, typ, orig, name string, length int, sizefrom string, alloc bool) {
+       isArray := length > 0 || sizefrom != ""
+       if isArray {
+               var size string
+               switch {
+               case length > 0:
+                       size = strconv.Itoa(length)
+               case sizefrom != "":
+                       size = sizefrom
+               }
+               switch typ {
+               case U8:
+                       if alloc {
+                               g.P(name, " = make([]byte, ", size, ")")
+                               g.P("copy(", name, ", buf.DecodeBytes(len(", name, ")))")
+                       } else {
+                               g.P("copy(", name, "[:], buf.DecodeBytes(", size, "))")
+                       }
+                       return
+               case I8, I16, U16, I32, U32, I64, U64, F64:
+                       if alloc {
+                               g.P(name, " = make([]", orig, ", ", size, ")")
+                       }
+                       g.P("for i := 0; i < len(", name, "); i++ {")
+                       name = fmt.Sprintf("%s[i]", name)
+               }
+       }
+       conv := func(s string) string {
+               if gotype, ok := BaseTypesGo[typ]; !ok || gotype != orig {
+                       return fmt.Sprintf("%s(%s)", orig, s)
+               }
+               return s
+       }
+       switch typ {
+       case I8, I16, I32, I64:
+               typsize := BaseTypeSizes[typ]
+               g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeInt%d()", typsize*8)))
+       case U8, U16, U32, U64:
+               typsize := BaseTypeSizes[typ]
+               g.P(name, " = ", conv(fmt.Sprintf("buf.DecodeUint%d()", typsize*8)))
+       case F64:
+               g.P(name, " = ", conv("buf.DecodeFloat64()"))
+       case BOOL:
+               g.P(name, " = buf.DecodeBool()")
+       case STRING:
+               g.P(name, " = buf.DecodeString(", length, ")")
+       default:
+               logrus.Panicf("\t// ??? %s %s\n", name, typ)
+       }
+       if isArray {
+               switch typ {
+               case I8, U8, I16, U16, I32, U32, I64, U64, F64:
+                       g.P("}")
+               }
+       }
+}