Improve binapi generator
[govpp.git] / binapigen / generator.go
index 07c1b13..e42e7fb 100644 (file)
 package binapigen
 
 import (
+       "bufio"
        "bytes"
        "fmt"
+       "go/ast"
        "go/format"
+       "go/parser"
+       "go/printer"
+       "go/token"
        "io/ioutil"
        "os"
        "path"
        "path/filepath"
-       "regexp"
+       "sort"
+       "strconv"
+       "strings"
 
        "github.com/sirupsen/logrus"
 
        "git.fd.io/govpp.git/binapigen/vppapi"
 )
 
-type Options struct {
-       VPPVersion string // version of VPP that produced API files
-
-       FilesToGenerate []string // list of API files to generate
-
-       ImportPrefix       string // defines import path prefix for importing types
-       ImportTypes        bool   // generate packages for import types
-       IncludeAPIVersion  bool   // include constant with API version string
-       IncludeComments    bool   // include parts of original source in comments
-       IncludeBinapiNames bool   // include binary API names as struct tag
-       IncludeServices    bool   // include service interface with client implementation
-       IncludeVppVersion  bool   // include info about used VPP version
-}
-
 type Generator struct {
-       Options
-
        Files       []*File
-       FilesByPath map[string]*File
        FilesByName map[string]*File
 
-       enumsByName   map[string]*Enum
-       aliasesByName map[string]*Alias
-       structsByName map[string]*Struct
-       unionsByName  map[string]*Union
+       opts       Options
+       apifiles   []*vppapi.File
+       vppVersion string
+
+       filesToGen []string
+       genfiles   []*GenFile
 
-       genfiles []*GenFile
+       enumsByName    map[string]*Enum
+       aliasesByName  map[string]*Alias
+       structsByName  map[string]*Struct
+       unionsByName   map[string]*Union
+       messagesByName map[string]*Message
 }
 
