Fix union data size for types with enums
[govpp.git] / cmd / binapi-generator / generate.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "path/filepath"
23         "sort"
24         "strings"
25         "unicode"
26 )
27
28 const (
29         inputFileExt  = ".api.json" // file extension of the VPP API files
30         outputFileExt = ".ba.go"    // file extension of the Go generated files
31
32         govppApiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API package
33         constAPIVersionCrc = "APIVersionCrc"           // name for the API version CRC constant
34 )
35
36 // context is a structure storing data for code generation
37 type context struct {
38         inputFile  string // input file with VPP API in JSON
39         outputFile string // output file with generated Go package
40
41         inputData []byte // contents of the input file
42
43         includeAPIVersionCrc bool // include constant with API version CRC string
44         includeComments      bool // include parts of original source in comments
45         includeBinapiNames   bool // include binary API names as struct tag
46
47         moduleName  string // name of the source VPP module
48         packageName string // name of the Go package being generated
49
50         packageData *Package // parsed package data
51 }
52
53 // getContext returns context details of the code generation task
54 func getContext(inputFile, outputDir string) (*context, error) {
55         if !strings.HasSuffix(inputFile, inputFileExt) {
56                 return nil, fmt.Errorf("invalid input file name: %q", inputFile)
57         }
58
59         ctx := &context{
60                 inputFile: inputFile,
61         }
62
63         // package name
64         inputFileName := filepath.Base(inputFile)
65         ctx.moduleName = inputFileName[:strings.Index(inputFileName, ".")]
66
67         // alter package names for modules that are reserved keywords in Go
68         switch ctx.moduleName {
69         case "interface":
70                 ctx.packageName = "interfaces"
71         case "map":
72                 ctx.packageName = "maps"
73         default:
74                 ctx.packageName = ctx.moduleName
75         }
76
77         // output file
78         packageDir := filepath.Join(outputDir, ctx.packageName)
79         outputFileName := ctx.packageName + outputFileExt
80         ctx.outputFile = filepath.Join(packageDir, outputFileName)
81
82         return ctx, nil
83 }
84
85 // generatePackage generates code for the parsed package data and writes it into w
86 func generatePackage(ctx *context, w *bufio.Writer) error {
87         logf("generating package %q", ctx.packageName)
88
89         // generate file header
90         generateHeader(ctx, w)
91         generateImports(ctx, w)
92
93         if ctx.includeAPIVersionCrc {
94                 fmt.Fprintf(w, "// %s defines API version CRC of the VPP binary API module.\n", constAPIVersionCrc)
95                 fmt.Fprintf(w, "const %s = %v\n", constAPIVersionCrc, ctx.packageData.APIVersion)
96                 fmt.Fprintln(w)
97         }
98
99         // generate services
100         if len(ctx.packageData.Services) > 0 {
101                 generateServices(ctx, w, ctx.packageData.Services)
102
103                 // TODO: generate implementation for Services interface
104         }
105
106         // generate enums
107         if len(ctx.packageData.Enums) > 0 {
108                 fmt.Fprintf(w, "/* Enums */\n\n")
109
110                 for _, enum := range ctx.packageData.Enums {
111                         generateEnum(ctx, w, &enum)
112                 }
113         }
114
115         // generate aliases
116         if len(ctx.packageData.Aliases) > 0 {
117                 fmt.Fprintf(w, "/* Aliases */\n\n")
118
119                 for _, alias := range ctx.packageData.Aliases {
120                         generateAlias(ctx, w, &alias)
121                 }
122         }
123
124         // generate types
125         if len(ctx.packageData.Types) > 0 {
126                 fmt.Fprintf(w, "/* Types */\n\n")
127
128                 for _, typ := range ctx.packageData.Types {
129                         generateType(ctx, w, &typ)
130                 }
131         }
132
133         // generate unions
134         if len(ctx.packageData.Unions) > 0 {
135                 fmt.Fprintf(w, "/* Unions */\n\n")
136
137                 for _, union := range ctx.packageData.Unions {
138                         generateUnion(ctx, w, &union)
139                 }
140         }
141
142         // generate messages
143         if len(ctx.packageData.Messages) > 0 {
144                 fmt.Fprintf(w, "/* Messages */\n\n")
145
146                 for _, msg := range ctx.packageData.Messages {
147                         generateMessage(ctx, w, &msg)
148                 }
149
150                 // generate message registrations
151                 fmt.Fprintln(w, "func init() {")
152                 for _, msg := range ctx.packageData.Messages {
153                         name := camelCaseName(msg.Name)
154                         fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", name, ctx.moduleName+"."+name)
155                 }
156                 fmt.Fprintln(w, "}")
157                 fmt.Fprintln(w)
158
159                 fmt.Fprintln(w, "var Messages = []api.Message{")
160                 for _, msg := range ctx.packageData.Messages {
161                         name := camelCaseName(msg.Name)
162                         fmt.Fprintf(w, "\t(*%s)(nil),\n", name)
163                 }
164                 fmt.Fprintln(w, "}")
165         }
166
167         // flush the data:
168         if err := w.Flush(); err != nil {
169                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
170         }
171
172         return nil
173 }
174
175 // generateHeader writes generated package header into w
176 func generateHeader(ctx *context, w io.Writer) {
177         fmt.Fprintln(w, "// Code generated by GoVPP binapi-generator. DO NOT EDIT.")
178         fmt.Fprintf(w, "//  source: %s\n", ctx.inputFile)
179         fmt.Fprintln(w)
180
181         fmt.Fprintln(w, "/*")
182         fmt.Fprintf(w, " Package %s is a generated from VPP binary API module '%s'.\n", ctx.packageName, ctx.moduleName)
183         fmt.Fprintln(w)
184         fmt.Fprintln(w, " It contains following objects:")
185         var printObjNum = func(obj string, num int) {
186                 if num > 0 {
187                         if num > 1 {
188                                 if strings.HasSuffix(obj, "s") {
189                                         obj += "es"
190                                 } else {
191                                         obj += "s"
192                                 }
193                         }
194                         fmt.Fprintf(w, "\t%3d %s\n", num, obj)
195                 }
196         }
197
198         printObjNum("service", len(ctx.packageData.Services))
199         printObjNum("enum", len(ctx.packageData.Enums))
200         printObjNum("alias", len(ctx.packageData.Aliases))
201         printObjNum("type", len(ctx.packageData.Types))
202         printObjNum("union", len(ctx.packageData.Unions))
203         printObjNum("message", len(ctx.packageData.Messages))
204         fmt.Fprintln(w, "*/")
205         fmt.Fprintf(w, "package %s\n", ctx.packageName)
206         fmt.Fprintln(w)
207 }
208
209 // generateImports writes generated package imports into w
210 func generateImports(ctx *context, w io.Writer) {
211         fmt.Fprintf(w, "import api \"%s\"\n", govppApiImportPath)
212         fmt.Fprintf(w, "import struc \"%s\"\n", "github.com/lunixbochs/struc")
213         fmt.Fprintf(w, "import bytes \"%s\"\n", "bytes")
214         fmt.Fprintln(w)
215
216         fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n")
217         fmt.Fprintf(w, "var _ = api.RegisterMessage\n")
218         fmt.Fprintf(w, "var _ = struc.Pack\n")
219         fmt.Fprintf(w, "var _ = bytes.NewBuffer\n")
220         fmt.Fprintln(w)
221
222         /*fmt.Fprintln(w, "// This is a compile-time assertion to ensure that this generated file")
223         fmt.Fprintln(w, "// is compatible with the GoVPP api package it is being compiled against.")
224         fmt.Fprintln(w, "// A compilation error at this line likely means your copy of the")
225         fmt.Fprintln(w, "// GoVPP api package needs to be updated.")
226         fmt.Fprintln(w, "const _ = api.GoVppAPIPackageIsVersion1 // please upgrade the GoVPP api package")
227         fmt.Fprintln(w)*/
228 }
229
230 // generateComment writes generated comment for the object into w
231 func generateComment(ctx *context, w io.Writer, goName string, vppName string, objKind string) {
232         if objKind == "service" {
233                 fmt.Fprintf(w, "// %s represents VPP binary API services:\n", goName)
234         } else {
235                 fmt.Fprintf(w, "// %s represents VPP binary API %s '%s':\n", goName, objKind, vppName)
236         }
237
238         if !ctx.includeComments {
239                 return
240         }
241
242         var isNotSpace = func(r rune) bool {
243                 return !unicode.IsSpace(r)
244         }
245
246         // print out the source of the generated object
247         mapType := false
248         objFound := false
249         objTitle := fmt.Sprintf(`"%s",`, vppName)
250         switch objKind {
251         case "alias", "service":
252                 objTitle = fmt.Sprintf(`"%s": {`, vppName)
253                 mapType = true
254         }
255
256         inputBuff := bytes.NewBuffer(ctx.inputData)
257         inputLine := 0
258
259         var trimIndent string
260         var indent int
261         for {
262                 line, err := inputBuff.ReadString('\n')
263                 if err != nil {
264                         break
265                 }
266                 inputLine++
267
268                 noSpaceAt := strings.IndexFunc(line, isNotSpace)
269                 if !objFound {
270                         indent = strings.Index(line, objTitle)
271                         if indent == -1 {
272                                 continue
273                         }
274                         trimIndent = line[:indent]
275                         // If no other non-whitespace character then we are at the message header.
276                         if trimmed := strings.TrimSpace(line); trimmed == objTitle {
277                                 objFound = true
278                                 fmt.Fprintln(w, "//")
279                         }
280                 } else if noSpaceAt < indent {
281                         break // end of the definition in JSON for array types
282                 } else if objFound && mapType && noSpaceAt <= indent {
283                         fmt.Fprintf(w, "//\t%s", strings.TrimPrefix(line, trimIndent))
284                         break // end of the definition in JSON for map types (aliases, services)
285                 }
286                 fmt.Fprintf(w, "//\t%s", strings.TrimPrefix(line, trimIndent))
287         }
288
289         fmt.Fprintln(w, "//")
290 }
291
292 // generateServices writes generated code for the Services interface into w
293 func generateServices(ctx *context, w *bufio.Writer, services []Service) {
294         // generate services comment
295         generateComment(ctx, w, "Services", "services", "service")
296
297         // generate interface
298         fmt.Fprintf(w, "type %s interface {\n", "Services")
299         for _, svc := range services {
300                 generateService(ctx, w, &svc)
301         }
302         fmt.Fprintln(w, "}")
303
304         fmt.Fprintln(w)
305 }
306
307 // generateService writes generated code for the service into w
308 func generateService(ctx *context, w io.Writer, svc *Service) {
309         reqTyp := camelCaseName(svc.RequestType)
310
311         // method name is same as parameter type name by default
312         method := reqTyp
313         if svc.Stream {
314                 // use Dump as prefix instead of suffix for stream services
315                 if m := strings.TrimSuffix(method, "Dump"); method != m {
316                         method = "Dump" + m
317                 }
318         }
319
320         params := fmt.Sprintf("*%s", reqTyp)
321         returns := "error"
322         if replyType := camelCaseName(svc.ReplyType); replyType != "" {
323                 repTyp := fmt.Sprintf("*%s", replyType)
324                 if svc.Stream {
325                         repTyp = fmt.Sprintf("[]%s", repTyp)
326                 }
327                 returns = fmt.Sprintf("(%s, error)", repTyp)
328         }
329
330         fmt.Fprintf(w, "\t%s(%s) %s\n", method, params, returns)
331 }
332
333 // generateEnum writes generated code for the enum into w
334 func generateEnum(ctx *context, w io.Writer, enum *Enum) {
335         name := camelCaseName(enum.Name)
336         typ := binapiTypes[enum.Type]
337
338         logf(" writing enum %q (%s) with %d entries", enum.Name, name, len(enum.Entries))
339
340         // generate enum comment
341         generateComment(ctx, w, name, enum.Name, "enum")
342
343         // generate enum definition
344         fmt.Fprintf(w, "type %s %s\n", name, typ)
345         fmt.Fprintln(w)
346
347         fmt.Fprintln(w, "const (")
348
349         // generate enum entries
350         for _, entry := range enum.Entries {
351                 fmt.Fprintf(w, "\t%s %s = %v\n", entry.Name, name, entry.Value)
352         }
353
354         fmt.Fprintln(w, ")")
355
356         fmt.Fprintln(w)
357 }
358
359 // generateAlias writes generated code for the alias into w
360 func generateAlias(ctx *context, w io.Writer, alias *Alias) {
361         name := camelCaseName(alias.Name)
362
363         logf(" writing type %q (%s), length: %d", alias.Name, name, alias.Length)
364
365         // generate struct comment
366         generateComment(ctx, w, name, alias.Name, "alias")
367
368         // generate struct definition
369         fmt.Fprintf(w, "type %s ", name)
370
371         if alias.Length > 0 {
372                 fmt.Fprintf(w, "[%d]", alias.Length)
373         }
374
375         dataType := convertToGoType(ctx, alias.Type)
376         fmt.Fprintf(w, "%s\n", dataType)
377
378         fmt.Fprintln(w)
379 }
380
381 // generateUnion writes generated code for the union into w
382 func generateUnion(ctx *context, w io.Writer, union *Union) {
383         name := camelCaseName(union.Name)
384
385         logf(" writing union %q (%s) with %d fields", union.Name, name, len(union.Fields))
386
387         // generate struct comment
388         generateComment(ctx, w, name, union.Name, "union")
389
390         // generate struct definition
391         fmt.Fprintln(w, "type", name, "struct {")
392
393         // maximum size for union
394         maxSize := getUnionSize(ctx, union)
395
396         // generate data field
397         fieldName := "Union_data"
398         fmt.Fprintf(w, "\t%s [%d]byte\n", fieldName, maxSize)
399
400         // generate end of the struct
401         fmt.Fprintln(w, "}")
402
403         // generate name getter
404         generateTypeNameGetter(w, name, union.Name)
405
406         // generate CRC getter
407         generateCrcGetter(w, name, union.CRC)
408
409         // generate getters for fields
410         for _, field := range union.Fields {
411                 fieldName := camelCaseName(field.Name)
412                 fieldType := convertToGoType(ctx, field.Type)
413                 generateUnionGetterSetter(w, name, fieldName, fieldType)
414         }
415
416         // generate union methods
417         //generateUnionMethods(w, name)
418
419         fmt.Fprintln(w)
420 }
421
422 // generateUnionMethods generates methods that implement struc.Custom
423 // interface to allow having Union_data field unexported
424 // TODO: do more testing when unions are actually used in some messages
425 func generateUnionMethods(w io.Writer, structName string) {
426         // generate struc.Custom implementation for union
427         fmt.Fprintf(w, `
428 func (u *%[1]s) Pack(p []byte, opt *struc.Options) (int, error) {
429         var b = new(bytes.Buffer)
430         if err := struc.PackWithOptions(b, u.union_data, opt); err != nil {
431                 return 0, err
432         }
433         copy(p, b.Bytes())
434         return b.Len(), nil
435 }
436 func (u *%[1]s) Unpack(r io.Reader, length int, opt *struc.Options) error {
437         return struc.UnpackWithOptions(r, u.union_data[:], opt)
438 }
439 func (u *%[1]s) Size(opt *struc.Options) int {
440         return len(u.union_data)
441 }
442 func (u *%[1]s) String() string {
443         return string(u.union_data[:])
444 }
445 `, structName)
446 }
447
448 func generateUnionGetterSetter(w io.Writer, structName string, getterField, getterStruct string) {
449         fmt.Fprintf(w, `
450 func %[1]s%[2]s(a %[3]s) (u %[1]s) {
451         u.Set%[2]s(a)
452         return
453 }
454 func (u *%[1]s) Set%[2]s(a %[3]s) {
455         var b = new(bytes.Buffer)
456         if err := struc.Pack(b, &a); err != nil {
457                 return
458         }
459         copy(u.Union_data[:], b.Bytes())
460 }
461 func (u *%[1]s) Get%[2]s() (a %[3]s) {
462         var b = bytes.NewReader(u.Union_data[:])
463         struc.Unpack(b, &a)
464         return
465 }
466 `, structName, getterField, getterStruct)
467 }
468
469 // generateType writes generated code for the type into w
470 func generateType(ctx *context, w io.Writer, typ *Type) {
471         name := camelCaseName(typ.Name)
472
473         logf(" writing type %q (%s) with %d fields", typ.Name, name, len(typ.Fields))
474
475         // generate struct comment
476         generateComment(ctx, w, name, typ.Name, "type")
477
478         // generate struct definition
479         fmt.Fprintf(w, "type %s struct {\n", name)
480
481         // generate struct fields
482         for i, field := range typ.Fields {
483                 // skip internal fields
484                 switch strings.ToLower(field.Name) {
485                 case crcField, msgIdField:
486                         continue
487                 }
488
489                 generateField(ctx, w, typ.Fields, i)
490         }
491
492         // generate end of the struct
493         fmt.Fprintln(w, "}")
494
495         // generate name getter
496         generateTypeNameGetter(w, name, typ.Name)
497
498         // generate CRC getter
499         generateCrcGetter(w, name, typ.CRC)
500
501         fmt.Fprintln(w)
502 }
503
504 // generateMessage writes generated code for the message into w
505 func generateMessage(ctx *context, w io.Writer, msg *Message) {
506         name := camelCaseName(msg.Name)
507
508         logf(" writing message %q (%s) with %d fields", msg.Name, name, len(msg.Fields))
509
510         // generate struct comment
511         generateComment(ctx, w, name, msg.Name, "message")
512
513         // generate struct definition
514         fmt.Fprintf(w, "type %s struct {", name)
515
516         msgType := otherMessage
517         wasClientIndex := false
518
519         // generate struct fields
520         n := 0
521         for i, field := range msg.Fields {
522                 if i == 1 {
523                         if field.Name == clientIndexField {
524                                 // "client_index" as the second member,
525                                 // this might be an event message or a request
526                                 msgType = eventMessage
527                                 wasClientIndex = true
528                         } else if field.Name == contextField {
529                                 // reply needs "context" as the second member
530                                 msgType = replyMessage
531                         }
532                 } else if i == 2 {
533                         if wasClientIndex && field.Name == contextField {
534                                 // request needs "client_index" as the second member
535                                 // and "context" as the third member
536                                 msgType = requestMessage
537                         }
538                 }
539
540                 // skip internal fields
541                 switch strings.ToLower(field.Name) {
542                 case crcField, msgIdField:
543                         continue
544                 case clientIndexField, contextField:
545                         if n == 0 {
546                                 continue
547                         }
548                 }
549                 n++
550                 if n == 1 {
551                         fmt.Fprintln(w)
552                 }
553
554                 generateField(ctx, w, msg.Fields, i)
555         }
556
557         // generate end of the struct
558         fmt.Fprintln(w, "}")
559
560         // generate name getter
561         generateMessageNameGetter(w, name, msg.Name)
562
563         // generate CRC getter
564         generateCrcGetter(w, name, msg.CRC)
565
566         // generate message type getter method
567         generateMessageTypeGetter(w, name, msgType)
568
569         fmt.Fprintln(w)
570 }
571
572 // generateField writes generated code for the field into w
573 func generateField(ctx *context, w io.Writer, fields []Field, i int) {
574         field := fields[i]
575
576         fieldName := strings.TrimPrefix(field.Name, "_")
577         fieldName = camelCaseName(fieldName)
578
579         // generate length field for strings
580         if field.Type == "string" {
581                 fmt.Fprintf(w, "\tXXX_%sLen uint32 `struc:\"sizeof=%s\"`\n", fieldName, fieldName)
582         }
583
584         dataType := convertToGoType(ctx, field.Type)
585         fieldType := dataType
586
587         // check if it is array
588         if field.Length > 0 || field.SizeFrom != "" {
589                 if dataType == "uint8" {
590                         dataType = "byte"
591                 }
592                 fieldType = "[]" + dataType
593         }
594         fmt.Fprintf(w, "\t%s %s", fieldName, fieldType)
595
596         fieldTags := map[string]string{}
597
598         if field.Length > 0 {
599                 // fixed size array
600                 fieldTags["struc"] = fmt.Sprintf("[%d]%s", field.Length, dataType)
601         } else {
602                 for _, f := range fields {
603                         if f.SizeFrom == field.Name {
604                                 // variable sized array
605                                 sizeOfName := camelCaseName(f.Name)
606                                 fieldTags["struc"] = fmt.Sprintf("sizeof=%s", sizeOfName)
607                         }
608                 }
609         }
610
611         if ctx.includeBinapiNames {
612                 fieldTags["binapi"] = field.Name
613         }
614         if field.Meta.Limit > 0 {
615                 fieldTags["binapi"] = fmt.Sprintf("%s,limit=%d", fieldTags["binapi"], field.Meta.Limit)
616         }
617
618         if len(fieldTags) > 0 {
619                 fmt.Fprintf(w, "\t`")
620                 var keys []string
621                 for k := range fieldTags {
622                         keys = append(keys, k)
623                 }
624                 sort.Strings(keys)
625                 var n int
626                 for _, tt := range keys {
627                         t, ok := fieldTags[tt]
628                         if !ok {
629                                 continue
630                         }
631                         if n > 0 {
632                                 fmt.Fprintf(w, " ")
633                         }
634                         n++
635                         fmt.Fprintf(w, `%s:"%s"`, tt, t)
636                 }
637                 fmt.Fprintf(w, "`")
638         }
639
640         fmt.Fprintln(w)
641 }
642
643 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
644 func generateMessageNameGetter(w io.Writer, structName, msgName string) {
645         fmt.Fprintf(w, `func (*%s) GetMessageName() string {
646         return %q
647 }
648 `, structName, msgName)
649 }
650
651 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
652 func generateTypeNameGetter(w io.Writer, structName, msgName string) {
653         fmt.Fprintf(w, `func (*%s) GetTypeName() string {
654         return %q
655 }
656 `, structName, msgName)
657 }
658
659 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
660 func generateCrcGetter(w io.Writer, structName, crc string) {
661         crc = strings.TrimPrefix(crc, "0x")
662         fmt.Fprintf(w, `func (*%s) GetCrcString() string {
663         return %q
664 }
665 `, structName, crc)
666 }
667
668 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
669 func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageType) {
670         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
671         if msgType == requestMessage {
672                 fmt.Fprintln(w, "\treturn api.RequestMessage")
673         } else if msgType == replyMessage {
674                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
675         } else if msgType == eventMessage {
676                 fmt.Fprintln(w, "\treturn api.EventMessage")
677         } else {
678                 fmt.Fprintln(w, "\treturn api.OtherMessage")
679         }
680         fmt.Fprintln(w, "}")
681 }