48c3a41b1bd66605793ae657f2d85b50f34cbc46
[govpp.git] / cmd / binapi-generator / generate.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "path/filepath"
23         "strings"
24         "unicode"
25 )
26
27 const (
28         govppApiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API package
29         inputFileExt       = ".api.json"               // file extension of the VPP binary API files
30         outputFileExt      = ".ba.go"                  // file extension of the Go generated files
31 )
32
33 // context is a structure storing data for code generation
34 type context struct {
35         inputFile  string // input file with VPP API in JSON
36         outputFile string // output file with generated Go package
37
38         inputData []byte // contents of the input file
39
40         moduleName  string // name of the source VPP module
41         packageName string // name of the Go package being generated
42
43         packageData *Package // parsed package data
44 }
45
46 // getContext returns context details of the code generation task
47 func getContext(inputFile, outputDir string) (*context, error) {
48         if !strings.HasSuffix(inputFile, inputFileExt) {
49                 return nil, fmt.Errorf("invalid input file name: %q", inputFile)
50         }
51
52         ctx := &context{
53                 inputFile: inputFile,
54         }
55
56         // package name
57         inputFileName := filepath.Base(inputFile)
58         ctx.moduleName = inputFileName[:strings.Index(inputFileName, ".")]
59
60         // alter package names for modules that are reserved keywords in Go
61         switch ctx.moduleName {
62         case "interface":
63                 ctx.packageName = "interfaces"
64         case "map":
65                 ctx.packageName = "maps"
66         default:
67                 ctx.packageName = ctx.moduleName
68         }
69
70         // output file
71         packageDir := filepath.Join(outputDir, ctx.packageName)
72         outputFileName := ctx.packageName + outputFileExt
73         ctx.outputFile = filepath.Join(packageDir, outputFileName)
74
75         return ctx, nil
76 }
77
78 // generatePackage generates code for the parsed package data and writes it into w
79 func generatePackage(ctx *context, w *bufio.Writer) error {
80         logf("generating package %q", ctx.packageName)
81
82         // generate file header
83         generateHeader(ctx, w)
84         generateImports(ctx, w)
85
86         if *includeAPIVer {
87                 const APIVerConstName = "VlAPIVersion"
88                 fmt.Fprintf(w, "// %s represents version of the binary API module.\n", APIVerConstName)
89                 fmt.Fprintf(w, "const %s = %v\n", APIVerConstName, ctx.packageData.APIVersion)
90                 fmt.Fprintln(w)
91         }
92
93         // generate services
94         if len(ctx.packageData.Services) > 0 {
95                 generateServices(ctx, w, ctx.packageData.Services)
96
97                 // TODO: generate implementation for Services interface
98         }
99
100         // generate enums
101         if len(ctx.packageData.Enums) > 0 {
102                 fmt.Fprintf(w, "/* Enums */\n\n")
103
104                 for _, enum := range ctx.packageData.Enums {
105                         generateEnum(ctx, w, &enum)
106                 }
107         }
108
109         // generate aliases
110         if len(ctx.packageData.Aliases) > 0 {
111                 fmt.Fprintf(w, "/* Aliases */\n\n")
112
113                 for _, alias := range ctx.packageData.Aliases {
114                         generateAlias(ctx, w, &alias)
115                 }
116         }
117
118         // generate types
119         if len(ctx.packageData.Types) > 0 {
120                 fmt.Fprintf(w, "/* Types */\n\n")
121
122                 for _, typ := range ctx.packageData.Types {
123                         generateType(ctx, w, &typ)
124                 }
125         }
126
127         // generate unions
128         if len(ctx.packageData.Unions) > 0 {
129                 fmt.Fprintf(w, "/* Unions */\n\n")
130
131                 for _, union := range ctx.packageData.Unions {
132                         generateUnion(ctx, w, &union)
133                 }
134         }
135
136         // generate messages
137         if len(ctx.packageData.Messages) > 0 {
138                 fmt.Fprintf(w, "/* Messages */\n\n")
139
140                 for _, msg := range ctx.packageData.Messages {
141                         generateMessage(ctx, w, &msg)
142                 }
143
144                 // generate message registrations
145                 fmt.Fprintln(w, "func init() {")
146                 for _, msg := range ctx.packageData.Messages {
147                         name := camelCaseName(msg.Name)
148                         fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", name, ctx.moduleName+"."+name)
149                 }
150                 fmt.Fprintln(w, "}")
151         }
152
153         // flush the data:
154         if err := w.Flush(); err != nil {
155                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
156         }
157
158         return nil
159 }
160
161 // generateHeader writes generated package header into w
162 func generateHeader(ctx *context, w io.Writer) {
163         fmt.Fprintln(w, "// Code generated by GoVPP binapi-generator. DO NOT EDIT.")
164         fmt.Fprintf(w, "//  source: %s\n", ctx.inputFile)
165         fmt.Fprintln(w)
166
167         fmt.Fprintln(w, "/*")
168         fmt.Fprintf(w, " Package %s is a generated from VPP binary API module '%s'.\n", ctx.packageName, ctx.moduleName)
169         fmt.Fprintln(w)
170         fmt.Fprintln(w, " It contains following objects:")
171         var printObjNum = func(obj string, num int) {
172                 if num > 0 {
173                         if num > 1 {
174                                 if strings.HasSuffix(obj, "s") {
175                                         obj += "es"
176                                 } else {
177                                         obj += "s"
178                                 }
179                         }
180                         fmt.Fprintf(w, "\t%3d %s\n", num, obj)
181                 }
182         }
183
184         printObjNum("service", len(ctx.packageData.Services))
185         printObjNum("enum", len(ctx.packageData.Enums))
186         printObjNum("alias", len(ctx.packageData.Aliases))
187         printObjNum("type", len(ctx.packageData.Types))
188         printObjNum("union", len(ctx.packageData.Unions))
189         printObjNum("message", len(ctx.packageData.Messages))
190         fmt.Fprintln(w, "*/")
191         fmt.Fprintf(w, "package %s\n", ctx.packageName)
192         fmt.Fprintln(w)
193 }
194
195 // generateImports writes generated package imports into w
196 func generateImports(ctx *context, w io.Writer) {
197         fmt.Fprintf(w, "import \"%s\"\n", govppApiImportPath)
198         fmt.Fprintf(w, "import \"%s\"\n", "github.com/lunixbochs/struc")
199         fmt.Fprintf(w, "import \"%s\"\n", "bytes")
200         fmt.Fprintln(w)
201
202         fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n")
203         fmt.Fprintf(w, "var _ = api.RegisterMessage\n")
204         fmt.Fprintf(w, "var _ = struc.Pack\n")
205         fmt.Fprintf(w, "var _ = bytes.NewBuffer\n")
206         fmt.Fprintln(w)
207 }
208
209 // generateComment writes generated comment for the object into w
210 func generateComment(ctx *context, w io.Writer, goName string, vppName string, objKind string) {
211         if objKind == "service" {
212                 fmt.Fprintf(w, "// %s represents VPP binary API services:\n", goName)
213         } else {
214                 fmt.Fprintf(w, "// %s represents VPP binary API %s '%s':\n", goName, objKind, vppName)
215         }
216
217         var isNotSpace = func(r rune) bool {
218                 return !unicode.IsSpace(r)
219         }
220
221         // print out the source of the generated object
222         mapType := false
223         objFound := false
224         objTitle := fmt.Sprintf(`"%s",`, vppName)
225         switch objKind {
226         case "alias", "service":
227                 objTitle = fmt.Sprintf(`"%s": {`, vppName)
228                 mapType = true
229         }
230
231         inputBuff := bytes.NewBuffer(ctx.inputData)
232         inputLine := 0
233
234         var trimIndent string
235         var indent int
236         for {
237                 line, err := inputBuff.ReadString('\n')
238                 if err != nil {
239                         break
240                 }
241                 inputLine++
242
243                 noSpaceAt := strings.IndexFunc(line, isNotSpace)
244                 if !objFound {
245                         indent = strings.Index(line, objTitle)
246                         if indent == -1 {
247                                 continue
248                         }
249                         trimIndent = line[:indent]
250                         // If no other non-whitespace character then we are at the message header.
251                         if trimmed := strings.TrimSpace(line); trimmed == objTitle {
252                                 objFound = true
253                                 fmt.Fprintln(w, "//")
254                         }
255                 } else if noSpaceAt < indent {
256                         break // end of the definition in JSON for array types
257                 } else if objFound && mapType && noSpaceAt <= indent {
258                         fmt.Fprintf(w, "//\t%s", strings.TrimPrefix(line, trimIndent))
259                         break // end of the definition in JSON for map types (aliases, services)
260                 }
261                 fmt.Fprintf(w, "//\t%s", strings.TrimPrefix(line, trimIndent))
262         }
263
264         fmt.Fprintln(w, "//")
265 }
266
267 // generateServices writes generated code for the Services interface into w
268 func generateServices(ctx *context, w *bufio.Writer, services []Service) {
269         // generate services comment
270         generateComment(ctx, w, "Services", "services", "service")
271
272         // generate interface
273         fmt.Fprintf(w, "type %s interface {\n", "Services")
274         for _, svc := range ctx.packageData.Services {
275                 generateService(ctx, w, &svc)
276         }
277         fmt.Fprintln(w, "}")
278
279         fmt.Fprintln(w)
280 }
281
282 // generateService writes generated code for the service into w
283 func generateService(ctx *context, w io.Writer, svc *Service) {
284         reqTyp := camelCaseName(svc.RequestType)
285
286         // method name is same as parameter type name by default
287         method := svc.MethodName()
288         params := fmt.Sprintf("*%s", reqTyp)
289         returns := "error"
290         if replyType := camelCaseName(svc.ReplyType); replyType != "" {
291                 repTyp := fmt.Sprintf("*%s", replyType)
292                 if svc.Stream {
293                         repTyp = fmt.Sprintf("[]%s", repTyp)
294                 }
295                 returns = fmt.Sprintf("(%s, error)", repTyp)
296         }
297
298         fmt.Fprintf(w, "\t%s(%s) %s\n", method, params, returns)
299 }
300
301 // generateEnum writes generated code for the enum into w
302 func generateEnum(ctx *context, w io.Writer, enum *Enum) {
303         name := camelCaseName(enum.Name)
304         typ := binapiTypes[enum.Type]
305
306         logf(" writing enum %q (%s) with %d entries", enum.Name, name, len(enum.Entries))
307
308         // generate enum comment
309         generateComment(ctx, w, name, enum.Name, "enum")
310
311         // generate enum definition
312         fmt.Fprintf(w, "type %s %s\n", name, typ)
313         fmt.Fprintln(w)
314
315         fmt.Fprintln(w, "const (")
316
317         // generate enum entries
318         for _, entry := range enum.Entries {
319                 fmt.Fprintf(w, "\t%s %s = %v\n", entry.Name, name, entry.Value)
320         }
321
322         fmt.Fprintln(w, ")")
323
324         fmt.Fprintln(w)
325 }
326
327 // generateAlias writes generated code for the alias into w
328 func generateAlias(ctx *context, w io.Writer, alias *Alias) {
329         name := camelCaseName(alias.Name)
330
331         logf(" writing type %q (%s), length: %d", alias.Name, name, alias.Length)
332
333         // generate struct comment
334         generateComment(ctx, w, name, alias.Name, "alias")
335
336         // generate struct definition
337         fmt.Fprintf(w, "type %s ", name)
338
339         if alias.Length > 0 {
340                 fmt.Fprintf(w, "[%d]", alias.Length)
341         }
342
343         dataType := convertToGoType(ctx, alias.Type)
344         fmt.Fprintf(w, "%s\n", dataType)
345
346         fmt.Fprintln(w)
347 }
348
349 // generateUnion writes generated code for the union into w
350 func generateUnion(ctx *context, w io.Writer, union *Union) {
351         name := camelCaseName(union.Name)
352
353         logf(" writing union %q (%s) with %d fields", union.Name, name, len(union.Fields))
354
355         // generate struct comment
356         generateComment(ctx, w, name, union.Name, "union")
357
358         // generate struct definition
359         fmt.Fprintln(w, "type", name, "struct {")
360
361         // maximum size for union
362         maxSize := getUnionSize(ctx, union)
363
364         // generate data field
365         fieldName := "Union_data"
366         fmt.Fprintf(w, "\t%s [%d]byte\n", fieldName, maxSize)
367
368         // generate end of the struct
369         fmt.Fprintln(w, "}")
370
371         // generate name getter
372         generateTypeNameGetter(w, name, union.Name)
373
374         // generate CRC getter
375         generateCrcGetter(w, name, union.CRC)
376
377         // generate getters for fields
378         for _, field := range union.Fields {
379                 fieldName := camelCaseName(field.Name)
380                 fieldType := convertToGoType(ctx, field.Type)
381                 generateUnionGetterSetter(w, name, fieldName, fieldType)
382         }
383
384         // generate union methods
385         //generateUnionMethods(w, name)
386
387         fmt.Fprintln(w)
388 }
389
390 // generateUnionMethods generates methods that implement struc.Custom
391 // interface to allow having Union_data field unexported
392 // TODO: do more testing when unions are actually used in some messages
393 func generateUnionMethods(w io.Writer, structName string) {
394         // generate struc.Custom implementation for union
395         fmt.Fprintf(w, `
396 func (u *%[1]s) Pack(p []byte, opt *struc.Options) (int, error) {
397         var b = new(bytes.Buffer)
398         if err := struc.PackWithOptions(b, u.union_data, opt); err != nil {
399                 return 0, err
400         }
401         copy(p, b.Bytes())
402         return b.Len(), nil
403 }
404 func (u *%[1]s) Unpack(r io.Reader, length int, opt *struc.Options) error {
405         return struc.UnpackWithOptions(r, u.union_data[:], opt)
406 }
407 func (u *%[1]s) Size(opt *struc.Options) int {
408         return len(u.union_data)
409 }
410 func (u *%[1]s) String() string {
411         return string(u.union_data[:])
412 }
413 `, structName)
414 }
415
416 func generateUnionGetterSetter(w io.Writer, structName string, getterField, getterStruct string) {
417         fmt.Fprintf(w, `
418 func %[1]s%[2]s(a %[3]s) (u %[1]s) {
419         u.Set%[2]s(a)
420         return
421 }
422 func (u *%[1]s) Set%[2]s(a %[3]s) {
423         var b = new(bytes.Buffer)
424         if err := struc.Pack(b, &a); err != nil {
425                 return
426         }
427         copy(u.Union_data[:], b.Bytes())
428 }
429 func (u *%[1]s) Get%[2]s() (a %[3]s) {
430         var b = bytes.NewReader(u.Union_data[:])
431         struc.Unpack(b, &a)
432         return
433 }
434 `, structName, getterField, getterStruct)
435 }
436
437 // generateType writes generated code for the type into w
438 func generateType(ctx *context, w io.Writer, typ *Type) {
439         name := camelCaseName(typ.Name)
440
441         logf(" writing type %q (%s) with %d fields", typ.Name, name, len(typ.Fields))
442
443         // generate struct comment
444         generateComment(ctx, w, name, typ.Name, "type")
445
446         // generate struct definition
447         fmt.Fprintf(w, "type %s struct {\n", name)
448
449         // generate struct fields
450         for i, field := range typ.Fields {
451                 // skip internal fields
452                 switch strings.ToLower(field.Name) {
453                 case "crc", "_vl_msg_id":
454                         continue
455                 }
456
457                 generateField(ctx, w, typ.Fields, i)
458         }
459
460         // generate end of the struct
461         fmt.Fprintln(w, "}")
462
463         // generate name getter
464         generateTypeNameGetter(w, name, typ.Name)
465
466         // generate CRC getter
467         generateCrcGetter(w, name, typ.CRC)
468
469         fmt.Fprintln(w)
470 }
471
472 // generateMessage writes generated code for the message into w
473 func generateMessage(ctx *context, w io.Writer, msg *Message) {
474         name := camelCaseName(msg.Name)
475
476         logf(" writing message %q (%s) with %d fields", msg.Name, name, len(msg.Fields))
477
478         // generate struct comment
479         generateComment(ctx, w, name, msg.Name, "message")
480
481         // generate struct definition
482         fmt.Fprintf(w, "type %s struct {", name)
483
484         msgType := otherMessage
485         wasClientIndex := false
486
487         // generate struct fields
488         n := 0
489         for i, field := range msg.Fields {
490                 if i == 1 {
491                         if field.Name == "client_index" {
492                                 // "client_index" as the second member,
493                                 // this might be an event message or a request
494                                 msgType = eventMessage
495                                 wasClientIndex = true
496                         } else if field.Name == "context" {
497                                 // reply needs "context" as the second member
498                                 msgType = replyMessage
499                         }
500                 } else if i == 2 {
501                         if wasClientIndex && field.Name == "context" {
502                                 // request needs "client_index" as the second member
503                                 // and "context" as the third member
504                                 msgType = requestMessage
505                         }
506                 }
507
508                 // skip internal fields
509                 switch strings.ToLower(field.Name) {
510                 case "crc", "_vl_msg_id":
511                         continue
512                 case "client_index", "context":
513                         if n == 0 {
514                                 continue
515                         }
516                 }
517                 n++
518                 if n == 1 {
519                         fmt.Fprintln(w)
520                 }
521
522                 generateField(ctx, w, msg.Fields, i)
523         }
524
525         // generate end of the struct
526         fmt.Fprintln(w, "}")
527
528         // generate name getter
529         generateMessageNameGetter(w, name, msg.Name)
530
531         // generate CRC getter
532         generateCrcGetter(w, name, msg.CRC)
533
534         // generate message type getter method
535         generateMessageTypeGetter(w, name, msgType)
536
537         fmt.Fprintln(w)
538 }
539
540 // generateField writes generated code for the field into w
541 func generateField(ctx *context, w io.Writer, fields []Field, i int) {
542         field := fields[i]
543
544         fieldName := strings.TrimPrefix(field.Name, "_")
545         fieldName = camelCaseName(fieldName)
546
547         // generate length field for strings
548         if field.Type == "string" {
549                 fmt.Fprintf(w, "\tXXX_%sLen uint32 `struc:\"sizeof=%s\"`\n", fieldName, fieldName)
550         }
551
552         dataType := convertToGoType(ctx, field.Type)
553
554         fieldType := dataType
555         if field.IsArray() {
556                 if dataType == "uint8" {
557                         dataType = "byte"
558                 }
559                 fieldType = "[]" + dataType
560         }
561         fmt.Fprintf(w, "\t%s %s", fieldName, fieldType)
562
563         if field.Length > 0 {
564                 // fixed size array
565                 fmt.Fprintf(w, "\t`struc:\"[%d]%s\"`", field.Length, dataType)
566         } else {
567                 for _, f := range fields {
568                         if f.SizeFrom == field.Name {
569                                 // variable sized array
570                                 sizeOfName := camelCaseName(f.Name)
571                                 fmt.Fprintf(w, "\t`struc:\"sizeof=%s\"`", sizeOfName)
572                         }
573                 }
574         }
575
576         fmt.Fprintln(w)
577 }
578
579 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
580 func generateMessageNameGetter(w io.Writer, structName, msgName string) {
581         fmt.Fprintf(w, `func (*%s) GetMessageName() string {
582         return %q
583 }
584 `, structName, msgName)
585 }
586
587 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
588 func generateTypeNameGetter(w io.Writer, structName, msgName string) {
589         fmt.Fprintf(w, `func (*%s) GetTypeName() string {
590         return %q
591 }
592 `, structName, msgName)
593 }
594
595 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
596 func generateCrcGetter(w io.Writer, structName, crc string) {
597         crc = strings.TrimPrefix(crc, "0x")
598         fmt.Fprintf(w, `func (*%s) GetCrcString() string {
599         return %q
600 }
601 `, structName, crc)
602 }
603
604 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
605 func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageType) {
606         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
607         if msgType == requestMessage {
608                 fmt.Fprintln(w, "\treturn api.RequestMessage")
609         } else if msgType == replyMessage {
610                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
611         } else if msgType == eventMessage {
612                 fmt.Fprintln(w, "\treturn api.EventMessage")
613         } else {
614                 fmt.Fprintln(w, "\treturn api.OtherMessage")
615         }
616         fmt.Fprintln(w, "}")
617 }