Improve binapi generator
[govpp.git] / binapigen / generator.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "go/ast"
22         "go/format"
23         "go/parser"
24         "go/printer"
25         "go/token"
26         "io/ioutil"
27         "os"
28         "path"
29         "path/filepath"
30         "sort"
31         "strconv"
32         "strings"
33
34         "github.com/sirupsen/logrus"
35
36         "git.fd.io/govpp.git/binapigen/vppapi"
37 )
38
39 type Generator struct {
40         Files       []*File
41         FilesByName map[string]*File
42
43         opts       Options
44         apifiles   []*vppapi.File
45         vppVersion string
46
47         filesToGen []string
48         genfiles   []*GenFile
49
50         enumsByName    map[string]*Enum
51         aliasesByName  map[string]*Alias
52         structsByName  map[string]*Struct
53         unionsByName   map[string]*Union
54         messagesByName map[string]*Message
55 }
56
57 func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator, error) {
58         gen := &Generator{
59                 FilesByName:    make(map[string]*File),
60                 opts:           opts,
61                 apifiles:       apifiles,
62                 filesToGen:     filesToGen,
63                 enumsByName:    map[string]*Enum{},
64                 aliasesByName:  map[string]*Alias{},
65                 structsByName:  map[string]*Struct{},
66                 unionsByName:   map[string]*Union{},
67                 messagesByName: map[string]*Message{},
68         }
69
70         // Normalize API files
71         SortFilesByImports(gen.apifiles)
72         for _, apifile := range apifiles {
73                 RemoveImportedTypes(gen.apifiles, apifile)
74                 SortFileObjectsByName(apifile)
75         }
76
77         // prepare package names and import paths
78         packageNames := make(map[string]GoPackageName)
79         importPaths := make(map[string]GoImportPath)
80         for _, apifile := range gen.apifiles {
81                 filename := getFilename(apifile)
82                 packageNames[filename] = cleanPackageName(apifile.Name)
83                 importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
84         }
85
86         logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
87
88         for _, apifile := range gen.apifiles {
89                 filename := getFilename(apifile)
90
91                 if _, ok := gen.FilesByName[apifile.Name]; ok {
92                         return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
93                 }
94
95                 file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
96                 if err != nil {
97                         return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
98                 }
99                 gen.Files = append(gen.Files, file)
100                 gen.FilesByName[apifile.Name] = file
101
102                 logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
103         }
104
105         // mark files for generation
106         if len(gen.filesToGen) > 0 {
107                 logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
108                 for _, genfile := range gen.filesToGen {
109                         file, ok := gen.FilesByName[genfile]
110                         if !ok {
111                                 return nil, fmt.Errorf("no API file found for: %v", genfile)
112                         }
113                         file.Generate = true
114                         // generate all imported files
115                         for _, impFile := range file.importedFiles(gen) {
116                                 impFile.Generate = true
117                         }
118                 }
119         } else {
120                 logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
121                 for _, file := range gen.Files {
122                         file.Generate = true
123                 }
124         }
125
126         return gen, nil
127 }
128
129 func getFilename(file *vppapi.File) string {
130         if file.Path == "" {
131                 return file.Name
132         }
133         return file.Path
134 }
135
136 func (g *Generator) Generate() error {
137         if len(g.genfiles) == 0 {
138                 return fmt.Errorf("no files to generate")
139         }
140
141         logrus.Infof("Generating %d files", len(g.genfiles))
142
143         for _, genfile := range g.genfiles {
144                 content, err := genfile.Content()
145                 if err != nil {
146                         return err
147                 }
148                 if err := writeSourceTo(genfile.filename, content); err != nil {
149                         return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
150                 }
151         }
152         return nil
153 }
154
155 type GenFile struct {
156         gen           *Generator
157         file          *File
158         filename      string
159         goImportPath  GoImportPath
160         buf           bytes.Buffer
161         manualImports map[GoImportPath]bool
162         packageNames  map[GoImportPath]GoPackageName
163 }
164
165 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
166         f := &GenFile{
167                 gen:           g,
168                 filename:      filename,
169                 goImportPath:  importPath,
170                 manualImports: make(map[GoImportPath]bool),
171                 packageNames:  make(map[GoImportPath]GoPackageName),
172         }
173         g.genfiles = append(g.genfiles, f)
174         return f
175 }
176
177 func (g *GenFile) Write(p []byte) (n int, err error) {
178         return g.buf.Write(p)
179 }
180
181 func (g *GenFile) Import(importPath GoImportPath) {
182         g.manualImports[importPath] = true
183 }
184
185 func (g *GenFile) GoIdent(ident GoIdent) string {
186         if ident.GoImportPath == g.goImportPath {
187                 return ident.GoName
188         }
189         if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
190                 return string(packageName) + "." + ident.GoName
191         }
192         packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
193         g.packageNames[ident.GoImportPath] = packageName
194         return string(packageName) + "." + ident.GoName
195 }
196
197 func (g *GenFile) P(v ...interface{}) {
198         for _, x := range v {
199                 switch x := x.(type) {
200                 case GoIdent:
201                         fmt.Fprint(&g.buf, g.GoIdent(x))
202                 default:
203                         fmt.Fprint(&g.buf, x)
204                 }
205         }
206         fmt.Fprintln(&g.buf)
207 }
208
209 func (g *GenFile) Content() ([]byte, error) {
210         if !strings.HasSuffix(g.filename, ".go") {
211                 return g.buf.Bytes(), nil
212         }
213         return g.injectImports(g.buf.Bytes())
214 }
215
216 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
217         // Parse source code
218         fset := token.NewFileSet()
219         file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
220         if err != nil {
221                 var src bytes.Buffer
222                 s := bufio.NewScanner(bytes.NewReader(original))
223                 for line := 1; s.Scan(); line++ {
224                         fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
225                 }
226                 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
227         }
228         type Import struct {
229                 Name string
230                 Path string
231         }
232         // Prepare list of all imports
233         var importPaths []Import
234         for importPath := range g.packageNames {
235                 importPaths = append(importPaths, Import{
236                         Name: string(g.packageNames[GoImportPath(importPath)]),
237                         Path: string(importPath),
238                 })
239         }
240         for importPath := range g.manualImports {
241                 if _, ok := g.packageNames[importPath]; ok {
242                         continue
243                 }
244                 importPaths = append(importPaths, Import{
245                         Name: "_",
246                         Path: string(importPath),
247                 })
248         }
249         // Sort imports by import path
250         sort.Slice(importPaths, func(i, j int) bool {
251                 return importPaths[i].Path < importPaths[j].Path
252         })
253         // Inject new import block into parsed AST
254         if len(importPaths) > 0 {
255                 // Find import block position
256                 pos := file.Package
257                 tokFile := fset.File(file.Package)
258                 pkgLine := tokFile.Line(file.Package)
259                 for _, c := range file.Comments {
260                         if tokFile.Line(c.Pos()) > pkgLine {
261                                 break
262                         }
263                         pos = c.End()
264                 }
265                 // Prepare the import block
266                 impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
267                 for _, importPath := range importPaths {
268                         var name *ast.Ident
269                         if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
270                                 name = &ast.Ident{Name: importPath.Name, NamePos: pos}
271                         }
272                         impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
273                                 Name:   name,
274                                 Path:   &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath.Path), ValuePos: pos},
275                                 EndPos: pos,
276                         })
277                 }
278
279                 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
280         }
281         // Reformat source code
282         var out bytes.Buffer
283         cfg := &printer.Config{
284                 Mode:     printer.TabIndent | printer.UseSpaces,
285                 Tabwidth: 8,
286         }
287         if err = cfg.Fprint(&out, fset, file); err != nil {
288                 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
289         }
290         return out.Bytes(), nil
291 }
292
293 // GoIdent is a Go identifier, consisting of a name and import path.
294 // The name is a single identifier and may not be a dot-qualified selector.
295 type GoIdent struct {
296         GoName       string
297         GoImportPath GoImportPath
298 }
299
300 func (id GoIdent) String() string {
301         return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName)
302 }
303
304 func newGoIdent(f *File, fullName string) GoIdent {
305         name := strings.TrimPrefix(fullName, string(f.PackageName)+".")
306         return GoIdent{
307                 GoName:       camelCaseName(name),
308                 GoImportPath: f.GoImportPath,
309         }
310 }
311
312 // GoImportPath is a Go import path for a package.
313 type GoImportPath string
314
315 func (p GoImportPath) String() string {
316         return strconv.Quote(string(p))
317 }
318
319 func (p GoImportPath) Ident(s string) GoIdent {
320         return GoIdent{GoName: s, GoImportPath: p}
321 }
322
323 type GoPackageName string
324
325 func cleanPackageName(name string) GoPackageName {
326         return GoPackageName(sanitizedName(name))
327 }
328
329 func sanitizedName(name string) string {
330         switch name {
331         case "interface":
332                 return "interfaces"
333         case "map":
334                 return "maps"
335         default:
336                 return name
337         }
338 }
339
340 // baseName returns the last path element of the name, with the last dotted suffix removed.
341 func baseName(name string) string {
342         // First, find the last element
343         if i := strings.LastIndex(name, "/"); i >= 0 {
344                 name = name[i+1:]
345         }
346         // Now drop the suffix
347         if i := strings.LastIndex(name, "."); i >= 0 {
348                 name = name[:i]
349         }
350         return name
351 }
352
353 func writeSourceTo(outputFile string, b []byte) error {
354         // create output directory
355         packageDir := filepath.Dir(outputFile)
356         if err := os.MkdirAll(packageDir, 0775); err != nil {
357                 return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
358         }
359
360         // format generated source code
361         gosrc, err := format.Source(b)
362         if err != nil {
363                 _ = ioutil.WriteFile(outputFile, b, 0666)
364                 return fmt.Errorf("formatting source code failed: %v", err)
365         }
366
367         // write generated code to output file
368         if err := ioutil.WriteFile(outputFile, gosrc, 0666); err != nil {
369                 return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
370         }
371
372         lines := bytes.Count(gosrc, []byte("\n"))
373         logf("wrote %d lines (%d bytes) to: %q", lines, len(gosrc), outputFile)
374
375         return nil
376 }