Refactor GoVPP
[govpp.git] / cmd / binapi-generator / generate.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "path/filepath"
23         "strings"
24         "unicode"
25 )
26
27 const (
28         govppApiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API package
29         inputFileExt       = ".api.json"               // file extension of the VPP binary API files
30         outputFileExt      = ".ba.go"                  // file extension of the Go generated files
31 )
32
33 // context is a structure storing data for code generation
34 type context struct {
35         inputFile  string // input file with VPP API in JSON
36         outputFile string // output file with generated Go package
37
38         inputData []byte        // contents of the input file
39         inputBuff *bytes.Buffer // contents of the input file currently being read
40         inputLine int           // currently processed line in the input file
41
42         moduleName  string // name of the source VPP module
43         packageName string // name of the Go package being generated
44
45         packageData *Package // parsed package data
46 }
47
48 // getContext returns context details of the code generation task
49 func getContext(inputFile, outputDir string) (*context, error) {
50         if !strings.HasSuffix(inputFile, inputFileExt) {
51                 return nil, fmt.Errorf("invalid input file name: %q", inputFile)
52         }
53
54         ctx := &context{
55                 inputFile: inputFile,
56         }
57
58         // package name
59         inputFileName := filepath.Base(inputFile)
60         ctx.moduleName = inputFileName[:strings.Index(inputFileName, ".")]
61
62         // alter package names for modules that are reserved keywords in Go
63         switch ctx.moduleName {
64         case "interface":
65                 ctx.packageName = "interfaces"
66         case "map":
67                 ctx.packageName = "maps"
68         default:
69                 ctx.packageName = ctx.moduleName
70         }
71
72         // output file
73         packageDir := filepath.Join(outputDir, ctx.packageName)
74         outputFileName := ctx.packageName + outputFileExt
75         ctx.outputFile = filepath.Join(packageDir, outputFileName)
76
77         return ctx, nil
78 }
79
80 // generatePackage generates code for the parsed package data and writes it into w
81 func generatePackage(ctx *context, w *bufio.Writer) error {
82         logf("generating package %q", ctx.packageName)
83
84         // generate file header
85         generateHeader(ctx, w)
86         generateImports(ctx, w)
87
88         if *includeAPIVer {
89                 const APIVerConstName = "VlAPIVersion"
90                 fmt.Fprintf(w, "// %s represents version of the API.\n", APIVerConstName)
91                 fmt.Fprintf(w, "const %s = %v\n", APIVerConstName, ctx.packageData.APIVersion)
92                 fmt.Fprintln(w)
93         }
94
95         // generate enums
96         if len(ctx.packageData.Enums) > 0 {
97                 fmt.Fprintf(w, "/* Enums */\n\n")
98
99                 ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
100                 ctx.inputLine = 0
101                 for _, enum := range ctx.packageData.Enums {
102                         generateEnum(ctx, w, &enum)
103                 }
104         }
105
106         // generate types
107         if len(ctx.packageData.Types) > 0 {
108                 fmt.Fprintf(w, "/* Types */\n\n")
109
110                 ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
111                 ctx.inputLine = 0
112                 for _, typ := range ctx.packageData.Types {
113                         generateType(ctx, w, &typ)
114                 }
115         }
116
117         // generate unions
118         if len(ctx.packageData.Unions) > 0 {
119                 fmt.Fprintf(w, "/* Unions */\n\n")
120
121                 ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
122                 ctx.inputLine = 0
123                 for _, union := range ctx.packageData.Unions {
124                         generateUnion(ctx, w, &union)
125                 }
126         }
127
128         // generate messages
129         if len(ctx.packageData.Messages) > 0 {
130                 fmt.Fprintf(w, "/* Messages */\n\n")
131
132                 ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
133                 ctx.inputLine = 0
134                 for _, msg := range ctx.packageData.Messages {
135                         generateMessage(ctx, w, &msg)
136                 }
137         }
138
139         // generate services
140         if len(ctx.packageData.Services) > 0 {
141                 fmt.Fprintf(w, "/* Services */\n\n")
142
143                 fmt.Fprintf(w, "type %s interface {\n", "Services")
144                 ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
145                 ctx.inputLine = 0
146                 for _, svc := range ctx.packageData.Services {
147                         generateService(ctx, w, &svc)
148                 }
149                 fmt.Fprintln(w, "}")
150         }
151
152         // TODO: generate implementation for Services interface
153
154         // generate message registrations
155         fmt.Fprintln(w)
156         fmt.Fprintln(w, "func init() {")
157         for _, msg := range ctx.packageData.Messages {
158                 name := camelCaseName(msg.Name)
159                 fmt.Fprintf(w, "\tapi.RegisterMessage((*%s)(nil), \"%s\")\n", name, ctx.moduleName+"."+name)
160         }
161         fmt.Fprintln(w, "}")
162
163         // flush the data:
164         if err := w.Flush(); err != nil {
165                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
166         }
167
168         return nil
169 }
170
171 // generateHeader writes generated package header into w
172 func generateHeader(ctx *context, w io.Writer) {
173         fmt.Fprintln(w, "// Code generated by GoVPP binapi-generator. DO NOT EDIT.")
174         fmt.Fprintf(w, "// source: %s\n", ctx.inputFile)
175         fmt.Fprintln(w)
176
177         fmt.Fprintln(w, "/*")
178         fmt.Fprintf(w, "Package %s is a generated VPP binary API of the '%s' VPP module.\n", ctx.packageName, ctx.moduleName)
179         fmt.Fprintln(w)
180         fmt.Fprintln(w, "It is generated from this file:")
181         fmt.Fprintf(w, "\t%s\n", filepath.Base(ctx.inputFile))
182         fmt.Fprintln(w)
183         fmt.Fprintln(w, "It contains these VPP binary API objects:")
184         var printObjNum = func(obj string, num int) {
185                 if num > 0 {
186                         if num > 1 {
187                                 obj += "s"
188                         }
189                         fmt.Fprintf(w, "\t%d %s\n", num, obj)
190                 }
191         }
192         printObjNum("message", len(ctx.packageData.Messages))
193         printObjNum("type", len(ctx.packageData.Types))
194         printObjNum("enum", len(ctx.packageData.Enums))
195         printObjNum("union", len(ctx.packageData.Unions))
196         printObjNum("service", len(ctx.packageData.Services))
197         fmt.Fprintln(w, "*/")
198         fmt.Fprintf(w, "package %s\n", ctx.packageName)
199         fmt.Fprintln(w)
200 }
201
202 // generateImports writes generated package imports into w
203 func generateImports(ctx *context, w io.Writer) {
204         fmt.Fprintf(w, "import \"%s\"\n", govppApiImportPath)
205         fmt.Fprintf(w, "import \"%s\"\n", "github.com/lunixbochs/struc")
206         fmt.Fprintf(w, "import \"%s\"\n", "bytes")
207         fmt.Fprintln(w)
208
209         fmt.Fprintf(w, "// Reference imports to suppress errors if they are not otherwise used.\n")
210         fmt.Fprintf(w, "var _ = struc.Pack\n")
211         fmt.Fprintf(w, "var _ = bytes.NewBuffer\n")
212         fmt.Fprintln(w)
213 }
214
215 // generateComment writes generated comment for the object into w
216 func generateComment(ctx *context, w io.Writer, goName string, vppName string, objKind string) {
217         fmt.Fprintf(w, "// %s represents the VPP binary API %s '%s'.\n", goName, objKind, vppName)
218
219         var isNotSpace = func(r rune) bool {
220                 return !unicode.IsSpace(r)
221         }
222
223         // print out the source of the generated object
224         objFound := false
225         objTitle := fmt.Sprintf(`"%s",`, vppName)
226         var indent int
227         for {
228                 line, err := ctx.inputBuff.ReadString('\n')
229                 if err != nil {
230                         break
231                 }
232                 ctx.inputLine++
233
234                 if !objFound {
235                         indent = strings.Index(line, objTitle)
236                         if indent == -1 {
237                                 continue
238                         }
239                         // If no other non-whitespace character then we are at the message header.
240                         if trimmed := strings.TrimSpace(line); trimmed == objTitle {
241                                 objFound = true
242                                 fmt.Fprintf(w, "// Generated from '%s', line %d:\n", filepath.Base(ctx.inputFile), ctx.inputLine)
243                                 fmt.Fprintln(w, "//")
244                         }
245                 } else {
246                         if strings.IndexFunc(line, isNotSpace) < indent {
247                                 break // end of the object definition in JSON
248                         }
249                 }
250                 fmt.Fprint(w, "//", line)
251         }
252
253         fmt.Fprintln(w, "//")
254 }
255
256 // generateEnum writes generated code for the enum into w
257 func generateEnum(ctx *context, w io.Writer, enum *Enum) {
258         name := camelCaseName(enum.Name)
259         typ := binapiTypes[enum.Type]
260
261         logf(" writing enum %q (%s) with %d entries", enum.Name, name, len(enum.Entries))
262
263         // generate enum comment
264         generateComment(ctx, w, name, enum.Name, "enum")
265
266         // generate enum definition
267         fmt.Fprintf(w, "type %s %s\n", name, typ)
268         fmt.Fprintln(w)
269
270         fmt.Fprintln(w, "const (")
271
272         // generate enum entries
273         for _, entry := range enum.Entries {
274                 fmt.Fprintf(w, "\t%s %s = %v\n", entry.Name, name, entry.Value)
275         }
276
277         fmt.Fprintln(w, ")")
278
279         fmt.Fprintln(w)
280 }
281
282 // generateType writes generated code for the type into w
283 func generateType(ctx *context, w io.Writer, typ *Type) {
284         name := camelCaseName(typ.Name)
285
286         logf(" writing type %q (%s) with %d fields", typ.Name, name, len(typ.Fields))
287
288         // generate struct comment
289         generateComment(ctx, w, name, typ.Name, "type")
290
291         // generate struct definition
292         fmt.Fprintf(w, "type %s struct {\n", name)
293
294         // generate struct fields
295         for i, field := range typ.Fields {
296                 // skip internal fields
297                 switch strings.ToLower(field.Name) {
298                 case "crc", "_vl_msg_id":
299                         continue
300                 }
301
302                 generateField(ctx, w, typ.Fields, i)
303         }
304
305         // generate end of the struct
306         fmt.Fprintln(w, "}")
307
308         // generate name getter
309         generateTypeNameGetter(w, name, typ.Name)
310
311         // generate CRC getter
312         generateCrcGetter(w, name, typ.CRC)
313
314         fmt.Fprintln(w)
315 }
316
317 // generateUnion writes generated code for the union into w
318 func generateUnion(ctx *context, w io.Writer, union *Union) {
319         name := camelCaseName(union.Name)
320
321         logf(" writing union %q (%s) with %d fields", union.Name, name, len(union.Fields))
322
323         // generate struct comment
324         generateComment(ctx, w, name, union.Name, "union")
325
326         // generate struct definition
327         fmt.Fprintln(w, "type", name, "struct {")
328
329         // maximum size for union
330         maxSize := getUnionSize(ctx, union)
331
332         // generate data field
333         fieldName := "Union_data"
334         fmt.Fprintf(w, "\t%s [%d]byte\n", fieldName, maxSize)
335
336         // generate end of the struct
337         fmt.Fprintln(w, "}")
338
339         // generate name getter
340         generateTypeNameGetter(w, name, union.Name)
341
342         // generate CRC getter
343         generateCrcGetter(w, name, union.CRC)
344
345         // generate getters for fields
346         for _, field := range union.Fields {
347                 fieldName := camelCaseName(field.Name)
348                 fieldType := convertToGoType(ctx, field.Type)
349                 generateUnionGetterSetter(w, name, fieldName, fieldType)
350         }
351
352         // generate union methods
353         //generateUnionMethods(w, name)
354
355         fmt.Fprintln(w)
356 }
357
358 // generateUnionMethods generates methods that implement struc.Custom
359 // interface to allow having Union_data field unexported
360 // TODO: do more testing when unions are actually used in some messages
361 func generateUnionMethods(w io.Writer, structName string) {
362         // generate struc.Custom implementation for union
363         fmt.Fprintf(w, `
364 func (u *%[1]s) Pack(p []byte, opt *struc.Options) (int, error) {
365         var b = new(bytes.Buffer)
366         if err := struc.PackWithOptions(b, u.union_data, opt); err != nil {
367                 return 0, err
368         }
369         copy(p, b.Bytes())
370         return b.Len(), nil
371 }
372 func (u *%[1]s) Unpack(r io.Reader, length int, opt *struc.Options) error {
373         return struc.UnpackWithOptions(r, u.union_data[:], opt)
374 }
375 func (u *%[1]s) Size(opt *struc.Options) int {
376         return len(u.union_data)
377 }
378 func (u *%[1]s) String() string {
379         return string(u.union_data[:])
380 }
381 `, structName)
382 }
383
384 func generateUnionGetterSetter(w io.Writer, structName string, getterField, getterStruct string) {
385         fmt.Fprintf(w, `
386 func (u *%[1]s) Set%[2]s(a %[3]s) {
387         var b = new(bytes.Buffer)
388         if err := struc.Pack(b, &a); err != nil {
389                 return
390         }
391         copy(u.Union_data[:], b.Bytes())
392 }
393 func (u *%[1]s) Get%[2]s() (a %[3]s) {
394         var b = bytes.NewReader(u.Union_data[:])
395         struc.Unpack(b, &a)
396         return
397 }
398 `, structName, getterField, getterStruct)
399 }
400
401 // generateMessage writes generated code for the message into w
402 func generateMessage(ctx *context, w io.Writer, msg *Message) {
403         name := camelCaseName(msg.Name)
404
405         logf(" writing message %q (%s) with %d fields", msg.Name, name, len(msg.Fields))
406
407         // generate struct comment
408         generateComment(ctx, w, name, msg.Name, "message")
409
410         // generate struct definition
411         fmt.Fprintf(w, "type %s struct {", name)
412
413         msgType := otherMessage
414         wasClientIndex := false
415
416         // generate struct fields
417         n := 0
418         for i, field := range msg.Fields {
419                 if i == 1 {
420                         if field.Name == "client_index" {
421                                 // "client_index" as the second member, this might be an event message or a request
422                                 msgType = eventMessage
423                                 wasClientIndex = true
424                         } else if field.Name == "context" {
425                                 // reply needs "context" as the second member
426                                 msgType = replyMessage
427                         }
428                 } else if i == 2 {
429                         if wasClientIndex && field.Name == "context" {
430                                 // request needs "client_index" as the second member and "context" as the third member
431                                 msgType = requestMessage
432                         }
433                 }
434
435                 // skip internal fields
436                 switch strings.ToLower(field.Name) {
437                 case "crc", "_vl_msg_id":
438                         continue
439                 case "client_index", "context":
440                         if n == 0 {
441                                 continue
442                         }
443                 }
444                 n++
445                 if n == 1 {
446                         fmt.Fprintln(w)
447                 }
448
449                 generateField(ctx, w, msg.Fields, i)
450         }
451
452         // generate end of the struct
453         fmt.Fprintln(w, "}")
454
455         // generate name getter
456         generateMessageNameGetter(w, name, msg.Name)
457
458         // generate CRC getter
459         generateCrcGetter(w, name, msg.CRC)
460
461         // generate message type getter method
462         generateMessageTypeGetter(w, name, msgType)
463
464         // generate message factory
465         generateMessageFactory(w, name)
466 }
467
468 // generateField writes generated code for the field into w
469 func generateField(ctx *context, w io.Writer, fields []Field, i int) {
470         field := fields[i]
471
472         fieldName := strings.TrimPrefix(field.Name, "_")
473         fieldName = camelCaseName(fieldName)
474
475         dataType := convertToGoType(ctx, field.Type)
476
477         fieldType := dataType
478         if field.IsArray() {
479                 if dataType == "uint8" {
480                         dataType = "byte"
481                 }
482                 fieldType = "[]" + dataType
483         }
484         fmt.Fprintf(w, "\t%s %s", fieldName, fieldType)
485
486         if field.Length > 0 {
487                 // fixed size array
488                 fmt.Fprintf(w, "\t`struc:\"[%d]%s\"`", field.Length, dataType)
489         } else {
490                 for _, f := range fields {
491                         if f.SizeFrom == field.Name {
492                                 // variable sized array
493                                 sizeOfName := camelCaseName(f.Name)
494                                 fmt.Fprintf(w, "\t`struc:\"sizeof=%s\"`", sizeOfName)
495                         }
496                 }
497         }
498
499         fmt.Fprintln(w)
500 }
501
502 // generateService writes generated code for the service into w
503 func generateService(ctx *context, w io.Writer, svc *Service) {
504         reqTyp := camelCaseName(svc.RequestType)
505
506         // method name is same as parameter type name by default
507         method := reqTyp
508         if svc.Stream {
509                 // use Dump as prefix instead of suffix for stream services
510                 if m := strings.TrimSuffix(method, "Dump"); method != m {
511                         method = "Dump" + m
512                 }
513         }
514         params := fmt.Sprintf("*%s", reqTyp)
515         returns := "error"
516         if replyTyp := camelCaseName(svc.ReplyType); replyTyp != "" {
517                 returns = fmt.Sprintf("(*%s, error)", replyTyp)
518         }
519
520         fmt.Fprintf(w, "\t%s(%s) %s\n", method, params, returns)
521 }
522
523 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
524 func generateMessageNameGetter(w io.Writer, structName string, msgName string) {
525         fmt.Fprintln(w, "func (*"+structName+") GetMessageName() string {")
526         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
527         fmt.Fprintln(w, "}")
528 }
529
530 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
531 func generateTypeNameGetter(w io.Writer, structName string, msgName string) {
532         fmt.Fprintln(w, "func (*"+structName+") GetTypeName() string {")
533         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
534         fmt.Fprintln(w, "}")
535 }
536
537 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
538 func generateCrcGetter(w io.Writer, structName string, crc string) {
539         crc = strings.TrimPrefix(crc, "0x")
540         fmt.Fprintln(w, "func (*"+structName+") GetCrcString() string {")
541         fmt.Fprintln(w, "\treturn \""+crc+"\"")
542         fmt.Fprintln(w, "}")
543 }
544
545 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
546 func generateMessageTypeGetter(w io.Writer, structName string, msgType MessageType) {
547         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
548         if msgType == requestMessage {
549                 fmt.Fprintln(w, "\treturn api.RequestMessage")
550         } else if msgType == replyMessage {
551                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
552         } else if msgType == eventMessage {
553                 fmt.Fprintln(w, "\treturn api.EventMessage")
554         } else {
555                 fmt.Fprintln(w, "\treturn api.OtherMessage")
556         }
557         fmt.Fprintln(w, "}")
558 }
559
560 // generateMessageFactory generates message factory for the generated message into the provider writer
561 func generateMessageFactory(w io.Writer, structName string) {
562         fmt.Fprintln(w, "func New"+structName+"() api.Message {")
563         fmt.Fprintln(w, "\treturn &"+structName+"{}")
564         fmt.Fprintln(w, "}")
565 }