Improve doc & fix import ordering
[govpp.git] / binapigen / generator.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "go/ast"
22         "go/parser"
23         "go/printer"
24         "go/token"
25         "io/ioutil"
26         "os"
27         "path"
28         "path/filepath"
29         "sort"
30         "strconv"
31         "strings"
32
33         "github.com/sirupsen/logrus"
34
35         "git.fd.io/govpp.git/binapigen/vppapi"
36 )
37
38 type Generator struct {
39         Files       []*File
40         FilesByName map[string]*File
41         FilesByPath map[string]*File
42
43         opts       Options
44         apifiles   []*vppapi.File
45         vppVersion string
46
47         filesToGen []string
48         genfiles   []*GenFile
49
50         enumsByName    map[string]*Enum
51         aliasesByName  map[string]*Alias
52         structsByName  map[string]*Struct
53         unionsByName   map[string]*Union
54         messagesByName map[string]*Message
55 }
56
57 func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
58         gen := &Generator{
59                 FilesByName:    make(map[string]*File),
60                 FilesByPath:    make(map[string]*File),
61                 opts:           opts,
62                 apifiles:       apiFiles,
63                 filesToGen:     filesToGen,
64                 enumsByName:    map[string]*Enum{},
65                 aliasesByName:  map[string]*Alias{},
66                 structsByName:  map[string]*Struct{},
67                 unionsByName:   map[string]*Union{},
68                 messagesByName: map[string]*Message{},
69         }
70
71         // Normalize API files
72         SortFilesByImports(gen.apifiles)
73         for _, apiFile := range apiFiles {
74                 RemoveImportedTypes(gen.apifiles, apiFile)
75                 SortFileObjectsByName(apiFile)
76         }
77
78         // prepare package names and import paths
79         packageNames := make(map[string]GoPackageName)
80         importPaths := make(map[string]GoImportPath)
81         for _, apifile := range gen.apifiles {
82                 filename := getFilename(apifile)
83                 packageNames[filename] = cleanPackageName(apifile.Name)
84                 importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
85         }
86
87         logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
88
89         for _, apifile := range gen.apifiles {
90                 if _, ok := gen.FilesByName[apifile.Name]; ok {
91                         return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
92                 }
93
94                 filename := getFilename(apifile)
95                 file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
96                 if err != nil {
97                         return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
98                 }
99                 gen.Files = append(gen.Files, file)
100                 gen.FilesByName[apifile.Name] = file
101                 gen.FilesByPath[apifile.Path] = file
102
103                 logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
104         }
105
106         // mark files for generation
107         if len(gen.filesToGen) > 0 {
108                 logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
109                 for _, genFile := range gen.filesToGen {
110                         markGen := func(file *File) {
111                                 file.Generate = true
112                                 // generate all imported files
113                                 for _, impFile := range file.importedFiles(gen) {
114                                         impFile.Generate = true
115                                 }
116                         }
117                         if file, ok := gen.FilesByName[genFile]; ok {
118                                 markGen(file)
119                                 continue
120                         }
121                         logrus.Debugf("File %s was not found by name", genFile)
122                         if file, ok := gen.FilesByPath[genFile]; ok {
123                                 markGen(file)
124                                 continue
125                         }
126                         return nil, fmt.Errorf("no API file found for: %v", genFile)
127                 }
128         } else {
129                 logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
130                 for _, file := range gen.Files {
131                         file.Generate = true
132                 }
133         }
134
135         return gen, nil
136 }
137
138 func getFilename(file *vppapi.File) string {
139         if file.Path == "" {
140                 return file.Name
141         }
142         return file.Path
143 }
144
145 func (g *Generator) Generate() error {
146         if len(g.genfiles) == 0 {
147                 return fmt.Errorf("no files to generate")
148         }
149
150         logrus.Infof("Generating %d files", len(g.genfiles))
151
152         for _, genfile := range g.genfiles {
153                 content, err := genfile.Content()
154                 if err != nil {
155                         return err
156                 }
157                 if err := writeSourceTo(genfile.filename, content); err != nil {
158                         return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
159                 }
160         }
161         return nil
162 }
163
164 type GenFile struct {
165         gen           *Generator
166         file          *File
167         filename      string
168         goImportPath  GoImportPath
169         buf           bytes.Buffer
170         manualImports map[GoImportPath]bool
171         packageNames  map[GoImportPath]GoPackageName
172 }
173
174 // NewGenFile creates new generated file with
175 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
176         f := &GenFile{
177                 gen:           g,
178                 filename:      filename,
179                 goImportPath:  importPath,
180                 manualImports: make(map[GoImportPath]bool),
181                 packageNames:  make(map[GoImportPath]GoPackageName),
182         }
183         g.genfiles = append(g.genfiles, f)
184         return f
185 }
186
187 func (g *GenFile) Write(p []byte) (n int, err error) {
188         return g.buf.Write(p)
189 }
190
191 func (g *GenFile) Import(importPath GoImportPath) {
192         g.manualImports[importPath] = true
193 }
194
195 func (g *GenFile) GoIdent(ident GoIdent) string {
196         if ident.GoImportPath == g.goImportPath {
197                 return ident.GoName
198         }
199         if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
200                 return string(packageName) + "." + ident.GoName
201         }
202         packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
203         g.packageNames[ident.GoImportPath] = packageName
204         return string(packageName) + "." + ident.GoName
205 }
206
207 func (g *GenFile) P(v ...interface{}) {
208         for _, x := range v {
209                 switch x := x.(type) {
210                 case GoIdent:
211                         fmt.Fprint(&g.buf, g.GoIdent(x))
212                 default:
213                         fmt.Fprint(&g.buf, x)
214                 }
215         }
216         fmt.Fprintln(&g.buf)
217 }
218
219 func (g *GenFile) Content() ([]byte, error) {
220         if !strings.HasSuffix(g.filename, ".go") {
221                 return g.buf.Bytes(), nil
222         }
223         return g.injectImports(g.buf.Bytes())
224 }
225
226 func getImportClass(importPath string) int {
227         if !strings.Contains(importPath, ".") {
228                 return 0 /* std */
229         }
230         return 1 /* External */
231 }
232
233 // injectImports parses source, injects import block declaration with all imports and return formatted
234 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
235         // Parse source code
236         fset := token.NewFileSet()
237         file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
238         if err != nil {
239                 var src bytes.Buffer
240                 s := bufio.NewScanner(bytes.NewReader(original))
241                 for line := 1; s.Scan(); line++ {
242                         fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
243                 }
244                 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
245         }
246         type Import struct {
247                 Name string
248                 Path string
249         }
250         // Prepare list of all imports
251         var importPaths []Import
252         for importPath := range g.packageNames {
253                 importPaths = append(importPaths, Import{
254                         Name: string(g.packageNames[importPath]),
255                         Path: string(importPath),
256                 })
257         }
258         for importPath := range g.manualImports {
259                 if _, ok := g.packageNames[importPath]; ok {
260                         continue
261                 }
262                 importPaths = append(importPaths, Import{
263                         Name: "_",
264                         Path: string(importPath),
265                 })
266         }
267         // Sort imports by import path
268         sort.Slice(importPaths, func(i, j int) bool {
269                 ci := getImportClass(importPaths[i].Path)
270                 cj := getImportClass(importPaths[j].Path)
271                 if ci == cj {
272                         return importPaths[i].Path < importPaths[j].Path
273                 }
274                 return ci < cj
275         })
276         // Inject new import block into parsed AST
277         if len(importPaths) > 0 {
278                 // Find import block position
279                 pos := file.Package
280                 tokFile := fset.File(file.Package)
281                 pkgLine := tokFile.Line(file.Package)
282                 for _, c := range file.Comments {
283                         if tokFile.Line(c.Pos()) > pkgLine {
284                                 break
285                         }
286                         pos = c.End()
287                 }
288                 // Prepare the import block
289                 impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
290                 for i, importPath := range importPaths {
291                         var name *ast.Ident
292                         if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
293                                 name = &ast.Ident{Name: importPath.Name, NamePos: pos}
294                         }
295                         value := strconv.Quote(importPath.Path)
296                         if i < len(importPaths)-1 {
297                                 if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
298                                         value += "\n"
299                                 }
300                         }
301                         impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
302                                 Name:   name,
303                                 Path:   &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
304                                 EndPos: pos,
305                         })
306                 }
307
308                 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
309         }
310         // Reformat source code
311         var out bytes.Buffer
312         cfg := &printer.Config{
313                 Mode:     printer.TabIndent | printer.UseSpaces,
314                 Tabwidth: 8,
315         }
316         if err = cfg.Fprint(&out, fset, file); err != nil {
317                 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
318         }
319         return out.Bytes(), nil
320 }
321
322 func writeSourceTo(outputFile string, b []byte) error {
323         // create output directory
324         packageDir := filepath.Dir(outputFile)
325         if err := os.MkdirAll(packageDir, 0775); err != nil {
326                 return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
327         }
328
329         // write generated code to output file
330         if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
331                 return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
332         }
333
334         lines := bytes.Count(b, []byte("\n"))
335         logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
336
337         return nil
338 }