Improve binapi generator
[govpp.git] / binapigen / binapigen.go
index c5a976b..1b4c7e5 100644 (file)
@@ -17,85 +17,123 @@ package binapigen
 import (
        "fmt"
        "path"
-       "sort"
        "strings"
 
        "git.fd.io/govpp.git/binapigen/vppapi"
 )
 
+// generatedCodeVersion indicates a version of the generated code.
+// It is incremented whenever an incompatibility between the generated code and
+// GoVPP api package is introduced; the generated code references
+// a constant, api.GoVppAPIPackageIsVersionN (where N is generatedCodeVersion).
+const generatedCodeVersion = 2
+
+// file options
+const (
+       optFileVersion = "version"
+)
+
 type File struct {
-       vppapi.File
+       Desc vppapi.File
 
-       Generate bool
+       Generate       bool
+       FilenamePrefix string
+       PackageName    GoPackageName
+       GoImportPath   GoImportPath
 
-       PackageName string
-       Imports     []string
+       Version string
+       Imports []string
 
-       Enums    []*Enum
-       Unions   []*Union
-       Structs  []*Struct
-       Aliases  []*Alias
-       Messages []*Message
+       Enums   []*Enum
+       Unions  []*Union
+       Structs []*Struct
+       Aliases []*Alias
 
-       imports map[string]string
-       refmap  map[string]string
+       Messages []*Message
+       Service  *Service
 }
 