-func New(opts Options, apifiles []*vppapi.File) (*Generator, error) {
-       g := &Generator{
-               Options:       opts,
-               FilesByPath:   make(map[string]*File),
-               FilesByName:   make(map[string]*File),
-               enumsByName:   map[string]*Enum{},
-               aliasesByName: map[string]*Alias{},
-               structsByName: map[string]*Struct{},
-               unionsByName:  map[string]*Union{},
+func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator, error) {
+       gen := &Generator{
+               FilesByName:    make(map[string]*File),
+               opts:           opts,
+               apifiles:       apifiles,
+               filesToGen:     filesToGen,
+               enumsByName:    map[string]*Enum{},
+               aliasesByName:  map[string]*Alias{},
+               structsByName:  map[string]*Struct{},
+               unionsByName:   map[string]*Union{},
+               messagesByName: map[string]*Message{},
        }
 
-       logrus.Debugf("adding %d VPP API files to generator", len(apifiles))
+       // Normalize API files
+       SortFilesByImports(gen.apifiles)
        for _, apifile := range apifiles {
-               filename := apifile.Path
-               if filename == "" {
-                       filename = apifile.Name
-               }
-               if _, ok := g.FilesByPath[filename]; ok {
-                       return nil, fmt.Errorf("duplicate file name: %q", filename)
-               }
-               if _, ok := g.FilesByName[apifile.Name]; ok {
+               RemoveImportedTypes(gen.apifiles, apifile)
+               SortFileObjectsByName(apifile)
+       }
+
+       // prepare package names and import paths
+       packageNames := make(map[string]GoPackageName)
+       importPaths := make(map[string]GoImportPath)
+       for _, apifile := range gen.apifiles {
+               filename := getFilename(apifile)
+               packageNames[filename] = cleanPackageName(apifile.Name)
+               importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
+       }
+
+       logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
+
+       for _, apifile := range gen.apifiles {
+               filename := getFilename(apifile)
+
+               if _, ok := gen.FilesByName[apifile.Name]; ok {
                        return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
                }
 
-               file, err := newFile(g, apifile)
+               file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
                if err != nil {
-                       return nil, err
+                       return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
                }
-               g.Files = append(g.Files, file)
-               g.FilesByPath[filename] = file
-               g.FilesByName[apifile.Name] = file
+               gen.Files = append(gen.Files, file)
+               gen.FilesByName[apifile.Name] = file
 
                logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
-               if len(file.Imports) > 0 {
-                       logrus.Debugf(" - %d imports: %v", len(file.Imports), file.Imports)
-               }
        }
 
-       if len(opts.FilesToGenerate) > 0 {
-               logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate)
-               for _, genfile := range opts.FilesToGenerate {
-                       file, ok := g.FilesByPath[genfile]
+       // mark files for generation
+       if len(gen.filesToGen) > 0 {
+               logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
+               for _, genfile := range gen.filesToGen {
+                       file, ok := gen.FilesByName[genfile]
                        if !ok {
-                               file, ok = g.FilesByName[genfile]
-                               if !ok {
-                                       return nil, fmt.Errorf("no API file found for: %v", genfile)
-                               }
+                               return nil, fmt.Errorf("no API file found for: %v", genfile)
                        }
                        file.Generate = true
-                       if opts.ImportTypes {
-                               // generate all imported files
-                               for _, impFile := range file.importedFiles(g) {
-                                       impFile.Generate = true
-                               }
+                       // generate all imported files
+                       for _, impFile := range file.importedFiles(gen) {
+                               impFile.Generate = true
                        }
                }
        } else {
-               logrus.Debugf("Files to generate not specified, marking all %d files to generate", len(g.Files))
-               for _, file := range g.Files {
+               logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
+               for _, file := range gen.Files {
                        file.Generate = true
                }
        }
 
-       logrus.Debugf("Resolving imported types")
-       for _, file := range g.Files {
-               if !file.Generate {
-                       // skip resolving for non-generated files
-                       continue
-               }
-               var importedFiles []*File
-               for _, impFile := range file.importedFiles(g) {
-                       if !impFile.Generate {
-                               // exclude imports of non-generated files
-                               continue
-                       }
-                       importedFiles = append(importedFiles, impFile)
-               }
-               file.loadTypeImports(g, importedFiles)
-       }
+       return gen, nil
+}
 
-       return g, nil
+func getFilename(file *vppapi.File) string {
+       if file.Path == "" {
+               return file.Name
+       }
+       return file.Path
 }
 
 func (g *Generator) Generate() error {
@@ -147,127 +139,238 @@ func (g *Generator) Generate() error {
        }
 
        logrus.Infof("Generating %d files", len(g.genfiles))
+
        for _, genfile := range g.genfiles {
-               if err := writeSourceTo(genfile.filename, genfile.Content()); err != nil {
-                       return fmt.Errorf("writing source for RPC package %s failed: %v", genfile.filename, err)
+               content, err := genfile.Content()
+               if err != nil {
+                       return err
+               }
+               if err := writeSourceTo(genfile.filename, content); err != nil {
+                       return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
                }
        }
        return nil
 }
 
 type GenFile struct {
-       *Generator
-       filename  string
-       file      *File
-       outputDir string
-       buf       bytes.Buffer
+       gen           *Generator
+       file          *File
+       filename      string
+       goImportPath  GoImportPath
+       buf           bytes.Buffer
+       manualImports map[GoImportPath]bool
+       packageNames  map[GoImportPath]GoPackageName
 }
 
-func (g *Generator) NewGenFile(filename string) *GenFile {
+func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
        f := &GenFile{
-               Generator: g,
-               filename:  filename,
+               gen:           g,
+               filename:      filename,
+               goImportPath:  importPath,
+               manualImports: make(map[GoImportPath]bool),
+               packageNames:  make(map[GoImportPath]GoPackageName),
        }
        g.genfiles = append(g.genfiles, f)
        return f
 }
 
-func (f *GenFile) Content() []byte {
-       return f.buf.Bytes()
+func (g *GenFile) Write(p []byte) (n int, err error) {
+       return g.buf.Write(p)
 }
 
-func writeSourceTo(outputFile string, b []byte) error {
-       // create output directory
-       packageDir := filepath.Dir(outputFile)
-       if err := os.MkdirAll(packageDir, 0775); err != nil {
-               return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
-       }
+func (g *GenFile) Import(importPath GoImportPath) {
+       g.manualImports[importPath] = true
+}
 
-       // format generated source code
-       gosrc, err := format.Source(b)
-       if err != nil {
-               _ = ioutil.WriteFile(outputFile, b, 0666)
-               return fmt.Errorf("formatting source code failed: %v", err)
+func (g *GenFile) GoIdent(ident GoIdent) string {
+       if ident.GoImportPath == g.goImportPath {
+               return ident.GoName
        }
-
-       // write generated code to output file
-       if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
-               return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
+       if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
+               return string(packageName) + "." + ident.GoName
        }
+       packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
+       g.packageNames[ident.GoImportPath] = packageName
+       return string(packageName) + "." + ident.GoName
+}
 
-       lines := bytes.Count(gosrc, []byte("\n"))
-       logf("wrote %d lines (%d bytes) of code to: %q", lines, len(gosrc), outputFile)
+func (g *GenFile) P(v ...interface{}) {
+       for _, x := range v {
+               switch x := x.(type) {
+               case GoIdent:
+                       fmt.Fprint(&g.buf, g.GoIdent(x))
+               default:
+                       fmt.Fprint(&g.buf, x)
+               }
+       }
+       fmt.Fprintln(&g.buf)
+}
 
-       return nil
+func (g *GenFile) Content() ([]byte, error) {
+       if !strings.HasSuffix(g.filename, ".go") {
+               return g.buf.Bytes(), nil
+       }
+       return g.injectImports(g.buf.Bytes())
 }
 
-func listImports(genfile *GenFile) map[string]string {
-       var importPath = genfile.ImportPrefix
-       if importPath == "" {
-               importPath = resolveImportPath(genfile.outputDir)
-               logrus.Debugf("resolved import path: %s", importPath)
+func (g *GenFile) injectImports(original []byte) ([]byte, error) {
+       // Parse source code
+       fset := token.NewFileSet()
+       file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
+       if err != nil {
+               var src bytes.Buffer
+               s := bufio.NewScanner(bytes.NewReader(original))
+               for line := 1; s.Scan(); line++ {
+                       fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
+               }
+               return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
+       }
+       type Import struct {
+               Name string
+               Path string
+       }
+       // Prepare list of all imports
+       var importPaths []Import
+       for importPath := range g.packageNames {
+               importPaths = append(importPaths, Import{
+                       Name: string(g.packageNames[GoImportPath(importPath)]),
+                       Path: string(importPath),
+               })
        }
-       imports := map[string]string{}
-       for _, imp := range genfile.file.imports {
-               if _, ok := imports[imp]; !ok {
-                       imports[imp] = path.Join(importPath, imp)
+       for importPath := range g.manualImports {
+               if _, ok := g.packageNames[importPath]; ok {
+                       continue
                }
+               importPaths = append(importPaths, Import{
+                       Name: "_",
+                       Path: string(importPath),
+               })
        }
-       return imports
-}
+       // Sort imports by import path
+       sort.Slice(importPaths, func(i, j int) bool {
+               return importPaths[i].Path < importPaths[j].Path
+       })
+       // Inject new import block into parsed AST
+       if len(importPaths) > 0 {
+               // Find import block position
+               pos := file.Package
+               tokFile := fset.File(file.Package)
+               pkgLine := tokFile.Line(file.Package)
+               for _, c := range file.Comments {
+                       if tokFile.Line(c.Pos()) > pkgLine {
+                               break
+                       }
+                       pos = c.End()
+               }
+               // Prepare the import block
+               impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
+               for _, importPath := range importPaths {
+                       var name *ast.Ident
+                       if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
+                               name = &ast.Ident{Name: importPath.Name, NamePos: pos}
+                       }
+                       impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
+                               Name:   name,
+                               Path:   &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath.Path), ValuePos: pos},
+                               EndPos: pos,
+                       })
+               }
 
-func resolveImportPath(outputDir string) string {
-       absPath, err := filepath.Abs(outputDir)
-       if err != nil {
-               panic(err)
+               file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
        }
-       modRoot := findModuleRoot(absPath)
-       if modRoot == "" {
-               logrus.Fatalf("module root not found at: %s", absPath)
+       // Reformat source code
+       var out bytes.Buffer
+       cfg := &printer.Config{
+               Mode:     printer.TabIndent | printer.UseSpaces,
+               Tabwidth: 8,
        }
-       modPath := findModulePath(path.Join(modRoot, "go.mod"))
-       if modPath == "" {
-               logrus.Fatalf("module path not found")
+       if err = cfg.Fprint(&out, fset, file); err != nil {
+               return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
        }
-       relDir, err := filepath.Rel(modRoot, absPath)
-       if err != nil {
-               panic(err)
+       return out.Bytes(), nil
+}
+
+// GoIdent is a Go identifier, consisting of a name and import path.
+// The name is a single identifier and may not be a dot-qualified selector.
+type GoIdent struct {
+       GoName       string
+       GoImportPath GoImportPath
+}
+
+func (id GoIdent) String() string {
+       return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName)
+}
+
+func newGoIdent(f *File, fullName string) GoIdent {
+       name := strings.TrimPrefix(fullName, string(f.PackageName)+".")
+       return GoIdent{
+               GoName:       camelCaseName(name),
+               GoImportPath: f.GoImportPath,
        }
-       return filepath.Join(modPath, relDir)
 }
 
-func findModuleRoot(dir string) (root string) {
-       if dir == "" {
-               panic("dir not set")
+// GoImportPath is a Go import path for a package.
+type GoImportPath string
+
+func (p GoImportPath) String() string {
+       return strconv.Quote(string(p))
+}
+
+func (p GoImportPath) Ident(s string) GoIdent {
+       return GoIdent{GoName: s, GoImportPath: p}
+}
+
+type GoPackageName string
+
+func cleanPackageName(name string) GoPackageName {
+       return GoPackageName(sanitizedName(name))
+}
+
+func sanitizedName(name string) string {
+       switch name {
+       case "interface":
+               return "interfaces"
+       case "map":
+               return "maps"
+       default:
+               return name
        }
-       dir = filepath.Clean(dir)
+}
 
-       // Look for enclosing go.mod.
-       for {
-               if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
-                       return dir
-               }
-               d := filepath.Dir(dir)
-               if d == dir {
-                       break
-               }
-               dir = d
+// baseName returns the last path element of the name, with the last dotted suffix removed.
+func baseName(name string) string {
+       // First, find the last element
+       if i := strings.LastIndex(name, "/"); i >= 0 {
+               name = name[i+1:]
+       }
+       // Now drop the suffix
+       if i := strings.LastIndex(name, "."); i >= 0 {
+               name = name[:i]
        }
-       return ""
+       return name
 }
 
-var (
-       modulePathRE = regexp.MustCompile(`module[ \t]+([^ \t\r\n]+)`)
-)
+func writeSourceTo(outputFile string, b []byte) error {
+       // create output directory
+       packageDir := filepath.Dir(outputFile)
+       if err := os.MkdirAll(packageDir, 0775); err != nil {
+               return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
+       }
 
-func findModulePath(file string) string {
-       data, err := ioutil.ReadFile(file)
+       // format generated source code
+       gosrc, err := format.Source(b)
        if err != nil {
-               return ""
+               _ = ioutil.WriteFile(outputFile, b, 0666)
+               return fmt.Errorf("formatting source code failed: %v", err)
        }
-       m := modulePathRE.FindSubmatch(data)
-       if m == nil {
-               return ""
+
+       // write generated code to output file
+       if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
+               return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
        }
-       return string(m[1])
+
+       lines := bytes.Count(gosrc, []byte("\n"))
+       logf("wrote %d lines (%d bytes) to: %q", lines, len(gosrc), outputFile)
+
+       return nil
 }