Add various generator improvements
[govpp.git] / cmd / binapi-generator / generate.go
index 48c3a41..d9555e7 100644 (file)
 package main
 
 import (
-       "bufio"
        "bytes"
        "fmt"
        "io"
        "path/filepath"
+       "sort"
        "strings"
        "unicode"
 )
 
+// generatedCodeVersion indicates a version of the generated code.
+// It is incremented whenever an incompatibility between the generated code and
+// GoVPP api package is introduced; the generated code references
+// a constant, api.GoVppAPIPackageIsVersionN (where N is generatedCodeVersion).
+const generatedCodeVersion = 1
+
 const (
+       inputFileExt  = ".api.json" // file extension of the VPP API files
+       outputFileExt = ".ba.go"    // file extension of the Go generated files
+
        govppApiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API package
-       inputFileExt       = ".api.json"               // file extension of the VPP binary API files
-       outputFileExt      = ".ba.go"                  // file extension of the Go generated files
+
+       constModuleName = "ModuleName" // module name constant
+       constAPIVersion = "APIVersion" // API version constant
+       constVersionCrc = "VersionCrc" // version CRC constant
+
+       unionDataField = "XXX_UnionData" // name for the union data field
 )
 
 // context is a structure storing data for code generation
@@ -37,6 +50,11 @@ type context struct {
 
        inputData []byte // contents of the input file
 
+       includeAPIVersion  bool // include constant with API version string
+       includeComments    bool // include parts of original source in comments
+       includeBinapiNames bool // include binary API names as struct tag
+       includeServices    bool // include service interface with client implementation
+
        moduleName  string // name of the source VPP module
        packageName string // name of the Go package being generated
 
@@ -76,26 +94,28 @@ func getContext(inputFile, outputDir string) (*context, error) {
 }
 
 // generatePackage generates code for the parsed package data and writes it into w
-func generatePackage(ctx *context, w *bufio.Writer) error {
+func generatePackage(ctx *context, w io.Writer) error {
        logf("generating package %q", ctx.packageName)
 
        // generate file header
        generateHeader(ctx, w)
        generateImports(ctx, w)
 
-       if *includeAPIVer {
-               const APIVerConstName = "VlAPIVersion"
-               fmt.Fprintf(w, "// %s represents version of the binary API module.\n", APIVerConstName)
-               fmt.Fprintf(w, "const %s = %v\n", APIVerConstName, ctx.packageData.APIVersion)
-               fmt.Fprintln(w)
-       }
-
-       // generate services
-       if len(ctx.packageData.Services) > 0 {
-               generateServices(ctx, w, ctx.packageData.Services)
+       // generate module desc
+       fmt.Fprintln(w, "const (")
+       fmt.Fprintf(w, "\t// %s is the name of this module.\n", constModuleName)
+       fmt.Fprintf(w, "\t%s = \"%s\"\n", constModuleName, ctx.moduleName)
 
-               // TODO: generate implementation for Services interface
+       if ctx.includeAPIVersion {
+               if ctx.packageData.Version != "" {
+                       fmt.Fprintf(w, "\t// %s is the API version of this module.\n", constAPIVersion)
+                       fmt.Fprintf(w, "\t%s = \"%s\"\n", constAPIVersion, ctx.packageData.Version)
+               }
+               fmt.Fprintf(w, "\t// %s is the CRC of this module.\n", constVersionCrc)
+               fmt.Fprintf(w, "\t%s = %v\n", constVersionCrc, ctx.packageData.CRC)
        }
+       fmt.Fprintln(w, ")")
+       fmt.Fprintln(w)
 
        // generate enums
        if len(ctx.packageData.Enums) > 0 {
@@ -148,11 +168,25 @@ func generatePackage(ctx *context, w *bufio.Writer) error {
                        fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", name, ctx.moduleName+"."+name)
                }
                fmt.Fprintln(w, "}")
+               fmt.Fprintln(w)
+
+               // generate list of messages
+               fmt.Fprintf(w, "// Messages returns list of all messages in this module.\n")
+               fmt.Fprintln(w, "func AllMessages() []api.Message {")
+               fmt.Fprintln(w, "\treturn []api.Message{")
+               for _, msg := range ctx.packageData.Messages {
+                       name := camelCaseName(msg.Name)
+                       fmt.Fprintf(w, "\t(*%s)(nil),\n", name)
+               }
+               fmt.Fprintln(w, "}")
+               fmt.Fprintln(w, "}")
        }
 
-       // flush the data:
-       if err := w.Flush(); err != nil {
-               return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
+       if ctx.includeServices {
+               // generate services
+               if len(ctx.packageData.Services) > 0 {
+                       generateServices(ctx, w, ctx.packageData.Services)
+               }
        }
 
        return nil
@@ -161,17 +195,18 @@ func generatePackage(ctx *context, w *bufio.Writer) error {
 // generateHeader writes generated package header into w
 func generateHeader(ctx *context, w io.Writer) {
        fmt.Fprintln(w, "// Code generated by GoVPP binapi-generator. DO NOT EDIT.")
-       fmt.Fprintf(w, "//  source: %s\n", ctx.inputFile)
+       fmt.Fprintf(w, "// source: %s\n", ctx.inputFile)
        fmt.Fprintln(w)
 
        fmt.Fprintln(w, "/*")
-       fmt.Fprintf(w, " Package %s is a generated from VPP binary API module '%s'.\n", ctx.packageName, ctx.moduleName)
+       fmt.Fprintf(w, "Package %s is a generated from VPP binary API module '%s'.\n", ctx.packageName, ctx.moduleName)
        fmt.Fprintln(w)
-       fmt.Fprintln(w, " It contains following objects:")
+       fmt.Fprintf(w, " The %s module consists of:\n", ctx.moduleName)
        var printObjNum = func(obj string, num int) {
                if num > 0 {
                        if num > 1 {
                                if strings.HasSuffix(obj, "s") {
+
                                        obj += "es"
                                } else {
                                        obj += "s"
@@ -181,39 +216,55 @@ func generateHeader(ctx *context, w io.Writer) {
                }
        }
 
-       printObjNum("service", len(ctx.packageData.Services))
        printObjNum("enum", len(ctx.packageData.Enums))
        printObjNum("alias", len(ctx.packageData.Aliases))
        printObjNum("type", len(ctx.packageData.Types))
        printObjNum("union", len(ctx.packageData.Unions))
        printObjNum("message", len(ctx.packageData.Messages))
+       printObjNum("service", len(ctx.packageData.Services))
        fmt.Fprintln(w, "*/")
+
        fmt.Fprintf(w, "package %s\n", ctx.packageName)
        fmt.Fprintln(w)
 }
 
 // generateImports writes generated package imports into w
 func generateImports(ctx *context, w io.Writer) {
-       fmt.Fprintf(w, "import \"%s\"\n", govppApiImportPath)
-       fmt.Fprintf(w, "import \"%s\"\n", "github.com/lunixbochs/struc")
-       fmt.Fprintf(w, "import \"%s\"\n", "bytes")
+       fmt.Fprintf(w, "import api \"%s\"\n", govppApiImportPath)
+       fmt.Fprintf(w, "import bytes \"%s\"\n", "bytes")
+       fmt.Fprintf(w, "import context \"%s\"\n", "context")
+       fmt.Fprintf(w, "import strconv \"%s\"\n", "strconv")
+       fmt.Fprintf(w, "import struc \"%s\"\n", "github.com/lunixbochs/struc")
        fmt.Fprintln(w)
 
        fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n")
        fmt.Fprintf(w, "var _ = api.RegisterMessage\n")
-       fmt.Fprintf(w, "var _ = struc.Pack\n")
        fmt.Fprintf(w, "var _ = bytes.NewBuffer\n")
+       fmt.Fprintf(w, "var _ = context.Background\n")
+       fmt.Fprintf(w, "var _ = strconv.Itoa\n")
+       fmt.Fprintf(w, "var _ = struc.Pack\n")
+       fmt.Fprintln(w)
+
+       fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
+       fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
+       fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the")
+       fmt.Fprintln(w, "// GoVPP api package needs to be updated.")
+       fmt.Fprintf(w, "const _ = api.GoVppAPIPackageIsVersion%d // please upgrade the GoVPP api package\n", generatedCodeVersion)
        fmt.Fprintln(w)
 }
 
 // generateComment writes generated comment for the object into w
 func generateComment(ctx *context, w io.Writer, goName string, vppName string, objKind string) {
        if objKind == "service" {
-               fmt.Fprintf(w, "// %s represents VPP binary API services:\n", goName)
+               fmt.Fprintf(w, "// %s represents VPP binary API services in %s module.\n", ctx.moduleName, goName)
        } else {
                fmt.Fprintf(w, "// %s represents VPP binary API %s '%s':\n", goName, objKind, vppName)
        }
 
+       if !ctx.includeComments {
+               return
+       }
+
        var isNotSpace = func(r rune) bool {
                return !unicode.IsSpace(r)
        }
@@ -265,37 +316,93 @@ func generateComment(ctx *context, w io.Writer, goName string, vppName string, o
 }
 
 // generateServices writes generated code for the Services interface into w
-func generateServices(ctx *context, w *bufio.Writer, services []Service) {
+func generateServices(ctx *context, w io.Writer, services []Service) {
+       const apiName = "Service"
+       const implName = "service"
+
        // generate services comment
-       generateComment(ctx, w, "Services", "services", "service")
+       generateComment(ctx, w, apiName, "services", "service")
 
        // generate interface
-       fmt.Fprintf(w, "type %s interface {\n", "Services")
-       for _, svc := range ctx.packageData.Services {
-               generateService(ctx, w, &svc)
+       fmt.Fprintf(w, "type %s interface {\n", apiName)
+       for _, svc := range services {
+               generateServiceMethod(ctx, w, &svc)
+               fmt.Fprintln(w)
        }
        fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
+
+       // generate client implementation
+       fmt.Fprintf(w, "type %s struct {\n", implName)
+       fmt.Fprintf(w, "\tch api.Channel\n")
+       fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
+
+       fmt.Fprintf(w, "func New%[1]s(ch api.Channel) %[1]s {\n", apiName)
+       fmt.Fprintf(w, "\treturn &%s{ch}\n", implName)
+       fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
+
+       for _, svc := range services {
+               fmt.Fprintf(w, "func (c *%s) ", implName)
+               generateServiceMethod(ctx, w, &svc)
+               fmt.Fprintln(w, " {")
+               if svc.Stream {
+                       // TODO: stream responses
+                       //fmt.Fprintf(w, "\tstream := make(chan *%s)\n", camelCaseName(svc.ReplyType))
+                       replyTyp := camelCaseName(svc.ReplyType)
+                       fmt.Fprintf(w, "\tvar dump []*%s\n", replyTyp)
+                       fmt.Fprintf(w, "\treq := c.ch.SendMultiRequest(in)\n")
+                       fmt.Fprintf(w, "\tfor {\n")
+                       fmt.Fprintf(w, "\tm := new(%s)\n", replyTyp)
+                       fmt.Fprintf(w, "\tstop, err := req.ReceiveReply(m)\n")
+                       fmt.Fprintf(w, "\tif stop { break }\n")
+                       fmt.Fprintf(w, "\tif err != nil { return nil, err }\n")
+                       fmt.Fprintf(w, "\tdump = append(dump, m)\n")
+                       fmt.Fprintln(w, "}")
+                       fmt.Fprintf(w, "\treturn dump, nil\n")
+               } else if replyTyp := camelCaseName(svc.ReplyType); replyTyp != "" {
+                       fmt.Fprintf(w, "\tout := new(%s)\n", replyTyp)
+                       fmt.Fprintf(w, "\terr:= c.ch.SendRequest(in).ReceiveReply(out)\n")
+                       fmt.Fprintf(w, "\tif err != nil { return nil, err }\n")
+                       fmt.Fprintf(w, "\treturn out, nil\n")
+               } else {
+                       fmt.Fprintf(w, "\tc.ch.SendRequest(in)\n")
+                       fmt.Fprintf(w, "\treturn nil\n")
+               }
+               fmt.Fprintln(w, "}")
+               fmt.Fprintln(w)
+       }
 
        fmt.Fprintln(w)
 }
 
-// generateService writes generated code for the service into w
-func generateService(ctx *context, w io.Writer, svc *Service) {
+// generateServiceMethod writes generated code for the service into w
+func generateServiceMethod(ctx *context, w io.Writer, svc *Service) {
        reqTyp := camelCaseName(svc.RequestType)
 
        // method name is same as parameter type name by default
-       method := svc.MethodName()
-       params := fmt.Sprintf("*%s", reqTyp)
+       method := reqTyp
+       if svc.Stream {
+               // use Dump as prefix instead of suffix for stream services
+               if m := strings.TrimSuffix(method, "Dump"); method != m {
+                       method = "Dump" + m
+               }
+       }
+
+       params := fmt.Sprintf("in *%s", reqTyp)
        returns := "error"
        if replyType := camelCaseName(svc.ReplyType); replyType != "" {
-               repTyp := fmt.Sprintf("*%s", replyType)
+               replyTyp := fmt.Sprintf("*%s", replyType)
                if svc.Stream {
-                       repTyp = fmt.Sprintf("[]%s", repTyp)
+                       // TODO: stream responses
+                       //replyTyp = fmt.Sprintf("<-chan %s", replyTyp)
+                       replyTyp = fmt.Sprintf("[]%s", replyTyp)
                }
-               returns = fmt.Sprintf("(%s, error)", repTyp)
+               returns = fmt.Sprintf("(%s, error)", replyTyp)
        }
 
-       fmt.Fprintf(w, "\t%s(%s) %s\n", method, params, returns)
+       fmt.Fprintf(w, "\t%s(ctx context.Context, %s) %s", method, params, returns)
 }
 
 // generateEnum writes generated code for the enum into w
@@ -312,15 +419,34 @@ func generateEnum(ctx *context, w io.Writer, enum *Enum) {
        fmt.Fprintf(w, "type %s %s\n", name, typ)
        fmt.Fprintln(w)
 
-       fmt.Fprintln(w, "const (")
-
        // generate enum entries
+       fmt.Fprintln(w, "const (")
        for _, entry := range enum.Entries {
                fmt.Fprintf(w, "\t%s %s = %v\n", entry.Name, name, entry.Value)
        }
-
        fmt.Fprintln(w, ")")
+       fmt.Fprintln(w)
 
+       // generate enum conversion maps
+       fmt.Fprintf(w, "var %s_name = map[%s]string{\n", name, typ)
+       for _, entry := range enum.Entries {
+               fmt.Fprintf(w, "\t%v: \"%s\",\n", entry.Value, entry.Name)
+       }
+       fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
+
+       fmt.Fprintf(w, "var %s_value = map[string]%s{\n", name, typ)
+       for _, entry := range enum.Entries {
+               fmt.Fprintf(w, "\t\"%s\": %v,\n", entry.Name, entry.Value)
+       }
+       fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
+
+       fmt.Fprintf(w, "func (x %s) String() string {\n", name)
+       fmt.Fprintf(w, "\ts, ok := %s_name[%s(x)]\n", name, typ)
+       fmt.Fprintf(w, "\tif ok { return s }\n")
+       fmt.Fprintf(w, "\treturn strconv.Itoa(int(x))\n")
+       fmt.Fprintln(w, "}")
        fmt.Fprintln(w)
 }
 
@@ -362,8 +488,7 @@ func generateUnion(ctx *context, w io.Writer, union *Union) {
        maxSize := getUnionSize(ctx, union)
 
        // generate data field
-       fieldName := "Union_data"
-       fmt.Fprintf(w, "\t%s [%d]byte\n", fieldName, maxSize)
+       fmt.Fprintf(w, "\t%s [%d]byte\n", unionDataField, maxSize)
 
        // generate end of the struct
        fmt.Fprintln(w, "}")
@@ -388,9 +513,9 @@ func generateUnion(ctx *context, w io.Writer, union *Union) {
 }
 
 // generateUnionMethods generates methods that implement struc.Custom
-// interface to allow having Union_data field unexported
+// interface to allow having XXX_uniondata field unexported
 // TODO: do more testing when unions are actually used in some messages
-func generateUnionMethods(w io.Writer, structName string) {
+/*func generateUnionMethods(w io.Writer, structName string) {
        // generate struc.Custom implementation for union
        fmt.Fprintf(w, `
 func (u *%[1]s) Pack(p []byte, opt *struc.Options) (int, error) {
@@ -411,7 +536,7 @@ func (u *%[1]s) String() string {
        return string(u.union_data[:])
 }
 `, structName)
-}
+}*/
 
 func generateUnionGetterSetter(w io.Writer, structName string, getterField, getterStruct string) {
        fmt.Fprintf(w, `
@@ -424,14 +549,14 @@ func (u *%[1]s) Set%[2]s(a %[3]s) {
        if err := struc.Pack(b, &a); err != nil {
                return
        }
-       copy(u.Union_data[:], b.Bytes())
+       copy(u.%[4]s[:], b.Bytes())
 }
 func (u *%[1]s) Get%[2]s() (a %[3]s) {
-       var b = bytes.NewReader(u.Union_data[:])
+       var b = bytes.NewReader(u.%[4]s[:])
        struc.Unpack(b, &a)
        return
 }
-`, structName, getterField, getterStruct)
+`, structName, getterField, getterStruct, unionDataField)
 }
 
 // generateType writes generated code for the type into w
@@ -450,7 +575,7 @@ func generateType(ctx *context, w io.Writer, typ *Type) {
        for i, field := range typ.Fields {
                // skip internal fields
                switch strings.ToLower(field.Name) {
-               case "crc", "_vl_msg_id":
+               case crcField, msgIdField:
                        continue
                }
 
@@ -488,17 +613,17 @@ func generateMessage(ctx *context, w io.Writer, msg *Message) {
        n := 0
        for i, field := range msg.Fields {
                if i == 1 {
-                       if field.Name == "client_index" {
+                       if field.Name == clientIndexField {
                                // "client_index" as the second member,
                                // this might be an event message or a request
                                msgType = eventMessage
                                wasClientIndex = true
-                       } else if field.Name == "context" {
+                       } else if field.Name == contextField {
                                // reply needs "context" as the second member
                                msgType = replyMessage
                        }
                } else if i == 2 {
-                       if wasClientIndex && field.Name == "context" {
+                       if wasClientIndex && field.Name == contextField {
                                // request needs "client_index" as the second member
                                // and "context" as the third member
                                msgType = requestMessage
@@ -507,9 +632,9 @@ func generateMessage(ctx *context, w io.Writer, msg *Message) {
 
                // skip internal fields
                switch strings.ToLower(field.Name) {
-               case "crc", "_vl_msg_id":
+               case crcField, msgIdField:
                        continue
-               case "client_index", "context":
+               case clientIndexField, contextField:
                        if n == 0 {
                                continue
                        }
@@ -550,9 +675,10 @@ func generateField(ctx *context, w io.Writer, fields []Field, i int) {
        }
 
        dataType := convertToGoType(ctx, field.Type)
-
        fieldType := dataType
-       if field.IsArray() {
+
+       // check if it is array
+       if field.Length > 0 || field.SizeFrom != "" {
                if dataType == "uint8" {
                        dataType = "byte"
                }
@@ -560,17 +686,48 @@ func generateField(ctx *context, w io.Writer, fields []Field, i int) {
        }
        fmt.Fprintf(w, "\t%s %s", fieldName, fieldType)
 
+       fieldTags := map[string]string{}
+
        if field.Length > 0 {
                // fixed size array
-               fmt.Fprintf(w, "\t`struc:\"[%d]%s\"`", field.Length, dataType)
+               fieldTags["struc"] = fmt.Sprintf("[%d]%s", field.Length, dataType)
        } else {
                for _, f := range fields {
                        if f.SizeFrom == field.Name {
                                // variable sized array
                                sizeOfName := camelCaseName(f.Name)
-                               fmt.Fprintf(w, "\t`struc:\"sizeof=%s\"`", sizeOfName)
+                               fieldTags["struc"] = fmt.Sprintf("sizeof=%s", sizeOfName)
+                       }
+               }
+       }
+
+       if ctx.includeBinapiNames {
+               fieldTags["binapi"] = field.Name
+       }
+       if field.Meta.Limit > 0 {
+               fieldTags["binapi"] = fmt.Sprintf("%s,limit=%d", fieldTags["binapi"], field.Meta.Limit)
+       }
+
+       if len(fieldTags) > 0 {
+               fmt.Fprintf(w, "\t`")
+               var keys []string
+               for k := range fieldTags {
+                       keys = append(keys, k)
+               }
+               sort.Strings(keys)
+               var n int
+               for _, tt := range keys {
+                       t, ok := fieldTags[tt]
+                       if !ok {
+                               continue
+                       }
+                       if n > 0 {
+                               fmt.Fprintf(w, " ")
                        }
+                       n++
+                       fmt.Fprintf(w, `%s:"%s"`, tt, t)
                }
+               fmt.Fprintf(w, "`")
        }
 
        fmt.Fprintln(w)