Format generated Go source code in-process
[govpp.git] / cmd / binapi-generator / generate.go
index cb1f470..715836d 100644 (file)
@@ -18,6 +18,8 @@ import (
        "bytes"
        "fmt"
        "io"
+       "os/exec"
+       "path"
        "path/filepath"
        "sort"
        "strings"
@@ -34,8 +36,6 @@ const (
        inputFileExt  = ".api.json" // file extension of the VPP API files
        outputFileExt = ".ba.go"    // file extension of the Go generated files
 
-       govppApiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API package
-
        constModuleName = "ModuleName" // module name constant
        constAPIVersion = "APIVersion" // API version constant
        constVersionCrc = "VersionCrc" // version CRC constant
@@ -52,6 +52,8 @@ type context struct {
        inputFile  string // input file with VPP API in JSON
        outputFile string // output file with generated Go package
 
+       importPrefix string // defines import path prefix for importing types
+
        inputData []byte // contents of the input file
 
        includeAPIVersion  bool // include constant with API version string
@@ -98,13 +100,16 @@ func newContext(inputFile, outputDir string) (*context, error) {
 }
 
 func generatePackage(ctx *context, w io.Writer) error {
+       logf("----------------------------")
        logf("generating package %q", ctx.packageName)
+       logf("----------------------------")
 
        fmt.Fprintln(w, "// Code generated by GoVPP's binapi-generator. DO NOT EDIT.")
        fmt.Fprintf(w, "// source: %s\n", ctx.inputFile)
        fmt.Fprintln(w)
 
        generateHeader(ctx, w)
+       generateImports(ctx, w)
 
        // generate module desc
        fmt.Fprintln(w, "const (")
@@ -125,6 +130,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate enums
        if len(ctx.packageData.Enums) > 0 {
                for _, enum := range ctx.packageData.Enums {
+                       if imp, ok := ctx.packageData.Imports[enum.Name]; ok {
+                               generateImportedAlias(ctx, w, enum.Name, &imp)
+                               continue
+                       }
                        generateEnum(ctx, w, &enum)
                }
        }
@@ -132,6 +141,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate aliases
        if len(ctx.packageData.Aliases) > 0 {
                for _, alias := range ctx.packageData.Aliases {
+                       if imp, ok := ctx.packageData.Imports[alias.Name]; ok {
+                               generateImportedAlias(ctx, w, alias.Name, &imp)
+                               continue
+                       }
                        generateAlias(ctx, w, &alias)
                }
        }
@@ -139,6 +152,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate types
        if len(ctx.packageData.Types) > 0 {
                for _, typ := range ctx.packageData.Types {
+                       if imp, ok := ctx.packageData.Imports[typ.Name]; ok {
+                               generateImportedAlias(ctx, w, typ.Name, &imp)
+                               continue
+                       }
                        generateType(ctx, w, &typ)
                }
        }
@@ -146,6 +163,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate unions
        if len(ctx.packageData.Unions) > 0 {
                for _, union := range ctx.packageData.Unions {
+                       if imp, ok := ctx.packageData.Imports[union.Name]; ok {
+                               generateImportedAlias(ctx, w, union.Name, &imp)
+                               continue
+                       }
                        generateUnion(ctx, w, &union)
                }
        }
@@ -217,17 +238,56 @@ func generateHeader(ctx *context, w io.Writer) {
        fmt.Fprintf(w, "package %s\n", ctx.packageName)
        fmt.Fprintln(w)
 
+}
+
+func generateImports(ctx *context, w io.Writer) {
        fmt.Fprintln(w, "import (")
-       fmt.Fprintf(w, "\tapi \"%s\"\n", govppApiImportPath)
-       fmt.Fprintf(w, "\tbytes \"%s\"\n", "bytes")
-       fmt.Fprintf(w, "\tcontext \"%s\"\n", "context")
-       fmt.Fprintf(w, "\tio \"%s\"\n", "io")
-       fmt.Fprintf(w, "\tstrconv \"%s\"\n", "strconv")
+       fmt.Fprintln(w, `       "bytes"`)
+       fmt.Fprintln(w, `       "context"`)
+       fmt.Fprintln(w, `       "io"`)
+       fmt.Fprintln(w, `       "strconv"`)
+       fmt.Fprintln(w)
+       fmt.Fprintf(w, "\tapi \"%s\"\n", "git.fd.io/govpp.git/api")
        fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc")
+       if len(ctx.packageData.Imports) > 0 {
+               fmt.Fprintln(w)
+               for _, imp := range getImports(ctx) {
+                       importPath := path.Join(ctx.importPrefix, imp)
+                       if importPath == "" {
+                               importPath = getImportPath(filepath.Dir(ctx.outputFile), imp)
+                       }
+                       fmt.Fprintf(w, "\t%s \"%s\"\n", imp, strings.TrimSpace(importPath))
+               }
+       }
        fmt.Fprintln(w, ")")
        fmt.Fprintln(w)
 }
 
+func getImportPath(outputDir string, pkg string) string {
+       absPath, _ := filepath.Abs(filepath.Join(outputDir, "..", pkg))
+       cmd := exec.Command("go", "list", absPath)
+       var errbuf, outbuf bytes.Buffer
+       cmd.Stdout = &outbuf
+       cmd.Stderr = &errbuf
+       if err := cmd.Run(); err != nil {
+               fmt.Printf("ERR: %v\n", errbuf.String())
+               panic(err)
+       }
+       return outbuf.String()
+}
+
+func getImports(ctx *context) (imports []string) {
+       impmap := map[string]struct{}{}
+       for _, imp := range ctx.packageData.Imports {
+               if _, ok := impmap[imp.Package]; !ok {
+                       imports = append(imports, imp.Package)
+                       impmap[imp.Package] = struct{}{}
+               }
+       }
+       sort.Strings(imports)
+       return imports
+}
+
 func generateFooter(ctx *context, w io.Writer) {
        fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
        fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
@@ -350,6 +410,14 @@ func generateEnum(ctx *context, w io.Writer, enum *Enum) {
        fmt.Fprintln(w)
 }
 
+func generateImportedAlias(ctx *context, w io.Writer, tName string, imp *Import) {
+       name := camelCaseName(tName)
+
+       fmt.Fprintf(w, "type %s = %s.%s\n", name, imp.Package, name)
+
+       fmt.Fprintln(w)
+}
+
 func generateAlias(ctx *context, w io.Writer, alias *Alias) {
        name := camelCaseName(alias.Name)
 
@@ -550,13 +618,10 @@ func generateMessage(ctx *context, w io.Writer, msg *Message) {
        // generate end of the struct
        fmt.Fprintln(w, "}")
 
-       // generate name getter
+       // generate message methods
+       generateMessageResetMethod(w, name)
        generateMessageNameGetter(w, name, msg.Name)
-
-       // generate CRC getter
        generateCrcGetter(w, name, msg.CRC)
-
-       // generate message type getter method
        generateMessageTypeGetter(w, name, msgType)
 
        fmt.Fprintln(w)
@@ -637,40 +702,36 @@ func generateField(ctx *context, w io.Writer, fields []Field, i int) {
        fmt.Fprintln(w)
 }
 
-func generateMessageNameGetter(w io.Writer, structName, msgName string) {
-       fmt.Fprintf(w, `func (*%s) GetMessageName() string {
-       return %q
+func generateMessageResetMethod(w io.Writer, structName string) {
+       fmt.Fprintf(w, "func (m *%[1]s) Reset() { *m = %[1]s{} }\n", structName)
 }
-`, structName, msgName)
+
+func generateMessageNameGetter(w io.Writer, structName, msgName string) {
+       fmt.Fprintf(w, "func (*%s) GetMessageName() string {    return %q }\n", structName, msgName)
 }
 
 func generateTypeNameGetter(w io.Writer, structName, msgName string) {
-       fmt.Fprintf(w, `func (*%s) GetTypeName() string {
-       return %q
-}
-`, structName, msgName)
+       fmt.Fprintf(w, "func (*%s) GetTypeName() string { return %q }\n", structName, msgName)
 }
 
 func generateCrcGetter(w io.Writer, structName, crc string) {
        crc = strings.TrimPrefix(crc, "0x")
-       fmt.Fprintf(w, `func (*%s) GetCrcString() string {
-       return %q
-}
-`, structName, crc)
+       fmt.Fprintf(w, "func (*%s) GetCrcString() string { return %q }\n", structName, crc)
 }
 
 func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageType) {
-       fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
+       fmt.Fprintf(w, "func (*"+structName+") GetMessageType() api.MessageType {")
        if msgType == requestMessage {
-               fmt.Fprintln(w, "\treturn api.RequestMessage")
+               fmt.Fprintf(w, "\treturn api.RequestMessage")
        } else if msgType == replyMessage {
-               fmt.Fprintln(w, "\treturn api.ReplyMessage")
+               fmt.Fprintf(w, "\treturn api.ReplyMessage")
        } else if msgType == eventMessage {
-               fmt.Fprintln(w, "\treturn api.EventMessage")
+               fmt.Fprintf(w, "\treturn api.EventMessage")
        } else {
-               fmt.Fprintln(w, "\treturn api.OtherMessage")
+               fmt.Fprintf(w, "\treturn api.OtherMessage")
        }
        fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
 }
 
 func generateServices(ctx *context, w io.Writer, services []Service) {