Fix binapigen decoding and minor improvements
[govpp.git] / binapigen / binapigen.go
index 1b4c7e5..2dbd661 100644 (file)
@@ -17,6 +17,7 @@ package binapigen
 import (
        "fmt"
        "path"
+       "strconv"
        "strings"
 
        "git.fd.io/govpp.git/binapigen/vppapi"
@@ -154,14 +155,6 @@ func (file *File) dependsOnFile(gen *Generator, dep string) bool {
        return false
 }
 
-func normalizeImport(imp string) string {
-       imp = path.Base(imp)
-       if idx := strings.Index(imp, "."); idx >= 0 {
-               imp = imp[:idx]
-       }
-       return imp
-}
-
 const (
        enumFlagSuffix = "_flags"
 )
@@ -251,8 +244,7 @@ func newStruct(gen *Generator, file *File, apitype vppapi.StructType) *Struct {
        }
        gen.structsByName[typ.Name] = typ
        for _, fieldType := range apitype.Fields {
-               field := newField(gen, file, fieldType)
-               field.ParentStruct = typ
+               field := newField(gen, file, typ, fieldType)
                typ.Fields = append(typ.Fields, field)
        }
        return typ
@@ -285,8 +277,7 @@ func newUnion(gen *Generator, file *File, apitype vppapi.UnionType) *Union {
        }
        gen.unionsByName[typ.Name] = typ
        for _, fieldType := range apitype.Fields {
-               field := newField(gen, file, fieldType)
-               field.ParentUnion = typ
+               field := newField(gen, file, typ, fieldType)
                typ.Fields = append(typ.Fields, field)
        }
        return typ
@@ -311,25 +302,11 @@ const (
        msgTypeEvent                  // msg_id, client_index
 )
 
-func apiMsgType(t msgType) GoIdent {
-       switch t {
-       case msgTypeRequest:
-               return govppApiPkg.Ident("RequestMessage")
-       case msgTypeReply:
-               return govppApiPkg.Ident("ReplyMessage")
-       case msgTypeEvent:
-               return govppApiPkg.Ident("EventMessage")
-       default:
-               return govppApiPkg.Ident("OtherMessage")
-       }
-}
-
 // message fields
 const (
        fieldMsgID       = "_vl_msg_id"
        fieldClientIndex = "client_index"
        fieldContext     = "context"
-       fieldRetval      = "retval"
 )
 
 // field options
@@ -356,20 +333,17 @@ func newMessage(gen *Generator, file *File, apitype vppapi.Message) *Message {
                GoIdent: newGoIdent(file, apitype.Name),
        }
        gen.messagesByName[apitype.Name] = msg
-       n := 0
+       var n int
        for _, fieldType := range apitype.Fields {
-               // skip internal fields
-               switch strings.ToLower(fieldType.Name) {
-               case fieldMsgID:
-                       continue
-               case fieldClientIndex, fieldContext:
-                       if n == 0 {
+               if n == 0 {
+                       // skip header fields
+                       switch strings.ToLower(fieldType.Name) {
+                       case fieldMsgID, fieldClientIndex, fieldContext:
                                continue
                        }
                }
                n++
-               field := newField(gen, file, fieldType)
-               field.ParentMessage = msg
+               field := newField(gen, file, msg, fieldType)
                msg.Fields = append(msg.Fields, field)
        }
        return msg
@@ -389,16 +363,17 @@ func (m *Message) resolveDependencies(gen *Generator) (err error) {
 
 func getMsgType(m vppapi.Message) (msgType, error) {
        if len(m.Fields) == 0 {
-               return msgType(0), fmt.Errorf("message %s has no fields", m.Name)
+               return msgType(-1), fmt.Errorf("message %s has no fields", m.Name)
        }
-       typ := msgTypeBase
-       wasClientIndex := false
+       var typ msgType
+       var wasClientIndex bool
        for i, field := range m.Fields {
-               if i == 0 {
+               switch i {
+               case 0:
                        if field.Name != fieldMsgID {
-                               return msgType(0), fmt.Errorf("message %s is missing ID field", m.Name)
+                               return msgType(-1), fmt.Errorf("message %s is missing ID field", m.Name)
                        }
-               } else if i == 1 {
+               case 1:
                        if field.Name == fieldClientIndex {
                                // "client_index" as the second member,
                                // this might be an event message or a request
@@ -408,8 +383,8 @@ func getMsgType(m vppapi.Message) (msgType, error) {
                                // reply needs "context" as the second member
                                typ = msgTypeReply
                        }
-               } else if i == 2 {
-                       if wasClientIndex && field.Name == fieldContext {
+               case 2:
+                       if field.Name == fieldContext && wasClientIndex {
                                // request needs "client_index" as the second member
                                // and "context" as the third member
                                typ = msgTypeRequest
@@ -419,34 +394,50 @@ func getMsgType(m vppapi.Message) (msgType, error) {
        return typ, nil
 }
 
+// Field represents a field for message or struct/union types.
 type Field struct {
        vppapi.Field
 
        GoName string
 
+       // DefaultValue is a default value of field or
+       // nil if default value is not defined for field.
        DefaultValue interface{}
 
-       // Reference to actual type of this field
+       // Reference to actual type of this field.
+       //
+       // For fields with built-in types all of these are nil,
+       // otherwise only one is set to non-nil value.
        TypeEnum   *Enum
        TypeAlias  *Alias
        TypeStruct *Struct
        TypeUnion  *Union
 
-       // Parent in which this field is declared
+       // Parent in which this field is declared.
        ParentMessage *Message
        ParentStruct  *Struct
        ParentUnion   *Union
 
-       // Field reference for fields determining size
+       // Field reference for fields with variable size.
        FieldSizeOf   *Field
        FieldSizeFrom *Field
 }
 
-func newField(gen *Generator, file *File, apitype vppapi.Field) *Field {
+func newField(gen *Generator, file *File, parent interface{}, apitype vppapi.Field) *Field {
        typ := &Field{
                Field:  apitype,
                GoName: camelCaseName(apitype.Name),
        }
+       switch p := parent.(type) {
+       case *Message:
+               typ.ParentMessage = p
+       case *Struct:
+               typ.ParentStruct = p
+       case *Union:
+               typ.ParentUnion = p
+       default:
+               panic(fmt.Sprintf("invalid field parent: %T", parent))
+       }
        if apitype.Meta != nil {
                if val, ok := apitype.Meta[optFieldDefault]; ok {
                        typ.DefaultValue = val
@@ -578,3 +569,61 @@ func (rpc *RPC) resolveMessages(gen *Generator) error {
        }
        return nil
 }
+
+// GoIdent is a Go identifier, consisting of a name and import path.
+// The name is a single identifier and may not be a dot-qualified selector.
+type GoIdent struct {
+       GoName       string
+       GoImportPath GoImportPath
+}
+
+func (id GoIdent) String() string {
+       return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName)
+}
+
+func newGoIdent(f *File, fullName string) GoIdent {
+       name := strings.TrimPrefix(fullName, string(f.PackageName)+".")
+       return GoIdent{
+               GoName:       camelCaseName(name),
+               GoImportPath: f.GoImportPath,
+       }
+}
+
+// GoImportPath is a Go import path for a package.
+type GoImportPath string
+
+func (p GoImportPath) String() string {
+       return strconv.Quote(string(p))
+}
+
+func (p GoImportPath) Ident(s string) GoIdent {
+       return GoIdent{GoName: s, GoImportPath: p}
+}
+
+type GoPackageName string
+
+func cleanPackageName(name string) GoPackageName {
+       return GoPackageName(sanitizedName(name))
+}
+
+// baseName returns the last path element of the name, with the last dotted suffix removed.
+func baseName(name string) string {
+       // First, find the last element
+       if i := strings.LastIndex(name, "/"); i >= 0 {
+               name = name[i+1:]
+       }
+       // Now drop the suffix
+       if i := strings.LastIndex(name, "."); i >= 0 {
+               name = name[:i]
+       }
+       return name
+}
+
+// normalizeImport returns the last path element of the import, with all dotted suffixes removed.
+func normalizeImport(imp string) string {
+       imp = path.Base(imp)
+       if idx := strings.Index(imp, "."); idx >= 0 {
+               imp = imp[:idx]
+       }
+       return imp
+}