22fe3e1c32eb2b1652c37dab8c7ea1246e1fa0fa
[govpp.git] / cmd / binapi-generator / generator.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "encoding/json"
21         "errors"
22         "flag"
23         "fmt"
24         "io"
25         "io/ioutil"
26         "os"
27         "os/exec"
28         "path/filepath"
29         "strings"
30         "unicode"
31
32         "github.com/bennyscetbun/jsongo"
33 )
34
35 // MessageType represents the type of a VPP message.
36 type messageType int
37
38 const (
39         requestMessage messageType = iota // VPP request message
40         replyMessage                      // VPP reply message
41         otherMessage                      // other VPP message
42 )
43
44 const (
45         apiImportPath = "git.fd.io/govpp.git/api" // import path of the govpp API
46         inputFileExt  = ".json"                   // filename extension of files that should be processed as the input
47 )
48
49 // context is a structure storing details of a particular code generation task
50 type context struct {
51         inputFile   string            // file with input JSON data
52         inputBuff   *bytes.Buffer     // contents of the input file
53         inputLine   int               // currently processed line in the input file
54         outputFile  string            // file with output data
55         packageName string            // name of the Go package being generated
56         packageDir  string            // directory where the package source files are located
57         types       map[string]string // map of the VPP typedef names to generated Go typedef names
58 }
59
60 func main() {
61         inputFile := flag.String("input-file", "", "Input JSON file.")
62         inputDir := flag.String("input-dir", ".", "Input directory with JSON files.")
63         outputDir := flag.String("output-dir", ".", "Output directory where package folders will be generated.")
64         flag.Parse()
65
66         if *inputFile == "" && *inputDir == "" {
67                 fmt.Fprintln(os.Stderr, "ERROR: input-file or input-dir must be specified")
68                 os.Exit(1)
69         }
70
71         var err, tmpErr error
72         if *inputFile != "" {
73                 // process one input file
74                 err = generateFromFile(*inputFile, *outputDir)
75                 if err != nil {
76                         fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", *inputFile, err)
77                 }
78         } else {
79                 // process all files in specified directory
80                 files, err := getInputFiles(*inputDir)
81                 if err != nil {
82                         fmt.Fprintf(os.Stderr, "ERROR: code generation failed: %v\n", err)
83                 }
84                 for _, file := range files {
85                         tmpErr = generateFromFile(file, *outputDir)
86                         if tmpErr != nil {
87                                 fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", file, err)
88                                 err = tmpErr // remember that the error occurred
89                         }
90                 }
91         }
92         if err != nil {
93                 os.Exit(1)
94         }
95 }
96
97 // getInputFiles returns all input files located in specified directory
98 func getInputFiles(inputDir string) ([]string, error) {
99         files, err := ioutil.ReadDir(inputDir)
100         if err != nil {
101                 return nil, fmt.Errorf("reading directory %s failed: %v", inputDir, err)
102         }
103         res := make([]string, 0)
104         for _, f := range files {
105                 if strings.HasSuffix(f.Name(), inputFileExt) {
106                         res = append(res, inputDir+"/"+f.Name())
107                 }
108         }
109         return res, nil
110 }
111
112 // generateFromFile generates Go bindings from one input JSON file
113 func generateFromFile(inputFile, outputDir string) error {
114         ctx, err := getContext(inputFile, outputDir)
115         if err != nil {
116                 return err
117         }
118         // read the file
119         inputData, err := readFile(inputFile)
120         if err != nil {
121                 return err
122         }
123         ctx.inputBuff = bytes.NewBuffer(inputData)
124
125         // parse JSON
126         jsonRoot, err := parseJSON(inputData)
127         if err != nil {
128                 return err
129         }
130
131         // create output directory
132         err = os.MkdirAll(ctx.packageDir, 0777)
133         if err != nil {
134                 return fmt.Errorf("creating output directory %s failed: %v", ctx.packageDir, err)
135         }
136
137         // open output file
138         f, err := os.Create(ctx.outputFile)
139         defer f.Close()
140         if err != nil {
141                 return fmt.Errorf("creating output file %s failed: %v", ctx.outputFile, err)
142         }
143         w := bufio.NewWriter(f)
144
145         // generate Go package code
146         err = generatePackage(ctx, w, jsonRoot)
147         if err != nil {
148                 return err
149         }
150
151         // go format the output file (non-fatal if fails)
152         exec.Command("gofmt", "-w", ctx.outputFile).Run()
153
154         return nil
155 }
156
157 // getContext returns context details of the code generation task
158 func getContext(inputFile, outputDir string) (*context, error) {
159         if !strings.HasSuffix(inputFile, inputFileExt) {
160                 return nil, fmt.Errorf("invalid input file name %s", inputFile)
161         }
162
163         ctx := &context{inputFile: inputFile}
164         inputFileName := filepath.Base(inputFile)
165
166         ctx.packageName = inputFileName[0:strings.Index(inputFileName, ".")]
167         if ctx.packageName == "interface" {
168                 // 'interface' cannot be a package name, it is a go keyword
169                 ctx.packageName = "interfaces"
170         }
171
172         ctx.packageDir = outputDir + "/" + ctx.packageName + "/"
173         ctx.outputFile = ctx.packageDir + ctx.packageName + ".go"
174
175         return ctx, nil
176 }
177
178 // readFile reads content of a file into memory
179 func readFile(inputFile string) ([]byte, error) {
180
181         inputData, err := ioutil.ReadFile(inputFile)
182
183         if err != nil {
184                 return nil, fmt.Errorf("reading data from file failed: %v", err)
185         }
186
187         return inputData, nil
188 }
189
190 // parseJSON parses a JSON data into an in-memory tree
191 func parseJSON(inputData []byte) (*jsongo.JSONNode, error) {
192         root := jsongo.JSONNode{}
193
194         err := json.Unmarshal(inputData, &root)
195         if err != nil {
196                 return nil, fmt.Errorf("JSON unmarshall failed: %v", err)
197         }
198
199         return &root, nil
200
201 }
202
203 // generatePackage generates Go code of a package from provided JSON
204 func generatePackage(ctx *context, w *bufio.Writer, jsonRoot *jsongo.JSONNode) error {
205         // generate file header
206         generatePackageHeader(ctx, w, jsonRoot)
207
208         // generate data types
209         ctx.types = make(map[string]string)
210         types := jsonRoot.Map("types")
211         for i := 0; i < types.Len(); i++ {
212                 typ := types.At(i)
213                 err := generateMessage(ctx, w, typ, true)
214                 if err != nil {
215                         return err
216                 }
217         }
218
219         // generate messages
220         messages := jsonRoot.Map("messages")
221         for i := 0; i < messages.Len(); i++ {
222                 msg := messages.At(i)
223                 err := generateMessage(ctx, w, msg, false)
224                 if err != nil {
225                         return err
226                 }
227         }
228
229         // flush the data:
230         err := w.Flush()
231         if err != nil {
232                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
233         }
234
235         return nil
236 }
237
238 // generateMessage generates Go code of one VPP message encoded in JSON into provided writer
239 func generateMessage(ctx *context, w io.Writer, msg *jsongo.JSONNode, isType bool) error {
240         if msg.Len() == 0 || msg.At(0).GetType() != jsongo.TypeValue {
241                 return errors.New("invalid JSON for message specified")
242         }
243
244         msgName, ok := msg.At(0).Get().(string)
245         if !ok {
246                 return fmt.Errorf("invalid JSON for message specified, message name is %T, not a string", msg.At(0).Get())
247         }
248         structName := camelCaseName(strings.Title(msgName))
249
250         // generate struct fields into the slice & determine message type
251         fields := make([]string, 0)
252         msgType := otherMessage
253         for j := 0; j < msg.Len(); j++ {
254                 if jsongo.TypeArray == msg.At(j).GetType() {
255                         fld := msg.At(j)
256                         err := processMessageField(ctx, &fields, fld)
257                         if err != nil {
258                                 return err
259                         }
260                         // determine whether ths is a request / reply / other message
261                         if j == 2 {
262                                 fieldName, ok := fld.At(1).Get().(string)
263                                 if ok {
264                                         if fieldName == "client_index" {
265                                                 msgType = requestMessage
266                                         } else if fieldName == "context" {
267                                                 msgType = replyMessage
268                                         } else {
269                                                 msgType = otherMessage
270                                         }
271                                 }
272                         }
273                 }
274         }
275
276         // generate struct comment
277         generateMessageComment(ctx, w, structName, msgName, isType)
278
279         // generate struct header
280         fmt.Fprintln(w, "type", structName, "struct {")
281
282         // print out the fields
283         for _, field := range fields {
284                 fmt.Fprintln(w, field)
285         }
286
287         // generate end of the struct
288         fmt.Fprintln(w, "}")
289
290         // generate name getter
291         if isType {
292                 generateTypeNameGetter(w, structName, msgName)
293         } else {
294                 generateMessageNameGetter(w, structName, msgName)
295         }
296
297         // generate message type getter method
298         if !isType {
299                 generateMessageTypeGetter(w, structName, msgType)
300         }
301
302         // generate CRC getter
303         crcIf := msg.At(msg.Len() - 1).At("crc").Get()
304         if crc, ok := crcIf.(string); ok {
305                 generateCrcGetter(w, structName, crc)
306         }
307
308         // generate message factory
309         if !isType {
310                 generateMessageFactory(w, structName)
311         }
312
313         // if this is a type, save it in the map for later use
314         if isType {
315                 ctx.types[fmt.Sprintf("vl_api_%s_t", msgName)] = structName
316         }
317
318         return nil
319 }
320
321 // processMessageField process JSON describing one message field into Go code emitted into provided slice of message fields
322 func processMessageField(ctx *context, fields *[]string, fld *jsongo.JSONNode) error {
323         if fld.Len() < 2 || fld.At(0).GetType() != jsongo.TypeValue || fld.At(1).GetType() != jsongo.TypeValue {
324                 return errors.New("invalid JSON for message field specified")
325         }
326         fieldVppType, ok := fld.At(0).Get().(string)
327         if !ok {
328                 return fmt.Errorf("invalid JSON for message specified, field type is %T, not a string", fld.At(0).Get())
329         }
330         fieldName, ok := fld.At(1).Get().(string)
331         if !ok {
332                 return fmt.Errorf("invalid JSON for message specified, field name is %T, not a string", fld.At(1).Get())
333         }
334
335         // skip internal fields
336         fieldNameLower := strings.ToLower(fieldName)
337         if fieldNameLower == "crc" || fieldNameLower == "_vl_msg_id" {
338                 return nil
339         }
340         if len(*fields) == 0 && (fieldNameLower == "client_index" || fieldNameLower == "context") {
341                 return nil
342         }
343
344         fieldName = strings.TrimPrefix(fieldName, "_")
345         fieldName = camelCaseName(strings.Title(fieldName))
346
347         fieldStr := ""
348         isArray := false
349         arraySize := 0
350
351         fieldStr += "\t" + fieldName + " "
352         if fld.Len() > 2 {
353                 isArray = true
354                 arraySize = int(fld.At(2).Get().(float64))
355                 fieldStr += "[]"
356         }
357
358         dataType := translateVppType(ctx, fieldVppType, isArray)
359         fieldStr += dataType
360
361         if isArray {
362                 if arraySize == 0 {
363                         // variable sized array
364                         if fld.Len() > 3 {
365                                 // array size is specified by another field
366                                 arraySizeField := string(fld.At(3).Get().(string))
367                                 arraySizeField = camelCaseName(strings.Title(arraySizeField))
368                                 // find & update the field that specifies the array size
369                                 for i, f := range *fields {
370                                         if strings.Contains(f, fmt.Sprintf("\t%s ", arraySizeField)) {
371                                                 (*fields)[i] += fmt.Sprintf("\t`struc:\"sizeof=%s\"`", fieldName)
372                                         }
373                                 }
374                         }
375                 } else {
376                         // fixed size array
377                         fieldStr += fmt.Sprintf("\t`struc:\"[%d]%s\"`", arraySize, dataType)
378                 }
379         }
380
381         *fields = append(*fields, fieldStr)
382         return nil
383 }
384
385 // generatePackageHeader generates package header into provider writer
386 func generatePackageHeader(ctx *context, w io.Writer, rootNode *jsongo.JSONNode) {
387         fmt.Fprintln(w, "// Package "+ctx.packageName+" represents the VPP binary API of the '"+ctx.packageName+"' VPP module.")
388         fmt.Fprintln(w, "// DO NOT EDIT. Generated from '"+ctx.inputFile+"'")
389
390         fmt.Fprintln(w, "package "+ctx.packageName)
391
392         fmt.Fprintln(w, "import \""+apiImportPath+"\"")
393
394         fmt.Fprintln(w)
395         fmt.Fprintln(w, "// VlApiVersion contains version of the API.")
396         vlAPIVersion := rootNode.Map("vl_api_version")
397         if vlAPIVersion != nil {
398                 fmt.Fprintln(w, "const VlAPIVersion = ", vlAPIVersion.Get())
399         }
400         fmt.Fprintln(w)
401 }
402
403 // generateMessageComment generates comment for a message into provider writer
404 func generateMessageComment(ctx *context, w io.Writer, structName string, msgName string, isType bool) {
405         fmt.Fprintln(w)
406         if isType {
407                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API data type '"+msgName+"'.")
408         } else {
409                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API message '"+msgName+"'.")
410         }
411
412         // print out the source of the generated message - the JSON
413         msgFound := false
414         for {
415                 lineBuff, err := ctx.inputBuff.ReadBytes('\n')
416                 if err != nil {
417                         break
418                 }
419                 ctx.inputLine++
420                 line := string(lineBuff)
421
422                 if !msgFound {
423                         if strings.Contains(line, msgName) {
424                                 fmt.Fprintf(w, "// Generated from '%s', line %d:\n", ctx.inputFile, ctx.inputLine)
425                                 fmt.Fprintln(w, "//")
426                                 fmt.Fprint(w, "//", line)
427                                 msgFound = true
428                         }
429                 } else {
430                         fmt.Fprint(w, "//", line)
431                         if len(strings.Trim(line, " ")) < 4 {
432                                 break // end of the message in JSON
433                         }
434                 }
435         }
436         fmt.Fprintln(w, "//")
437 }
438
439 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
440 func generateMessageNameGetter(w io.Writer, structName string, msgName string) {
441         fmt.Fprintln(w, "func (*"+structName+") GetMessageName() string {")
442         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
443         fmt.Fprintln(w, "}")
444 }
445
446 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
447 func generateTypeNameGetter(w io.Writer, structName string, msgName string) {
448         fmt.Fprintln(w, "func (*"+structName+") GetTypeName() string {")
449         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
450         fmt.Fprintln(w, "}")
451 }
452
453 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
454 func generateMessageTypeGetter(w io.Writer, structName string, msgType messageType) {
455         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
456         if msgType == requestMessage {
457                 fmt.Fprintln(w, "\treturn api.RequestMessage")
458         } else if msgType == replyMessage {
459                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
460         } else {
461                 fmt.Fprintln(w, "\treturn api.OtherMessage")
462         }
463         fmt.Fprintln(w, "}")
464 }
465
466 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
467 func generateCrcGetter(w io.Writer, structName string, crc string) {
468         crc = strings.TrimPrefix(crc, "0x")
469         fmt.Fprintln(w, "func (*"+structName+") GetCrcString() string {")
470         fmt.Fprintln(w, "\treturn \""+crc+"\"")
471         fmt.Fprintln(w, "}")
472 }
473
474 // generateMessageFactory generates message factory for the generated message into the provider writer
475 func generateMessageFactory(w io.Writer, structName string) {
476         fmt.Fprintln(w, "func New"+structName+"() api.Message {")
477         fmt.Fprintln(w, "\treturn &"+structName+"{}")
478         fmt.Fprintln(w, "}")
479 }
480
481 // translateVppType translates the VPP data type into Go data type
482 func translateVppType(ctx *context, vppType string, isArray bool) string {
483         // basic types
484         switch vppType {
485         case "u8":
486                 if isArray {
487                         return "byte"
488                 }
489                 return "uint8"
490         case "i8":
491                 return "int8"
492         case "u16":
493                 return "uint16"
494         case "i16":
495                 return "int16"
496         case "u32":
497                 return "uint32"
498         case "i32":
499                 return "int32"
500         case "u64":
501                 return "uint64"
502         case "i64":
503                 return "int64"
504         case "f64":
505                 return "float64"
506         }
507
508         // typedefs
509         typ, ok := ctx.types[vppType]
510         if ok {
511                 return typ
512         }
513
514         panic(fmt.Sprintf("Unknown VPP type %s", vppType))
515 }
516
517 // camelCaseName returns correct name identifier (camelCase).
518 func camelCaseName(name string) (should string) {
519         // Fast path for simple cases: "_" and all lowercase.
520         if name == "_" {
521                 return name
522         }
523         allLower := true
524         for _, r := range name {
525                 if !unicode.IsLower(r) {
526                         allLower = false
527                         break
528                 }
529         }
530         if allLower {
531                 return name
532         }
533
534         // Split camelCase at any lower->upper transition, and split on underscores.
535         // Check each word for common initialisms.
536         runes := []rune(name)
537         w, i := 0, 0 // index of start of word, scan
538         for i+1 <= len(runes) {
539                 eow := false // whether we hit the end of a word
540                 if i+1 == len(runes) {
541                         eow = true
542                 } else if runes[i+1] == '_' {
543                         // underscore; shift the remainder forward over any run of underscores
544                         eow = true
545                         n := 1
546                         for i+n+1 < len(runes) && runes[i+n+1] == '_' {
547                                 n++
548                         }
549
550                         // Leave at most one underscore if the underscore is between two digits
551                         if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
552                                 n--
553                         }
554
555                         copy(runes[i+1:], runes[i+n+1:])
556                         runes = runes[:len(runes)-n]
557                 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
558                         // lower->non-lower
559                         eow = true
560                 }
561                 i++
562                 if !eow {
563                         continue
564                 }
565
566                 // [w,i) is a word.
567                 word := string(runes[w:i])
568                 if u := strings.ToUpper(word); commonInitialisms[u] {
569                         // Keep consistent case, which is lowercase only at the start.
570                         if w == 0 && unicode.IsLower(runes[w]) {
571                                 u = strings.ToLower(u)
572                         }
573                         // All the common initialisms are ASCII,
574                         // so we can replace the bytes exactly.
575                         copy(runes[w:], []rune(u))
576                 } else if w > 0 && strings.ToLower(word) == word {
577                         // already all lowercase, and not the first word, so uppercase the first character.
578                         runes[w] = unicode.ToUpper(runes[w])
579                 }
580                 w = i
581         }
582         return string(runes)
583 }
584
585 // commonInitialisms is a set of common initialisms that need to stay in upper case.
586 var commonInitialisms = map[string]bool{
587         "ACL":   true,
588         "API":   true,
589         "ASCII": true,
590         "CPU":   true,
591         "CSS":   true,
592         "DNS":   true,
593         "EOF":   true,
594         "GUID":  true,
595         "HTML":  true,
596         "HTTP":  true,
597         "HTTPS": true,
598         "ID":    true,
599         "IP":    true,
600         "ICMP":  true,
601         "JSON":  true,
602         "LHS":   true,
603         "QPS":   true,
604         "RAM":   true,
605         "RHS":   true,
606         "RPC":   true,
607         "SLA":   true,
608         "SMTP":  true,
609         "SQL":   true,
610         "SSH":   true,
611         "TCP":   true,
612         "TLS":   true,
613         "TTL":   true,
614         "UDP":   true,
615         "UI":    true,
616         "UID":   true,
617         "UUID":  true,
618         "URI":   true,
619         "URL":   true,
620         "UTF8":  true,
621         "VM":    true,
622         "XML":   true,
623         "XMPP":  true,
624         "XSRF":  true,
625         "XSS":   true,
626 }