make api.Channel as interface
[govpp.git] / adapter / mock / mock_adapter.go
index ab49cef..a5cb62d 100644 (file)
@@ -22,40 +22,36 @@ import (
        "reflect"
        "sync"
 
-       "github.com/lunixbochs/struc"
-
        "git.fd.io/govpp.git/adapter"
-       "git.fd.io/govpp.git/adapter/mock/binapi_reflect"
+       "git.fd.io/govpp.git/adapter/mock/binapi"
        "git.fd.io/govpp.git/api"
+
+       "git.fd.io/govpp.git/codec"
+       "github.com/lunixbochs/struc"
+)
+
+type replyMode int
+
+const (
+       _                replyMode = 0
+       useRepliesQueue            = 1 // use replies in the queue
+       useReplyHandlers           = 2 // use reply handler
 )
 
 // VppAdapter represents a mock VPP adapter that can be used for unit/integration testing instead of the vppapiclient adapter.
 type VppAdapter struct {
        callback func(context uint32, msgId uint16, data []byte)
 
-       msgNameToIds *map[string]uint16
-       msgIdsToName *map[uint16]string
-       msgIdSeq     uint16
-       binApiTypes  map[string]reflect.Type
-       //TODO lock
-}
-
-// replyHeader represents a common header of each VPP request message.
-type requestHeader struct {
-       VlMsgID     uint16
-       ClientIndex uint32
-       Context     uint32
-}
-
-// replyHeader represents a common header of each VPP reply message.
-type replyHeader struct {
-       VlMsgID uint16
-       Context uint32
-}
+       msgNameToIds map[string]uint16
+       msgIDsToName map[uint16]string
+       msgIDSeq     uint16
+       binAPITypes  map[string]reflect.Type
+       access       sync.RWMutex
 
-// replyHeader represents a common header of each VPP reply message.
-type vppOtherHeader struct {
-       VlMsgID uint16
+       replies       []reply        // FIFO queue of messages
+       replyHandlers []ReplyHandler // callbacks that are able to calculate mock responses
+       repliesLock   sync.Mutex     // mutex for the queue
+       mode          replyMode      // mode in which the mock operates
 }
 
 // defaultReply is a default reply message that mock adapter returns for a request.
@@ -63,7 +59,7 @@ type defaultReply struct {
        Retval int32
 }
 
