Update binapi-generator for the new VPPAPIGEN.
[govpp.git] / cmd / binapi-generator / generator.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "encoding/json"
21         "errors"
22         "flag"
23         "fmt"
24         "io"
25         "io/ioutil"
26         "os"
27         "os/exec"
28         "path/filepath"
29         "strings"
30         "unicode"
31
32         "github.com/bennyscetbun/jsongo"
33 )
34
35 // MessageType represents the type of a VPP message.
36 type messageType int
37
38 const (
39         requestMessage messageType = iota // VPP request message
40         replyMessage                      // VPP reply message
41         eventMessage                      // VPP event message
42         otherMessage                      // other VPP message
43 )
44
45 const (
46         apiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API
47         inputFileExt  = ".json"                   // filename extension of files that should be processed as the input
48 )
49
50 // context is a structure storing details of a particular code generation task
51 type context struct {
52         inputFile   string            // file with input JSON data
53         inputData   []byte            // contents of the input file
54         inputBuff   *bytes.Buffer     // contents of the input file currently being read
55         inputLine   int               // currently processed line in the input file
56         outputFile  string            // file with output data
57         packageName string            // name of the Go package being generated
58         packageDir  string            // directory where the package source files are located
59         types       map[string]string // map of the VPP typedef names to generated Go typedef names
60 }
61
62 func main() {
63         inputFile := flag.String("input-file", "", "Input JSON file.")
64         inputDir := flag.String("input-dir", ".", "Input directory with JSON files.")
65         outputDir := flag.String("output-dir", ".", "Output directory where package folders will be generated.")
66         flag.Parse()
67
68         if *inputFile == "" && *inputDir == "" {
69                 fmt.Fprintln(os.Stderr, "ERROR: input-file or input-dir must be specified")
70                 os.Exit(1)
71         }
72
73         var err, tmpErr error
74         if *inputFile != "" {
75                 // process one input file
76                 err = generateFromFile(*inputFile, *outputDir)
77                 if err != nil {
78                         fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", *inputFile, err)
79                 }
80         } else {
81                 // process all files in specified directory
82                 files, err := getInputFiles(*inputDir)
83                 if err != nil {
84                         fmt.Fprintf(os.Stderr, "ERROR: code generation failed: %v\n", err)
85                 }
86                 for _, file := range files {
87                         tmpErr = generateFromFile(file, *outputDir)
88                         if tmpErr != nil {
89                                 fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", file, err)
90                                 err = tmpErr // remember that the error occurred
91                         }
92                 }
93         }
94         if err != nil {
95                 os.Exit(1)
96         }
97 }
98
99 // getInputFiles returns all input files located in specified directory
100 func getInputFiles(inputDir string) ([]string, error) {
101         files, err := ioutil.ReadDir(inputDir)
102         if err != nil {
103                 return nil, fmt.Errorf("reading directory %s failed: %v", inputDir, err)
104         }
105         res := make([]string, 0)
106         for _, f := range files {
107                 if strings.HasSuffix(f.Name(), inputFileExt) {
108                         res = append(res, inputDir+"/"+f.Name())
109                 }
110         }
111         return res, nil
112 }
113
114 // generateFromFile generates Go bindings from one input JSON file
115 func generateFromFile(inputFile, outputDir string) error {
116         ctx, err := getContext(inputFile, outputDir)
117         if err != nil {
118                 return err
119         }
120         // read the file
121         ctx.inputData, err = readFile(inputFile)
122         if err != nil {
123                 return err
124         }
125
126         // parse JSON
127         jsonRoot, err := parseJSON(ctx.inputData)
128         if err != nil {
129                 return err
130         }
131
132         // create output directory
133         err = os.MkdirAll(ctx.packageDir, 0777)
134         if err != nil {
135                 return fmt.Errorf("creating output directory %s failed: %v", ctx.packageDir, err)
136         }
137
138         // open output file
139         f, err := os.Create(ctx.outputFile)
140         defer f.Close()
141         if err != nil {
142                 return fmt.Errorf("creating output file %s failed: %v", ctx.outputFile, err)
143         }
144         w := bufio.NewWriter(f)
145
146         // generate Go package code
147         err = generatePackage(ctx, w, jsonRoot)
148         if err != nil {
149                 return err
150         }
151
152         // go format the output file (non-fatal if fails)
153         exec.Command("gofmt", "-w", ctx.outputFile).Run()
154
155         return nil
156 }
157
158 // getContext returns context details of the code generation task
159 func getContext(inputFile, outputDir string) (*context, error) {
160         if !strings.HasSuffix(inputFile, inputFileExt) {
161                 return nil, fmt.Errorf("invalid input file name %s", inputFile)
162         }
163
164         ctx := &context{inputFile: inputFile}
165         inputFileName := filepath.Base(inputFile)
166
167         ctx.packageName = inputFileName[0:strings.Index(inputFileName, ".")]
168         if ctx.packageName == "interface" {
169                 // 'interface' cannot be a package name, it is a go keyword
170                 ctx.packageName = "interfaces"
171         }
172
173         ctx.packageDir = outputDir + "/" + ctx.packageName + "/"
174         ctx.outputFile = ctx.packageDir + ctx.packageName + ".go"
175
176         return ctx, nil
177 }
178
179 // readFile reads content of a file into memory
180 func readFile(inputFile string) ([]byte, error) {
181
182         inputData, err := ioutil.ReadFile(inputFile)
183
184         if err != nil {
185                 return nil, fmt.Errorf("reading data from file failed: %v", err)
186         }
187
188         return inputData, nil
189 }
190
191 // parseJSON parses a JSON data into an in-memory tree
192 func parseJSON(inputData []byte) (*jsongo.JSONNode, error) {
193         root := jsongo.JSONNode{}
194
195         err := json.Unmarshal(inputData, &root)
196         if err != nil {
197                 return nil, fmt.Errorf("JSON unmarshall failed: %v", err)
198         }
199
200         return &root, nil
201
202 }
203
204 // generatePackage generates Go code of a package from provided JSON
205 func generatePackage(ctx *context, w *bufio.Writer, jsonRoot *jsongo.JSONNode) error {
206         // generate file header
207         generatePackageHeader(ctx, w, jsonRoot)
208
209         // generate data types
210         ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
211         ctx.inputLine = 0
212         ctx.types = make(map[string]string)
213         types := jsonRoot.Map("types")
214         for i := 0; i < types.Len(); i++ {
215                 typ := types.At(i)
216                 err := generateMessage(ctx, w, typ, true)
217                 if err != nil {
218                         return err
219                 }
220         }
221
222         // generate messages
223         ctx.inputBuff = bytes.NewBuffer(ctx.inputData)
224         ctx.inputLine = 0
225         messages := jsonRoot.Map("messages")
226         for i := 0; i < messages.Len(); i++ {
227                 msg := messages.At(i)
228                 err := generateMessage(ctx, w, msg, false)
229                 if err != nil {
230                         return err
231                 }
232         }
233
234         // flush the data:
235         err := w.Flush()
236         if err != nil {
237                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
238         }
239
240         return nil
241 }
242
243 // generateMessage generates Go code of one VPP message encoded in JSON into provided writer
244 func generateMessage(ctx *context, w io.Writer, msg *jsongo.JSONNode, isType bool) error {
245         if msg.Len() == 0 || msg.At(0).GetType() != jsongo.TypeValue {
246                 return errors.New("invalid JSON for message specified")
247         }
248
249         msgName, ok := msg.At(0).Get().(string)
250         if !ok {
251                 return fmt.Errorf("invalid JSON for message specified, message name is %T, not a string", msg.At(0).Get())
252         }
253         structName := camelCaseName(strings.Title(msgName))
254
255         // generate struct fields into the slice & determine message type
256         fields := make([]string, 0)
257         msgType := otherMessage
258         wasClientIndex := false
259         for j := 0; j < msg.Len(); j++ {
260                 if jsongo.TypeArray == msg.At(j).GetType() {
261                         fld := msg.At(j)
262                         if !isType {
263                                 // determine whether ths is a request / reply / other message
264                                 fieldName, ok := fld.At(1).Get().(string)
265                                 if ok {
266                                         if j == 2 {
267                                                 if fieldName == "client_index" {
268                                                         // "client_index" as the second member, this might be an event message or a request
269                                                         msgType = eventMessage
270                                                         wasClientIndex = true
271                                                 } else if fieldName == "context" {
272                                                         // reply needs "context" as the second member
273                                                         msgType = replyMessage
274                                                 }
275                                         } else if j == 3 {
276                                                 if wasClientIndex && fieldName == "context" {
277                                                         // request needs "client_index" as the second member and "context" as the third member
278                                                         msgType = requestMessage
279                                                 }
280                                         }
281                                 }
282                         }
283                         err := processMessageField(ctx, &fields, fld, isType)
284                         if err != nil {
285                                 return err
286                         }
287                 }
288         }
289
290         // generate struct comment
291         generateMessageComment(ctx, w, structName, msgName, isType)
292
293         // generate struct header
294         fmt.Fprintln(w, "type", structName, "struct {")
295
296         // print out the fields
297         for _, field := range fields {
298                 fmt.Fprintln(w, field)
299         }
300
301         // generate end of the struct
302         fmt.Fprintln(w, "}")
303
304         // generate name getter
305         if isType {
306                 generateTypeNameGetter(w, structName, msgName)
307         } else {
308                 generateMessageNameGetter(w, structName, msgName)
309         }
310
311         // generate message type getter method
312         if !isType {
313                 generateMessageTypeGetter(w, structName, msgType)
314         }
315
316         // generate CRC getter
317         crcIf := msg.At(msg.Len() - 1).At("crc").Get()
318         if crc, ok := crcIf.(string); ok {
319                 generateCrcGetter(w, structName, crc)
320         }
321
322         // generate message factory
323         if !isType {
324                 generateMessageFactory(w, structName)
325         }
326
327         // if this is a type, save it in the map for later use
328         if isType {
329                 ctx.types[fmt.Sprintf("vl_api_%s_t", msgName)] = structName
330         }
331
332         return nil
333 }
334
335 // processMessageField process JSON describing one message field into Go code emitted into provided slice of message fields
336 func processMessageField(ctx *context, fields *[]string, fld *jsongo.JSONNode, isType bool) error {
337         if fld.Len() < 2 || fld.At(0).GetType() != jsongo.TypeValue || fld.At(1).GetType() != jsongo.TypeValue {
338                 return errors.New("invalid JSON for message field specified")
339         }
340         fieldVppType, ok := fld.At(0).Get().(string)
341         if !ok {
342                 return fmt.Errorf("invalid JSON for message specified, field type is %T, not a string", fld.At(0).Get())
343         }
344         fieldName, ok := fld.At(1).Get().(string)
345         if !ok {
346                 return fmt.Errorf("invalid JSON for message specified, field name is %T, not a string", fld.At(1).Get())
347         }
348
349         // skip internal fields
350         fieldNameLower := strings.ToLower(fieldName)
351         if fieldNameLower == "crc" || fieldNameLower == "_vl_msg_id" {
352                 return nil
353         }
354         if !isType && len(*fields) == 0 && (fieldNameLower == "client_index" || fieldNameLower == "context") {
355                 return nil
356         }
357
358         fieldName = strings.TrimPrefix(fieldName, "_")
359         fieldName = camelCaseName(strings.Title(fieldName))
360
361         fieldStr := ""
362         isArray := false
363         arraySize := 0
364
365         fieldStr += "\t" + fieldName + " "
366         if fld.Len() > 2 {
367                 isArray = true
368                 arraySize = int(fld.At(2).Get().(float64))
369                 fieldStr += "[]"
370         }
371
372         dataType := translateVppType(ctx, fieldVppType, isArray)
373         fieldStr += dataType
374
375         if isArray {
376                 if arraySize == 0 {
377                         // variable sized array
378                         if fld.Len() > 3 {
379                                 // array size is specified by another field
380                                 arraySizeField := string(fld.At(3).Get().(string))
381                                 arraySizeField = camelCaseName(strings.Title(arraySizeField))
382                                 // find & update the field that specifies the array size
383                                 for i, f := range *fields {
384                                         if strings.Contains(f, fmt.Sprintf("\t%s ", arraySizeField)) {
385                                                 (*fields)[i] += fmt.Sprintf("\t`struc:\"sizeof=%s\"`", fieldName)
386                                         }
387                                 }
388                         }
389                 } else {
390                         // fixed size array
391                         fieldStr += fmt.Sprintf("\t`struc:\"[%d]%s\"`", arraySize, dataType)
392                 }
393         }
394
395         *fields = append(*fields, fieldStr)
396         return nil
397 }
398
399 // generatePackageHeader generates package header into provider writer
400 func generatePackageHeader(ctx *context, w io.Writer, rootNode *jsongo.JSONNode) {
401         fmt.Fprintln(w, "// Code generated by govpp binapi-generator DO NOT EDIT.")
402         fmt.Fprintln(w, "// Package "+ctx.packageName+" represents the VPP binary API of the '"+ctx.packageName+"' VPP module.")
403         fmt.Fprintln(w, "// Generated from '"+ctx.inputFile+"'")
404
405         fmt.Fprintln(w, "package "+ctx.packageName)
406
407         fmt.Fprintln(w, "import \""+apiImportPath+"\"")
408
409         fmt.Fprintln(w)
410         fmt.Fprintln(w, "// VlApiVersion contains version of the API.")
411         vlAPIVersion := rootNode.Map("vl_api_version")
412         if vlAPIVersion != nil {
413                 fmt.Fprintln(w, "const VlAPIVersion = ", vlAPIVersion.Get())
414         }
415         fmt.Fprintln(w)
416 }
417
418 // generateMessageComment generates comment for a message into provider writer
419 func generateMessageComment(ctx *context, w io.Writer, structName string, msgName string, isType bool) {
420         fmt.Fprintln(w)
421         if isType {
422                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API data type '"+msgName+"'.")
423         } else {
424                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API message '"+msgName+"'.")
425         }
426
427         // print out the source of the generated message - the JSON
428         msgFound := false
429         msgTitle := "\"" + msgName + "\","
430         var msgIndent int
431         for {
432                 lineBuff, err := ctx.inputBuff.ReadBytes('\n')
433                 if err != nil {
434                         break
435                 }
436                 ctx.inputLine++
437                 line := string(lineBuff)
438
439                 if !msgFound {
440                         msgIndent = strings.Index(line, msgTitle)
441                         if msgIndent > -1 {
442                                 prefix := line[:msgIndent]
443                                 suffix := line[msgIndent+len(msgTitle):]
444                                 // If no other non-whitespace character then we are at the message header.
445                                 if strings.IndexFunc(prefix, isNotSpace) == -1 && strings.IndexFunc(suffix, isNotSpace) == -1 {
446                                         fmt.Fprintf(w, "// Generated from '%s', line %d:\n", ctx.inputFile, ctx.inputLine)
447                                         fmt.Fprintln(w, "//")
448                                         fmt.Fprint(w, "//", line)
449                                         msgFound = true
450                                 }
451                         }
452                 } else {
453                         if strings.IndexFunc(line, isNotSpace) < msgIndent {
454                                 break // end of the message in JSON
455                         }
456                         fmt.Fprint(w, "//", line)
457                 }
458         }
459         fmt.Fprintln(w, "//")
460 }
461
462 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
463 func generateMessageNameGetter(w io.Writer, structName string, msgName string) {
464         fmt.Fprintln(w, "func (*"+structName+") GetMessageName() string {")
465         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
466         fmt.Fprintln(w, "}")
467 }
468
469 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
470 func generateTypeNameGetter(w io.Writer, structName string, msgName string) {
471         fmt.Fprintln(w, "func (*"+structName+") GetTypeName() string {")
472         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
473         fmt.Fprintln(w, "}")
474 }
475
476 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
477 func generateMessageTypeGetter(w io.Writer, structName string, msgType messageType) {
478         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
479         if msgType == requestMessage {
480                 fmt.Fprintln(w, "\treturn api.RequestMessage")
481         } else if msgType == replyMessage {
482                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
483         } else if msgType == eventMessage {
484                 fmt.Fprintln(w, "\treturn api.EventMessage")
485         } else {
486                 fmt.Fprintln(w, "\treturn api.OtherMessage")
487         }
488         fmt.Fprintln(w, "}")
489 }
490
491 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
492 func generateCrcGetter(w io.Writer, structName string, crc string) {
493         crc = strings.TrimPrefix(crc, "0x")
494         fmt.Fprintln(w, "func (*"+structName+") GetCrcString() string {")
495         fmt.Fprintln(w, "\treturn \""+crc+"\"")
496         fmt.Fprintln(w, "}")
497 }
498
499 // generateMessageFactory generates message factory for the generated message into the provider writer
500 func generateMessageFactory(w io.Writer, structName string) {
501         fmt.Fprintln(w, "func New"+structName+"() api.Message {")
502         fmt.Fprintln(w, "\treturn &"+structName+"{}")
503         fmt.Fprintln(w, "}")
504 }
505
506 // translateVppType translates the VPP data type into Go data type
507 func translateVppType(ctx *context, vppType string, isArray bool) string {
508         // basic types
509         switch vppType {
510         case "u8":
511                 if isArray {
512                         return "byte"
513                 }
514                 return "uint8"
515         case "i8":
516                 return "int8"
517         case "u16":
518                 return "uint16"
519         case "i16":
520                 return "int16"
521         case "u32":
522                 return "uint32"
523         case "i32":
524                 return "int32"
525         case "u64":
526                 return "uint64"
527         case "i64":
528                 return "int64"
529         case "f64":
530                 return "float64"
531         }
532
533         // typedefs
534         typ, ok := ctx.types[vppType]
535         if ok {
536                 return typ
537         }
538
539         panic(fmt.Sprintf("Unknown VPP type %s", vppType))
540 }
541
542 // camelCaseName returns correct name identifier (camelCase).
543 func camelCaseName(name string) (should string) {
544         // Fast path for simple cases: "_" and all lowercase.
545         if name == "_" {
546                 return name
547         }
548         allLower := true
549         for _, r := range name {
550                 if !unicode.IsLower(r) {
551                         allLower = false
552                         break
553                 }
554         }
555         if allLower {
556                 return name
557         }
558
559         // Split camelCase at any lower->upper transition, and split on underscores.
560         // Check each word for common initialisms.
561         runes := []rune(name)
562         w, i := 0, 0 // index of start of word, scan
563         for i+1 <= len(runes) {
564                 eow := false // whether we hit the end of a word
565                 if i+1 == len(runes) {
566                         eow = true
567                 } else if runes[i+1] == '_' {
568                         // underscore; shift the remainder forward over any run of underscores
569                         eow = true
570                         n := 1
571                         for i+n+1 < len(runes) && runes[i+n+1] == '_' {
572                                 n++
573                         }
574
575                         // Leave at most one underscore if the underscore is between two digits
576                         if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
577                                 n--
578                         }
579
580                         copy(runes[i+1:], runes[i+n+1:])
581                         runes = runes[:len(runes)-n]
582                 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
583                         // lower->non-lower
584                         eow = true
585                 }
586                 i++
587                 if !eow {
588                         continue
589                 }
590
591                 // [w,i) is a word.
592                 word := string(runes[w:i])
593                 if u := strings.ToUpper(word); commonInitialisms[u] {
594                         // Keep consistent case, which is lowercase only at the start.
595                         if w == 0 && unicode.IsLower(runes[w]) {
596                                 u = strings.ToLower(u)
597                         }
598                         // All the common initialisms are ASCII,
599                         // so we can replace the bytes exactly.
600                         copy(runes[w:], []rune(u))
601                 } else if w > 0 && strings.ToLower(word) == word {
602                         // already all lowercase, and not the first word, so uppercase the first character.
603                         runes[w] = unicode.ToUpper(runes[w])
604                 }
605                 w = i
606         }
607         return string(runes)
608 }
609
610 // isNotSpace returns true if the rune is NOT a whitespace character.
611 func isNotSpace(r rune) bool {
612         return !unicode.IsSpace(r)
613 }
614
615 // commonInitialisms is a set of common initialisms that need to stay in upper case.
616 var commonInitialisms = map[string]bool{
617         "ACL":   true,
618         "API":   true,
619         "ASCII": true,
620         "CPU":   true,
621         "CSS":   true,
622         "DNS":   true,
623         "EOF":   true,
624         "GUID":  true,
625         "HTML":  true,
626         "HTTP":  true,
627         "HTTPS": true,
628         "ID":    true,
629         "IP":    true,
630         "ICMP":  true,
631         "JSON":  true,
632         "LHS":   true,
633         "QPS":   true,
634         "RAM":   true,
635         "RHS":   true,
636         "RPC":   true,
637         "SLA":   true,
638         "SMTP":  true,
639         "SQL":   true,
640         "SSH":   true,
641         "TCP":   true,
642         "TLS":   true,
643         "TTL":   true,
644         "UDP":   true,
645         "UI":    true,
646         "UID":   true,
647         "UUID":  true,
648         "URI":   true,
649         "URL":   true,
650         "UTF8":  true,
651         "VM":    true,
652         "XML":   true,
653         "XMPP":  true,
654         "XSRF":  true,
655         "XSS":   true,
656 }