From bcf3fbd21aa22d1546bc85ffb887ae5ba557808e Mon Sep 17 00:00:00 2001 From: Vladimir Lavor Date: Tue, 1 Dec 2020 13:57:29 +0100 Subject: [PATCH] Fixed incorrect message error in the stream API The message package is passed to the stream object and used to evaluate correct reply message type Change-Id: I2c9844d6447d024af1693205efd5721e2f89f22d Signed-off-by: Vladimir Lavor --- adapter/mock/mock_vpp_adapter.go | 28 ++++++++------ api/binapi.go | 22 +++++------ cmd/vpp-proxy/main.go | 6 ++- core/channel.go | 7 +++- core/connection.go | 84 +++++++++++++++++++++------------------- core/request_handler.go | 22 +++++++++-- core/stream.go | 12 +++++- proxy/server.go | 42 ++++++++++++-------- 8 files changed, 137 insertions(+), 86 deletions(-) diff --git a/adapter/mock/mock_vpp_adapter.go b/adapter/mock/mock_vpp_adapter.go index f79bb8b..90195e7 100644 --- a/adapter/mock/mock_vpp_adapter.go +++ b/adapter/mock/mock_vpp_adapter.go @@ -44,7 +44,7 @@ type VppAdapter struct { access sync.RWMutex msgNameToIds map[string]uint16 msgIDsToName map[uint16]string - binAPITypes map[string]reflect.Type + binAPITypes map[string]map[string]reflect.Type repliesLock sync.Mutex // mutex for the queue replies []reply // FIFO queue of messages @@ -126,7 +126,7 @@ func NewVppAdapter() *VppAdapter { msgIDSeq: 1000, msgIDsToName: make(map[uint16]string), msgNameToIds: make(map[string]uint16), - binAPITypes: make(map[string]reflect.Type), + binAPITypes: make(map[string]map[string]reflect.Type), } a.registerBinAPITypes() return a @@ -165,19 +165,25 @@ func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) { func (a *VppAdapter) registerBinAPITypes() { a.access.Lock() defer a.access.Unlock() - for _, msg := range api.GetRegisteredMessages() { - a.binAPITypes[msg.GetMessageName()] = reflect.TypeOf(msg).Elem() + for pkg, msgs := range api.GetRegisteredMessages() { + msgMap := make(map[string]reflect.Type) + for _, msg := range msgs { + msgMap[msg.GetMessageName()] = reflect.TypeOf(msg).Elem() + } + a.binAPITypes[pkg] = msgMap } } // ReplyTypeFor returns reply message type for given request message name. -func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) { +func (a *VppAdapter) ReplyTypeFor(pkg, requestMsgName string) (reflect.Type, uint16, bool) { replyName, foundName := binapi.ReplyNameFor(requestMsgName) if foundName { - if reply, found := a.binAPITypes[replyName]; found { - msgID, err := a.GetMsgID(replyName, "") - if err == nil { - return reply, msgID, found + if messages, found := a.binAPITypes[pkg]; found { + if reply, found := messages[replyName]; found { + msgID, err := a.GetMsgID(replyName, "") + if err == nil { + return reply, msgID, found + } } } } @@ -186,8 +192,8 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, } // ReplyFor returns reply message for given request message name. -func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) { - replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName) +func (a *VppAdapter) ReplyFor(pkg, requestMsgName string) (api.Message, uint16, bool) { + replType, msgID, foundReplType := a.ReplyTypeFor(pkg, requestMsgName) if foundReplType { msgVal := reflect.New(replType) if msg, ok := msgVal.Interface().(api.Message); ok { diff --git a/api/binapi.go b/api/binapi.go index cb4ab85..1b07a7e 100644 --- a/api/binapi.go +++ b/api/binapi.go @@ -15,7 +15,7 @@ package api import ( - "fmt" + "path" "reflect" ) @@ -59,27 +59,27 @@ type DataType interface { } var ( - registeredMessageTypes = make(map[reflect.Type]string) - registeredMessages = make(map[string]Message) + registeredMessages = make(map[string]map[string]Message) + registeredMessageTypes = make(map[string]map[reflect.Type]string) ) // RegisterMessage is called from generated code to register message. func RegisterMessage(x Message, name string) { - typ := reflect.TypeOf(x) - namecrc := x.GetMessageName() + "_" + x.GetCrcString() - if _, ok := registeredMessageTypes[typ]; ok { - panic(fmt.Errorf("govpp: message type %v already registered as %s (%s)", typ, name, namecrc)) + binapiPath := path.Dir(reflect.TypeOf(x).Elem().PkgPath()) + if _, ok := registeredMessages[binapiPath]; !ok { + registeredMessages[binapiPath] = make(map[string]Message) + registeredMessageTypes[binapiPath] = make(map[reflect.Type]string) } - registeredMessages[namecrc] = x - registeredMessageTypes[typ] = name + registeredMessages[binapiPath][x.GetMessageName()+"_"+x.GetCrcString()] = x + registeredMessageTypes[binapiPath][reflect.TypeOf(x)] = name } // GetRegisteredMessages returns list of all registered messages. -func GetRegisteredMessages() map[string]Message { +func GetRegisteredMessages() map[string]map[string]Message { return registeredMessages } // GetRegisteredMessageTypes returns list of all registered message types. -func GetRegisteredMessageTypes() map[reflect.Type]string { +func GetRegisteredMessageTypes() map[string]map[reflect.Type]string { return registeredMessageTypes } diff --git a/cmd/vpp-proxy/main.go b/cmd/vpp-proxy/main.go index d1af5df..3c85bcf 100644 --- a/cmd/vpp-proxy/main.go +++ b/cmd/vpp-proxy/main.go @@ -35,8 +35,10 @@ var ( ) func init() { - for _, msg := range api.GetRegisteredMessages() { - gob.Register(msg) + for _, msgList := range api.GetRegisteredMessages() { + for _, msg := range msgList { + gob.Register(msg) + } } } diff --git a/core/channel.go b/core/channel.go index 28d0710..fbb3e59 100644 --- a/core/channel.go +++ b/core/channel.go @@ -45,8 +45,10 @@ type MessageCodec interface { type MessageIdentifier interface { // GetMessageID returns message identifier of given API message. GetMessageID(msg api.Message) (uint16, error) + // GetMessagePath returns path for the given message + GetMessagePath(msg api.Message) string // LookupByID looks up message name and crc by ID - LookupByID(msgID uint16) (api.Message, error) + LookupByID(path string, msgID uint16) (api.Message, error) } // vppRequest is a request that will be sent to VPP. @@ -329,7 +331,8 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa if reply.msgID != expMsgID { var msgNameCrc string - if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil { + pkgPath := ch.msgIdentifier.GetMessagePath(msg) + if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil { msgNameCrc = err.Error() } else { msgNameCrc = getMsgNameWithCrc(replyMsg) diff --git a/core/connection.go b/core/connection.go index 0f54f38..f3ff964 100644 --- a/core/connection.go +++ b/core/connection.go @@ -17,6 +17,7 @@ package core import ( "errors" "fmt" + "path" "reflect" "sync" "sync/atomic" @@ -103,9 +104,9 @@ type Connection struct { connChan chan ConnectionEvent // connection status events are sent to this channel - codec MessageCodec // message codec - msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC - msgMap map[uint16]api.Message // map of messages indexed by message ID + codec MessageCodec // message codec + msgIDs map[string]uint16 // map of message IDs indexed by message name + CRC + msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path maxChannelID uint32 // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations) channelsLock sync.RWMutex // lock for the channels map @@ -139,7 +140,7 @@ func newConnection(binapi adapter.VppAPI, attempts int, interval time.Duration) connChan: make(chan ConnectionEvent, NotificationChanBufSize), codec: codec.DefaultCodec, msgIDs: make(map[string]uint16), - msgMap: make(map[uint16]api.Message), + msgMapByPath: make(map[string]map[uint16]api.Message), channels: make(map[uint16]*Channel), subscriptions: make(map[uint16][]*subscriptionCtx), msgControlPing: msgControlPing, @@ -400,69 +401,74 @@ func (c *Connection) GetMessageID(msg api.Message) (uint16, error) { if c == nil { return 0, errors.New("nil connection passed in") } - - if msgID, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok { - return msgID, nil - } - + pkgPath := c.GetMessagePath(msg) msgID, err := c.vppClient.GetMsgID(msg.GetMessageName(), msg.GetCrcString()) if err != nil { return 0, err } - + if pathMsgs, pathOk := c.msgMapByPath[pkgPath]; !pathOk { + c.msgMapByPath[pkgPath] = make(map[uint16]api.Message) + c.msgMapByPath[pkgPath][msgID] = msg + } else if _, msgOk := pathMsgs[msgID]; !msgOk { + c.msgMapByPath[pkgPath][msgID] = msg + } + if _, ok := c.msgIDs[getMsgNameWithCrc(msg)]; ok { + return msgID, nil + } c.msgIDs[getMsgNameWithCrc(msg)] = msgID - c.msgMap[msgID] = msg - return msgID, nil } // LookupByID looks up message name and crc by ID. -func (c *Connection) LookupByID(msgID uint16) (api.Message, error) { +func (c *Connection) LookupByID(path string, msgID uint16) (api.Message, error) { if c == nil { return nil, errors.New("nil connection passed in") } - - if msg, ok := c.msgMap[msgID]; ok { + if msg, ok := c.msgMapByPath[path][msgID]; ok { return msg, nil } + return nil, fmt.Errorf("unknown message ID %d for path '%s'", msgID, path) +} - return nil, fmt.Errorf("unknown message ID: %d", msgID) +// GetMessagePath returns path for the given message +func (c *Connection) GetMessagePath(msg api.Message) string { + return path.Dir(reflect.TypeOf(msg).Elem().PkgPath()) } // retrieveMessageIDs retrieves IDs for all registered messages and stores them in map func (c *Connection) retrieveMessageIDs() (err error) { t := time.Now() - msgs := api.GetRegisteredMessages() + msgsByPath := api.GetRegisteredMessages() var n int - for name, msg := range msgs { - typ := reflect.TypeOf(msg).Elem() - path := fmt.Sprintf("%s.%s", typ.PkgPath(), typ.Name()) + for pkgPath, msgs := range msgsByPath { + for _, msg := range msgs { + msgID, err := c.GetMessageID(msg) + if err != nil { + if debugMsgIDs { + log.Debugf("retrieving message ID for %s.%s failed: %v", + pkgPath, msg.GetMessageName(), err) + } + continue + } + n++ + + if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() { + c.pingReqID = msgID + c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() { + c.pingReplyID = msgID + c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + } - msgID, err := c.GetMessageID(msg) - if err != nil { if debugMsgIDs { - log.Debugf("retrieving message ID for %s failed: %v", path, err) + log.Debugf("message %q (%s) has ID: %d", msg.GetMessageName(), getMsgNameWithCrc(msg), msgID) } - continue - } - n++ - - if c.pingReqID == 0 && msg.GetMessageName() == c.msgControlPing.GetMessageName() { - c.pingReqID = msgID - c.msgControlPing = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - } else if c.pingReplyID == 0 && msg.GetMessageName() == c.msgControlPingReply.GetMessageName() { - c.pingReplyID = msgID - c.msgControlPingReply = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - } - - if debugMsgIDs { - log.Debugf("message %q (%s) has ID: %d", name, getMsgNameWithCrc(msg), msgID) } + log.WithField("took", time.Since(t)). + Debugf("retrieved IDs for %d messages (registered %d) from path %s", n, len(msgs), pkgPath) } - log.WithField("took", time.Since(t)). - Debugf("retrieved IDs for %d messages (registered %d)", n, len(msgs)) return nil } diff --git a/core/request_handler.go b/core/request_handler.go index fc704cb..f9d972a 100644 --- a/core/request_handler.go +++ b/core/request_handler.go @@ -210,9 +210,9 @@ func (c *Connection) msgCallback(msgID uint16, data []byte) { return } - msg, ok := c.msgMap[msgID] - if !ok { - log.Warnf("Unknown message received, ID: %d", msgID) + msg, err := c.getMessageByID(msgID) + if err != nil { + log.Warnln(err) return } @@ -419,3 +419,19 @@ func compareSeqNumbers(seqNum1, seqNum2 uint16) int { } return 1 } + +// Returns first message from any package where the message ID matches +// Note: the msg is further used only for its MessageType which is not +// affected by the message's package +func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) { + var ok bool + for _, msgs := range c.msgMapByPath { + if msg, ok = msgs[msgID]; ok { + break + } + } + if !ok { + return nil, fmt.Errorf("unknown message received, ID: %d", msgID) + } + return msg, nil +} diff --git a/core/stream.go b/core/stream.go index abe9d55..3d417f1 100644 --- a/core/stream.go +++ b/core/stream.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "reflect" + "sync" "sync/atomic" "time" @@ -34,6 +35,9 @@ type Stream struct { requestSize int replySize int replyTimeout time.Duration + // per-request context + pkgPath string + sync.Mutex } func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) { @@ -109,6 +113,9 @@ func (s *Stream) SendMsg(msg api.Message) error { if err := s.conn.processRequest(s.channel, req); err != nil { return err } + s.Lock() + s.pkgPath = s.conn.GetMessagePath(msg) + s.Unlock() return nil } @@ -118,7 +125,10 @@ func (s *Stream) RecvMsg() (api.Message, error) { return nil, err } // resolve message type - msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID) + s.Lock() + path := s.pkgPath + s.Unlock() + msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID) if err != nil { return nil, err } diff --git a/proxy/server.go b/proxy/server.go index 21d8e1b..e395468 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -226,8 +226,8 @@ type BinapiCompatibilityRequest struct { } type BinapiCompatibilityResponse struct { - CompatibleMsgs []string - IncompatibleMsgs []string + CompatibleMsgs map[string][]string + IncompatibleMsgs map[string][]string } // BinapiRPC is a RPC server for proxying client request to api.Channel. @@ -379,25 +379,33 @@ func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCo } defer ch.Close() - resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs)) - resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs)) + resp.CompatibleMsgs = make(map[string][]string) + resp.IncompatibleMsgs = make(map[string][]string) - for _, msg := range req.MsgNameCrcs { - val, ok := api.GetRegisteredMessages()[msg] - if !ok { - resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg) - continue + for path, messages := range api.GetRegisteredMessages() { + if resp.IncompatibleMsgs[path] == nil { + resp.IncompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs)) } - - if err = ch.CheckCompatiblity(val); err != nil { - resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg) - } else { - resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg) + if resp.CompatibleMsgs[path] == nil { + resp.CompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs)) + } + for _, msg := range req.MsgNameCrcs { + val, ok := messages[msg] + if !ok { + resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg) + continue + } + if err = ch.CheckCompatiblity(val); err != nil { + resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg) + } else { + resp.CompatibleMsgs[path] = append(resp.CompatibleMsgs[path], msg) + } } } - - if len(resp.IncompatibleMsgs) > 0 { - return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs) + for _, messages := range resp.IncompatibleMsgs { + if len(messages) > 0 { + return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs) + } } return nil -- 2.16.6