-// MessageDTO is a structure used for propageating informations to ReplyHandlers
+// MessageDTO is a structure used for propagating information to ReplyHandlers.
 type MessageDTO struct {
        MsgID    uint16
        MsgName  string
@@ -71,23 +67,29 @@ type MessageDTO struct {
        Data     []byte
 }
 
+// reply for one request (can be multipart, contain replies to previously timeouted requests, etc.)
+type reply struct {
+       msgs []MsgWithContext
+}
+
+// MsgWithContext encapsulates reply message with possibly sequence number and is-multipart flag.
+type MsgWithContext struct {
+       Msg       api.Message
+       SeqNum    uint16
+       Multipart bool
+
+       /* set by mock adapter */
+       hasCtx bool
+}
+
 // ReplyHandler is a type that allows to extend the behaviour of VPP mock.
-// Return value prepared is used to signalize that mock reply is calculated.
-type ReplyHandler func(request MessageDTO) (reply []byte, msgID uint16, prepared bool)
+// Return value ok is used to signalize that mock reply is calculated and ready to be used.
+type ReplyHandler func(request MessageDTO) (reply []byte, msgID uint16, ok bool)
 
 const (
-       //defaultMsgID      = 1 // default message ID to be returned from GetMsgId
-       defaultReplyMsgID = 2 // default message ID for the reply to be sent back via callback
+       defaultReplyMsgID = 1 // default message ID for the reply to be sent back via callback
 )
 
-var replies []api.Message        // FIFO queue of messages
-var replyHandlers []ReplyHandler // callbacks that are able to calculate mock responses
-var repliesLock sync.Mutex       // mutex for the queue
-var mode = 0
-
-const useRepliesQueue = 1  // use replies in the queue instead of the default one
-const useReplyHandlers = 2 //use ReplyHandler
-
 // NewVppAdapter returns a new mock adapter.
 func NewVppAdapter() adapter.VppAdapter {
        return &VppAdapter{}
@@ -103,10 +105,9 @@ func (a *VppAdapter) Disconnect() {
        // no op
 }
 
-func (a *VppAdapter) GetMsgNameByID(msgId uint16) (string, bool) {
-       a.initMaps()
-
-       switch msgId {
+// GetMsgNameByID returns message name for specified message ID.
+func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) {
+       switch msgID {
        case 100:
                return "control_ping", true
        case 101:
@@ -117,24 +118,31 @@ func (a *VppAdapter) GetMsgNameByID(msgId uint16) (string, bool) {
                return "sw_interface_details", true
        }
 
-       msgName, found := (*a.msgIdsToName)[msgId]
+       a.access.Lock()
+       defer a.access.Unlock()
+       a.initMaps()
+       msgName, found := a.msgIDsToName[msgID]
 
        return msgName, found
 }
 
-func (a *VppAdapter) RegisterBinApiTypes(binApiTypes map[string]reflect.Type) {
+// RegisterBinAPITypes registers binary API message types in the mock adapter.
+func (a *VppAdapter) RegisterBinAPITypes(binAPITypes map[string]reflect.Type) {
+       a.access.Lock()
+       defer a.access.Unlock()
        a.initMaps()
-       for _, v := range binApiTypes {
+       for _, v := range binAPITypes {
                if msg, ok := reflect.New(v).Interface().(api.Message); ok {
-                       a.binApiTypes[msg.GetMessageName()] = v
+                       a.binAPITypes[msg.GetMessageName()] = v
                }
        }
 }
 
+// ReplyTypeFor returns reply message type for given request message name.
 func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) {
-       replyName, foundName := binapi_reflect.ReplyNameFor(requestMsgName)
+       replyName, foundName := binapi.ReplyNameFor(requestMsgName)
        if foundName {
-               if reply, found := a.binApiTypes[replyName]; found {
+               if reply, found := a.binAPITypes[replyName]; found {
                        msgID, err := a.GetMsgID(replyName, "")
                        if err == nil {
                                return reply, msgID, found
@@ -145,6 +153,7 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16,
        return nil, 0, false
 }
 
+// ReplyFor returns reply message for given request message name.
 func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) {
        replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName)
        if foundReplType {
@@ -158,17 +167,18 @@ func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool)
        return nil, 0, false
 }
 
+// ReplyBytes encodes the mocked reply into binary format.
 func (a *VppAdapter) ReplyBytes(request MessageDTO, reply api.Message) ([]byte, error) {
-       replyMsgId, err := a.GetMsgID(reply.GetMessageName(), reply.GetCrcString())
+       replyMsgID, err := a.GetMsgID(reply.GetMessageName(), reply.GetCrcString())
        if err != nil {
-               log.Println("ReplyBytesE ", replyMsgId, " ", reply.GetMessageName(), " clientId: ", request.ClientID,
+               log.Println("ReplyBytesE ", replyMsgID, " ", reply.GetMessageName(), " clientId: ", request.ClientID,
                        " ", err)
                return nil, err
        }
-       log.Println("ReplyBytes ", replyMsgId, " ", reply.GetMessageName(), " clientId: ", request.ClientID)
+       log.Println("ReplyBytes ", replyMsgID, " ", reply.GetMessageName(), " clientId: ", request.ClientID)
 
        buf := new(bytes.Buffer)
-       struc.Pack(buf, &replyHeader{VlMsgID: replyMsgId, Context: request.ClientID})
+       struc.Pack(buf, &codec.VppReplyHeader{VlMsgID: replyMsgID, Context: request.ClientID})
        struc.Pack(buf, reply)
 
        return buf.Bytes(), nil
@@ -187,50 +197,60 @@ func (a *VppAdapter) GetMsgID(msgName string, msgCrc string) (uint16, error) {
                return 201, nil
        }
 
+       a.access.Lock()
+       defer a.access.Unlock()
        a.initMaps()
 
-       if msgId, found := (*a.msgNameToIds)[msgName]; found {
-               return msgId, nil
-       } else {
-               a.msgIdSeq++
-               msgId = a.msgIdSeq
-               (*a.msgNameToIds)[msgName] = msgId
-               (*a.msgIdsToName)[msgId] = msgName
+       msgID, found := a.msgNameToIds[msgName]
+       if found {
+               return msgID, nil
+       }
 
-               log.Println("VPP GetMessageId ", msgId, " name:", msgName, " crc:", msgCrc)
+       a.msgIDSeq++
+       msgID = a.msgIDSeq
+       a.msgNameToIds[msgName] = msgID
+       a.msgIDsToName[msgID] = msgName
 
-               return msgId, nil
-       }
+       log.Println("VPP GetMessageId ", msgID, " name:", msgName, " crc:", msgCrc)
+
+       return msgID, nil
 }
 
+// initMaps initializes internal maps (if not already initialized).
 func (a *VppAdapter) initMaps() {
-       if a.msgIdsToName == nil {
-               a.msgIdsToName = &map[uint16]string{}
-               a.msgNameToIds = &map[string]uint16{}
-               a.msgIdSeq = 1000
+       if a.msgIDsToName == nil {
+               a.msgIDsToName = map[uint16]string{}
+               a.msgNameToIds = map[string]uint16{}
+               a.msgIDSeq = 1000
        }
 
-       if a.binApiTypes == nil {
-               a.binApiTypes = map[string]reflect.Type{}
+       if a.binAPITypes == nil {
+               a.binAPITypes = map[string]reflect.Type{}
        }
 }
 
 // SendMsg emulates sending a binary-encoded message to VPP.
 func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
-       switch mode {
+       switch a.mode {
        case useReplyHandlers:
-               for i := len(replyHandlers) - 1; i >= 0; i-- {
-                       replyHandler := replyHandlers[i]
+               a.initMaps()
+               for i := len(a.replyHandlers) - 1; i >= 0; i-- {
+                       replyHandler := a.replyHandlers[i]
 
                        buf := bytes.NewReader(data)
-                       reqHeader := requestHeader{}
+                       reqHeader := codec.VppRequestHeader{}
                        struc.Unpack(buf, &reqHeader)
 
-                       a.initMaps()
-                       reqMsgName, _ := (*a.msgIdsToName)[reqHeader.VlMsgID]
+                       a.access.Lock()
+                       reqMsgName := a.msgIDsToName[reqHeader.VlMsgID]
+                       a.access.Unlock()
 
-                       reply, msgID, finished := replyHandler(MessageDTO{reqHeader.VlMsgID, reqMsgName,
-                               clientID, data})
+                       reply, msgID, finished := replyHandler(MessageDTO{
+                               MsgID:    reqHeader.VlMsgID,
+                               MsgName:  reqMsgName,
+                               ClientID: clientID,
+                               Data:     data,
+                       })
                        if finished {
                                a.callback(clientID, msgID, reply)
                                return nil
@@ -238,29 +258,39 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
                }
                fallthrough
        case useRepliesQueue:
-               repliesLock.Lock()
-               defer repliesLock.Unlock()
-
-               // pop all replies from queue
-               for i, reply := range replies {
-                       if i > 0 && reply.GetMessageName() == "control_ping_reply" {
-                               // hack - do not send control_ping_reply immediately, leave it for the the next callback
-                               replies = []api.Message{}
-                               replies = append(replies, reply)
-                               return nil
+               a.repliesLock.Lock()
+               defer a.repliesLock.Unlock()
+
+               // pop the first reply
+               if len(a.replies) > 0 {
+                       reply := a.replies[0]
+                       for _, msg := range reply.msgs {
+                               msgID, _ := a.GetMsgID(msg.Msg.GetMessageName(), msg.Msg.GetCrcString())
+                               buf := new(bytes.Buffer)
+                               context := clientID
+                               if msg.hasCtx {
+                                       context = setMultipart(context, msg.Multipart)
+                                       context = setSeqNum(context, msg.SeqNum)
+                               }
+                               if msg.Msg.GetMessageType() == api.ReplyMessage {
+                                       struc.Pack(buf, &codec.VppReplyHeader{VlMsgID: msgID, Context: context})
+                               } else if msg.Msg.GetMessageType() == api.EventMessage {
+                                       struc.Pack(buf, &codec.VppEventHeader{VlMsgID: msgID, Context: context})
+                               } else if msg.Msg.GetMessageType() == api.RequestMessage {
+                                       struc.Pack(buf, &codec.VppRequestHeader{VlMsgID: msgID, Context: context})
+                               } else {
+                                       struc.Pack(buf, &codec.VppOtherHeader{VlMsgID: msgID})
+                               }
+                               struc.Pack(buf, msg.Msg)
+                               a.callback(context, msgID, buf.Bytes())
                        }
-                       msgID, _ := a.GetMsgID(reply.GetMessageName(), reply.GetCrcString())
-                       buf := new(bytes.Buffer)
-                       if reply.GetMessageType() == api.ReplyMessage {
-                               struc.Pack(buf, &replyHeader{VlMsgID: msgID, Context: clientID})
-                       } else {
-                               struc.Pack(buf, &requestHeader{VlMsgID: msgID, Context: clientID})
+
+                       a.replies = a.replies[1:]
+                       if len(a.replies) == 0 && len(a.replyHandlers) > 0 {
+                               // Switch back to handlers once the queue is empty to revert back
+                               // the fallthrough effect.
+                               a.mode = useReplyHandlers
                        }
-                       struc.Pack(buf, reply)
-                       a.callback(clientID, msgID, buf.Bytes())
-               }
-               if len(replies) > 0 {
-                       replies = []api.Message{}
                        return nil
                }
 
@@ -269,7 +299,7 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
                // return default reply
                buf := new(bytes.Buffer)
                msgID := uint16(defaultReplyMsgID)
-               struc.Pack(buf, &replyHeader{VlMsgID: msgID, Context: clientID})
+               struc.Pack(buf, &codec.VppReplyHeader{VlMsgID: msgID, Context: clientID})
                struc.Pack(buf, &defaultReply{})
                a.callback(clientID, msgID, buf.Bytes())
        }
@@ -281,22 +311,84 @@ func (a *VppAdapter) SetMsgCallback(cb func(context uint32, msgID uint16, data [
        a.callback = cb
 }
 
-// MockReply stores a message to be returned when the next request comes. It is a FIFO queue - multiple replies
-// can be pushed into it, the first one will be popped when some request comes.
+// WaitReady mocks waiting for VPP
+func (a *VppAdapter) WaitReady() error {
+       return nil
+}
+
+// MockReply stores a message or a list of multipart messages to be returned when
+// the next request comes. It is a FIFO queue - multiple replies can be pushed into it,
+// the first message or the first set of multi-part messages will be popped when
+// some request comes.
+// Using of this method automatically switches the mock into the useRepliesQueue mode.
+//
+// Note: multipart requests are implemented using two requests actually - the multipart
+// request itself followed by control ping used to tell which multipart message
+// is the last one. A mock reply to a multipart request has to thus consist of
+// exactly two calls of this method.
+// For example:
 //
-// It is able to also receive callback that calculates the reply
-func (a *VppAdapter) MockReply(msg api.Message) {
-       repliesLock.Lock()
-       defer repliesLock.Unlock()
+//    mockVpp.MockReply(  // push multipart messages all at once
+//                     &interfaces.SwInterfaceDetails{SwIfIndex:1},
+//                     &interfaces.SwInterfaceDetails{SwIfIndex:2},
+//                     &interfaces.SwInterfaceDetails{SwIfIndex:3},
+//    )
+//    mockVpp.MockReply(&vpe.ControlPingReply{})
+//
+// Even if the multipart request has no replies, MockReply has to be called twice:
+//
+//    mockVpp.MockReply()  // zero multipart messages
+//    mockVpp.MockReply(&vpe.ControlPingReply{})
+func (a *VppAdapter) MockReply(msgs ...api.Message) {
+       a.repliesLock.Lock()
+       defer a.repliesLock.Unlock()
+
+       r := reply{}
+       for _, msg := range msgs {
+               r.msgs = append(r.msgs, MsgWithContext{Msg: msg, hasCtx: false})
+       }
+       a.replies = append(a.replies, r)
+       a.mode = useRepliesQueue
+}
 
-       replies = append(replies, msg)
-       mode = useRepliesQueue
+// MockReplyWithContext queues next reply like MockReply() does, except that the
+// sequence number and multipart flag (= context minus channel ID) can be customized
+// and not necessarily match with the request.
+// The purpose of this function is to test handling of sequence numbers and as such
+// it is not really meant to be used outside the govpp UTs.
+func (a *VppAdapter) MockReplyWithContext(msgs ...MsgWithContext) {
+       a.repliesLock.Lock()
+       defer a.repliesLock.Unlock()
+
+       r := reply{}
+       for _, msg := range msgs {
+               r.msgs = append(r.msgs,
+                       MsgWithContext{Msg: msg.Msg, SeqNum: msg.SeqNum, Multipart: msg.Multipart, hasCtx: true})
+       }
+       a.replies = append(a.replies, r)
+       a.mode = useRepliesQueue
 }
 
+// MockReplyHandler registers a handler function that is supposed to generate mock responses to incoming requests.
+// Using of this method automatically switches the mock into th useReplyHandlers mode.
 func (a *VppAdapter) MockReplyHandler(replyHandler ReplyHandler) {
-       repliesLock.Lock()
-       defer repliesLock.Unlock()
+       a.repliesLock.Lock()
+       defer a.repliesLock.Unlock()
 
-       replyHandlers = append(replyHandlers, replyHandler)
-       mode = useReplyHandlers
+       a.replyHandlers = append(a.replyHandlers, replyHandler)
+       a.mode = useReplyHandlers
+}
+
+func setSeqNum(context uint32, seqNum uint16) (newContext uint32) {
+       context &= 0xffff0000
+       context |= uint32(seqNum)
+       return context
+}
+
+func setMultipart(context uint32, isMultipart bool) (newContext uint32) {
+       context &= 0xfffeffff
+       if isMultipart {
+               context |= 1 << 16
+       }
+       return context
 }