Fix codec fallback and generate type imports
[govpp.git] / binapigen / generate.go
index 1f9b89a..8a34445 100644 (file)
 package binapigen
 
 import (
-       "bytes"
        "fmt"
        "io"
-       "os/exec"
-       "path"
-       "path/filepath"
        "sort"
        "strings"
 
        "git.fd.io/govpp.git/version"
+       "github.com/sirupsen/logrus"
 )
 
 // generatedCodeVersion indicates a version of the generated code.
@@ -33,7 +30,7 @@ import (
 // a constant, api.GoVppAPIPackageIsVersionN (where N is generatedCodeVersion).
 const generatedCodeVersion = 2
 
-// message field names
+// common message fields
 const (
        msgIdField       = "_vl_msg_id"
        clientIndexField = "client_index"
@@ -41,23 +38,16 @@ const (
        retvalField      = "retval"
 )
 
+// global API info
 const (
-       outputFileExt = ".ba.go" // file extension of the Go generated files
-       rpcFileSuffix = "_rpc"   // file name suffix for the RPC services
-
        constModuleName = "ModuleName" // module name constant
        constAPIVersion = "APIVersion" // API version constant
        constVersionCrc = "VersionCrc" // version CRC constant
+)
 
+// generated fiels
+const (
        unionDataField = "XXX_UnionData" // name for the union data field
-
-       serviceApiName    = "RPCService"    // name for the RPC service interface
-       serviceImplName   = "serviceClient" // name for the RPC service implementation
-       serviceClientName = "ServiceClient" // name for the RPC service client
-
-       // TODO: register service descriptor
-       //serviceDescType = "ServiceDesc"             // name for service descriptor type
-       //serviceDescName = "_ServiceRPC_serviceDesc" // name for service descriptor var
 )
 
 // MessageType represents the type of a VPP message
@@ -70,22 +60,90 @@ const (
        otherMessage                      // other VPP message
 )
 
-type GenFile struct {
-       *Generator
-       filename   string
-       file       *File
-       packageDir string
-       buf        bytes.Buffer
-}
-
-func generatePackage(ctx *GenFile, w io.Writer) {
+func generateFileBinapi(ctx *GenFile, w io.Writer) {
        logf("----------------------------")
-       logf("generating binapi package: %q", ctx.file.PackageName)
+       logf("generating BINAPI file package: %q", ctx.file.PackageName)
        logf("----------------------------")
 
-       generateHeader(ctx, w)
+       // generate file header
+       fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.")
+       fmt.Fprintln(w, "// versions:")
+       fmt.Fprintf(w, "//  binapi-generator: %s\n", version.Version())
+       if ctx.IncludeVppVersion {
+               fmt.Fprintf(w, "//  VPP:              %s\n", ctx.VPPVersion)
+       }
+       fmt.Fprintf(w, "// source: %s\n", ctx.file.Path)
+       fmt.Fprintln(w)
+
+       generatePackageHeader(ctx, w)
        generateImports(ctx, w)
 
+       generateApiInfo(ctx, w)
+       generateTypes(ctx, w)
+       generateMessages(ctx, w)
+
+       generateImportRefs(ctx, w)
+}
+
+func generatePackageHeader(ctx *GenFile, w io.Writer) {
+       fmt.Fprintln(w, "/*")
+       fmt.Fprintf(w, "Package %s contains generated code for VPP API file %s.api (%s).\n",
+               ctx.file.PackageName, ctx.file.Name, ctx.file.Version())
+       fmt.Fprintln(w)
+       fmt.Fprintln(w, "It consists of:")
+       printObjNum := func(obj string, num int) {
+               if num > 0 {
+                       if num > 1 {
+                               if strings.HasSuffix(obj, "s") {
+                                       obj += "es"
+                               } else {
+                                       obj += "s"
+                               }
+                       }
+                       fmt.Fprintf(w, "\t%3d %s\n", num, obj)
+               }
+       }
+       printObjNum("alias", len(ctx.file.Aliases))
+       printObjNum("enum", len(ctx.file.Enums))
+       printObjNum("message", len(ctx.file.Messages))
+       printObjNum("type", len(ctx.file.Structs))
+       printObjNum("union", len(ctx.file.Unions))
+       fmt.Fprintln(w, "*/")
+       fmt.Fprintf(w, "package %s\n", ctx.file.PackageName)
+       fmt.Fprintln(w)
+}
+
+func generateImports(ctx *GenFile, w io.Writer) {
+       fmt.Fprintln(w, "import (")
+       fmt.Fprintln(w, `       "bytes"`)
+       fmt.Fprintln(w, `       "context"`)
+       fmt.Fprintln(w, `       "encoding/binary"`)
+       fmt.Fprintln(w, `       "io"`)
+       fmt.Fprintln(w, `       "math"`)
+       fmt.Fprintln(w, `       "strconv"`)
+       fmt.Fprintln(w)
+       fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api")
+       fmt.Fprintf(w, "\tcodec \"%s\"\n", "git.fd.io/govpp.git/codec")
+       fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc")
+       imports := listImports(ctx)
+       if len(imports) > 0 {
+               fmt.Fprintln(w)
+               for imp, importPath := range imports {
+                       fmt.Fprintf(w, "\t%s \"%s\"\n", imp, importPath)
+               }
+       }
+       fmt.Fprintln(w, ")")
+       fmt.Fprintln(w)
+
+       fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
+       fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
+       fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the")
+       fmt.Fprintln(w, "// GoVPP api package needs to be updated.")
+       fmt.Fprintf(w, "const _ = api.GoVppAPIPackageIsVersion%d // please upgrade the GoVPP api package\n", generatedCodeVersion)
+       fmt.Fprintln(w)
+}
+
+func generateApiInfo(ctx *GenFile, w io.Writer) {
        // generate module desc
        fmt.Fprintln(w, "const (")
        fmt.Fprintf(w, "\t// %s is the name of this module.\n", constModuleName)
@@ -99,7 +157,9 @@ func generatePackage(ctx *GenFile, w io.Writer) {
        }
        fmt.Fprintln(w, ")")
        fmt.Fprintln(w)
+}
 
+func generateTypes(ctx *GenFile, w io.Writer) {
        // generate enums
        if len(ctx.file.Enums) > 0 {
                for _, enum := range ctx.file.Enums {
@@ -143,129 +203,41 @@ func generatePackage(ctx *GenFile, w io.Writer) {
                        generateUnion(ctx, w, union)
                }
        }
-
-       // generate messages
-       if len(ctx.file.Messages) > 0 {
-               for _, msg := range ctx.file.Messages {
-                       generateMessage(ctx, w, msg)
-               }
-
-               initFnName := fmt.Sprintf("file_%s_binapi_init", ctx.file.PackageName)
-
-               // generate message registrations
-               fmt.Fprintf(w, "func init() { %s() }\n", initFnName)
-               fmt.Fprintf(w, "func %s() {\n", initFnName)
-               for _, msg := range ctx.file.Messages {
-                       fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n",
-                               msg.GoName, ctx.file.Name+"."+msg.GoName)
-               }
-               fmt.Fprintln(w, "}")
-               fmt.Fprintln(w)
-
-               // generate list of messages
-               fmt.Fprintf(w, "// Messages returns list of all messages in this module.\n")
-               fmt.Fprintln(w, "func AllMessages() []api.Message {")
-               fmt.Fprintln(w, "\treturn []api.Message{")
-               for _, msg := range ctx.file.Messages {
-                       fmt.Fprintf(w, "\t(*%s)(nil),\n", msg.GoName)
-               }
-               fmt.Fprintln(w, "}")
-               fmt.Fprintln(w, "}")
-       }
-
-       generateFooter(ctx, w)
-
 }
 
-func generateHeader(ctx *GenFile, w io.Writer) {
-       fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.")
-       fmt.Fprintln(w, "// versions:")
-       fmt.Fprintf(w, "//  binapi-generator: %s\n", version.Version())
-       if ctx.IncludeVppVersion {
-               fmt.Fprintf(w, "//  VPP:              %s\n", ctx.VPPVersion)
+func generateMessages(ctx *GenFile, w io.Writer) {
+       if len(ctx.file.Messages) == 0 {
+               return
        }
-       fmt.Fprintf(w, "// source: %s\n", ctx.file.Path)
-       fmt.Fprintln(w)
-
-       fmt.Fprintln(w, "/*")
-       fmt.Fprintf(w, "Package %s contains generated code for VPP binary API defined by %s.api (version %s).\n",
-               ctx.file.PackageName, ctx.file.Name, ctx.file.Version())
-       fmt.Fprintln(w)
-       fmt.Fprintln(w, "It consists of:")
-       printObjNum := func(obj string, num int) {
-               if num > 0 {
-                       if num > 1 {
-                               if strings.HasSuffix(obj, "s") {
 
-                                       obj += "es"
-                               } else {
-                                       obj += "s"
-                               }
-                       }
-                       fmt.Fprintf(w, "\t%3d %s\n", num, obj)
-               }
+       for _, msg := range ctx.file.Messages {
+               generateMessage(ctx, w, msg)
        }
-       //printObjNum("RPC", len(ctx.file.Service.RPCs))
-       printObjNum("alias", len(ctx.file.Aliases))
-       printObjNum("enum", len(ctx.file.Enums))
-       printObjNum("message", len(ctx.file.Messages))
-       printObjNum("type", len(ctx.file.Structs))
-       printObjNum("union", len(ctx.file.Unions))
-       fmt.Fprintln(w, "*/")
-       fmt.Fprintf(w, "package %s\n", ctx.file.PackageName)
-       fmt.Fprintln(w)
-}
 
-func generateImports(ctx *GenFile, w io.Writer) {
-       fmt.Fprintln(w, "import (")
-       fmt.Fprintln(w, `       "bytes"`)
-       fmt.Fprintln(w, `       "context"`)
-       fmt.Fprintln(w, `       "encoding/binary"`)
-       fmt.Fprintln(w, `       "io"`)
-       fmt.Fprintln(w, `       "math"`)
-       fmt.Fprintln(w, `       "strconv"`)
-       fmt.Fprintln(w)
-       fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api")
-       fmt.Fprintf(w, "\tcodec \"%s\"\n", "git.fd.io/govpp.git/codec")
-       fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc")
-       if len(ctx.file.Imports) > 0 {
-               fmt.Fprintln(w)
-               for _, imp := range ctx.file.Imports {
-                       importPath := path.Join(ctx.ImportPrefix, imp)
-                       if ctx.ImportPrefix == "" {
-                               importPath = getImportPath(ctx.packageDir, imp)
-                       }
-                       fmt.Fprintf(w, "\t%s \"%s\"\n", imp, strings.TrimSpace(importPath))
-               }
-       }
-       fmt.Fprintln(w, ")")
-       fmt.Fprintln(w)
+       // generate message registrations
+       initFnName := fmt.Sprintf("file_%s_binapi_init", ctx.file.PackageName)
 
-       fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
-       fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
-       fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the")
-       fmt.Fprintln(w, "// GoVPP api package needs to be updated.")
-       fmt.Fprintf(w, "const _ = api.GoVppAPIPackageIsVersion%d // please upgrade the GoVPP api package\n", generatedCodeVersion)
+       fmt.Fprintf(w, "func init() { %s() }\n", initFnName)
+       fmt.Fprintf(w, "func %s() {\n", initFnName)
+       for _, msg := range ctx.file.Messages {
+               fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n",
+                       msg.GoName, ctx.file.Name+"."+msg.GoName)
+       }
+       fmt.Fprintln(w, "}")
        fmt.Fprintln(w)
-}
 
-func getImportPath(outputDir string, pkg string) string {
-       absPath, err := filepath.Abs(filepath.Join(outputDir, "..", pkg))
-       if err != nil {
-               panic(err)
-       }
-       cmd := exec.Command("go", "list", absPath)
-       var errbuf, outbuf bytes.Buffer
-       cmd.Stdout = &outbuf
-       cmd.Stderr = &errbuf
-       if err := cmd.Run(); err != nil {
-               fmt.Printf("ERR: %v\n", errbuf.String())
-               panic(err)
+       // generate list of messages
+       fmt.Fprintf(w, "// Messages returns list of all messages in this module.\n")
+       fmt.Fprintln(w, "func AllMessages() []api.Message {")
+       fmt.Fprintln(w, "\treturn []api.Message{")
+       for _, msg := range ctx.file.Messages {
+               fmt.Fprintf(w, "\t(*%s)(nil),\n", msg.GoName)
        }
-       return outbuf.String()
+       fmt.Fprintln(w, "}")
+       fmt.Fprintln(w, "}")
 }
 
-func generateFooter(ctx *GenFile, w io.Writer) {
+func generateImportRefs(ctx *GenFile, w io.Writer) {
        fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n")
        fmt.Fprintf(w, "var _ = api.RegisterMessage\n")
        fmt.Fprintf(w, "var _ = codec.DecodeString\n")
@@ -522,7 +494,7 @@ func generateMessage(ctx *GenFile, w io.Writer, msg *Message) {
 
                // skip internal fields
                switch strings.ToLower(field.Name) {
-               case /*crcField,*/ msgIdField:
+               case msgIdField:
                        continue
                case clientIndexField, contextField:
                        if n == 0 {
@@ -590,20 +562,22 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field
        }
 
        lvl := 0
-       var encodeFields func(fields []*Field, parentName string)
-       encodeFields = func(fields []*Field, parentName string) {
+       var sizeFields func(fields []*Field, parentName string)
+       sizeFields = func(fields []*Field, parentName string) {
                lvl++
                defer func() { lvl-- }()
 
                n := 0
                for _, field := range fields {
-                       // skip internal fields
-                       switch strings.ToLower(field.Name) {
-                       case /*crcField,*/ msgIdField:
-                               continue
-                       case clientIndexField, contextField:
-                               if n == 0 {
+                       if field.ParentMessage != nil {
+                               // skip internal fields
+                               switch strings.ToLower(field.Name) {
+                               case msgIdField:
                                        continue
+                               case clientIndexField, contextField:
+                                       if n == 0 {
+                                               continue
+                                       }
                                }
                        }
                        n++
@@ -646,12 +620,12 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field
                        } else if alias := getAliasByRef(ctx.file, field.Type); alias != nil {
                                if encodeBaseType(alias.Type, name, alias.Length, "") {
                                } else if typ := getTypeByRef(ctx.file, alias.Type); typ != nil {
-                                       encodeFields(typ.Fields, name)
+                                       sizeFields(typ.Fields, name)
                                } else {
                                        fmt.Fprintf(w, "\t// ??? ALIAS %s %s\n", name, alias.Type)
                                }
                        } else if typ := getTypeByRef(ctx.file, field.Type); typ != nil {
-                               encodeFields(typ.Fields, name)
+                               sizeFields(typ.Fields, name)
                        } else if union := getUnionByRef(ctx.file, field.Type); union != nil {
                                maxSize := getUnionSize(ctx.file, union)
                                fmt.Fprintf(w, "\tsize += %d\n", maxSize)
@@ -665,7 +639,7 @@ func generateMessageSize(ctx *GenFile, w io.Writer, name string, fields []*Field
                }
        }
 
-       encodeFields(fields, "m")
+       sizeFields(fields, "m")
 
        fmt.Fprintf(w, "return size\n")
 
@@ -786,13 +760,15 @@ func generateMessageMarshal(ctx *GenFile, w io.Writer, name string, fields []*Fi
 
                n := 0
                for _, field := range fields {
-                       // skip internal fields
-                       switch strings.ToLower(field.Name) {
-                       case /*crcField,*/ msgIdField:
-                               continue
-                       case clientIndexField, contextField:
-                               if n == 0 {
+                       if field.ParentMessage != nil {
+                               // skip internal fields
+                               switch strings.ToLower(field.Name) {
+                               case msgIdField:
                                        continue
+                               case clientIndexField, contextField:
+                                       if n == 0 {
+                                               continue
+                                       }
                                }
                        }
                        n++
@@ -1004,13 +980,15 @@ func generateMessageUnmarshal(ctx *GenFile, w io.Writer, name string, fields []*
 
                n := 0
                for _, field := range fields {
-                       // skip internal fields
-                       switch strings.ToLower(field.Name) {
-                       case /*crcField,*/ msgIdField:
-                               continue
-                       case clientIndexField, contextField:
-                               if n == 0 {
+                       if field.ParentMessage != nil {
+                               // skip internal fields
+                               switch strings.ToLower(field.Name) {
+                               case msgIdField:
                                        continue
+                               case clientIndexField, contextField:
+                                       if n == 0 {
+                                               continue
+                                       }
                                }
                        }
                        n++
@@ -1239,3 +1217,7 @@ func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageTy
        fmt.Fprintln(w, "}")
        fmt.Fprintln(w)
 }
+
+func logf(f string, v ...interface{}) {
+       logrus.Debugf(f, v...)
+}