Fix codec fallback and generate type imports
[govpp.git] / binapigen / generator.go
index 9471462..07c1b13 100644 (file)
@@ -20,7 +20,9 @@ import (
        "go/format"
        "io/ioutil"
        "os"
+       "path"
        "path/filepath"
+       "regexp"
 
        "github.com/sirupsen/logrus"
 
@@ -94,29 +96,45 @@ func New(opts Options, apifiles []*vppapi.File) (*Generator, error) {
                }
        }
 
-       logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate)
-       for _, genfile := range opts.FilesToGenerate {
-               file, ok := g.FilesByPath[genfile]
-               if !ok {
-                       file, ok = g.FilesByName[genfile]
+       if len(opts.FilesToGenerate) > 0 {
+               logrus.Debugf("Checking %d files to generate: %v", len(opts.FilesToGenerate), opts.FilesToGenerate)
+               for _, genfile := range opts.FilesToGenerate {
+                       file, ok := g.FilesByPath[genfile]
                        if !ok {
-                               return nil, fmt.Errorf("no API file found for: %v", genfile)
+                               file, ok = g.FilesByName[genfile]
+                               if !ok {
+                                       return nil, fmt.Errorf("no API file found for: %v", genfile)
+                               }
                        }
-               }
-               file.Generate = true
-               if opts.ImportTypes {
-                       for _, impFile := range file.importedFiles(g) {
-                               impFile.Generate = true
+                       file.Generate = true
+                       if opts.ImportTypes {
+                               // generate all imported files
+                               for _, impFile := range file.importedFiles(g) {
+                                       impFile.Generate = true
+                               }
                        }
                }
+       } else {
+               logrus.Debugf("Files to generate not specified, marking all %d files to generate", len(g.Files))
+               for _, file := range g.Files {
+                       file.Generate = true
+               }
        }
 
        logrus.Debugf("Resolving imported types")
        for _, file := range g.Files {
                if !file.Generate {
+                       // skip resolving for non-generated files
                        continue
                }
-               importedFiles := file.importedFiles(g)
+               var importedFiles []*File
+               for _, impFile := range file.importedFiles(g) {
+                       if !impFile.Generate {
+                               // exclude imports of non-generated files
+                               continue
+                       }
+                       importedFiles = append(importedFiles, impFile)
+               }
                file.loadTypeImports(g, importedFiles)
        }
 
@@ -130,13 +148,21 @@ func (g *Generator) Generate() error {
 
        logrus.Infof("Generating %d files", len(g.genfiles))
        for _, genfile := range g.genfiles {
-               if err := writeSourceTo(genfile.filename, genfile.buf.Bytes()); err != nil {
+               if err := writeSourceTo(genfile.filename, genfile.Content()); err != nil {
                        return fmt.Errorf("writing source for RPC package %s failed: %v", genfile.filename, err)
                }
        }
        return nil
 }
 
+type GenFile struct {
+       *Generator
+       filename  string
+       file      *File
+       outputDir string
+       buf       bytes.Buffer
+}
+
 func (g *Generator) NewGenFile(filename string) *GenFile {
        f := &GenFile{
                Generator: g,
@@ -146,6 +172,10 @@ func (g *Generator) NewGenFile(filename string) *GenFile {
        return f
 }
 
+func (f *GenFile) Content() []byte {
+       return f.buf.Bytes()
+}
+
 func writeSourceTo(outputFile string, b []byte) error {
        // create output directory
        packageDir := filepath.Dir(outputFile)
@@ -170,3 +200,74 @@ func writeSourceTo(outputFile string, b []byte) error {
 
        return nil
 }
+
+func listImports(genfile *GenFile) map[string]string {
+       var importPath = genfile.ImportPrefix
+       if importPath == "" {
+               importPath = resolveImportPath(genfile.outputDir)
+               logrus.Debugf("resolved import path: %s", importPath)
+       }
+       imports := map[string]string{}
+       for _, imp := range genfile.file.imports {
+               if _, ok := imports[imp]; !ok {
+                       imports[imp] = path.Join(importPath, imp)
+               }
+       }
+       return imports
+}
+
+func resolveImportPath(outputDir string) string {
+       absPath, err := filepath.Abs(outputDir)
+       if err != nil {
+               panic(err)
+       }
+       modRoot := findModuleRoot(absPath)
+       if modRoot == "" {
+               logrus.Fatalf("module root not found at: %s", absPath)
+       }
+       modPath := findModulePath(path.Join(modRoot, "go.mod"))
+       if modPath == "" {
+               logrus.Fatalf("module path not found")
+       }
+       relDir, err := filepath.Rel(modRoot, absPath)
+       if err != nil {
+               panic(err)
+       }
+       return filepath.Join(modPath, relDir)
+}
+
+func findModuleRoot(dir string) (root string) {
+       if dir == "" {
+               panic("dir not set")
+       }
+       dir = filepath.Clean(dir)
+
+       // Look for enclosing go.mod.
+       for {
+               if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
+                       return dir
+               }
+               d := filepath.Dir(dir)
+               if d == dir {
+                       break
+               }
+               dir = d
+       }
+       return ""
+}
+
+var (
+       modulePathRE = regexp.MustCompile(`module[ \t]+([^ \t\r\n]+)`)
+)
+
+func findModulePath(file string) string {
+       data, err := ioutil.ReadFile(file)
+       if err != nil {
+               return ""
+       }
+       m := modulePathRE.FindSubmatch(data)
+       if m == nil {
+               return ""
+       }
+       return string(m[1])
+}