-func newFile(gen *Generator, apifile *vppapi.File) (*File, error) {
+func newFile(gen *Generator, apifile *vppapi.File, packageName GoPackageName, importPath GoImportPath) (*File, error) {
        file := &File{
-               File:        *apifile,
-               PackageName: sanitizedName(apifile.Name),
-               imports:     make(map[string]string),
-               refmap:      make(map[string]string),
+               Desc:         *apifile,
+               PackageName:  packageName,
+               GoImportPath: importPath,
+       }
+       if apifile.Options != nil {
+               file.Version = apifile.Options[optFileVersion]
        }
 
-       sortFileObjects(&file.File)
+       file.FilenamePrefix = path.Join(gen.opts.OutputDir, file.Desc.Name)
 
        for _, imp := range apifile.Imports {
                file.Imports = append(file.Imports, normalizeImport(imp))
        }
-       for _, enum := range apifile.EnumTypes {
-               file.Enums = append(file.Enums, newEnum(gen, file, enum))
+
+       for _, enumType := range apifile.EnumTypes {
+               file.Enums = append(file.Enums, newEnum(gen, file, enumType))
        }
-       for _, alias := range apifile.AliasTypes {
-               file.Aliases = append(file.Aliases, newAlias(gen, file, alias))
+       for _, aliasType := range apifile.AliasTypes {
+               file.Aliases = append(file.Aliases, newAlias(gen, file, aliasType))
        }
        for _, structType := range apifile.StructTypes {
                file.Structs = append(file.Structs, newStruct(gen, file, structType))
        }
-       for _, union := range apifile.UnionTypes {
-               file.Unions = append(file.Unions, newUnion(gen, file, union))
+       for _, unionType := range apifile.UnionTypes {
+               file.Unions = append(file.Unions, newUnion(gen, file, unionType))
        }
+
        for _, msg := range apifile.Messages {
                file.Messages = append(file.Messages, newMessage(gen, file, msg))
        }
+       if apifile.Service != nil {
+               file.Service = newService(gen, file, *apifile.Service)
+       }
+
+       for _, t := range file.Aliases {
+               if err := t.resolveDependencies(gen); err != nil {
+                       return nil, err
+               }
+       }
+       for _, t := range file.Structs {
+               if err := t.resolveDependencies(gen); err != nil {
+                       return nil, err
+               }
+       }
+       for _, t := range file.Unions {
+               if err := t.resolveDependencies(gen); err != nil {
+                       return nil, err
+               }
+       }
+       for _, m := range file.Messages {
+               if err := m.resolveDependencies(gen); err != nil {
+                       return nil, err
+               }
+       }
+       if file.Service != nil {
+               for _, rpc := range file.Service.RPCs {
+                       if err := rpc.resolveMessages(gen); err != nil {
+                               return nil, err
+                       }
+               }
+       }
 
        return file, nil
 }
 
-func (file *File) isTypes() bool {
-       return strings.HasSuffix(file.File.Name, "_types")
+func (file *File) isTypesFile() bool {
+       return strings.HasSuffix(file.Desc.Name, "_types")
 }
 
 func (file *File) hasService() bool {
        return file.Service != nil && len(file.Service.RPCs) > 0
 }
 
-func (file *File) addRef(typ string, name string, ref interface{}) {
-       apiName := toApiType(name)
-       if _, ok := file.refmap[apiName]; ok {
-               logf("%s type %v already in refmap", typ, apiName)
-               return
-       }
-       file.refmap[apiName] = name
-}
-
 func (file *File) importedFiles(gen *Generator) []*File {
        var files []*File
        for _, imp := range file.Imports {
                impFile, ok := gen.FilesByName[imp]
                if !ok {
-                       logf("file %s import %s not found API files", file.Name, imp)
+                       logf("file %s import %s not found API files", file.Desc.Name, imp)
                        continue
                }
                files = append(files, impFile)
@@ -103,115 +141,102 @@ func (file *File) importedFiles(gen *Generator) []*File {
        return files
 }
 
-func (file *File) loadTypeImports(gen *Generator, typeFiles []*File) {
-       if len(typeFiles) == 0 {
-               return
-       }
-       for _, t := range file.Structs {
-               for _, imp := range typeFiles {
-                       if _, ok := file.imports[t.Name]; ok {
-                               break
-                       }
-                       for _, at := range imp.File.StructTypes {
-                               if at.Name != t.Name {
-                                       continue
-                               }
-                               if len(at.Fields) != len(t.Fields) {
-                                       continue
-                               }
-                               file.imports[t.Name] = imp.PackageName
-                       }
-               }
-       }
-       for _, t := range file.AliasTypes {
-               for _, imp := range typeFiles {
-                       if _, ok := file.imports[t.Name]; ok {
-                               break
-                       }
-                       for _, at := range imp.File.AliasTypes {
-                               if at.Name != t.Name {
-                                       continue
-                               }
-                               if at.Length != t.Length {
-                                       continue
-                               }
-                               if at.Type != t.Type {
-                                       continue
-                               }
-                               file.imports[t.Name] = imp.PackageName
-                       }
+func (file *File) dependsOnFile(gen *Generator, dep string) bool {
+       for _, imp := range file.Imports {
+               if imp == dep {
+                       return true
                }
-       }
-       for _, t := range file.EnumTypes {
-               for _, imp := range typeFiles {
-                       if _, ok := file.imports[t.Name]; ok {
-                               break
-                       }
-                       for _, at := range imp.File.EnumTypes {
-                               if at.Name != t.Name {
-                                       continue
-                               }
-                               if at.Type != t.Type {
-                                       continue
-                               }
-                               file.imports[t.Name] = imp.PackageName
-                       }
+               impFile, ok := gen.FilesByName[imp]
+               if ok && impFile.dependsOnFile(gen, dep) {
+                       return true
                }
        }
-       for _, t := range file.UnionTypes {
-               for _, imp := range typeFiles {
-                       if _, ok := file.imports[t.Name]; ok {
-                               break
-                       }
-                       for _, at := range imp.File.UnionTypes {
-                               if at.Name != t.Name {
-                                       continue
-                               }
-                               file.imports[t.Name] = imp.PackageName
-                               /*if gen.ImportTypes {
-                                       imp.Generate = true
-                               }*/
-                       }
-               }
+       return false
+}
+
+func normalizeImport(imp string) string {
+       imp = path.Base(imp)
+       if idx := strings.Index(imp, "."); idx >= 0 {
+               imp = imp[:idx]
        }
+       return imp
+}
+
+const (
+       enumFlagSuffix = "_flags"
+)
+
+func isEnumFlag(enum *Enum) bool {
+       return strings.HasSuffix(enum.Name, enumFlagSuffix)
 }
 
 type Enum struct {
        vppapi.EnumType
 
-       GoName string
+       GoIdent
 }
 
 func newEnum(gen *Generator, file *File, apitype vppapi.EnumType) *Enum {
        typ := &Enum{
                EnumType: apitype,
-               GoName:   camelCaseName(apitype.Name),
+               GoIdent: GoIdent{
+                       GoName:       camelCaseName(apitype.Name),
+                       GoImportPath: file.GoImportPath,
+               },
        }
-       gen.enumsByName[fmt.Sprintf("%s.%s", file.Name, typ.Name)] = typ
-       file.addRef("enum", typ.Name, typ)
+       gen.enumsByName[typ.Name] = typ
        return typ
 }
 
 type Alias struct {
        vppapi.AliasType
 
-       GoName string
+       GoIdent
+
+       TypeBasic  *string
+       TypeStruct *Struct
+       TypeUnion  *Union
 }
 
 func newAlias(gen *Generator, file *File, apitype vppapi.AliasType) *Alias {
        typ := &Alias{
                AliasType: apitype,
-               GoName:    camelCaseName(apitype.Name),
+               GoIdent: GoIdent{
+                       GoName:       camelCaseName(apitype.Name),
+                       GoImportPath: file.GoImportPath,
+               },
        }
-       gen.aliasesByName[fmt.Sprintf("%s.%s", file.Name, typ.Name)] = typ
-       file.addRef("alias", typ.Name, typ)
+       gen.aliasesByName[typ.Name] = typ
        return typ
 }
 
+func (a *Alias) resolveDependencies(gen *Generator) error {
+       if err := a.resolveType(gen); err != nil {
+               return fmt.Errorf("unable to resolve field: %w", err)
+       }
+       return nil
+}
+
+func (a *Alias) resolveType(gen *Generator) error {
+       if _, ok := BaseTypesGo[a.Type]; ok {
+               return nil
+       }
+       typ := fromApiType(a.Type)
+       if t, ok := gen.structsByName[typ]; ok {
+               a.TypeStruct = t
+               return nil
+       }
+       if t, ok := gen.unionsByName[typ]; ok {
+               a.TypeUnion = t
+               return nil
+       }
+       return fmt.Errorf("unknown type: %q", a.Type)
+}
+
 type Struct struct {
        vppapi.StructType
 
-       GoName string
+       GoIdent
 
        Fields []*Field
 }
@@ -219,22 +244,33 @@ type Struct struct {
 func newStruct(gen *Generator, file *File, apitype vppapi.StructType) *Struct {
        typ := &Struct{
                StructType: apitype,
-               GoName:     camelCaseName(apitype.Name),
+               GoIdent: GoIdent{
+                       GoName:       camelCaseName(apitype.Name),
+                       GoImportPath: file.GoImportPath,
+               },
        }
+       gen.structsByName[typ.Name] = typ
        for _, fieldType := range apitype.Fields {
                field := newField(gen, file, fieldType)
                field.ParentStruct = typ
                typ.Fields = append(typ.Fields, field)
        }
-       gen.structsByName[fmt.Sprintf("%s.%s", file.Name, typ.Name)] = typ
-       file.addRef("struct", typ.Name, typ)
        return typ
 }
 
+func (m *Struct) resolveDependencies(gen *Generator) (err error) {
+       for _, field := range m.Fields {
+               if err := field.resolveDependencies(gen); err != nil {
+                       return fmt.Errorf("unable to resolve for struct %s: %w", m.Name, err)
+               }
+       }
+       return nil
+}
+
 type Union struct {
        vppapi.UnionType
 
-       GoName string
+       GoIdent
 
        Fields []*Field
 }
@@ -242,32 +278,96 @@ type Union struct {
 func newUnion(gen *Generator, file *File, apitype vppapi.UnionType) *Union {
        typ := &Union{
                UnionType: apitype,
-               GoName:    camelCaseName(apitype.Name),
+               GoIdent: GoIdent{
+                       GoName:       camelCaseName(apitype.Name),
+                       GoImportPath: file.GoImportPath,
+               },
        }
-       gen.unionsByName[fmt.Sprintf("%s.%s", file.Name, typ.Name)] = typ
+       gen.unionsByName[typ.Name] = typ
        for _, fieldType := range apitype.Fields {
                field := newField(gen, file, fieldType)
                field.ParentUnion = typ
                typ.Fields = append(typ.Fields, field)
        }
-       file.addRef("union", typ.Name, typ)
        return typ
 }
 
+func (m *Union) resolveDependencies(gen *Generator) (err error) {
+       for _, field := range m.Fields {
+               if err := field.resolveDependencies(gen); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+// msgType determines message header fields
+type msgType int
+
+const (
+       msgTypeBase    msgType = iota // msg_id
+       msgTypeRequest                // msg_id, client_index, context
+       msgTypeReply                  // msg_id, context
+       msgTypeEvent                  // msg_id, client_index
+)
+
+func apiMsgType(t msgType) GoIdent {
+       switch t {
+       case msgTypeRequest:
+               return govppApiPkg.Ident("RequestMessage")
+       case msgTypeReply:
+               return govppApiPkg.Ident("ReplyMessage")
+       case msgTypeEvent:
+               return govppApiPkg.Ident("EventMessage")
+       default:
+               return govppApiPkg.Ident("OtherMessage")
+       }
+}
+
+// message fields
+const (
+       fieldMsgID       = "_vl_msg_id"
+       fieldClientIndex = "client_index"
+       fieldContext     = "context"
+       fieldRetval      = "retval"
+)
+
+// field options
+const (
+       optFieldDefault = "default"
+)
+
 type Message struct {
        vppapi.Message
 
-       GoName string
+       CRC string
+
+       GoIdent
 
        Fields []*Field
+
+       msgType msgType
 }
 
 func newMessage(gen *Generator, file *File, apitype vppapi.Message) *Message {
        msg := &Message{
                Message: apitype,
-               GoName:  camelCaseName(apitype.Name),
+               CRC:     strings.TrimPrefix(apitype.CRC, "0x"),
+               GoIdent: newGoIdent(file, apitype.Name),
        }
+       gen.messagesByName[apitype.Name] = msg
+       n := 0
        for _, fieldType := range apitype.Fields {
+               // skip internal fields
+               switch strings.ToLower(fieldType.Name) {
+               case fieldMsgID:
+                       continue
+               case fieldClientIndex, fieldContext:
+                       if n == 0 {
+                               continue
+                       }
+               }
+               n++
                field := newField(gen, file, fieldType)
                field.ParentMessage = msg
                msg.Fields = append(msg.Fields, field)
@@ -275,21 +375,71 @@ func newMessage(gen *Generator, file *File, apitype vppapi.Message) *Message {
        return msg
 }
 
+func (m *Message) resolveDependencies(gen *Generator) (err error) {
+       if m.msgType, err = getMsgType(m.Message); err != nil {
+               return err
+       }
+       for _, field := range m.Fields {
+               if err := field.resolveDependencies(gen); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func getMsgType(m vppapi.Message) (msgType, error) {
+       if len(m.Fields) == 0 {
+               return msgType(0), fmt.Errorf("message %s has no fields", m.Name)
+       }
+       typ := msgTypeBase
+       wasClientIndex := false
+       for i, field := range m.Fields {
+               if i == 0 {
+                       if field.Name != fieldMsgID {
+                               return msgType(0), fmt.Errorf("message %s is missing ID field", m.Name)
+                       }
+               } else if i == 1 {
+                       if field.Name == fieldClientIndex {
+                               // "client_index" as the second member,
+                               // this might be an event message or a request
+                               typ = msgTypeEvent
+                               wasClientIndex = true
+                       } else if field.Name == fieldContext {
+                               // reply needs "context" as the second member
+                               typ = msgTypeReply
+                       }
+               } else if i == 2 {
+                       if wasClientIndex && field.Name == fieldContext {
+                               // request needs "client_index" as the second member
+                               // and "context" as the third member
+                               typ = msgTypeRequest
+                       }
+               }
+       }
+       return typ, nil
+}
+
 type Field struct {
        vppapi.Field
 
        GoName string
 
-       // Field parent
+       DefaultValue interface{}
+
+       // Reference to actual type of this field
+       TypeEnum   *Enum
+       TypeAlias  *Alias
+       TypeStruct *Struct
+       TypeUnion  *Union
+
+       // Parent in which this field is declared
        ParentMessage *Message
        ParentStruct  *Struct
        ParentUnion   *Union
 
-       // Type reference
-       Enum   *Enum
-       Alias  *Alias
-       Struct *Struct
-       Union  *Union
+       // Field reference for fields determining size
+       FieldSizeOf   *Field
+       FieldSizeFrom *Field
 }
 
 func newField(gen *Generator, file *File, apitype vppapi.Field) *Field {
@@ -297,64 +447,134 @@ func newField(gen *Generator, file *File, apitype vppapi.Field) *Field {
                Field:  apitype,
                GoName: camelCaseName(apitype.Name),
        }
+       if apitype.Meta != nil {
+               if val, ok := apitype.Meta[optFieldDefault]; ok {
+                       typ.DefaultValue = val
+               }
+       }
        return typ
 }
 
-type Service = vppapi.Service
-type RPC = vppapi.RPC
-
-func sortFileObjects(file *vppapi.File) {
-       // sort imports
-       sort.SliceStable(file.Imports, func(i, j int) bool {
-               return file.Imports[i] < file.Imports[j]
-       })
-       // sort enum types
-       sort.SliceStable(file.EnumTypes, func(i, j int) bool {
-               return file.EnumTypes[i].Name < file.EnumTypes[j].Name
-       })
-       // sort alias types
-       sort.Slice(file.AliasTypes, func(i, j int) bool {
-               return file.AliasTypes[i].Name < file.AliasTypes[j].Name
-       })
-       // sort struct types
-       sort.SliceStable(file.StructTypes, func(i, j int) bool {
-               return file.StructTypes[i].Name < file.StructTypes[j].Name
-       })
-       // sort union types
-       sort.SliceStable(file.UnionTypes, func(i, j int) bool {
-               return file.UnionTypes[i].Name < file.UnionTypes[j].Name
-       })
-       // sort messages
-       sort.SliceStable(file.Messages, func(i, j int) bool {
-               return file.Messages[i].Name < file.Messages[j].Name
-       })
-       // sort services
-       if file.Service != nil {
-               sort.Slice(file.Service.RPCs, func(i, j int) bool {
-                       // dumps first
-                       if file.Service.RPCs[i].Stream != file.Service.RPCs[j].Stream {
-                               return file.Service.RPCs[i].Stream
+func (f *Field) resolveDependencies(gen *Generator) error {
+       if err := f.resolveType(gen); err != nil {
+               return fmt.Errorf("unable to resolve field type: %w", err)
+       }
+       if err := f.resolveFields(gen); err != nil {
+               return fmt.Errorf("unable to resolve fields: %w", err)
+       }
+       return nil
+}
+
+func (f *Field) resolveType(gen *Generator) error {
+       if _, ok := BaseTypesGo[f.Type]; ok {
+               return nil
+       }
+       typ := fromApiType(f.Type)
+       if t, ok := gen.structsByName[typ]; ok {
+               f.TypeStruct = t
+               return nil
+       }
+       if t, ok := gen.enumsByName[typ]; ok {
+               f.TypeEnum = t
+               return nil
+       }
+       if t, ok := gen.aliasesByName[typ]; ok {
+               f.TypeAlias = t
+               return nil
+       }
+       if t, ok := gen.unionsByName[typ]; ok {
+               f.TypeUnion = t
+               return nil
+       }
+       return fmt.Errorf("unknown type: %q", f.Type)
+}
+
+func (f *Field) resolveFields(gen *Generator) error {
+       var fields []*Field
+       if f.ParentMessage != nil {
+               fields = f.ParentMessage.Fields
+       } else if f.ParentStruct != nil {
+               fields = f.ParentStruct.Fields
+       }
+       if f.SizeFrom != "" {
+               for _, field := range fields {
+                       if field.Name == f.SizeFrom {
+                               f.FieldSizeFrom = field
+                               break
+                       }
+               }
+       } else {
+               for _, field := range fields {
+                       if field.SizeFrom == f.Name {
+                               f.FieldSizeOf = field
+                               break
                        }
-                       return file.Service.RPCs[i].RequestMsg < file.Service.RPCs[j].RequestMsg
-               })
+               }
        }
+       return nil
 }
 
-func sanitizedName(name string) string {
-       switch name {
-       case "interface":
-               return "interfaces"
-       case "map":
-               return "maps"
-       default:
-               return name
+type Service struct {
+       vppapi.Service
+
+       RPCs []*RPC
+}
+
+func newService(gen *Generator, file *File, apitype vppapi.Service) *Service {
+       svc := &Service{
+               Service: apitype,
        }
+       for _, rpc := range apitype.RPCs {
+               svc.RPCs = append(svc.RPCs, newRpc(file, svc, rpc))
+       }
+       return svc
 }
 
-func normalizeImport(imp string) string {
-       imp = path.Base(imp)
-       if idx := strings.Index(imp, "."); idx >= 0 {
-               imp = imp[:idx]
+const (
+       serviceNoReply = "null"
+)
+
+type RPC struct {
+       VPP vppapi.RPC
+
+       GoName string
+
+       Service *Service
+
+       MsgRequest *Message
+       MsgReply   *Message
+       MsgStream  *Message
+}
+
+func newRpc(file *File, service *Service, apitype vppapi.RPC) *RPC {
+       rpc := &RPC{
+               VPP:     apitype,
+               GoName:  camelCaseName(apitype.Request),
+               Service: service,
        }
-       return imp
+       return rpc
+}
+
+func (rpc *RPC) resolveMessages(gen *Generator) error {
+       msg, ok := gen.messagesByName[rpc.VPP.Request]
+       if !ok {
+               return fmt.Errorf("rpc %v: no message for request type %v", rpc.GoName, rpc.VPP.Request)
+       }
+       rpc.MsgRequest = msg
+
+       if rpc.VPP.Reply != "" && rpc.VPP.Reply != serviceNoReply {
+               msg, ok := gen.messagesByName[rpc.VPP.Reply]
+               if !ok {
+                       return fmt.Errorf("rpc %v: no message for reply type %v", rpc.GoName, rpc.VPP.Reply)
+               }
+               rpc.MsgReply = msg
+       }
+       if rpc.VPP.StreamMsg != "" {
+               msg, ok := gen.messagesByName[rpc.VPP.StreamMsg]
+               if !ok {
+                       return fmt.Errorf("rpc %v: no message for stream type %v", rpc.GoName, rpc.VPP.StreamMsg)
+               }
+               rpc.MsgStream = msg
+       }
+       return nil
 }