Improve doc & fix import ordering
[govpp.git] / binapigen / generator.go
index 07c1b13..e5eed5a 100644 (file)
 package binapigen
 
 import (
+       "bufio"
        "bytes"
        "fmt"
-       "go/format"
+       "go/ast"
+       "go/parser"
+       "go/printer"
+       "go/token"
        "io/ioutil"
        "os"
        "path"
        "path/filepath"
-       "regexp"
+       "sort"
+       "strconv"
+       "strings"
 
        "github.com/sirupsen/logrus"
 
        "git.fd.io/govpp.git/binapigen/vppapi"
 )
 
-type Options struct {
-       VPPVersion string // version of VPP that produced API files
-
-       FilesToGenerate []string // list of API files to generate
-
-       ImportPrefix       string // defines import path prefix for importing types
-       ImportTypes        bool   // generate packages for import types
-       IncludeAPIVersion  bool   // include constant with API version string
-       IncludeComments    bool   // include parts of original source in comments
-       IncludeBinapiNames bool   // include binary API names as struct tag
-       IncludeServices    bool   // include service interface with client implementation
-       IncludeVppVersion  bool   // include info about used VPP version
-}
-
 type Generator struct {
-       Options
-
        Files       []*File
-       FilesByPath map[string]*File
        FilesByName map[string]*File
+       FilesByPath map[string]*File
+
+       opts       Options
+       apifiles   []*vppapi.File
+       vppVersion string
 
-       enumsByName   map[string]*Enum
-       aliasesByName map[string]*Alias
-       structsByName map[string]*Struct
-       unionsByName  map[string]*Union
+       filesToGen []string
+       genfiles   []*GenFile
 
-       genfiles []*GenFile
+       enumsByName    map[string]*Enum
+       aliasesByName  map[string]*Alias
+       structsByName  map[string]*Struct
+       unionsByName   map[string]*Union
+       messagesByName map[string]*Message
 }
 
