Improve doc & fix import ordering
[govpp.git] / binapigen / generator.go
index 9471462..e5eed5a 100644 (file)
 package binapigen
 
 import (
+       "bufio"
        "bytes"
        "fmt"
-       "go/format"
+       "go/ast"
+       "go/parser"
+       "go/printer"
+       "go/token"
        "io/ioutil"
        "os"
+       "path"
        "path/filepath"
+       "sort"
+       "strconv"
+       "strings"
 
        "github.com/sirupsen/logrus"
 
        "git.fd.io/govpp.git/binapigen/vppapi"
 )
 
-type Options struct {
-       VPPVersion string // version of VPP that produced API files
+type Generator struct {
+       Files       []*File
+       FilesByName map[string]*File
+       FilesByPath map[string]*File
+
+       opts       Options
+       apifiles   []*vppapi.File
+       vppVersion string
 
-       FilesToGenerate []string // list of API files to generate
+       filesToGen []string
+       genfiles   []*GenFile
 
-       ImportPrefix       string // defines import path prefix for importing types
-       ImportTypes        bool   // generate packages for import types
-       IncludeAPIVersion  bool   // include constant with API version string
-       IncludeComments    bool   // include parts of original source in comments
-       IncludeBinapiNames bool   // include binary API names as struct tag
-       IncludeServices    bool   // include service interface with client implementation
-       IncludeVppVersion  bool   // include info about used VPP version
+       enumsByName    map[string]*Enum
+       aliasesByName  map[string]*Alias
+       structsByName  map[string]*Struct
+       unionsByName   map[string]*Union
+       messagesByName map[string]*Message
 }
 
-type Generator struct {
-       Options
+func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
+       gen := &Generator{
+               FilesByName:    make(map[string]*File),
+               FilesByPath:    make(map[string]*File),
+               opts:           opts,
+               apifiles:       apiFiles,
+               filesToGen:     filesToGen,
+               enumsByName:    map[string]*Enum{},
+               aliasesByName:  map[string]*Alias{},
+               structsByName:  map[string]*Struct{},
+               unionsByName:   map[string]*Union{},
+               messagesByName: map[string]*Message{},
+       }
 
-       Files       []*File
-       FilesByPath map[string]*File
-       FilesByName map[string]*File
+       // Normalize API files
+       SortFilesByImports(gen.apifiles)
+       for _, apiFile := range apiFiles {
+               RemoveImportedTypes(gen.apifiles, apiFile)
+               SortFileObjectsByName(apiFile)
+       }
 
-       enumsByName   map[string]*Enum
-       aliasesByName map[string]*Alias
-       structsByName map[string]*Struct
-       unionsByName  map[string]*Union
+       // prepare package names and import paths
+       packageNames := make(map[string]GoPackageName)
+       importPaths := make(map[string]GoImportPath)
+       for _, apifile := range gen.apifiles {
+               filename := getFilename(apifile)
+               packageNames[filename] = cleanPackageName(apifile.Name)
+               importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
+       }
 
-       genfiles []*GenFile
-}
+       logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
 
-func New(opts Options, apifiles []*vppapi.File) (*Generator, error) {
-       g := &Generator{
-               Options:       opts,
-               FilesByPath:   make(map[string]*File),
-               FilesByName:   make(map[string]*File),
-               enumsByName:   map[string]*Enum{},
-               aliasesByName: map[string]*Alias{},
-               structsByName: map[string]*Struct{},
-               unionsByName:  map[string]*Union{},
-       }
-
-       logrus.Debugf("adding %d VPP API files to generator", len(apifiles))
-       for _, apifile := range apifiles {
-               filename := apifile.Path
-               if filename == "" {
-                       filename = apifile.Name
-               }
-               if _, ok := g.FilesByPath[filename]; ok {
-                       return nil, fmt.Errorf("duplicate file name: %q", filename)
-               }
-               if _, ok := g.FilesByName[apifile.Name]; ok {
+       for _, apifile := range gen.apifiles {
+               if _, ok := gen.FilesByName[apifile.Name]; ok {
                        return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
                }
 
-               file, err := newFile(g, apifile)
+               filename := getFilename(apifile)
+               file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
                if err != nil {
-                       return nil, err
+                       return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
                }
-               g.Files = append(g.Files, file)
-               g.FilesByPath[filename] = file
-               g.FilesByName[apifile.Name] = file
+               gen.Files = append(gen.Files, file)
+               gen.FilesByName[apifile.Name] = file
+               gen.FilesByPath[apifile.Path] = file
 
                logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
-               if len(file.Imports) > 0 {
-                       logrus.Debugf(" - %d imports: %v", len(file.Imports), file.Imports)
-               }
        }
 
-       logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate)
-       for _, genfile := range opts.FilesToGenerate {
-               file, ok := g.FilesByPath[genfile]
-               if !ok {
-                       file, ok = g.FilesByName[genfile]
-                       if !ok {
-                               return nil, fmt.Errorf("no API file found for: %v", genfile)
+       // mark files for generation
+       if len(gen.filesToGen) > 0 {
+               logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
+               for _, genFile := range gen.filesToGen {
+                       markGen := func(file *File) {
+                               file.Generate = true
+                               // generate all imported files
+                               for _, impFile := range file.importedFiles(gen) {
+                                       impFile.Generate = true
+                               }
                        }
-               }
-               file.Generate = true
-               if opts.ImportTypes {
-                       for _, impFile := range file.importedFiles(g) {
-                               impFile.Generate = true
+                       if file, ok := gen.FilesByName[genFile]; ok {
+                               markGen(file)
+                               continue
+                       }
+                       logrus.Debugf("File %s was not found by name", genFile)
+                       if file, ok := gen.FilesByPath[genFile]; ok {
+                               markGen(file)
+                               continue
                        }
+                       return nil, fmt.Errorf("no API file found for: %v", genFile)
                }
-       }
-
-       logrus.Debugf("Resolving imported types")
-       for _, file := range g.Files {
-               if !file.Generate {
-                       continue
+       } else {
+               logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
+               for _, file := range gen.Files {
+                       file.Generate = true
                }
-               importedFiles := file.importedFiles(g)
-               file.loadTypeImports(g, importedFiles)
        }
 
-       return g, nil
+       return gen, nil
+}
+
+func getFilename(file *vppapi.File) string {
+       if file.Path == "" {
+               return file.Name
+       }
+       return file.Path
 }
 
 func (g *Generator) Generate() error {
@@ -129,23 +148,177 @@ func (g *Generator) Generate() error {
        }
 
        logrus.Infof("Generating %d files", len(g.genfiles))
+
        for _, genfile := range g.genfiles {
-               if err := writeSourceTo(genfile.filename, genfile.buf.Bytes()); err != nil {
-                       return fmt.Errorf("writing source for RPC package %s failed: %v", genfile.filename, err)
+               content, err := genfile.Content()
+               if err != nil {
+                       return err
+               }
+               if err := writeSourceTo(genfile.filename, content); err != nil {
+                       return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
                }
        }
        return nil
 }
 
-func (g *Generator) NewGenFile(filename string) *GenFile {
+type GenFile struct {
+       gen           *Generator
+       file          *File
+       filename      string
+       goImportPath  GoImportPath
+       buf           bytes.Buffer
+       manualImports map[GoImportPath]bool
+       packageNames  map[GoImportPath]GoPackageName
+}
+
+// NewGenFile creates new generated file with
+func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
        f := &GenFile{
-               Generator: g,
-               filename:  filename,
+               gen:           g,
+               filename:      filename,
+               goImportPath:  importPath,
+               manualImports: make(map[GoImportPath]bool),
+               packageNames:  make(map[GoImportPath]GoPackageName),
        }
        g.genfiles = append(g.genfiles, f)
        return f
 }
 
+func (g *GenFile) Write(p []byte) (n int, err error) {
+       return g.buf.Write(p)
+}
+
+func (g *GenFile) Import(importPath GoImportPath) {
+       g.manualImports[importPath] = true
+}
+
+func (g *GenFile) GoIdent(ident GoIdent) string {
+       if ident.GoImportPath == g.goImportPath {
+               return ident.GoName
+       }
+       if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
+               return string(packageName) + "." + ident.GoName
+       }
+       packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
+       g.packageNames[ident.GoImportPath] = packageName
+       return string(packageName) + "." + ident.GoName
+}
+
+func (g *GenFile) P(v ...interface{}) {
+       for _, x := range v {
+               switch x := x.(type) {
+               case GoIdent:
+                       fmt.Fprint(&g.buf, g.GoIdent(x))
+               default:
+                       fmt.Fprint(&g.buf, x)
+               }
+       }
+       fmt.Fprintln(&g.buf)
+}
+
+func (g *GenFile) Content() ([]byte, error) {
+       if !strings.HasSuffix(g.filename, ".go") {
+               return g.buf.Bytes(), nil
+       }
+       return g.injectImports(g.buf.Bytes())
+}
+
+func getImportClass(importPath string) int {
+       if !strings.Contains(importPath, ".") {
+               return 0 /* std */
+       }
+       return 1 /* External */
+}
+
+// injectImports parses source, injects import block declaration with all imports and return formatted
+func (g *GenFile) injectImports(original []byte) ([]byte, error) {
+       // Parse source code
+       fset := token.NewFileSet()
+       file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
+       if err != nil {
+               var src bytes.Buffer
+               s := bufio.NewScanner(bytes.NewReader(original))
+               for line := 1; s.Scan(); line++ {
+                       fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
+               }
+               return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
+       }
+       type Import struct {
+               Name string
+               Path string
+       }
+       // Prepare list of all imports
+       var importPaths []Import
+       for importPath := range g.packageNames {
+               importPaths = append(importPaths, Import{
+                       Name: string(g.packageNames[importPath]),
+                       Path: string(importPath),
+               })
+       }
+       for importPath := range g.manualImports {
+               if _, ok := g.packageNames[importPath]; ok {
+                       continue
+               }
+               importPaths = append(importPaths, Import{
+                       Name: "_",
+                       Path: string(importPath),
+               })
+       }
+       // Sort imports by import path
+       sort.Slice(importPaths, func(i, j int) bool {
+               ci := getImportClass(importPaths[i].Path)
+               cj := getImportClass(importPaths[j].Path)
+               if ci == cj {
+                       return importPaths[i].Path < importPaths[j].Path
+               }
+               return ci < cj
+       })
+       // Inject new import block into parsed AST
+       if len(importPaths) > 0 {
+               // Find import block position
+               pos := file.Package
+               tokFile := fset.File(file.Package)
+               pkgLine := tokFile.Line(file.Package)
+               for _, c := range file.Comments {
+                       if tokFile.Line(c.Pos()) > pkgLine {
+                               break
+                       }
+                       pos = c.End()
+               }
+               // Prepare the import block
+               impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
+               for i, importPath := range importPaths {
+                       var name *ast.Ident
+                       if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
+                               name = &ast.Ident{Name: importPath.Name, NamePos: pos}
+                       }
+                       value := strconv.Quote(importPath.Path)
+                       if i < len(importPaths)-1 {
+                               if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
+                                       value += "\n"
+                               }
+                       }
+                       impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
+                               Name:   name,
+                               Path:   &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
+                               EndPos: pos,
+                       })
+               }
+
+               file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
+       }
+       // Reformat source code
+       var out bytes.Buffer
+       cfg := &printer.Config{
+               Mode:     printer.TabIndent | printer.UseSpaces,
+               Tabwidth: 8,
+       }
+       if err = cfg.Fprint(&out, fset, file); err != nil {
+               return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
+       }
+       return out.Bytes(), nil
+}
+
 func writeSourceTo(outputFile string, b []byte) error {
        // create output directory
        packageDir := filepath.Dir(outputFile)
@@ -153,20 +326,13 @@ func writeSourceTo(outputFile string, b []byte) error {
                return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
        }
 
-       // format generated source code
-       gosrc, err := format.Source(b)
-       if err != nil {
-               _ = ioutil.WriteFile(outputFile, b, 0666)
-               return fmt.Errorf("formatting source code failed: %v", err)
-       }
-
        // write generated code to output file
-       if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
+       if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
                return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
        }
 
-       lines := bytes.Count(gosrc, []byte("\n"))
-       logf("wrote %d lines (%d bytes) of code to: %q", lines, len(gosrc), outputFile)
+       lines := bytes.Count(b, []byte("\n"))
+       logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
 
        return nil
 }