initial commit
[govpp.git] / binapi_generator / generator.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package main
16
17 import (
18         "bufio"
19         "bytes"
20         "encoding/json"
21         "errors"
22         "flag"
23         "fmt"
24         "io"
25         "io/ioutil"
26         "os"
27         "os/exec"
28         "path/filepath"
29         "strings"
30         "time"
31         "unicode"
32
33         "github.com/bennyscetbun/jsongo"
34 )
35
36 // MessageType represents the type of a VPP message.
37 type messageType int
38
39 const (
40         requestMessage messageType = iota // VPP request message
41         replyMessage                      // VPP reply message
42         otherMessage                      // other VPP message
43 )
44
45 const (
46         apiImportPath = "gerrit.fd.io/r/govpp/api" // import path of the govpp API
47         inputFileExt  = ".json"                                               // filename extension of files that should be processed as the input
48 )
49
50 // context is a structure storing details of a particular code generation task
51 type context struct {
52         inputFile   string            // file with input JSON data
53         inputBuff   *bytes.Buffer     // contents of the input file
54         inputLine   int               // currently processed line in the input file
55         outputFile  string            // file with output data
56         packageName string            // name of the Go package being generated
57         packageDir  string            // directory where the package source files are located
58         types       map[string]string // map of the VPP typedef names to generated Go typedef names
59 }
60
61 func main() {
62         inputFile := flag.String("input-file", "", "Input JSON file.")
63         inputDir := flag.String("input-dir", ".", "Input directory with JSON files.")
64         outputDir := flag.String("output-dir", ".", "Output directory where package folders will be generated.")
65         flag.Parse()
66
67         if *inputFile == "" && *inputDir == "" {
68                 fmt.Fprintln(os.Stderr, "ERROR: input-file or input-dir must be specified")
69                 os.Exit(1)
70         }
71
72         var err, tmpErr error
73         if *inputFile != "" {
74                 // process one input file
75                 err = generateFromFile(*inputFile, *outputDir)
76                 if err != nil {
77                         fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", *inputFile, err)
78                 }
79         } else {
80                 // process all files in specified directory
81                 files, err := getInputFiles(*inputDir)
82                 if err != nil {
83                         fmt.Fprintf(os.Stderr, "ERROR: code generation failed: %v\n", err)
84                 }
85                 for _, file := range files {
86                         tmpErr = generateFromFile(file, *outputDir)
87                         if tmpErr != nil {
88                                 fmt.Fprintf(os.Stderr, "ERROR: code generation from %s failed: %v\n", file, err)
89                                 err = tmpErr // remember that the error occurred
90                         }
91                 }
92         }
93         if err != nil {
94                 os.Exit(1)
95         }
96 }
97
98 // getInputFiles returns all input files located in specified directory
99 func getInputFiles(inputDir string) ([]string, error) {
100         files, err := ioutil.ReadDir(inputDir)
101         if err != nil {
102                 return nil, fmt.Errorf("reading directory %s failed: %v", inputDir, err)
103         }
104         res := make([]string, 0)
105         for _, f := range files {
106                 if strings.HasSuffix(f.Name(), inputFileExt) {
107                         res = append(res, inputDir+"/"+f.Name())
108                 }
109         }
110         return res, nil
111 }
112
113 // generateFromFile generates Go bindings from one input JSON file
114 func generateFromFile(inputFile, outputDir string) error {
115         ctx, err := getContext(inputFile, outputDir)
116         if err != nil {
117                 return err
118         }
119         // read the file
120         inputData, err := readFile(inputFile)
121         if err != nil {
122                 return err
123         }
124         ctx.inputBuff = bytes.NewBuffer(inputData)
125
126         // parse JSON
127         jsonRoot, err := parseJSON(inputData)
128         if err != nil {
129                 return err
130         }
131
132         // create output directory
133         err = os.MkdirAll(ctx.packageDir, 0777)
134         if err != nil {
135                 return fmt.Errorf("creating output directory %s failed: %v", ctx.packageDir, err)
136         }
137
138         // open output file
139         f, err := os.Create(ctx.outputFile)
140         defer f.Close()
141         if err != nil {
142                 return fmt.Errorf("creating output file %s failed: %v", ctx.outputFile, err)
143         }
144         w := bufio.NewWriter(f)
145
146         // generate Go package code
147         err = generatePackage(ctx, w, jsonRoot)
148         if err != nil {
149                 return err
150         }
151
152         // go format the output file (non-fatal if fails)
153         exec.Command("gofmt", "-w", ctx.outputFile).Run()
154
155         return nil
156 }
157
158 // getContext returns context details of the code generation task
159 func getContext(inputFile, outputDir string) (*context, error) {
160         if !strings.HasSuffix(inputFile, inputFileExt) {
161                 return nil, fmt.Errorf("invalid input file name %s", inputFile)
162         }
163
164         ctx := &context{inputFile: inputFile}
165         inputFileName := filepath.Base(inputFile)
166
167         ctx.packageName = inputFileName[0:strings.Index(inputFileName, ".")]
168         if ctx.packageName == "interface" {
169                 // 'interface' cannot be a package name, it is a go keyword
170                 ctx.packageName = "interfaces"
171         }
172
173         ctx.packageDir = outputDir + "/" + ctx.packageName + "/"
174         ctx.outputFile = ctx.packageDir + ctx.packageName + ".go"
175
176         return ctx, nil
177 }
178
179 // readFile reads content of a file into memory
180 func readFile(inputFile string) ([]byte, error) {
181
182         inputData, err := ioutil.ReadFile(inputFile)
183
184         if err != nil {
185                 return nil, fmt.Errorf("reading data from file failed: %v", err)
186         }
187
188         return inputData, nil
189 }
190
191 // parseJSON parses a JSON data into an in-memory tree
192 func parseJSON(inputData []byte) (*jsongo.JSONNode, error) {
193         root := jsongo.JSONNode{}
194
195         err := json.Unmarshal(inputData, &root)
196         if err != nil {
197                 return nil, fmt.Errorf("JSON unmarshall failed: %v", err)
198         }
199
200         return &root, nil
201
202 }
203
204 // generatePackage generates Go code of a package from provided JSON
205 func generatePackage(ctx *context, w *bufio.Writer, jsonRoot *jsongo.JSONNode) error {
206         // generate file header
207         generatePackageHeader(ctx, w, jsonRoot)
208
209         // generate data types
210         ctx.types = make(map[string]string)
211         types := jsonRoot.Map("types")
212         for i := 0; i < types.Len(); i++ {
213                 typ := types.At(i)
214                 err := generateMessage(ctx, w, typ, true)
215                 if err != nil {
216                         return err
217                 }
218         }
219
220         // generate messages
221         messages := jsonRoot.Map("messages")
222         for i := 0; i < messages.Len(); i++ {
223                 msg := messages.At(i)
224                 err := generateMessage(ctx, w, msg, false)
225                 if err != nil {
226                         return err
227                 }
228         }
229
230         // flush the data:
231         err := w.Flush()
232         if err != nil {
233                 return fmt.Errorf("flushing data to %s failed: %v", ctx.outputFile, err)
234         }
235
236         return nil
237 }
238
239 // generateMessage generates Go code of one VPP message encoded in JSON into provided writer
240 func generateMessage(ctx *context, w io.Writer, msg *jsongo.JSONNode, isType bool) error {
241         if msg.Len() == 0 || msg.At(0).GetType() != jsongo.TypeValue {
242                 return errors.New("invalid JSON for message specified")
243         }
244
245         msgName, ok := msg.At(0).Get().(string)
246         if !ok {
247                 return fmt.Errorf("invalid JSON for message specified, message name is %T, not a string", msg.At(0).Get())
248         }
249         structName := camelCaseName(strings.Title(msgName))
250
251         // generate struct fields into the slice & determine message type
252         fields := make([]string, 0)
253         msgType := otherMessage
254         for j := 0; j < msg.Len(); j++ {
255                 if jsongo.TypeArray == msg.At(j).GetType() {
256                         fld := msg.At(j)
257                         err := processMessageField(ctx, &fields, fld)
258                         if err != nil {
259                                 return err
260                         }
261                         // determine whether ths is a request / reply / other message
262                         if j == 2 {
263                                 fieldName, ok := fld.At(1).Get().(string)
264                                 if ok {
265                                         if fieldName == "client_index" {
266                                                 msgType = requestMessage
267                                         } else if fieldName == "context" {
268                                                 msgType = replyMessage
269                                         } else {
270                                                 msgType = otherMessage
271                                         }
272                                 }
273                         }
274                 }
275         }
276
277         // generate struct comment
278         generateMessageComment(ctx, w, structName, msgName, isType)
279
280         // generate struct header
281         fmt.Fprintln(w, "type", structName, "struct {")
282
283         // print out the fields
284         for _, field := range fields {
285                 fmt.Fprintln(w, field)
286         }
287
288         // generate end of the struct
289         fmt.Fprintln(w, "}")
290
291         // generate name getter
292         if isType {
293                 generateTypeNameGetter(w, structName, msgName)
294         } else {
295                 generateMessageNameGetter(w, structName, msgName)
296         }
297
298         // generate message type getter method
299         if !isType {
300                 generateMessageTypeGetter(w, structName, msgType)
301         }
302
303         // generate CRC getter
304         crcIf := msg.At(msg.Len() - 1).At("crc").Get()
305         if crc, ok := crcIf.(string); ok {
306                 generateCrcGetter(w, structName, crc)
307         }
308
309         // generate message factory
310         if !isType {
311                 generateMessageFactory(w, structName)
312         }
313
314         // if this is a type, save it in the map for later use
315         if isType {
316                 ctx.types[fmt.Sprintf("vl_api_%s_t", msgName)] = structName
317         }
318
319         return nil
320 }
321
322 // processMessageField process JSON describing one message field into Go code emitted into provided slice of message fields
323 func processMessageField(ctx *context, fields *[]string, fld *jsongo.JSONNode) error {
324         if fld.Len() < 2 || fld.At(0).GetType() != jsongo.TypeValue || fld.At(1).GetType() != jsongo.TypeValue {
325                 return errors.New("invalid JSON for message field specified")
326         }
327         fieldVppType, ok := fld.At(0).Get().(string)
328         if !ok {
329                 return fmt.Errorf("invalid JSON for message specified, field type is %T, not a string", fld.At(0).Get())
330         }
331         fieldName, ok := fld.At(1).Get().(string)
332         if !ok {
333                 return fmt.Errorf("invalid JSON for message specified, field name is %T, not a string", fld.At(1).Get())
334         }
335
336         // skip internal fields
337         fieldNameLower := strings.ToLower(fieldName)
338         if fieldNameLower == "crc" || fieldNameLower == "_vl_msg_id" {
339                 return nil
340         }
341         if len(*fields) == 0 && (fieldNameLower == "client_index" || fieldNameLower == "context") {
342                 return nil
343         }
344
345         fieldName = strings.TrimPrefix(fieldName, "_")
346         fieldName = camelCaseName(strings.Title(fieldName))
347
348         fieldStr := ""
349         isArray := false
350         arraySize := 0
351
352         fieldStr += "\t" + fieldName + " "
353         if fld.Len() > 2 {
354                 isArray = true
355                 arraySize = int(fld.At(2).Get().(float64))
356                 fieldStr += "[]"
357         }
358
359         dataType := translateVppType(ctx, fieldVppType, isArray)
360         fieldStr += dataType
361
362         if isArray {
363                 if arraySize == 0 {
364                         // variable sized array
365                         if fld.Len() > 3 {
366                                 // array size is specified by another field
367                                 arraySizeField := string(fld.At(3).Get().(string))
368                                 arraySizeField = camelCaseName(strings.Title(arraySizeField))
369                                 // find & update the field that specifies the array size
370                                 for i, f := range *fields {
371                                         if strings.Contains(f, fmt.Sprintf("\t%s ", arraySizeField)) {
372                                                 (*fields)[i] += fmt.Sprintf("\t`struc:\"sizeof=%s\"`", fieldName)
373                                         }
374                                 }
375                         }
376                 } else {
377                         // fixed size array
378                         fieldStr += fmt.Sprintf("\t`struc:\"[%d]%s\"`", arraySize, dataType)
379                 }
380         }
381
382         *fields = append(*fields, fieldStr)
383         return nil
384 }
385
386 // generatePackageHeader generates package header into provider writer
387 func generatePackageHeader(ctx *context, w io.Writer, rootNode *jsongo.JSONNode) {
388         fmt.Fprintln(w, "// Package "+ctx.packageName+" represents the VPP binary API of the '"+ctx.packageName+"' VPP module.")
389         fmt.Fprintln(w, "// DO NOT EDIT. Generated from '"+ctx.inputFile+"' on "+time.Now().Format(time.RFC1123)+".")
390
391         fmt.Fprintln(w, "package "+ctx.packageName)
392
393         fmt.Fprintln(w, "import \""+apiImportPath+"\"")
394
395         fmt.Fprintln(w)
396         fmt.Fprintln(w, "// VlApiVersion contains version of the API.")
397         vlAPIVersion := rootNode.Map("vl_api_version")
398         if vlAPIVersion != nil {
399                 fmt.Fprintln(w, "const VlAPIVersion = ", vlAPIVersion.Get())
400         }
401         fmt.Fprintln(w)
402 }
403
404 // generateMessageComment generates comment for a message into provider writer
405 func generateMessageComment(ctx *context, w io.Writer, structName string, msgName string, isType bool) {
406         fmt.Fprintln(w)
407         if isType {
408                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API data type '"+msgName+"'.")
409         } else {
410                 fmt.Fprintln(w, "// "+structName+" represents the VPP binary API message '"+msgName+"'.")
411         }
412
413         // print out the source of the generated message - the JSON
414         msgFound := false
415         for {
416                 lineBuff, err := ctx.inputBuff.ReadBytes('\n')
417                 if err != nil {
418                         break
419                 }
420                 ctx.inputLine++
421                 line := string(lineBuff)
422
423                 if !msgFound {
424                         if strings.Contains(line, msgName) {
425                                 fmt.Fprintf(w, "// Generated from '%s', line %d:\n", ctx.inputFile, ctx.inputLine)
426                                 fmt.Fprintln(w, "//")
427                                 fmt.Fprint(w, "//", line)
428                                 msgFound = true
429                         }
430                 } else {
431                         fmt.Fprint(w, "//", line)
432                         if len(strings.Trim(line, " ")) < 4 {
433                                 break // end of the message in JSON
434                         }
435                 }
436         }
437         fmt.Fprintln(w, "//")
438 }
439
440 // generateMessageNameGetter generates getter for original VPP message name into the provider writer
441 func generateMessageNameGetter(w io.Writer, structName string, msgName string) {
442         fmt.Fprintln(w, "func (*"+structName+") GetMessageName() string {")
443         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
444         fmt.Fprintln(w, "}")
445 }
446
447 // generateTypeNameGetter generates getter for original VPP type name into the provider writer
448 func generateTypeNameGetter(w io.Writer, structName string, msgName string) {
449         fmt.Fprintln(w, "func (*"+structName+") GetTypeName() string {")
450         fmt.Fprintln(w, "\treturn \""+msgName+"\"")
451         fmt.Fprintln(w, "}")
452 }
453
454 // generateMessageTypeGetter generates message factory for the generated message into the provider writer
455 func generateMessageTypeGetter(w io.Writer, structName string, msgType messageType) {
456         fmt.Fprintln(w, "func (*"+structName+") GetMessageType() api.MessageType {")
457         if msgType == requestMessage {
458                 fmt.Fprintln(w, "\treturn api.RequestMessage")
459         } else if msgType == replyMessage {
460                 fmt.Fprintln(w, "\treturn api.ReplyMessage")
461         } else {
462                 fmt.Fprintln(w, "\treturn api.OtherMessage")
463         }
464         fmt.Fprintln(w, "}")
465 }
466
467 // generateCrcGetter generates getter for CRC checksum of the message definition into the provider writer
468 func generateCrcGetter(w io.Writer, structName string, crc string) {
469         crc = strings.TrimPrefix(crc, "0x")
470         fmt.Fprintln(w, "func (*"+structName+") GetCrcString() string {")
471         fmt.Fprintln(w, "\treturn \""+crc+"\"")
472         fmt.Fprintln(w, "}")
473 }
474
475 // generateMessageFactory generates message factory for the generated message into the provider writer
476 func generateMessageFactory(w io.Writer, structName string) {
477         fmt.Fprintln(w, "func New"+structName+"() api.Message {")
478         fmt.Fprintln(w, "\treturn &"+structName+"{}")
479         fmt.Fprintln(w, "}")
480 }
481
482 // translateVppType translates the VPP data type into Go data type
483 func translateVppType(ctx *context, vppType string, isArray bool) string {
484         // basic types
485         switch vppType {
486         case "u8":
487                 if isArray {
488                         return "byte"
489                 }
490                 return "uint8"
491         case "i8":
492                 return "int8"
493         case "u16":
494                 return "uint16"
495         case "i16":
496                 return "int16"
497         case "u32":
498                 return "uint32"
499         case "i32":
500                 return "int32"
501         case "u64":
502                 return "uint64"
503         case "i64":
504                 return "int64"
505         case "f64":
506                 return "float64"
507         }
508
509         // typedefs
510         typ, ok := ctx.types[vppType]
511         if ok {
512                 return typ
513         }
514
515         panic(fmt.Sprintf("Unknown VPP type %s", vppType))
516 }
517
518 // camelCaseName returns correct name identifier (camelCase).
519 func camelCaseName(name string) (should string) {
520         // Fast path for simple cases: "_" and all lowercase.
521         if name == "_" {
522                 return name
523         }
524         allLower := true
525         for _, r := range name {
526                 if !unicode.IsLower(r) {
527                         allLower = false
528                         break
529                 }
530         }
531         if allLower {
532                 return name
533         }
534
535         // Split camelCase at any lower->upper transition, and split on underscores.
536         // Check each word for common initialisms.
537         runes := []rune(name)
538         w, i := 0, 0 // index of start of word, scan
539         for i+1 <= len(runes) {
540                 eow := false // whether we hit the end of a word
541                 if i+1 == len(runes) {
542                         eow = true
543                 } else if runes[i+1] == '_' {
544                         // underscore; shift the remainder forward over any run of underscores
545                         eow = true
546                         n := 1
547                         for i+n+1 < len(runes) && runes[i+n+1] == '_' {
548                                 n++
549                         }
550
551                         // Leave at most one underscore if the underscore is between two digits
552                         if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
553                                 n--
554                         }
555
556                         copy(runes[i+1:], runes[i+n+1:])
557                         runes = runes[:len(runes)-n]
558                 } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
559                         // lower->non-lower
560                         eow = true
561                 }
562                 i++
563                 if !eow {
564                         continue
565                 }
566
567                 // [w,i) is a word.
568                 word := string(runes[w:i])
569                 if u := strings.ToUpper(word); commonInitialisms[u] {
570                         // Keep consistent case, which is lowercase only at the start.
571                         if w == 0 && unicode.IsLower(runes[w]) {
572                                 u = strings.ToLower(u)
573                         }
574                         // All the common initialisms are ASCII,
575                         // so we can replace the bytes exactly.
576                         copy(runes[w:], []rune(u))
577                 } else if w > 0 && strings.ToLower(word) == word {
578                         // already all lowercase, and not the first word, so uppercase the first character.
579                         runes[w] = unicode.ToUpper(runes[w])
580                 }
581                 w = i
582         }
583         return string(runes)
584 }
585
586 // commonInitialisms is a set of common initialisms that need to stay in upper case.
587 var commonInitialisms = map[string]bool{
588         "ACL":   true,
589         "API":   true,
590         "ASCII": true,
591         "CPU":   true,
592         "CSS":   true,
593         "DNS":   true,
594         "EOF":   true,
595         "GUID":  true,
596         "HTML":  true,
597         "HTTP":  true,
598         "HTTPS": true,
599         "ID":    true,
600         "IP":    true,
601         "ICMP":  true,
602         "JSON":  true,
603         "LHS":   true,
604         "QPS":   true,
605         "RAM":   true,
606         "RHS":   true,
607         "RPC":   true,
608         "SLA":   true,
609         "SMTP":  true,
610         "SQL":   true,
611         "SSH":   true,
612         "TCP":   true,
613         "TLS":   true,
614         "TTL":   true,
615         "UDP":   true,
616         "UI":    true,
617         "UID":   true,
618         "UUID":  true,
619         "URI":   true,
620         "URL":   true,
621         "UTF8":  true,
622         "VM":    true,
623         "XML":   true,
624         "XMPP":  true,
625         "XSRF":  true,
626         "XSS":   true,
627 }