Binary API generator improvements
[govpp.git] / binapigen / generate.go
index 8a34445..d35427f 100644 (file)
@@ -118,9 +118,12 @@ func generateImports(ctx *GenFile, w io.Writer) {
        fmt.Fprintln(w, `       "bytes"`)
        fmt.Fprintln(w, `       "context"`)
        fmt.Fprintln(w, `       "encoding/binary"`)
+       fmt.Fprintln(w, `       "fmt"`)
        fmt.Fprintln(w, `       "io"`)
        fmt.Fprintln(w, `       "math"`)
+       fmt.Fprintln(w, `       "net"`)
        fmt.Fprintln(w, `       "strconv"`)
+       fmt.Fprintln(w, `       "strings"`)
        fmt.Fprintln(w)
        fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api")
        fmt.Fprintf(w, "\tcodec \"%s\"\n", "git.fd.io/govpp.git/codec")
@@ -164,7 +167,9 @@ func generateTypes(ctx *GenFile, w io.Writer) {
        if len(ctx.file.Enums) > 0 {
                for _, enum := range ctx.file.Enums {
                        if imp, ok := ctx.file.imports[enum.Name]; ok {
-                               generateImportedAlias(ctx, w, enum.GoName, imp)
+                               if strings.HasSuffix(ctx.file.Name, "_types") {
+                                       generateImportedAlias(ctx, w, enum.GoName, imp)
+                               }
                                continue
                        }
                        generateEnum(ctx, w, enum)
@@ -175,7 +180,9 @@ func generateTypes(ctx *GenFile, w io.Writer) {
        if len(ctx.file.Aliases) > 0 {
                for _, alias := range ctx.file.Aliases {
                        if imp, ok := ctx.file.imports[alias.Name]; ok {
-                               generateImportedAlias(ctx, w, alias.GoName, imp)
+                               if strings.HasSuffix(ctx.file.Name, "_types") {
+                                       generateImportedAlias(ctx, w, alias.GoName, imp)
+                               }
                                continue
                        }
                        generateAlias(ctx, w, alias)
@@ -186,7 +193,9 @@ func generateTypes(ctx *GenFile, w io.Writer) {
        if len(ctx.file.Structs) > 0 {
                for _, typ := range ctx.file.Structs {
                        if imp, ok := ctx.file.imports[typ.Name]; ok {
-                               generateImportedAlias(ctx, w, typ.GoName, imp)
+                               if strings.HasSuffix(ctx.file.Name, "_types") {
+                                       generateImportedAlias(ctx, w, typ.GoName, imp)
+                               }
                                continue
                        }
                        generateStruct(ctx, w, typ)
@@ -197,7 +206,9 @@ func generateTypes(ctx *GenFile, w io.Writer) {
        if len(ctx.file.Unions) > 0 {
                for _, union := range ctx.file.Unions {
                        if imp, ok := ctx.file.imports[union.Name]; ok {
-                               generateImportedAlias(ctx, w, union.GoName, imp)
+                               if strings.HasSuffix(ctx.file.Name, "_types") {
+                                       generateImportedAlias(ctx, w, union.GoName, imp)
+                               }
                                continue
                        }
                        generateUnion(ctx, w, union)
@@ -245,9 +256,12 @@ func generateImportRefs(ctx *GenFile, w io.Writer) {
        fmt.Fprintf(w, "var _ = context.Background\n")
        fmt.Fprintf(w, "var _ = io.Copy\n")
        fmt.Fprintf(w, "var _ = strconv.Itoa\n")
+       fmt.Fprintf(w, "var _ = strings.Contains\n")
        fmt.Fprintf(w, "var _ = struc.Pack\n")
        fmt.Fprintf(w, "var _ = binary.BigEndian\n")
        fmt.Fprintf(w, "var _ = math.Float32bits\n")
+       fmt.Fprintf(w, "var _ = net.ParseIP\n")
+       fmt.Fprintf(w, "var _ = fmt.Errorf\n")
 }
 
 func generateComment(ctx *GenFile, w io.Writer, goName string, vppName string, objKind string) {
@@ -325,6 +339,13 @@ func generateAlias(ctx *GenFile, w io.Writer, alias *Alias) {
        dataType := convertToGoType(ctx.file, alias.Type)
        fmt.Fprintf(w, "%s\n", dataType)
 
+       // generate alias-specific methods
+       switch alias.Name {
+       case "mac_address":
+               fmt.Fprintln(w)
+               generateMacAddressConversion(w, name)
+       }
+
        fmt.Fprintln(w)
 }
 
@@ -356,6 +377,16 @@ func generateStruct(ctx *GenFile, w io.Writer, typ *Struct) {
        // generate name getter
        generateTypeNameGetter(w, name, typ.Name)
 
+       // generate type-specific methods
+       switch typ.Name {
+       case "address":
+               fmt.Fprintln(w)
+               generateIPAddressConversion(w, name)
+       case "prefix":
+               fmt.Fprintln(w)
+               generatePrefixConversion(w, name)
+       }
+
        fmt.Fprintln(w)
 }
 
@@ -1198,6 +1229,105 @@ func generateTypeNameGetter(w io.Writer, structName, msgName string) {
        fmt.Fprintf(w, "func (*%s) GetTypeName() string { return %q }\n", structName, msgName)
 }
 
+func generateIPAddressConversion(w io.Writer, structName string) {
+       f1 := func(ipVer, ipVerExt int) string {
+               return fmt.Sprintf(`address.Af = ADDRESS_IP%[1]d
+               var ip%[1]daddr IP%[1]dAddress
+               copy(ip%[1]daddr[:], netIP.To%[2]d())
+               address.Un.SetIP%[1]d(ip%[1]daddr)`, ipVer, ipVerExt)
+       }
+       f2 := func(ipVer, ipVerExt int) string {
+               return fmt.Sprintf("ip%[1]dAddress := a.Un.GetIP%[1]d()\nip = net.IP(ip%[1]dAddress[:]).To%[2]d().String()",
+                       ipVer, ipVerExt)
+       }
+       // IP to Address
+       fmt.Fprintf(w, `func ParseAddress(ip string) (%[1]s, error) {
+       var address %[1]s
+       netIP := net.ParseIP(ip)
+       if netIP == nil {
+               return address, fmt.Errorf("invalid address: %[2]s", ip)
+       }
+       if ip4 := netIP.To4(); ip4 == nil {
+               %[3]s
+       } else {
+               %[4]s
+       }
+       return address, nil
+}
+`, structName, "%s", f1(6, 16), f1(4, 4))
+       fmt.Fprintln(w)
+
+       // Address to IP
+       fmt.Fprintln(w)
+       fmt.Fprintf(w, `func (a *%[1]s) ToString() string {
+       var ip string
+       if a.Af == ADDRESS_IP6 {
+               %[2]s
+       } else {
+               %[3]s
+       }
+       return ip
+}`, structName, f2(6, 16), f2(4, 4))
+}
+
+func generatePrefixConversion(w io.Writer, structName string) {
+       fErr := func() string {
+               return fmt.Sprintf(`if err != nil {
+                       return Prefix{}, fmt.Errorf("invalid IP %s: %s", ip, err)
+               }`, "%s", "%v")
+       }
+
+       // IP to Prefix
+       fmt.Fprintf(w, `func ParsePrefix(ip string) (prefix %[1]s, err error) {
+       hasPrefix := strings.Contains(ip, "/")
+       if hasPrefix {
+               netIP, network, err := net.ParseCIDR(ip)
+               %[2]s
+       maskSize, _ := network.Mask.Size()
+       prefix.Len = byte(maskSize)
+       prefix.Address, err = ParseAddress(netIP.String())
+               %[2]s
+       } else {
+               netIP := net.ParseIP(ip)
+               defaultMaskSize, _ := net.CIDRMask(32, 32).Size()
+               if netIP.To4() == nil {
+                       defaultMaskSize, _ = net.CIDRMask(128, 128).Size()
+               }
+               prefix.Len = byte(defaultMaskSize)
+               prefix.Address, err = ParseAddress(netIP.String())
+               %[2]s
+       }
+       return prefix, nil
+}`, structName, fErr(), nil)
+       fmt.Fprintln(w)
+
+       // Prefix to IP
+       fmt.Fprintln(w)
+       fmt.Fprintf(w, `func (p *%[1]s) ToString() string {
+               ip := p.Address.ToString()
+               return ip + "/" + strconv.Itoa(int(p.Len))
+       }`, structName)
+}
+
+func generateMacAddressConversion(w io.Writer, structName string) {
+       // string to MAC
+       fmt.Fprintf(w, `func ParseMAC(mac string) (parsed %[1]s, err error) {
+       var hw net.HardwareAddr
+       if hw, err = net.ParseMAC(mac); err != nil {
+               return
+       }
+       copy(parsed[:], hw[:])
+       return
+}`, structName)
+       fmt.Fprintln(w)
+
+       // MAC to string
+       fmt.Fprintln(w)
+       fmt.Fprintf(w, `func (m *%[1]s) ToString() string {
+               return net.HardwareAddr(m[:]).String()
+       }`, structName)
+}
+
 func generateCrcGetter(w io.Writer, structName, crc string) {
        crc = strings.TrimPrefix(crc, "0x")
        fmt.Fprintf(w, "func (*%s) GetCrcString() string { return %q }\n", structName, crc)