Pair requests with replies using sequence numbers
[govpp.git] / adapter / mock / mock_adapter.go
index dab51a6..ae4caf0 100644 (file)
@@ -48,10 +48,10 @@ type VppAdapter struct {
        binAPITypes  map[string]reflect.Type
        access       sync.RWMutex
 
-       replies       []api.Message  // FIFO queue of messages
-       replyHandlers []ReplyHandler // callbacks that are able to calculate mock responses
-       repliesLock   sync.Mutex     // mutex for the queue
-       mode          replyMode      // mode in which the mock operates
+       replies       []reply            // FIFO queue of messages
+       replyHandlers []ReplyHandler     // callbacks that are able to calculate mock responses
+       repliesLock   sync.Mutex         // mutex for the queue
+       mode          replyMode          // mode in which the mock operates
 }
 
 // defaultReply is a default reply message that mock adapter returns for a request.
@@ -67,6 +67,13 @@ type MessageDTO struct {
        Data     []byte
 }
 
+type reply struct {
+       msg       api.Message
+       multipart bool
+       hasSeqNum bool
+       seqNum    uint16
+}
+
 // ReplyHandler is a type that allows to extend the behaviour of VPP mock.
 // Return value ok is used to signalize that mock reply is calculated and ready to be used.
 type ReplyHandler func(request MessageDTO) (reply []byte, msgID uint16, ok bool)
@@ -247,28 +254,35 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
                defer a.repliesLock.Unlock()
 
                // pop all replies from queue
+               withSeqNums := false
                for i, reply := range a.replies {
-                       if i > 0 && reply.GetMessageName() == "control_ping_reply" {
+                       if i > 0 && reply.msg.GetMessageName() == "control_ping_reply" && !withSeqNums {
                                // hack - do not send control_ping_reply immediately, leave it for the the next callback
                                a.replies = a.replies[i:]
                                return nil
                        }
-                       msgID, _ := a.GetMsgID(reply.GetMessageName(), reply.GetCrcString())
+                       msgID, _ := a.GetMsgID(reply.msg.GetMessageName(), reply.msg.GetCrcString())
                        buf := new(bytes.Buffer)
-                       if reply.GetMessageType() == api.ReplyMessage {
-                               struc.Pack(buf, &core.VppReplyHeader{VlMsgID: msgID, Context: clientID})
-                       } else if reply.GetMessageType() == api.EventMessage {
-                               struc.Pack(buf, &core.VppEventHeader{VlMsgID: msgID, Context: clientID})
-                       } else if reply.GetMessageType() == api.RequestMessage {
-                               struc.Pack(buf, &core.VppRequestHeader{VlMsgID: msgID, Context: clientID})
+                       context := clientID
+                       context = setMultipart(context, reply.multipart)
+                       if reply.hasSeqNum {
+                               withSeqNums = true
+                               context = setSeqNum(context, reply.seqNum)
+                       }
+                       if reply.msg.GetMessageType() == api.ReplyMessage {
+                               struc.Pack(buf, &core.VppReplyHeader{VlMsgID: msgID, Context: context})
+                       } else if reply.msg.GetMessageType() == api.EventMessage {
+                               struc.Pack(buf, &core.VppEventHeader{VlMsgID: msgID, Context: context})
+                       } else if reply.msg.GetMessageType() == api.RequestMessage {
+                               struc.Pack(buf, &core.VppRequestHeader{VlMsgID: msgID, Context: context})
                        } else {
                                struc.Pack(buf, &core.VppOtherHeader{VlMsgID: msgID})
                        }
-                       struc.Pack(buf, reply)
-                       a.callback(clientID, msgID, buf.Bytes())
+                       struc.Pack(buf, reply.msg)
+                       a.callback(context, msgID, buf.Bytes())
                }
                if len(a.replies) > 0 {
-                       a.replies = []api.Message{}
+                       a.replies = []reply{}
                        if len(a.replyHandlers) > 0 {
                                // Switch back to handlers once the queue is empty to revert back
                                // the fallthrough effect.
@@ -302,11 +316,21 @@ func (a *VppAdapter) WaitReady() error {
 // MockReply stores a message to be returned when the next request comes. It is a FIFO queue - multiple replies
 // can be pushed into it, the first one will be popped when some request comes.
 // Using of this method automatically switches the mock into th useRepliesQueue mode.
-func (a *VppAdapter) MockReply(msg api.Message) {
+func (a *VppAdapter) MockReply(replyMsg api.Message, isMultipart bool) {
        a.repliesLock.Lock()
        defer a.repliesLock.Unlock()
 
-       a.replies = append(a.replies, msg)
+       a.replies = append(a.replies, reply{msg: replyMsg, multipart: isMultipart, hasSeqNum: false})
+       a.mode = useRepliesQueue
+}
+
+// MockReplyWithSeqNum queues next reply like MockReply() does, except that the
+// sequence number can be customized and not necessarily match with the request.
+func (a *VppAdapter) MockReplyWithSeqNum(replyMsg api.Message, isMultipart bool, sequenceNum uint16) {
+       a.repliesLock.Lock()
+       defer a.repliesLock.Unlock()
+
+       a.replies = append(a.replies, reply{msg: replyMsg, multipart: isMultipart, hasSeqNum: true, seqNum: sequenceNum})
        a.mode = useRepliesQueue
 }
 
@@ -319,3 +343,18 @@ func (a *VppAdapter) MockReplyHandler(replyHandler ReplyHandler) {
        a.replyHandlers = append(a.replyHandlers, replyHandler)
        a.mode = useReplyHandlers
 }
+
+func setSeqNum(context uint32, seqNum uint16) (newContext uint32) {
+       context &= 0xffff0000
+       context |= uint32(seqNum)
+       return context
+}
+
+func setMultipart(context uint32, isMultipart bool) (newContext uint32) {
+       context &= 0xfffeffff
+       if isMultipart {
+               context |= 1 << 16
+       }
+       return context
+}
+