Support imported type aliases 58/24658/2
authorOndrej Fabry <ofabry@cisco.com>
Tue, 28 Jan 2020 17:03:58 +0000 (18:03 +0100)
committerOndrej Fabry <ofabry@cisco.com>
Fri, 31 Jan 2020 10:23:35 +0000 (11:23 +0100)
Change-Id: I2e6ad9fb51e1cf55a52267720f2394e792946f7e
Signed-off-by: Ondrej Fabry <ofabry@cisco.com>
cmd/binapi-generator/generate.go
cmd/binapi-generator/main.go
cmd/binapi-generator/objects.go
cmd/binapi-generator/parse.go

index cb1f470..fb6cee5 100644 (file)
@@ -18,6 +18,7 @@ import (
        "bytes"
        "fmt"
        "io"
+       "os/exec"
        "path/filepath"
        "sort"
        "strings"
@@ -125,6 +126,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate enums
        if len(ctx.packageData.Enums) > 0 {
                for _, enum := range ctx.packageData.Enums {
+                       if imp, ok := ctx.packageData.Imports[enum.Name]; ok {
+                               generateImportedAlias(ctx, w, enum.Name, &imp)
+                               continue
+                       }
                        generateEnum(ctx, w, &enum)
                }
        }
@@ -132,6 +137,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate aliases
        if len(ctx.packageData.Aliases) > 0 {
                for _, alias := range ctx.packageData.Aliases {
+                       if imp, ok := ctx.packageData.Imports[alias.Name]; ok {
+                               generateImportedAlias(ctx, w, alias.Name, &imp)
+                               continue
+                       }
                        generateAlias(ctx, w, &alias)
                }
        }
@@ -139,6 +148,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate types
        if len(ctx.packageData.Types) > 0 {
                for _, typ := range ctx.packageData.Types {
+                       if imp, ok := ctx.packageData.Imports[typ.Name]; ok {
+                               generateImportedAlias(ctx, w, typ.Name, &imp)
+                               continue
+                       }
                        generateType(ctx, w, &typ)
                }
        }
@@ -146,6 +159,10 @@ func generatePackage(ctx *context, w io.Writer) error {
        // generate unions
        if len(ctx.packageData.Unions) > 0 {
                for _, union := range ctx.packageData.Unions {
+                       if imp, ok := ctx.packageData.Imports[union.Name]; ok {
+                               generateImportedAlias(ctx, w, union.Name, &imp)
+                               continue
+                       }
                        generateUnion(ctx, w, &union)
                }
        }
@@ -224,10 +241,42 @@ func generateHeader(ctx *context, w io.Writer) {
        fmt.Fprintf(w, "\tio \"%s\"\n", "io")
        fmt.Fprintf(w, "\tstrconv \"%s\"\n", "strconv")
        fmt.Fprintf(w, "\tstruc \"%s\"\n", "github.com/lunixbochs/struc")
+       if len(ctx.packageData.Imports) > 0 {
+               fmt.Fprintln(w)
+               for _, imp := range getImports(ctx) {
+                       impPkg := getImportPkg(filepath.Dir(ctx.outputFile), imp)
+                       fmt.Fprintf(w, "\t%s \"%s\"\n", imp, strings.TrimSpace(impPkg))
+               }
+       }
        fmt.Fprintln(w, ")")
        fmt.Fprintln(w)
 }
 
+func getImportPkg(outputDir string, pkg string) string {
+       absPath, _ := filepath.Abs(filepath.Join(outputDir, "..", pkg))
+       cmd := exec.Command("go", "list", absPath)
+       var errbuf, outbuf bytes.Buffer
+       cmd.Stdout = &outbuf
+       cmd.Stderr = &errbuf
+       if err := cmd.Run(); err != nil {
+               fmt.Printf("ERR: %v\n", errbuf.String())
+               panic(err)
+       }
+       return outbuf.String()
+}
+
+func getImports(ctx *context) (imports []string) {
+       impmap := map[string]struct{}{}
+       for _, imp := range ctx.packageData.Imports {
+               if _, ok := impmap[imp.Package]; !ok {
+                       imports = append(imports, imp.Package)
+                       impmap[imp.Package] = struct{}{}
+               }
+       }
+       sort.Strings(imports)
+       return imports
+}
+
 func generateFooter(ctx *context, w io.Writer) {
        fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
        fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
@@ -350,6 +399,14 @@ func generateEnum(ctx *context, w io.Writer, enum *Enum) {
        fmt.Fprintln(w)
 }
 
+func generateImportedAlias(ctx *context, w io.Writer, tName string, imp *Import) {
+       name := camelCaseName(tName)
+
+       fmt.Fprintf(w, "type %s = %s.%s\n", name, imp.Package, name)
+
+       fmt.Fprintln(w)
+}
+
 func generateAlias(ctx *context, w io.Writer, alias *Alias) {
        name := camelCaseName(alias.Name)
 
@@ -550,13 +607,10 @@ func generateMessage(ctx *context, w io.Writer, msg *Message) {
        // generate end of the struct
        fmt.Fprintln(w, "}")
 
-       // generate name getter
+       // generate message methods
+       generateMessageResetMethod(w, name)
        generateMessageNameGetter(w, name, msg.Name)
-
-       // generate CRC getter
        generateCrcGetter(w, name, msg.CRC)
-
-       // generate message type getter method
        generateMessageTypeGetter(w, name, msgType)
 
        fmt.Fprintln(w)
@@ -637,40 +691,36 @@ func generateField(ctx *context, w io.Writer, fields []Field, i int) {
        fmt.Fprintln(w)
 }
 
-func generateMessageNameGetter(w io.Writer, structName, msgName string) {
-       fmt.Fprintf(w, `func (*%s) GetMessageName() string {
-       return %q
+func generateMessageResetMethod(w io.Writer, structName string) {
+       fmt.Fprintf(w, "func (m *%[1]s) Reset() { *m = %[1]s{} }\n", structName)
 }
-`, structName, msgName)
+
+func generateMessageNameGetter(w io.Writer, structName, msgName string) {
+       fmt.Fprintf(w, "func (*%s) GetMessageName() string {    return %q }\n", structName, msgName)
 }
 
 func generateTypeNameGetter(w io.Writer, structName, msgName string) {
-       fmt.Fprintf(w, `func (*%s) GetTypeName() string {
-       return %q
-}
-`, structName, msgName)
+       fmt.Fprintf(w, "func (*%s) GetTypeName() string { return %q }\n", structName, msgName)
 }
 
 func generateCrcGetter(w io.Writer, structName, crc string) {
        crc = strings.TrimPrefix(crc, "0x")
-       fmt.Fprintf(w, `func (*%s) GetCrcString() string {
-       return %q
-}
-`, structName, crc)
+       fmt.Fprintf(w, "func (*%s) GetCrcString() string { return %q }\n", structName, crc)
 }
 
 func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageType) {
-       fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
+       fmt.Fprintf(w, "func (*"+structName+") GetMessageType() api.MessageType {")
        if msgType == requestMessage {
-               fmt.Fprintln(w, "\treturn api.RequestMessage")
+               fmt.Fprintf(w, "\treturn api.RequestMessage")
        } else if msgType == replyMessage {
-               fmt.Fprintln(w, "\treturn api.ReplyMessage")
+               fmt.Fprintf(w, "\treturn api.ReplyMessage")
        } else if msgType == eventMessage {
-               fmt.Fprintln(w, "\treturn api.EventMessage")
+               fmt.Fprintf(w, "\treturn api.EventMessage")
        } else {
-               fmt.Fprintln(w, "\treturn api.OtherMessage")
+               fmt.Fprintf(w, "\treturn api.OtherMessage")
        }
        fmt.Fprintln(w, "}")
+       fmt.Fprintln(w)
 }
 
 func generateServices(ctx *context, w io.Writer, services []Service) {
index c66fc4f..e0e2f08 100644 (file)
@@ -32,9 +32,10 @@ import (
 )
 
 var (
-       theInputFile = flag.String("input-file", "", "Input file with VPP API in JSON format.")
-       theInputDir  = flag.String("input-dir", "/usr/share/vpp/api", "Input directory with VPP API files in JSON format.")
-       theOutputDir = flag.String("output-dir", ".", "Output directory where package folders will be generated.")
+       theInputFile  = flag.String("input-file", "", "Input file with VPP API in JSON format.")
+       theInputTypes = flag.String("input-types", "", "Types input file with VPP API in JSON format. (split by comma)")
+       theInputDir   = flag.String("input-dir", "/usr/share/vpp/api", "Input directory with VPP API files in JSON format.")
+       theOutputDir  = flag.String("output-dir", ".", "Output directory where package folders will be generated.")
 
        includeAPIVer      = flag.Bool("include-apiver", true, "Include APIVersion constant for each module.")
        includeServices    = flag.Bool("include-services", true, "Include RPC service api and client implementation.")
@@ -84,14 +85,23 @@ func main() {
        }
 }
 
-func run(inputFile, inputDir string, outputDir string, continueErr bool) error {
+func run(inputFile, inputDir string, outputDir string, continueErr bool) (err error) {
        if inputFile == "" && inputDir == "" {
                return fmt.Errorf("input-file or input-dir must be specified")
        }
 
+       var typesPkgs []*context
+       if *theInputTypes != "" {
+               types := strings.Split(*theInputTypes, ",")
+               typesPkgs, err = loadTypesPackages(types...)
+               if err != nil {
+                       return fmt.Errorf("loading types input failed: %v", err)
+               }
+       }
+
        if inputFile != "" {
                // process one input file
-               if err := generateFromFile(inputFile, outputDir); err != nil {
+               if err := generateFromFile(inputFile, outputDir, typesPkgs); err != nil {
                        return fmt.Errorf("code generation from %s failed: %v\n", inputFile, err)
                }
        } else {
@@ -107,7 +117,7 @@ func run(inputFile, inputDir string, outputDir string, continueErr bool) error {
                        return fmt.Errorf("no input files found in input directory: %v\n", dir)
                }
                for _, file := range files {
-                       if err := generateFromFile(file, outputDir); err != nil {
+                       if err := generateFromFile(file, outputDir, typesPkgs); err != nil {
                                if continueErr {
                                        logrus.Warnf("code generation from %s failed: %v (error ignored)\n", file, err)
                                        continue
@@ -151,7 +161,7 @@ func parseInputJSON(inputData []byte) (*jsongo.Node, error) {
 }
 
 // generateFromFile generates Go package from one input JSON file
-func generateFromFile(inputFile, outputDir string) error {
+func generateFromFile(inputFile, outputDir string, typesPkgs []*context) error {
        // create generator context
        ctx, err := newContext(inputFile, outputDir)
        if err != nil {
@@ -185,6 +195,13 @@ func generateFromFile(inputFile, outputDir string) error {
                return fmt.Errorf("parsing package %s failed: %v", ctx.packageName, err)
        }
 
+       if len(typesPkgs) > 0 {
+               err = loadTypeAliases(ctx, typesPkgs)
+               if err != nil {
+                       return fmt.Errorf("loading type aliases failed: %v", err)
+               }
+       }
+
        // generate Go package code
        var buf bytes.Buffer
        if err := generatePackage(ctx, &buf); err != nil {
@@ -210,6 +227,109 @@ func generateFromFile(inputFile, outputDir string) error {
        return nil
 }
 
+func loadTypesPackages(types ...string) ([]*context, error) {
+       var ctxs []*context
+       for _, inputFile := range types {
+               // create generator context
+               ctx, err := newContext(inputFile, "")
+               if err != nil {
+                       return nil, err
+               }
+               // read API definition from input file
+               ctx.inputData, err = ioutil.ReadFile(ctx.inputFile)
+               if err != nil {
+                       return nil, fmt.Errorf("reading input file %s failed: %v", ctx.inputFile, err)
+               }
+               // parse JSON data into objects
+               jsonRoot, err := parseInputJSON(ctx.inputData)
+               if err != nil {
+                       return nil, fmt.Errorf("parsing JSON input failed: %v", err)
+               }
+               ctx.packageData, err = parsePackage(ctx, jsonRoot)
+               if err != nil {
+                       return nil, fmt.Errorf("parsing package %s failed: %v", ctx.packageName, err)
+               }
+               ctxs = append(ctxs, ctx)
+       }
+       return ctxs, nil
+}
+
+func loadTypeAliases(ctx *context, typesCtxs []*context) error {
+       for _, t := range ctx.packageData.Types {
+               for _, c := range typesCtxs {
+                       if _, ok := ctx.packageData.Imports[t.Name]; ok {
+                               break
+                       }
+                       for _, at := range c.packageData.Types {
+                               if at.Name != t.Name {
+                                       continue
+                               }
+                               if len(at.Fields) != len(t.Fields) {
+                                       continue
+                               }
+                               ctx.packageData.Imports[t.Name] = Import{
+                                       Package: c.packageName,
+                               }
+                       }
+               }
+       }
+       for _, t := range ctx.packageData.Aliases {
+               for _, c := range typesCtxs {
+                       if _, ok := ctx.packageData.Imports[t.Name]; ok {
+                               break
+                       }
+                       for _, at := range c.packageData.Aliases {
+                               if at.Name != t.Name {
+                                       continue
+                               }
+                               if at.Length != t.Length {
+                                       continue
+                               }
+                               if at.Type != t.Type {
+                                       continue
+                               }
+                               ctx.packageData.Imports[t.Name] = Import{
+                                       Package: c.packageName,
+                               }
+                       }
+               }
+       }
+       for _, t := range ctx.packageData.Enums {
+               for _, c := range typesCtxs {
+                       if _, ok := ctx.packageData.Imports[t.Name]; ok {
+                               break
+                       }
+                       for _, at := range c.packageData.Enums {
+                               if at.Name != t.Name {
+                                       continue
+                               }
+                               if at.Type != t.Type {
+                                       continue
+                               }
+                               ctx.packageData.Imports[t.Name] = Import{
+                                       Package: c.packageName,
+                               }
+                       }
+               }
+       }
+       for _, t := range ctx.packageData.Unions {
+               for _, c := range typesCtxs {
+                       if _, ok := ctx.packageData.Imports[t.Name]; ok {
+                               break
+                       }
+                       for _, at := range c.packageData.Unions {
+                               if at.Name != t.Name {
+                                       continue
+                               }
+                               ctx.packageData.Imports[t.Name] = Import{
+                                       Package: c.packageName,
+                               }
+                       }
+               }
+       }
+       return nil
+}
+
 func logf(f string, v ...interface{}) {
        if *debugMode {
                logrus.Debugf(f, v...)
index 9325d03..9871abc 100644 (file)
@@ -14,6 +14,11 @@ type Package struct {
        Unions   []Union
        Messages []Message
        RefMap   map[string]string
+       Imports  map[string]Import
+}
+
+type Import struct {
+       Package string
 }
 
 // Service represents VPP binary API service
index 9eed08c..6598b7b 100644 (file)
@@ -75,8 +75,9 @@ const (
 // parsePackage parses provided JSON data into objects prepared for code generation
 func parsePackage(ctx *context, jsonRoot *jsongo.Node) (*Package, error) {
        pkg := Package{
-               Name:   ctx.packageName,
-               RefMap: make(map[string]string),
+               Name:    ctx.packageName,
+               RefMap:  make(map[string]string),
+               Imports: map[string]Import{},
        }
 
        // parse CRC for API version