a6f161594b685a9836caae4182380d8f014dcd22
[govpp.git] / binapigen / generator.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "go/ast"
22         "go/parser"
23         "go/printer"
24         "go/token"
25         "io/ioutil"
26         "os"
27         "path"
28         "path/filepath"
29         "sort"
30         "strconv"
31         "strings"
32
33         "github.com/sirupsen/logrus"
34
35         "git.fd.io/govpp.git/binapigen/vppapi"
36 )
37
38 type Generator struct {
39         Files       []*File
40         FilesByName map[string]*File
41         FilesByPath map[string]*File
42
43         opts       Options
44         apifiles   []*vppapi.File
45         vppVersion string
46
47         filesToGen []string
48         genfiles   []*GenFile
49
50         enumsByName    map[string]*Enum
51         aliasesByName  map[string]*Alias
52         structsByName  map[string]*Struct
53         unionsByName   map[string]*Union
54         messagesByName map[string]*Message
55 }
56
57 func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
58         gen := &Generator{
59                 FilesByName:    make(map[string]*File),
60                 FilesByPath:    make(map[string]*File),
61                 opts:           opts,
62                 apifiles:       apiFiles,
63                 filesToGen:     filesToGen,
64                 enumsByName:    map[string]*Enum{},
65                 aliasesByName:  map[string]*Alias{},
66                 structsByName:  map[string]*Struct{},
67                 unionsByName:   map[string]*Union{},
68                 messagesByName: map[string]*Message{},
69         }
70
71         // Normalize API files
72         SortFilesByImports(gen.apifiles)
73         for _, apiFile := range apiFiles {
74                 RemoveImportedTypes(gen.apifiles, apiFile)
75                 SortFileObjectsByName(apiFile)
76         }
77
78         // prepare package names and import paths
79         packageNames := make(map[string]GoPackageName)
80         importPaths := make(map[string]GoImportPath)
81         for _, apifile := range gen.apifiles {
82                 filename := getFilename(apifile)
83                 packageNames[filename] = cleanPackageName(apifile.Name)
84                 importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
85         }
86
87         logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
88
89         for _, apifile := range gen.apifiles {
90                 if _, ok := gen.FilesByName[apifile.Name]; ok {
91                         return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
92                 }
93
94                 filename := getFilename(apifile)
95                 file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
96                 if err != nil {
97                         return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
98                 }
99                 gen.Files = append(gen.Files, file)
100                 gen.FilesByName[apifile.Name] = file
101                 gen.FilesByPath[apifile.Path] = file
102
103                 logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
104         }
105
106         // mark files for generation
107         if len(gen.filesToGen) > 0 {
108                 logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
109                 for _, genFile := range gen.filesToGen {
110                         markGen := func(file *File) {
111                                 file.Generate = true
112                                 // generate all imported files
113                                 for _, impFile := range file.importedFiles(gen) {
114                                         impFile.Generate = true
115                                 }
116                         }
117                         if file, ok := gen.FilesByName[genFile]; ok {
118                                 markGen(file)
119                                 continue
120                         }
121                         logrus.Debugf("File %s was not found by name", genFile)
122                         if file, ok := gen.FilesByPath[genFile]; ok {
123                                 markGen(file)
124                                 continue
125                         }
126                         return nil, fmt.Errorf("no API file found for: %v", genFile)
127                 }
128         } else {
129                 logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
130                 for _, file := range gen.Files {
131                         file.Generate = true
132                 }
133         }
134
135         return gen, nil
136 }
137
138 func getFilename(file *vppapi.File) string {
139         if file.Path == "" {
140                 return file.Name
141         }
142         return file.Path
143 }
144
145 func (g *Generator) Generate() error {
146         if len(g.genfiles) == 0 {
147                 return fmt.Errorf("no files to generate")
148         }
149
150         logrus.Infof("Generating %d files", len(g.genfiles))
151
152         for _, genfile := range g.genfiles {
153                 content, err := genfile.Content()
154                 if err != nil {
155                         return err
156                 }
157                 if err := writeSourceTo(genfile.filename, content); err != nil {
158                         return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
159                 }
160         }
161         return nil
162 }
163
164 type GenFile struct {
165         gen           *Generator
166         file          *File
167         filename      string
168         goImportPath  GoImportPath
169         buf           bytes.Buffer
170         manualImports map[GoImportPath]bool
171         packageNames  map[GoImportPath]GoPackageName
172 }
173
174 // NewGenFile creates new generated file with
175 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
176         f := &GenFile{
177                 gen:           g,
178                 filename:      filename,
179                 goImportPath:  importPath,
180                 manualImports: make(map[GoImportPath]bool),
181                 packageNames:  make(map[GoImportPath]GoPackageName),
182         }
183         g.genfiles = append(g.genfiles, f)
184         return f
185 }
186
187 func (g *GenFile) Write(p []byte) (n int, err error) {
188         return g.buf.Write(p)
189 }
190
191 func (g *GenFile) Import(importPath GoImportPath) {
192         g.manualImports[importPath] = true
193 }
194
195 func (g *GenFile) GoIdent(ident GoIdent) string {
196         if ident.GoImportPath == g.goImportPath {
197                 return ident.GoName
198         }
199         if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
200                 return string(packageName) + "." + ident.GoName
201         }
202         packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
203         g.packageNames[ident.GoImportPath] = packageName
204         return string(packageName) + "." + ident.GoName
205 }
206
207 func (g *GenFile) P(v ...interface{}) {
208         for _, x := range v {
209                 switch x := x.(type) {
210                 case GoIdent:
211                         fmt.Fprint(&g.buf, g.GoIdent(x))
212                 default:
213                         fmt.Fprint(&g.buf, x)
214                 }
215         }
216         fmt.Fprintln(&g.buf)
217 }
218
219 func (g *GenFile) Content() ([]byte, error) {
220         if !strings.HasSuffix(g.filename, ".go") {
221                 return g.buf.Bytes(), nil
222         }
223         return g.injectImports(g.buf.Bytes())
224 }
225
226 // injectImports parses source, injects import block declaration with all imports and return formatted
227 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
228         // Parse source code
229         fset := token.NewFileSet()
230         file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
231         if err != nil {
232                 var src bytes.Buffer
233                 s := bufio.NewScanner(bytes.NewReader(original))
234                 for line := 1; s.Scan(); line++ {
235                         fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
236                 }
237                 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
238         }
239         type Import struct {
240                 Name string
241                 Path string
242         }
243         // Prepare list of all imports
244         var importPaths []Import
245         for importPath := range g.packageNames {
246                 importPaths = append(importPaths, Import{
247                         Name: string(g.packageNames[importPath]),
248                         Path: string(importPath),
249                 })
250         }
251         for importPath := range g.manualImports {
252                 if _, ok := g.packageNames[importPath]; ok {
253                         continue
254                 }
255                 importPaths = append(importPaths, Import{
256                         Name: "_",
257                         Path: string(importPath),
258                 })
259         }
260         // Sort imports by import path
261         sort.Slice(importPaths, func(i, j int) bool {
262                 return importPaths[i].Path < importPaths[j].Path
263         })
264         // Inject new import block into parsed AST
265         if len(importPaths) > 0 {
266                 // Find import block position
267                 pos := file.Package
268                 tokFile := fset.File(file.Package)
269                 pkgLine := tokFile.Line(file.Package)
270                 for _, c := range file.Comments {
271                         if tokFile.Line(c.Pos()) > pkgLine {
272                                 break
273                         }
274                         pos = c.End()
275                 }
276                 // Prepare the import block
277                 impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
278                 for _, importPath := range importPaths {
279                         var name *ast.Ident
280                         if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
281                                 name = &ast.Ident{Name: importPath.Name, NamePos: pos}
282                         }
283                         impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
284                                 Name:   name,
285                                 Path:   &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(importPath.Path), ValuePos: pos},
286                                 EndPos: pos,
287                         })
288                 }
289
290                 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
291         }
292         // Reformat source code
293         var out bytes.Buffer
294         cfg := &printer.Config{
295                 Mode:     printer.TabIndent | printer.UseSpaces,
296                 Tabwidth: 8,
297         }
298         if err = cfg.Fprint(&out, fset, file); err != nil {
299                 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
300         }
301         return out.Bytes(), nil
302 }
303
304 func writeSourceTo(outputFile string, b []byte) error {
305         // create output directory
306         packageDir := filepath.Dir(outputFile)
307         if err := os.MkdirAll(packageDir, 0775); err != nil {
308                 return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
309         }
310
311         // write generated code to output file
312         if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
313                 return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
314         }
315
316         lines := bytes.Count(b, []byte("\n"))
317         logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)
318
319         return nil
320 }