Fixed incorrect message error in the stream API 18/30218/3
authorVladimir Lavor <vlavor@cisco.com>
Tue, 1 Dec 2020 12:57:29 +0000 (13:57 +0100)
committerVladimir Lavor <vlavor@cisco.com>
Thu, 3 Dec 2020 09:14:12 +0000 (10:14 +0100)
The message package is passed to the stream object and
used to evaluate correct reply message type

Change-Id: I2c9844d6447d024af1693205efd5721e2f89f22d
Signed-off-by: Vladimir Lavor <vlavor@cisco.com>
adapter/mock/mock_vpp_adapter.go
api/binapi.go
cmd/vpp-proxy/main.go
core/channel.go
core/connection.go
core/request_handler.go
core/stream.go
proxy/server.go

index f79bb8b..90195e7 100644 (file)
@@ -44,7 +44,7 @@ type VppAdapter struct {
        access       sync.RWMutex
        msgNameToIds map[string]uint16
        msgIDsToName map[uint16]string
-       binAPITypes  map[string]reflect.Type
+       binAPITypes  map[string]map[string]reflect.Type
 
        repliesLock   sync.Mutex     // mutex for the queue
        replies       []reply        // FIFO queue of messages
@@ -126,7 +126,7 @@ func NewVppAdapter() *VppAdapter {
                msgIDSeq:     1000,
                msgIDsToName: make(map[uint16]string),
                msgNameToIds: make(map[string]uint16),
-               binAPITypes:  make(map[string]reflect.Type),
+               binAPITypes:  make(map[string]map[string]reflect.Type),
        }
        a.registerBinAPITypes()
        return a
@@ -165,19 +165,25 @@ func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) {
 func (a *VppAdapter) registerBinAPITypes() {
        a.access.Lock()
        defer a.access.Unlock()
-       for _, msg := range api.GetRegisteredMessages() {
-               a.binAPITypes[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+       for pkg, msgs := range api.GetRegisteredMessages() {
+               msgMap := make(map[string]reflect.Type)
+               for _, msg := range msgs {
+                       msgMap[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+               }
+               a.binAPITypes[pkg] = msgMap
        }
 }
 
 // ReplyTypeFor returns reply message type for given request message name.
-func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) {
+func (a *VppAdapter) ReplyTypeFor(pkg, requestMsgName string) (reflect.Type, uint16, bool) {
        replyName, foundName := binapi.ReplyNameFor(requestMsgName)
        if foundName {
-               if reply, found := a.binAPITypes[replyName]; found {
-                       msgID, err := a.GetMsgID(replyName, "")
-                       if err == nil {
-                               return reply, msgID, found
+               if messages, found := a.binAPITypes[pkg]; found {
+                       if reply, found := messages[replyName]; found {
+                               msgID, err := a.GetMsgID(replyName, "")
+                               if err == nil {
+                                       return reply, msgID, found
+                               }
                        }
                }
        }
@@ -186,8 +192,8 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16,
 }
 
 // ReplyFor returns reply message for given request message name.
-func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) {
-       replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName)
+func (a *VppAdapter) ReplyFor(pkg, requestMsgName string) (api.Message, uint16, bool) {
+       replType, msgID, foundReplType := a.ReplyTypeFor(pkg, requestMsgName)
        if foundReplType {
                msgVal := reflect.New(replType)
                if msg, ok := msgVal.Interface().(api.Message); ok {
index cb4ab85..1b07a7e 100644 (file)
@@ -15,7 +15,7 @@
 package api
 
 import (
-       "fmt"
+       "path"
        "reflect"
 )
 
@@ -59,27 +59,27 @@ type DataType interface {
 }
 
 var (
-       registeredMessageTypes = make(map[reflect.Type]string)
-       registeredMessages     = make(map[string]Message)
+       registeredMessages     = make(map[string]map[string]Message)
+       registeredMessageTypes = make(map[string]map[reflect.Type]string)
 )
 
 // RegisterMessage is called from generated code to register message.
 func RegisterMessage(x Message, name string) {
-       typ := reflect.TypeOf(x)
-       namecrc := x.GetMessageName() + "_" + x.GetCrcString()
-       if _, ok := registeredMessageTypes[typ]; ok {
-               panic(fmt.Errorf("govpp: message type %v already registered as %s (%s)", typ, name, namecrc))
+       binapiPath := path.Dir(reflect.TypeOf(x).Elem().PkgPath())
+       if _, ok := registeredMessages[binapiPath]; !ok {
+               registeredMessages[binapiPath] = make(map[string]Message)
+               registeredMessageTypes[binapiPath] = make(map[reflect.Type]string)
        }
-       registeredMessages[namecrc] = x
-       registeredMessageTypes[typ] = name
+       registeredMessages[binapiPath][x.GetMessageName()+"_"+x.GetCrcString()] = x
+       registeredMessageTypes[binapiPath][reflect.TypeOf(x)] = name
 }
 
 // GetRegisteredMessages returns list of all registered messages.
-func GetRegisteredMessages() map[string]Message {
+func GetRegisteredMessages() map[string]map[string]Message {
        return registeredMessages
 }
 
 // GetRegisteredMessageTypes returns list of all registered message types.
-func GetRegisteredMessageTypes() map[reflect.Type]string {
+func GetRegisteredMessageTypes() map[string]map[reflect.Type]string {
        return registeredMessageTypes
 }
index d1af5df..3c85bcf 100644 (file)
@@ -35,8 +35,10 @@ var (
 )
 
 func init() {
-       for _, msg := range api.GetRegisteredMessages() {
-               gob.Register(msg)
+       for _, msgList := range api.GetRegisteredMessages() {
+               for _, msg := range msgList {
+                       gob.Register(msg)
+               }
        }
 }
 
index 28d0710..fbb3e59 100644 (file)
@@ -45,8 +45,10 @@ type MessageCodec interface {
 type MessageIdentifier interface {
        // GetMessageID returns message identifier of given API message.
        GetMessageID(msg api.Message) (uint16, error)
+       // GetMessagePath returns path for the given message
+       GetMessagePath(msg api.Message) string
        // LookupByID looks up message name and crc by ID
-       LookupByID(msgID uint16) (api.Message, error)
+       LookupByID(path string, msgID uint16) (api.Message, error)
 }
 
 // vppRequest is a request that will be sent to VPP.
@@ -329,7 +331,8 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa
 
        if reply.msgID != expMsgID {
                var msgNameCrc string
-               if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
+               pkgPath := ch.msgIdentifier.GetMessagePath(msg)
+               if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
                        msgNameCrc = err.Error()
                } else {
                        msgNameCrc = getMsgNameWithCrc(replyMsg)
index 0f54f38..f3ff964 100644 (file)
@@ -17,6 +17,7 @@ package core
 import (
        "errors"
        "fmt"
+       "path"
        "reflect"
        "sync"
        "sync/atomic"
@@ -103,9 +104,9 @@ type Connection struct {
 
        connChan chan ConnectionEvent // connection status events are sent to this channel
 
-       codec  MessageCodec           // message codec
-       msgIDs map[string]uint16      // map of message IDs indexed by message name + CRC
-       msgMap map[uint16]api.Message // map of messages indexed by message ID
+       codec        MessageCodec                      // message codec
+       msgIDs       map[string]uint16                 // map of message IDs indexed by message name + CRC
+       msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path
 
        maxChannelID uint32              // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations)
        channelsLock sync.RWMutex        // lock for the channels map
@@ -139,7 +140,7 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration)
                connChan:            make(chan ConnectionEvent, NotificationChanBufSize),
                codec:               codec.DefaultCodec,
                msgIDs:              make(map[string]uint16),
-               msgMap:              make(map[uint16]api.Message),
+               msgMapByPath:        make(map[string]map[uint16]api.Message),
                channels:            make(map[uint16]*Channel),
                subscriptions:       make(map[uint16][]*subscriptionCtx),
                msgControlPing:      msgControlPing,
@@ -400,69 +401,74 @@ func (c *Connection) GetMessageID(msg api.Message) (uint16, error) {
        if c == nil {
                return 0, errors.New("nil connection passed in")
        }
-
-       if msgID, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
-               return msgID, nil
-       }
-
+       pkgPath := c.GetMessagePath(msg)
        msgID, err := c.vppClient.GetMsgID(msg.GetMessageName(), msg.GetCrcString())
        if err != nil {
                return 0, err
        }
-
+       if pathMsgs, pathOk := c.msgMapByPath[pkgPath]; !pathOk {
+               c.msgMapByPath[pkgPath] = make(map[uint16]api.Message)
+               c.msgMapByPath[pkgPath][msgID] = msg
+       } else if _, msgOk := pathMsgs[msgID]; !msgOk {
+               c.msgMapByPath[pkgPath][msgID] = msg
+       }
+       if _, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok {
+               return msgID, nil
+       }
        c.msgIDs[getMsgNameWithCrc(msg)] = msgID
-       c.msgMap[msgID] = msg
-
        return msgID, nil
 }
 
 // LookupByID looks up message name and crc by ID.
-func (c *Connection) LookupByID(msgID uint16) (api.Message, error) {
+func (c *Connection) LookupByID(path string, msgID uint16) (api.Message, error) {
        if c == nil {
                return nil, errors.New("nil connection passed in")
        }
-
-       if msg, ok := c.msgMap[msgID]; ok {
+       if msg, ok := c.msgMapByPath[path][msgID]; ok {
                return msg, nil
        }
+       return nil, fmt.Errorf("unknown message ID %d for path '%s'", msgID, path)
+}
 
-       return nil, fmt.Errorf("unknown message ID: %d", msgID)
+// GetMessagePath returns path for the given message
+func (c *Connection) GetMessagePath(msg api.Message) string {
+       return path.Dir(reflect.TypeOf(msg).Elem().PkgPath())
 }
 
 // retrieveMessageIDs retrieves IDs for all registered messages and stores them in map
 func (c *Connection) retrieveMessageIDs() (err error) {
        t := time.Now()
 
-       msgs := api.GetRegisteredMessages()
+       msgsByPath := api.GetRegisteredMessages()
 
        var n int
-       for name, msg := range msgs {
-               typ := reflect.TypeOf(msg).Elem()
-               path := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name())
+       for pkgPath, msgs := range msgsByPath {
+               for _, msg := range msgs {
+                       msgID, err := c.GetMessageID(msg)
+                       if err != nil {
+                               if debugMsgIDs {
+                                       log.Debugf("retrieving message ID for %s.%s failed: %v",
+                                               pkgPath, msg.GetMessageName(), err)
+                               }
+                               continue
+                       }
+                       n++
+
+                       if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
+                               c.pingReqID = msgID
+                               c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+                       } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
+                               c.pingReplyID = msgID
+                               c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+                       }
 
-               msgID, err := c.GetMessageID(msg)
-               if err != nil {
                        if debugMsgIDs {
-                               log.Debugf("retrieving message ID for %s failed: %v", path, err)
+                               log.Debugf("message %q (%s) has ID: %d", msg.GetMessageName(), getMsgNameWithCrc(msg), msgID)
                        }
-                       continue
-               }
-               n++
-
-               if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() {
-                       c.pingReqID = msgID
-                       c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
-               } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() {
-                       c.pingReplyID = msgID
-                       c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
-               }
-
-               if debugMsgIDs {
-                       log.Debugf("message %q (%s) has ID: %d", name, getMsgNameWithCrc(msg), msgID)
                }
+               log.WithField("took", time.Since(t)).
+                       Debugf("retrieved IDs for %d messages (registered %d) from path %s", n, len(msgs), pkgPath)
        }
-       log.WithField("took", time.Since(t)).
-               Debugf("retrieved IDs for %d messages (registered %d)", n, len(msgs))
 
        return nil
 }
index fc704cb..f9d972a 100644 (file)
@@ -210,9 +210,9 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) {
                return
        }
 
-       msg, ok := c.msgMap[msgID]
-       if !ok {
-               log.Warnf("Unknown message received, ID: %d", msgID)
+       msg, err := c.getMessageByID(msgID)
+       if err != nil {
+               log.Warnln(err)
                return
        }
 
@@ -419,3 +419,19 @@ func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
        }
        return 1
 }
+
+// Returns first message from any package where the message ID matches
+// Note: the msg is further used only for its MessageType which is not
+// affected by the message's package
+func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) {
+       var ok bool
+       for _, msgs := range c.msgMapByPath {
+               if msg, ok = msgs[msgID]; ok {
+                       break
+               }
+       }
+       if !ok {
+               return nil, fmt.Errorf("unknown message received, ID: %d", msgID)
+       }
+       return msg, nil
+}
index abe9d55..3d417f1 100644 (file)
@@ -19,6 +19,7 @@ import (
        "errors"
        "fmt"
        "reflect"
+       "sync"
        "sync/atomic"
        "time"
 
@@ -34,6 +35,9 @@ type Stream struct {
        requestSize  int
        replySize    int
        replyTimeout time.Duration
+       // per-request context
+       pkgPath string
+       sync.Mutex
 }
 
 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
@@ -109,6 +113,9 @@ func (s *Stream) SendMsg(msg api.Message) error {
        if err := s.conn.processRequest(s.channel, req); err != nil {
                return err
        }
+       s.Lock()
+       s.pkgPath = s.conn.GetMessagePath(msg)
+       s.Unlock()
        return nil
 }
 
@@ -118,7 +125,10 @@ func (s *Stream) RecvMsg() (api.Message, error) {
                return nil, err
        }
        // resolve message type
-       msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
+       s.Lock()
+       path := s.pkgPath
+       s.Unlock()
+       msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
        if err != nil {
                return nil, err
        }
index 21d8e1b..e395468 100644 (file)
@@ -226,8 +226,8 @@ type BinapiCompatibilityRequest struct {
 }
 
 type BinapiCompatibilityResponse struct {
-       CompatibleMsgs   []string
-       IncompatibleMsgs []string
+       CompatibleMsgs   map[string][]string
+       IncompatibleMsgs map[string][]string
 }
 
 // BinapiRPC is a RPC server for proxying client request to api.Channel.
@@ -379,25 +379,33 @@ func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCo
        }
        defer ch.Close()
 
-       resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
-       resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
+       resp.CompatibleMsgs = make(map[string][]string)
+       resp.IncompatibleMsgs = make(map[string][]string)
 
-       for _, msg := range req.MsgNameCrcs {
-               val, ok := api.GetRegisteredMessages()[msg]
-               if !ok {
-                       resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
-                       continue
+       for path, messages := range api.GetRegisteredMessages() {
+               if resp.IncompatibleMsgs[path] == nil {
+                       resp.IncompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
                }
-
-               if err = ch.CheckCompatiblity(val); err != nil {
-                       resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
-               } else {
-                       resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg)
+               if resp.CompatibleMsgs[path] == nil {
+                       resp.CompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
+               }
+               for _, msg := range req.MsgNameCrcs {
+                       val, ok := messages[msg]
+                       if !ok {
+                               resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+                               continue
+                       }
+                       if err = ch.CheckCompatiblity(val); err != nil {
+                               resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+                       } else {
+                               resp.CompatibleMsgs[path] = append(resp.CompatibleMsgs[path], msg)
+                       }
                }
        }
-
-       if len(resp.IncompatibleMsgs) > 0 {
-               return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+       for _, messages := range resp.IncompatibleMsgs {
+               if len(messages) > 0 {
+                       return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+               }
        }
 
        return nil