Fix binapigen decoding and minor improvements
[govpp.git] / binapigen / generator.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "go/ast"
22         "go/parser"
23         "go/printer"
24         "go/token"
25         "io/ioutil"
26         "os"
27         "path"
28         "path/filepath"
29         "sort"
30         "strconv"
31         "strings"
32
33         "github.com/sirupsen/logrus"
34
35         "git.fd.io/govpp.git/binapigen/vppapi"
36 )
37
38 type Generator struct {
39         Files       []*File
40         FilesByName map[string]*File
41
42         opts       Options
43         apifiles   []*vppapi.File
44         vppVersion string
45
46         filesToGen []string
47         genfiles   []*GenFile
48
49         enumsByName    map[string]*Enum
50         aliasesByName  map[string]*Alias
51         structsByName  map[string]*Struct
52         unionsByName   map[string]*Union
53         messagesByName map[string]*Message
54 }
55
56 func New(opts Options, apifiles []*vppapi.File, filesToGen []string) (*Generator, error) {
57         gen := &Generator{
58                 FilesByName:    make(map[string]*File),
59                 opts:           opts,
60                 apifiles:       apifiles,
61                 filesToGen:     filesToGen,
62                 enumsByName:    map[string]*Enum{},
63                 aliasesByName:  map[string]*Alias{},
64                 structsByName:  map[string]*Struct{},
65                 unionsByName:   map[string]*Union{},
66                 messagesByName: map[string]*Message{},
67         }
68
69         // Normalize API files
70         SortFilesByImports(gen.apifiles)
71         for _, apifile := range apifiles {
72                 RemoveImportedTypes(gen.apifiles, apifile)
73                 SortFileObjectsByName(apifile)
74         }
75
76         // prepare package names and import paths
77         packageNames := make(map[string]GoPackageName)
78         importPaths := make(map[string]GoImportPath)
79         for _, apifile := range gen.apifiles {
80                 filename := getFilename(apifile)
81                 packageNames[filename] = cleanPackageName(apifile.Name)
82                 importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
83         }
84
85         logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
86
87         for _, apifile := range gen.apifiles {
88                 if _, ok := gen.FilesByName[apifile.Name]; ok {
89                         return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
90                 }
91
92                 filename := getFilename(apifile)
93                 file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
94                 if err != nil {
95                         return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
96                 }
97                 gen.Files = append(gen.Files, file)
98                 gen.FilesByName[apifile.Name] = file
99
100                 logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
101         }
102
103         // mark files for generation
104         if len(gen.filesToGen) > 0 {
105                 logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
106                 for _, genfile := range gen.filesToGen {
107                         file, ok := gen.FilesByName[genfile]
108                         if !ok {
109                                 return nil, fmt.Errorf("nol API file found for: %v", genfile)
110                         }
111                         file.Generate = true
112                         // generate all imported files
113                         for _, impFile := range file.importedFiles(gen) {
114                                 impFile.Generate = true
115                         }
116                 }
117         } else {
118                 logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
119                 for _, file := range gen.Files {
120                         file.Generate = true
121                 }
122         }
123
124         return gen, nil
125 }
126
127 func getFilename(file *vppapi.File) string {
128         if file.Path == "" {
129                 return file.Name
130         }
131         return file.Path
132 }
133
134 func (g *Generator) Generate() error {
135         if len(g.genfiles) == 0 {
136                 return fmt.Errorf("no files to generate")
137         }
138
139         logrus.Infof("Generating %d files", len(g.genfiles))
140
141         for _, genfile := range g.genfiles {
142                 content, err := genfile.Content()
143                 if err != nil {
144                         return err
145                 }
146                 if err := writeSourceTo(genfile.filename, content); err != nil {
147                         return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
148                 }
149         }
150         return nil
151 }
152
153 type GenFile struct {
154         gen           *Generator
155         file          *File
156         filename      string
157         goImportPath  GoImportPath
158         buf           bytes.Buffer
159         manualImports map[GoImportPath]bool
160         packageNames  map[GoImportPath]GoPackageName
161 }
162
163 // NewGenFile creates new generated file with
164 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
165         f := &GenFile{
166                 gen:           g,
167                 filename:      filename,
168                 goImportPath:  importPath,
169                 manualImports: make(map[GoImportPath]bool),
170                 packageNames:  make(map[GoImportPath]GoPackageName),
171         }
172         g.genfiles = append(g.genfiles, f)
173         return f
174 }
175
176 func (g *GenFile) Write(p []byte) (n int, err error) {
177         return g.buf.Write(p)
178 }
179
180 func (g *GenFile) Import(importPath GoImportPath) {
181         g.manualImports[importPath] = true
182 }
183
184 func (g *GenFile) GoIdent(ident GoIdent) string {
185         if ident.GoImportPath == g.goImportPath {
186                 return ident.GoName
187         }
188         if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
189                 return string(packageName) + "." + ident.GoName
190         }
191         packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
192         g.packageNames[ident.GoImportPath] = packageName
193         return string(packageName) + "." + ident.GoName
194 }
195
196 func (g *GenFile) P(v ...interface{}) {
197         for _, x := range v {
198                 switch x := x.(type) {
199                 case GoIdent:
200                         fmt.Fprint(&g.buf, g.GoIdent(x))
201                 default:
202                         fmt.Fprint(&g.buf, x)
203                 }
204         }
205         fmt.Fprintln(&g.buf)
206 }
207
208 func (g *GenFile) Content() ([]byte, error) {
209         if !strings.HasSuffix(g.filename, ".go") {
210                 return g.buf.Bytes(), nil
211         }
212         return g.injectImports(g.buf.Bytes())
213 }
214
215 // injectImports parses source, injects import block declaration with all imports and return formatted
216 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
217         // Parse source code
218         fset := token.NewFileSet()
219         file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
220         if err != nil {
221                 var src bytes.Buffer
222                 s := bufio.NewScanner(bytes.NewReader(original))
223                 for line := 1; s.Scan(); line++ {
224                         fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
225                 }
226                 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
227         }
228         type Import struct {
229                 Name string
230                 Path string
231         }
232         // Prepare list of all imports
233         var importPaths []Import
234         for importPath := range g.packageNames {
235                 importPaths = append(importPaths, Import{
236                         Name: string(g.packageNames[GoImportPath(importPath)]),
237                         Path: string(importPath),
238                 })
239         }
240         for importPath := range g.manualImports {
241                 if _, ok := g.packageNames[importPath]; ok {
242                         continue
243                 }
244                 importPaths = append(importPaths, Import{
245                         Name: "_",
246                         Path: string(importPath),
247                 })
248         }
249         // Sort imports by import path
250         sort.Slice(importPaths, func(i, j int) bool {
251                 return importPaths[i].Path < importPaths[j].Path
252         })
253         // Inject new import block into parsed AST
254         if len(importPaths) > 0 {
255                 // Find import block position
256                 pos := file.Package
257                 tokFile := fset.File(file.Package)
258                 pkgLine := tokFile.Line(file.Package)
259                 for _, c := range file.Comments {
260                         if tokFile.Line(c.Pos()) > pkgLine {
261                                 break
262                         }
263                         pos = c.End()
264                 }
265                 // Prepare the import block
266                 impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
267                 for _, importPath := range importPaths {
268                         var name *ast.Ident
269                         if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
270                                 name = &ast.Ident{Name: importPath.Name, NamePos: pos}
271                         }
272                         impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
273                                 Name:   name,
274                                 Path:   &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath.Path), ValuePos: pos},
275                                 EndPos: pos,
276                         })
277                 }
278
279                 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
280         }
281         // Reformat source code
282         var out bytes.Buffer
283         cfg := &printer.Config{
284                 Mode:     printer.TabIndent | printer.UseSpaces,
285                 Tabwidth: 8,
286         }
287         if err = cfg.Fprint(&out, fset, file); err != nil {
288                 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
289         }
290         return out.Bytes(), nil
291 }
292
293 func writeSourceTo(outputFile string, b []byte) error {
294         // create output directory
295         packageDir := filepath.Dir(outputFile)
296         if err := os.MkdirAll(packageDir, 0775); err != nil {
297                 return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
298         }
299
300         // write generated code to output file
301         if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
302                 return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
303         }
304
305         lines := bytes.Count(b, []byte("\n"))
306         logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
307
308         return nil
309 }