1 // Copyright (c) 2020 Cisco and/or its affiliates.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
33 "github.com/sirupsen/logrus"
35 "git.fd.io/govpp.git/binapigen/vppapi"
38 type Generator struct {
40 FilesByName map[string]*File
41 FilesByPath map[string]*File
44 apifiles []*vppapi.File
50 enumsByName map[string]*Enum
51 aliasesByName map[string]*Alias
52 structsByName map[string]*Struct
53 unionsByName map[string]*Union
54 messagesByName map[string]*Message
57 func New(opts Options, apiFiles []*vppapi.File, filesToGen []string) (*Generator, error) {
59 FilesByName: make(map[string]*File),
60 FilesByPath: make(map[string]*File),
63 filesToGen: filesToGen,
64 enumsByName: map[string]*Enum{},
65 aliasesByName: map[string]*Alias{},
66 structsByName: map[string]*Struct{},
67 unionsByName: map[string]*Union{},
68 messagesByName: map[string]*Message{},
71 // Normalize API files
72 SortFilesByImports(gen.apifiles)
73 for _, apiFile := range apiFiles {
74 RemoveImportedTypes(gen.apifiles, apiFile)
75 SortFileObjectsByName(apiFile)
78 // prepare package names and import paths
79 packageNames := make(map[string]GoPackageName)
80 importPaths := make(map[string]GoImportPath)
81 for _, apifile := range gen.apifiles {
82 filename := getFilename(apifile)
83 packageNames[filename] = cleanPackageName(apifile.Name)
84 importPaths[filename] = GoImportPath(path.Join(gen.opts.ImportPrefix, baseName(apifile.Name)))
87 logrus.Debugf("adding %d VPP API files to generator", len(gen.apifiles))
89 for _, apifile := range gen.apifiles {
90 if _, ok := gen.FilesByName[apifile.Name]; ok {
91 return nil, fmt.Errorf("duplicate file: %q", apifile.Name)
94 filename := getFilename(apifile)
95 file, err := newFile(gen, apifile, packageNames[filename], importPaths[filename])
97 return nil, fmt.Errorf("loading file %s failed: %w", apifile.Name, err)
99 gen.Files = append(gen.Files, file)
100 gen.FilesByName[apifile.Name] = file
101 gen.FilesByPath[apifile.Path] = file
103 logrus.Debugf("added file %q (path: %v)", apifile.Name, apifile.Path)
106 // mark files for generation
107 if len(gen.filesToGen) > 0 {
108 logrus.Debugf("Checking %d files to generate: %v", len(gen.filesToGen), gen.filesToGen)
109 for _, genFile := range gen.filesToGen {
110 markGen := func(file *File) {
112 // generate all imported files
113 for _, impFile := range file.importedFiles(gen) {
114 impFile.Generate = true
117 if file, ok := gen.FilesByName[genFile]; ok {
121 logrus.Debugf("File %s was not found by name", genFile)
122 if file, ok := gen.FilesByPath[genFile]; ok {
126 return nil, fmt.Errorf("no API file found for: %v", genFile)
129 logrus.Debugf("Files to generate not specified, marking all %d files for generate", len(gen.Files))
130 for _, file := range gen.Files {
138 func getFilename(file *vppapi.File) string {
145 func (g *Generator) Generate() error {
146 if len(g.genfiles) == 0 {
147 return fmt.Errorf("no files to generate")
150 logrus.Infof("Generating %d files", len(g.genfiles))
152 for _, genfile := range g.genfiles {
153 content, err := genfile.Content()
157 if err := writeSourceTo(genfile.filename, content); err != nil {
158 return fmt.Errorf("writing source package %s failed: %v", genfile.filename, err)
164 type GenFile struct {
168 goImportPath GoImportPath
170 manualImports map[GoImportPath]bool
171 packageNames map[GoImportPath]GoPackageName
174 // NewGenFile creates new generated file with
175 func (g *Generator) NewGenFile(filename string, importPath GoImportPath) *GenFile {
179 goImportPath: importPath,
180 manualImports: make(map[GoImportPath]bool),
181 packageNames: make(map[GoImportPath]GoPackageName),
183 g.genfiles = append(g.genfiles, f)
187 func (g *GenFile) Write(p []byte) (n int, err error) {
188 return g.buf.Write(p)
191 func (g *GenFile) Import(importPath GoImportPath) {
192 g.manualImports[importPath] = true
195 func (g *GenFile) GoIdent(ident GoIdent) string {
196 if ident.GoImportPath == g.goImportPath {
199 if packageName, ok := g.packageNames[ident.GoImportPath]; ok {
200 return string(packageName) + "." + ident.GoName
202 packageName := cleanPackageName(baseName(string(ident.GoImportPath)))
203 g.packageNames[ident.GoImportPath] = packageName
204 return string(packageName) + "." + ident.GoName
207 func (g *GenFile) P(v ...interface{}) {
208 for _, x := range v {
209 switch x := x.(type) {
211 fmt.Fprint(&g.buf, g.GoIdent(x))
213 fmt.Fprint(&g.buf, x)
219 func (g *GenFile) Content() ([]byte, error) {
220 if !strings.HasSuffix(g.filename, ".go") {
221 return g.buf.Bytes(), nil
223 return g.injectImports(g.buf.Bytes())
226 func getImportClass(importPath string) int {
227 if !strings.Contains(importPath, ".") {
230 return 1 /* External */
233 // injectImports parses source, injects import block declaration with all imports and return formatted
234 func (g *GenFile) injectImports(original []byte) ([]byte, error) {
236 fset := token.NewFileSet()
237 file, err := parser.ParseFile(fset, "", original, parser.ParseComments)
240 s := bufio.NewScanner(bytes.NewReader(original))
241 for line := 1; s.Scan(); line++ {
242 fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
244 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
250 // Prepare list of all imports
251 var importPaths []Import
252 for importPath := range g.packageNames {
253 importPaths = append(importPaths, Import{
254 Name: string(g.packageNames[importPath]),
255 Path: string(importPath),
258 for importPath := range g.manualImports {
259 if _, ok := g.packageNames[importPath]; ok {
262 importPaths = append(importPaths, Import{
264 Path: string(importPath),
267 // Sort imports by import path
268 sort.Slice(importPaths, func(i, j int) bool {
269 ci := getImportClass(importPaths[i].Path)
270 cj := getImportClass(importPaths[j].Path)
272 return importPaths[i].Path < importPaths[j].Path
276 // Inject new import block into parsed AST
277 if len(importPaths) > 0 {
278 // Find import block position
280 tokFile := fset.File(file.Package)
281 pkgLine := tokFile.Line(file.Package)
282 for _, c := range file.Comments {
283 if tokFile.Line(c.Pos()) > pkgLine {
288 // Prepare the import block
289 impDecl := &ast.GenDecl{Tok: token.IMPORT, TokPos: pos, Lparen: pos, Rparen: pos}
290 for i, importPath := range importPaths {
292 if importPath.Name == "_" || strings.Contains(importPath.Path, ".") {
293 name = &ast.Ident{Name: importPath.Name, NamePos: pos}
295 value := strconv.Quote(importPath.Path)
296 if i < len(importPaths)-1 {
297 if getImportClass(importPath.Path) != getImportClass(importPaths[i+1].Path) {
301 impDecl.Specs = append(impDecl.Specs, &ast.ImportSpec{
303 Path: &ast.BasicLit{Kind: token.STRING, Value: value, ValuePos: pos},
308 file.Decls = append([]ast.Decl{impDecl}, file.Decls...)
310 // Reformat source code
312 cfg := &printer.Config{
313 Mode: printer.TabIndent | printer.UseSpaces,
316 if err = cfg.Fprint(&out, fset, file); err != nil {
317 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
319 return out.Bytes(), nil
322 func writeSourceTo(outputFile string, b []byte) error {
323 // create output directory
324 packageDir := filepath.Dir(outputFile)
325 if err := os.MkdirAll(packageDir, 0775); err != nil {
326 return fmt.Errorf("creating output dir %s failed: %v", packageDir, err)
329 // write generated code to output file
330 if err := ioutil.WriteFile(outputFile, b, 0666); err != nil {
331 return fmt.Errorf("writing to output file %s failed: %v", outputFile, err)
334 lines := bytes.Count(b, []byte("\n"))
335 logf("wrote %d lines (%d bytes) to: %q", lines, len(b), outputFile)