Improve doc & fix import ordering
[govpp.git] / binapigen / generator.go
index e42e7fb..e5eed5a 100644 (file)
@@ -19,7 +19,6 @@ import (
        "bytes"
        "fmt"
        "go/ast"
-       "go/format"
        "go/parser"
        "go/printer"
        "go/token"
@@ -39,6 +38,7 @@ import (
 type Generator struct {
        Files       []*File
        FilesByName map[string]*File
+       FilesByPath map[string]*File
 
        opts       Options
        apifiles   []*vppapi.File
@@ -54,11 +54,12 @@ type Generator struct {
        messagesByName map[string]*Message
 }
 
-func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator, error) {
+func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
        gen := &Generator{
                FilesByName:    make(map[string]*File),
+               FilesByPath:    make(map[string]*File),
                opts:           opts,
-               apifiles:       apifiles,
+               apifiles:       apiFiles,
                filesToGen:     filesToGen,
                enumsByName:    map[string]*Enum{},
                aliasesByName:  map[string]*Alias{},
@@ -69,9 +70,9 @@ func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator
 
        // Normalize API files
        SortFilesByImports(gen.apifiles)
-       for _, apifile := range apifiles {
-               RemoveImportedTypes(gen.apifiles, apifile)
-               SortFileObjectsByName(apifile)
+       for _, apiFile := range apiFiles {
+               RemoveImportedTypes(gen.apifiles, apiFile)
+               SortFileObjectsByName(apiFile)
        }
 
        // prepare package names and import paths
@@ -86,18 +87,18 @@ func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator
        logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
 
        for _, apifile := range gen.apifiles {
-               filename := getFilename(apifile)
-
                if _, ok := gen.FilesByName[apifile.Name]; ok {
                        return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
                }
 
+               filename := getFilename(apifile)
                file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
                if err != nil {
                        return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
                }
                gen.Files = append(gen.Files, file)
                gen.FilesByName[apifile.Name] = file
+               gen.FilesByPath[apifile.Path] = file
 
                logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
        }
@@ -105,16 +106,24 @@ func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator
        // mark files for generation
        if len(gen.filesToGen) > 0 {
                logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
-               for _, genfile := range gen.filesToGen {
-                       file, ok := gen.FilesByName[genfile]
-                       if !ok {
-                               return nil, fmt.Errorf("no API file found for: %v", genfile)
+               for _, genFile := range gen.filesToGen {
+                       markGen := func(file *File) {
+                               file.Generate = true
+                               // generate all imported files
+                               for _, impFile := range file.importedFiles(gen) {
+                                       impFile.Generate = true
+                               }
                        }
-                       file.Generate = true
-                       // generate all imported files
-                       for _, impFile := range file.importedFiles(gen) {
-                               impFile.Generate = true
+                       if file, ok := gen.FilesByName[genFile]; ok {
+                               markGen(file)
+                               continue
+                       }
+                       logrus.Debugf("File %s was not found by name", genFile)
+                       if file, ok := gen.FilesByPath[genFile]; ok {
+                               markGen(file)
+                               continue
                        }
+                       return nil, fmt.Errorf("no API file found for: %v", genFile)
                }
        } else {
                logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
@@ -162,6 +171,7 @@ type GenFile struct {
        packageNames  map[GoImportPath]GoPackageName
 }
 
+// NewGenFile creates new generated file with
 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
        f := &GenFile{
                gen:           g,
@@ -213,6 +223,14 @@ func (g *GenFile) Content() ([]byte, error) {
        return g.injectImports(g.buf.Bytes())
 }
 
+func getImportClass(importPath string) int {
+       if !strings.Contains(importPath, ".") {
+               return 0 /* std */
+       }
+       return 1 /* External */
+}
+
+// injectImports parses source, injects import block declaration with all imports and return formatted
 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
        // Parse source code
        fset := token.NewFileSet()
@@ -233,7 +251,7 @@ func (g *GenFile) injectImports(original []byte) ([]byte, error) {
        var importPaths []Import
        for importPath := range g.packageNames {
                importPaths = append(importPaths, Import{
-                       Name: string(g.packageNames[GoImportPath(importPath)]),
+                       Name: string(g.packageNames[importPath]),
                        Path: string(importPath),
                })
        }
@@ -248,7 +266,12 @@ func (g *GenFile) injectImports(original []byte) ([]byte, error) {
        }
        // Sort imports by import path
        sort.Slice(importPaths, func(i, j int) bool {
-               return importPaths[i].Path < importPaths[j].Path
+               ci := getImportClass(importPaths[i].Path)
+               cj := getImportClass(importPaths[j].Path)
+               if ci == cj {
+                       return importPaths[i].Path < importPaths[j].Path
+               }
+               return ci < cj
        })
        // Inject new import block into parsed AST
        if len(importPaths) > 0 {
@@ -264,14 +287,20 @@ func (g *GenFile) injectImports(original []byte) ([]byte, error) {
                }
                // Prepare the import block
                impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
-               for _, importPath := range importPaths {
+               for i, importPath := range importPaths {
                        var name *ast.Ident
                        if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
                                name = &ast.Ident{Name: importPath.Name, NamePos: pos}
                        }
+                       value := strconv.Quote(importPath.Path)
+                       if i < len(importPaths)-1 {
+                               if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
+                                       value += "\n"
+                               }
+                       }
                        impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
                                Name:   name,
-                               Path:   &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath.Path), ValuePos: pos},
+                               Path:   &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
                                EndPos: pos,
                        })
                }
@@ -290,66 +319,6 @@ func (g *GenFile) injectImports(original []byte) ([]byte, error) {
        return out.Bytes(), nil
 }
 
-// GoIdent is a Go identifier, consisting of a name and import path.
-// The name is a single identifier and may not be a dot-qualified selector.
-type GoIdent struct {
-       GoName       string
-       GoImportPath GoImportPath
-}
-
-func (id GoIdent) String() string {
-       return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName)
-}
-
-func newGoIdent(f *File, fullName string) GoIdent {
-       name := strings.TrimPrefix(fullName, string(f.PackageName)+".")
-       return GoIdent{
-               GoName:       camelCaseName(name),
-               GoImportPath: f.GoImportPath,
-       }
-}
-
-// GoImportPath is a Go import path for a package.
-type GoImportPath string
-
-func (p GoImportPath) String() string {
-       return strconv.Quote(string(p))
-}
-
-func (p GoImportPath) Ident(s string) GoIdent {
-       return GoIdent{GoName: s, GoImportPath: p}
-}
-
-type GoPackageName string
-
-func cleanPackageName(name string) GoPackageName {
-       return GoPackageName(sanitizedName(name))
-}
-
-func sanitizedName(name string) string {
-       switch name {
-       case "interface":
-               return "interfaces"
-       case "map":
-               return "maps"
-       default:
-               return name
-       }
-}
-
-// baseName returns the last path element of the name, with the last dotted suffix removed.
-func baseName(name string) string {
-       // First, find the last element
-       if i := strings.LastIndex(name, "/"); i >= 0 {
-               name = name[i+1:]
-       }
-       // Now drop the suffix
-       if i := strings.LastIndex(name, "."); i >= 0 {
-               name = name[:i]
-       }
-       return name
-}
-
 func writeSourceTo(outputFile string, b []byte) error {
        // create output directory
        packageDir := filepath.Dir(outputFile)
@@ -357,20 +326,13 @@ func writeSourceTo(outputFile string, b []byte) error {
                return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
        }
 
-       // format generated source code
-       gosrc, err := format.Source(b)
-       if err != nil {
-               _ = ioutil.WriteFile(outputFile, b, 0666)
-               return fmt.Errorf("formatting source code failed: %v", err)
-       }
-
        // write generated code to output file
-       if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
+       if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
                return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
        }
 
-       lines := bytes.Count(gosrc, []byte("\n"))
-       logf("wrote %d lines (%d bytes) to: %q", lines, len(gosrc), outputFile)
+       lines := bytes.Count(b, []byte("\n"))
+       logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
 
        return nil
 }