Change module name to go.fd.io/govpp
[govpp.git] / adapter / mock / mock_vpp_adapter.go
index 9dca2ba..cb37dd2 100644 (file)
 package mock
 
 import (
-       "bytes"
+       "encoding/binary"
        "log"
        "reflect"
        "sync"
 
-       "git.fd.io/govpp.git/adapter"
-       "git.fd.io/govpp.git/adapter/mock/binapi"
-       "git.fd.io/govpp.git/api"
-       "git.fd.io/govpp.git/codec"
-       "github.com/lunixbochs/struc"
+       "go.fd.io/govpp/adapter"
+       "go.fd.io/govpp/adapter/mock/binapi"
+       "go.fd.io/govpp/api"
+       "go.fd.io/govpp/codec"
 )
 
 type replyMode int
@@ -45,7 +44,7 @@ type VppAdapter struct {
        access       sync.RWMutex
        msgNameToIds map[string]uint16
        msgIDsToName map[uint16]string
-       binAPITypes  map[string]reflect.Type
+       binAPITypes  map[string]map[string]reflect.Type
 
        repliesLock   sync.Mutex     // mutex for the queue
        replies       []reply        // FIFO queue of messages
@@ -58,6 +57,38 @@ type defaultReply struct {
        Retval int32
 }
 
+func (*defaultReply) GetMessageName() string { return "mock_default_reply" }
+func (*defaultReply) GetCrcString() string   { return "xxxxxxxx" }
+func (*defaultReply) GetMessageType() api.MessageType {
+       return api.ReplyMessage
+}
+func (m *defaultReply) Size() int {
+       if m == nil {
+               return 0
+       }
+       var size int
+       // field[1] m.Retval
+       size += 4
+       return size
+}
+func (m *defaultReply) Marshal(b []byte) ([]byte, error) {
+       var buf *codec.Buffer
+       if b == nil {
+               buf = codec.NewBuffer(make([]byte, m.Size()))
+       } else {
+               buf = codec.NewBuffer(b)
+       }
+       // field[1] m.Retval
+       buf.EncodeUint32(uint32(m.Retval))
+       return buf.Bytes(), nil
+}
+func (m *defaultReply) Unmarshal(b []byte) error {
+       buf := codec.NewBuffer(b)
+       // field[1] m.Retval
+       m.Retval = int32(buf.DecodeUint32())
+       return nil
+}
+
 // MessageDTO is a structure used for propagating information to ReplyHandlers.
 type MessageDTO struct {
        MsgID    uint16
@@ -95,7 +126,7 @@ func NewVppAdapter() *VppAdapter {
                msgIDSeq:     1000,
                msgIDsToName: make(map[uint16]string),
                msgNameToIds: make(map[string]uint16),
-               binAPITypes:  make(map[string]reflect.Type),
+               binAPITypes:  make(map[string]map[string]reflect.Type),
        }
        a.registerBinAPITypes()
        return a
@@ -134,19 +165,25 @@ func (a *VppAdapter) GetMsgNameByID(msgID uint16) (string, bool) {
 func (a *VppAdapter) registerBinAPITypes() {
        a.access.Lock()
        defer a.access.Unlock()
-       for _, msg := range api.GetRegisteredMessages() {
-               a.binAPITypes[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+       for pkg, msgs := range api.GetRegisteredMessages() {
+               msgMap := make(map[string]reflect.Type)
+               for _, msg := range msgs {
+                       msgMap[msg.GetMessageName()] = reflect.TypeOf(msg).Elem()
+               }
+               a.binAPITypes[pkg] = msgMap
        }
 }
 
 // ReplyTypeFor returns reply message type for given request message name.
-func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16, bool) {
+func (a *VppAdapter) ReplyTypeFor(pkg, requestMsgName string) (reflect.Type, uint16, bool) {
        replyName, foundName := binapi.ReplyNameFor(requestMsgName)
        if foundName {
-               if reply, found := a.binAPITypes[replyName]; found {
-                       msgID, err := a.GetMsgID(replyName, "")
-                       if err == nil {
-                               return reply, msgID, found
+               if messages, found := a.binAPITypes[pkg]; found {
+                       if reply, found := messages[replyName]; found {
+                               msgID, err := a.GetMsgID(replyName, "")
+                               if err == nil {
+                                       return reply, msgID, found
+                               }
                        }
                }
        }
@@ -155,8 +192,8 @@ func (a *VppAdapter) ReplyTypeFor(requestMsgName string) (reflect.Type, uint16,
 }
 
 // ReplyFor returns reply message for given request message name.
-func (a *VppAdapter) ReplyFor(requestMsgName string) (api.Message, uint16, bool) {
-       replType, msgID, foundReplType := a.ReplyTypeFor(requestMsgName)
+func (a *VppAdapter) ReplyFor(pkg, requestMsgName string) (api.Message, uint16, bool) {
+       replType, msgID, foundReplType := a.ReplyTypeFor(pkg, requestMsgName)
        if foundReplType {
                msgVal := reflect.New(replType)
                if msg, ok := msgVal.Interface().(api.Message); ok {
@@ -178,19 +215,16 @@ func (a *VppAdapter) ReplyBytes(request MessageDTO, reply api.Message) ([]byte,
        }
        log.Println("ReplyBytes ", replyMsgID, " ", reply.GetMessageName(), " clientId: ", request.ClientID)
 
-       buf := new(bytes.Buffer)
-       err = struc.Pack(buf, &codec.VppReplyHeader{
-               VlMsgID: replyMsgID,
-               Context: request.ClientID,
-       })
+       data, err := codec.DefaultCodec.EncodeMsg(reply, replyMsgID)
        if err != nil {
                return nil, err
        }
-       if err = struc.Pack(buf, reply); err != nil {
-               return nil, err
+       if reply.GetMessageType() == api.ReplyMessage {
+               binary.BigEndian.PutUint32(data[2:6], request.ClientID)
+       } else if reply.GetMessageType() == api.RequestMessage {
+               binary.BigEndian.PutUint32(data[6:10], request.ClientID)
        }
-
-       return buf.Bytes(), nil
+       return data, nil
 }
 
 // GetMsgID returns mocked message ID for the given message name and CRC.
@@ -224,21 +258,22 @@ func (a *VppAdapter) GetMsgID(msgName string, msgCrc string) (uint16, error) {
 
 // SendMsg emulates sending a binary-encoded message to VPP.
 func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
-       switch a.mode {
+       a.repliesLock.Lock()
+       mode := a.mode
+       a.repliesLock.Unlock()
+       switch mode {
        case useReplyHandlers:
                for i := len(a.replyHandlers) - 1; i >= 0; i-- {
                        replyHandler := a.replyHandlers[i]
 
-                       buf := bytes.NewReader(data)
-                       reqHeader := codec.VppRequestHeader{}
-                       struc.Unpack(buf, &reqHeader)
+                       msgID := binary.BigEndian.Uint16(data[0:2])
 
                        a.access.Lock()
-                       reqMsgName := a.msgIDsToName[reqHeader.VlMsgID]
+                       reqMsgName := a.msgIDsToName[msgID]
                        a.access.Unlock()
 
                        reply, msgID, finished := replyHandler(MessageDTO{
-                               MsgID:    reqHeader.VlMsgID,
+                               MsgID:    msgID,
                                MsgName:  reqMsgName,
                                ClientID: clientID,
                                Data:     data,
@@ -259,23 +294,21 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
                        reply := a.replies[0]
                        for _, msg := range reply.msgs {
                                msgID, _ := a.GetMsgID(msg.Msg.GetMessageName(), msg.Msg.GetCrcString())
-                               buf := new(bytes.Buffer)
                                context := clientID
                                if msg.hasCtx {
                                        context = setMultipart(context, msg.Multipart)
                                        context = setSeqNum(context, msg.SeqNum)
                                }
+                               data, err := codec.DefaultCodec.EncodeMsg(msg.Msg, msgID)
+                               if err != nil {
+                                       panic(err)
+                               }
                                if msg.Msg.GetMessageType() == api.ReplyMessage {
-                                       struc.Pack(buf, &codec.VppReplyHeader{VlMsgID: msgID, Context: context})
+                                       binary.BigEndian.PutUint32(data[2:6], context)
                                } else if msg.Msg.GetMessageType() == api.RequestMessage {
-                                       struc.Pack(buf, &codec.VppRequestHeader{VlMsgID: msgID, Context: context})
-                               } else if msg.Msg.GetMessageType() == api.EventMessage {
-                                       struc.Pack(buf, &codec.VppEventHeader{VlMsgID: msgID})
-                               } else {
-                                       struc.Pack(buf, &codec.VppOtherHeader{VlMsgID: msgID})
+                                       binary.BigEndian.PutUint32(data[6:10], context)
                                }
-                               struc.Pack(buf, msg.Msg)
-                               a.callback(msgID, buf.Bytes())
+                               a.callback(msgID, data)
                        }
 
                        a.replies = a.replies[1:]
@@ -290,11 +323,13 @@ func (a *VppAdapter) SendMsg(clientID uint32, data []byte) error {
                //fallthrough
        default:
                // return default reply
-               buf := new(bytes.Buffer)
                msgID := uint16(defaultReplyMsgID)
-               struc.Pack(buf, &codec.VppReplyHeader{VlMsgID: msgID, Context: clientID})
-               struc.Pack(buf, &defaultReply{})
-               a.callback(msgID, buf.Bytes())
+               data, err := codec.DefaultCodec.EncodeMsg(&defaultReply{}, msgID)
+               if err != nil {
+                       panic(err)
+               }
+               binary.BigEndian.PutUint32(data[2:6], clientID)
+               a.callback(msgID, data)
        }
        return nil
 }
@@ -372,6 +407,16 @@ func (a *VppAdapter) MockReplyHandler(replyHandler ReplyHandler) {
        a.mode = useReplyHandlers
 }
 
+// MockClearReplyHanders clears all reply handlers that were registered
+// Will also set the mode to useReplyHandlers
+func (a *VppAdapter) MockClearReplyHandlers() {
+       a.repliesLock.Lock()
+       defer a.repliesLock.Unlock()
+
+       a.replyHandlers = a.replyHandlers[:0]
+       a.mode = useReplyHandlers
+}
+
 func setSeqNum(context uint32, seqNum uint16) (newContext uint32) {
        context &= 0xffff0000
        context |= uint32(seqNum)