-func New(opts Options, apifiles []*vppapi.File) (*Generator, error) {
-       g := &Generator{
-               Options:       opts,
-               FilesByPath:   make(map[string]*File),
-               FilesByName:   make(map[string]*File),
-               enumsByName:   map[string]*Enum{},
-               aliasesByName: map[string]*Alias{},
-               structsByName: map[string]*Struct{},
-               unionsByName:  map[string]*Union{},
+func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
+       gen := &Generator{
+               FilesByName:    make(map[string]*File),
+               FilesByPath:    make(map[string]*File),
+               opts:           opts,
+               apifiles:       apiFiles,
+               filesToGen:     filesToGen,
+               enumsByName:    map[string]*Enum{},
+               aliasesByName:  map[string]*Alias{},
+               structsByName:  map[string]*Struct{},
+               unionsByName:   map[string]*Union{},
+               messagesByName: map[string]*Message{},
        }
 
-       logrus.Debugf("adding %d VPP API files to generator", len(apifiles))
-       for _, apifile := range apifiles {
-               filename := apifile.Path
-               if filename == "" {
-                       filename = apifile.Name
-               }
-               if _, ok := g.FilesByPath[filename]; ok {
-                       return nil, fmt.Errorf("duplicate file name: %q", filename)
-               }
-               if _, ok := g.FilesByName[apifile.Name]; ok {
+       // Normalize API files
+       SortFilesByImports(gen.apifiles)
+       for _, apiFile := range apiFiles {
+               RemoveImportedTypes(gen.apifiles, apiFile)
+               SortFileObjectsByName(apiFile)
+       }
+
+       // prepare package names and import paths
+       packageNames := make(map[string]GoPackageName)
+       importPaths := make(map[string]GoImportPath)
+       for _, apifile := range gen.apifiles {
+               filename := getFilename(apifile)
+               packageNames[filename] = cleanPackageName(apifile.Name)
+               importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
+       }
+
+       logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
+
+       for _, apifile := range gen.apifiles {
+               if _, ok := gen.FilesByName[apifile.Name]; ok {
                        return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
                }
 
-               file, err := newFile(g, apifile)
+               filename := getFilename(apifile)
+               file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
                if err != nil {
-                       return nil, err
+                       return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
                }
-               g.Files = append(g.Files, file)
-               g.FilesByPath[filename] = file
-               g.FilesByName[apifile.Name] = file
+               gen.Files = append(gen.Files, file)
+               gen.FilesByName[apifile.Name] = file
+               gen.FilesByPath[apifile.Path] = file
 
                logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
-               if len(file.Imports) > 0 {
-                       logrus.Debugf(" - %d imports: %v", len(file.Imports), file.Imports)
-               }
        }
 
-       if len(opts.FilesToGenerate) > 0 {
-               logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate)
-               for _, genfile := range opts.FilesToGenerate {
-                       file, ok := g.FilesByPath[genfile]
-                       if !ok {
-                               file, ok = g.FilesByName[genfile]
-                               if !ok {
-                                       return nil, fmt.Errorf("no API file found for: %v", genfile)
-                               }
-                       }
-                       file.Generate = true
-                       if opts.ImportTypes {
+       // mark files for generation
+       if len(gen.filesToGen) > 0 {
+               logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
+               for _, genFile := range gen.filesToGen {
+                       markGen := func(file *File) {
+                               file.Generate = true
                                // generate all imported files
-                               for _, impFile := range file.importedFiles(g) {
+                               for _, impFile := range file.importedFiles(gen) {
                                        impFile.Generate = true
                                }
                        }
+                       if file, ok := gen.FilesByName[genFile]; ok {
+                               markGen(file)
+                               continue
+                       }
+                       logrus.Debugf("File %s was not found by name", genFile)
+                       if file, ok := gen.FilesByPath[genFile]; ok {
+                               markGen(file)
+                               continue
+                       }
+                       return nil, fmt.Errorf("no API file found for: %v", genFile)
                }
        } else {
-               logrus.Debugf("Files to generate not specified, marking all %d files to generate", len(g.Files))
-               for _, file := range g.Files {
+               logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
+               for _, file := range gen.Files {
                        file.Generate = true
                }
        }
 
-       logrus.Debugf("Resolving imported types")
-       for _, file := range g.Files {
-               if !file.Generate {
-                       // skip resolving for non-generated files
-                       continue
-               }
-               var importedFiles []*File
-               for _, impFile := range file.importedFiles(g) {
-                       if !impFile.Generate {
-                               // exclude imports of non-generated files
-                               continue
-                       }
-                       importedFiles = append(importedFiles, impFile)
-               }
-               file.loadTypeImports(g, importedFiles)
-       }
+       return gen, nil
+}
 
-       return g, nil
+func getFilename(file *vppapi.File) string {
+       if file.Path == "" {
+               return file.Name
+       }
+       return file.Path
 }
 
 func (g *Generator) Generate() error {
@@ -147,127 +148,191 @@ func (g *Generator) Generate() error {
        }
 
        logrus.Infof("Generating %d files", len(g.genfiles))
+
        for _, genfile := range g.genfiles {
-               if err := writeSourceTo(genfile.filename, genfile.Content()); err != nil {
-                       return fmt.Errorf("writing source for RPC package %s failed: %v", genfile.filename, err)
+               content, err := genfile.Content()
+               if err != nil {
+                       return err
+               }
+               if err := writeSourceTo(genfile.filename, content); err != nil {
+                       return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
                }
        }
        return nil
 }
 
 type GenFile struct {
-       *Generator
-       filename  string
-       file      *File
-       outputDir string
-       buf       bytes.Buffer
+       gen           *Generator
+       file          *File
+       filename      string
+       goImportPath  GoImportPath
+       buf           bytes.Buffer
+       manualImports map[GoImportPath]bool
+       packageNames  map[GoImportPath]GoPackageName
 }
 
-func (g *Generator) NewGenFile(filename string) *GenFile {
+// NewGenFile creates new generated file with
+func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
        f := &GenFile{
-               Generator: g,
-               filename:  filename,
+               gen:           g,
+               filename:      filename,
+               goImportPath:  importPath,
+               manualImports: make(map[GoImportPath]bool),
+               packageNames:  make(map[GoImportPath]GoPackageName),
        }
        g.genfiles = append(g.genfiles, f)
        return f
 }
 
-func (f *GenFile) Content() []byte {
-       return f.buf.Bytes()
+func (g *GenFile) Write(p []byte) (n int, err error) {
+       return g.buf.Write(p)
 }
 
-func writeSourceTo(outputFile string, b []byte) error {
-       // create output directory
-       packageDir := filepath.Dir(outputFile)
-       if err := os.MkdirAll(packageDir, 0775); err != nil {
-               return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
-       }
+func (g *GenFile) Import(importPath GoImportPath) {
+       g.manualImports[importPath] = true
+}
 
-       // format generated source code
-       gosrc, err := format.Source(b)
-       if err != nil {
-               _ = ioutil.WriteFile(outputFile, b, 0666)
-               return fmt.Errorf("formatting source code failed: %v", err)
+func (g *GenFile) GoIdent(ident GoIdent) string {
+       if ident.GoImportPath == g.goImportPath {
+               return ident.GoName
        }
-
-       // write generated code to output file
-       if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
-               return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
+       if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
+               return string(packageName) + "." + ident.GoName
        }
+       packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
+       g.packageNames[ident.GoImportPath] = packageName
+       return string(packageName) + "." + ident.GoName
+}
 
-       lines := bytes.Count(gosrc, []byte("\n"))
-       logf("wrote %d lines (%d bytes) of code to: %q", lines, len(gosrc), outputFile)
-
-       return nil
+func (g *GenFile) P(v ...interface{}) {
+       for _, x := range v {
+               switch x := x.(type) {
+               case GoIdent:
+                       fmt.Fprint(&g.buf, g.GoIdent(x))
+               default:
+                       fmt.Fprint(&g.buf, x)
+               }
+       }
+       fmt.Fprintln(&g.buf)
 }
 
-func listImports(genfile *GenFile) map[string]string {
-       var importPath = genfile.ImportPrefix
-       if importPath == "" {
-               importPath = resolveImportPath(genfile.outputDir)
-               logrus.Debugf("resolved import path: %s", importPath)
+func (g *GenFile) Content() ([]byte, error) {
+       if !strings.HasSuffix(g.filename, ".go") {
+               return g.buf.Bytes(), nil
        }
-       imports := map[string]string{}
-       for _, imp := range genfile.file.imports {
-               if _, ok := imports[imp]; !ok {
-                       imports[imp] = path.Join(importPath, imp)
-               }
+       return g.injectImports(g.buf.Bytes())
+}
+
+func getImportClass(importPath string) int {
+       if !strings.Contains(importPath, ".") {
+               return 0 /* std */
        }
-       return imports
+       return 1 /* External */
 }
 
-func resolveImportPath(outputDir string) string {
-       absPath, err := filepath.Abs(outputDir)
+// injectImports parses source, injects import block declaration with all imports and return formatted
+func (g *GenFile) injectImports(original []byte) ([]byte, error) {
+       // Parse source code
+       fset := token.NewFileSet()
+       file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
        if err != nil {
-               panic(err)
-       }
-       modRoot := findModuleRoot(absPath)
-       if modRoot == "" {
-               logrus.Fatalf("module root not found at: %s", absPath)
+               var src bytes.Buffer
+               s := bufio.NewScanner(bytes.NewReader(original))
+               for line := 1; s.Scan(); line++ {
+                       fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
+               }
+               return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
        }
-       modPath := findModulePath(path.Join(modRoot, "go.mod"))
-       if modPath == "" {
-               logrus.Fatalf("module path not found")
+       type Import struct {
+               Name string
+               Path string
        }
-       relDir, err := filepath.Rel(modRoot, absPath)
-       if err != nil {
-               panic(err)
+       // Prepare list of all imports
+       var importPaths []Import
+       for importPath := range g.packageNames {
+               importPaths = append(importPaths, Import{
+                       Name: string(g.packageNames[importPath]),
+                       Path: string(importPath),
+               })
        }
-       return filepath.Join(modPath, relDir)
-}
-
-func findModuleRoot(dir string) (root string) {
-       if dir == "" {
-               panic("dir not set")
+       for importPath := range g.manualImports {
+               if _, ok := g.packageNames[importPath]; ok {
+                       continue
+               }
+               importPaths = append(importPaths, Import{
+                       Name: "_",
+                       Path: string(importPath),
+               })
        }
-       dir = filepath.Clean(dir)
-
-       // Look for enclosing go.mod.
-       for {
-               if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
-                       return dir
+       // Sort imports by import path
+       sort.Slice(importPaths, func(i, j int) bool {
+               ci := getImportClass(importPaths[i].Path)
+               cj := getImportClass(importPaths[j].Path)
+               if ci == cj {
+                       return importPaths[i].Path < importPaths[j].Path
                }
-               d := filepath.Dir(dir)
-               if d == dir {
-                       break
+               return ci < cj
+       })
+       // Inject new import block into parsed AST
+       if len(importPaths) > 0 {
+               // Find import block position
+               pos := file.Package
+               tokFile := fset.File(file.Package)
+               pkgLine := tokFile.Line(file.Package)
+               for _, c := range file.Comments {
+                       if tokFile.Line(c.Pos()) > pkgLine {
+                               break
+                       }
+                       pos = c.End()
                }
-               dir = d
+               // Prepare the import block
+               impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
+               for i, importPath := range importPaths {
+                       var name *ast.Ident
+                       if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
+                               name = &ast.Ident{Name: importPath.Name, NamePos: pos}
+                       }
+                       value := strconv.Quote(importPath.Path)
+                       if i < len(importPaths)-1 {
+                               if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
+                                       value += "\n"
+                               }
+                       }
+                       impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
+                               Name:   name,
+                               Path:   &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
+                               EndPos: pos,
+                       })
+               }
+
+               file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
+       }
+       // Reformat source code
+       var out bytes.Buffer
+       cfg := &printer.Config{
+               Mode:     printer.TabIndent | printer.UseSpaces,
+               Tabwidth: 8,
        }
-       return ""
+       if err = cfg.Fprint(&out, fset, file); err != nil {
+               return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
+       }
+       return out.Bytes(), nil
 }
 
-var (
-       modulePathRE = regexp.MustCompile(`module[ \t]+([^ \t\r\n]+)`)
-)
-
-func findModulePath(file string) string {
-       data, err := ioutil.ReadFile(file)
-       if err != nil {
-               return ""
+func writeSourceTo(outputFile string, b []byte) error {
+       // create output directory
+       packageDir := filepath.Dir(outputFile)
+       if err := os.MkdirAll(packageDir, 0775); err != nil {
+               return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
        }
-       m := modulePathRE.FindSubmatch(data)
-       if m == nil {
-               return ""
+
+       // write generated code to output file
+       if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
+               return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
        }
-       return string(m[1])
+
+       lines := bytes.Count(b, []byte("\n"))
+       logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
+
+       return nil
 }