Decode message context using the message type only
[govpp.git] / core / channel.go
index 87b3e29..4cb5761 100644 (file)
 package core
 
 import (
+       "errors"
        "fmt"
+       "reflect"
+       "strings"
        "time"
 
-       "errors"
+       "github.com/sirupsen/logrus"
 
+       "git.fd.io/govpp.git/adapter"
        "git.fd.io/govpp.git/api"
-       "github.com/sirupsen/logrus"
 )
 
-const defaultReplyTimeout = time.Second * 1 // default timeout for replies from VPP, can be changed with SetReplyTimeout
+var (
+       ErrInvalidRequestCtx = errors.New("invalid request context")
+)
 
-// requestCtxData is a context of a ongoing request (simple one - only one response is expected).
-type requestCtxData struct {
-       ch     *channel
-       seqNum uint16
+// MessageCodec provides functionality for decoding binary data to generated API messages.
+type MessageCodec interface {
+       // EncodeMsg encodes message into binary data.
+       EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
+       // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
+       DecodeMsg(data []byte, msg api.Message) error
+       // DecodeMsgContext decodes context from message data and type.
+       DecodeMsgContext(data []byte, msgType api.MessageType) (context uint32, err error)
 }
 
-// multiRequestCtxData is a context of a ongoing multipart request (multiple responses are expected).
-type multiRequestCtxData struct {
-       ch     *channel
-       seqNum uint16
+// MessageIdentifier provides identification of generated API messages.
+type MessageIdentifier interface {
+       // GetMessageID returns message identifier of given API message.
+       GetMessageID(msg api.Message) (uint16, error)
+       // GetMessagePath returns path for the given message
+       GetMessagePath(msg api.Message) string
+       // LookupByID looks up message name and crc by ID
+       LookupByID(path string, msgID uint16) (api.Message, error)
 }
 
-func (req *requestCtxData) ReceiveReply(msg api.Message) error {
-       if req == nil || req.ch == nil {
-               return errors.New("invalid request context")
-       }
+// vppRequest is a request that will be sent to VPP.
+type vppRequest struct {
+       seqNum uint16      // sequence number
+       msg    api.Message // binary API message to be send to VPP
+       multi  bool        // true if multipart response is expected
+}
 
-       lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
+// vppReply is a reply received from VPP.
+type vppReply struct {
+       seqNum       uint16 // sequence number
+       msgID        uint16 // ID of the message
+       data         []byte // encoded data with the message
+       lastReceived bool   // for multi request, true if the last reply has been already received
+       err          error  // in case of error, data is nil and this member contains error
+}
 
-       if lastReplyReceived {
-               err = errors.New("multipart reply recieved while a simple reply expected")
-       }
-       return err
+// requestCtx is a context for request with single reply
+type requestCtx struct {
+       ch     *Channel
+       seqNum uint16
 }
 
-func (req *multiRequestCtxData) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
-       if req == nil || req.ch == nil {
-               return false, errors.New("invalid request context")
-       }
+// multiRequestCtx is a context for request with multiple responses
+type multiRequestCtx struct {
+       ch     *Channel
+       seqNum uint16
+}
 
-       return req.ch.receiveReplyInternal(msg, req.seqNum)
+// subscriptionCtx is a context of subscription for delivery of specific notification messages.
+type subscriptionCtx struct {
+       ch         *Channel
+       notifChan  chan api.Message   // channel where notification messages will be delivered to
+       msgID      uint16             // message ID for the subscribed event message
+       event      api.Message        // event message that this subscription is for
+       msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
 }
 
-// channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
+// Channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
 // concurrently, otherwise the responses could mix! Use multiple channels instead.
-type channel struct {
-       id uint16 // channel ID
-
-       reqChan   chan *api.VppRequest // channel for sending the requests to VPP, closing this channel releases all resources in the ChannelProvider
-       replyChan chan *api.VppReply   // channel where VPP replies are delivered to
+type Channel struct {
+       id   uint16
+       conn *Connection
 
-       notifSubsChan      chan *api.NotifSubscribeRequest // channel for sending notification subscribe requests
-       notifSubsReplyChan chan error                      // channel where replies to notification subscribe requests are delivered to
+       reqChan   chan *vppRequest // channel for sending the requests to VPP
+       replyChan chan *vppReply   // channel where VPP replies are delivered to
 
-       msgDecoder    api.MessageDecoder    // used to decode binary data to generated API messages
-       msgIdentifier api.MessageIdentifier // used to retrieve message ID of a message
+       msgCodec      MessageCodec      // used to decode binary data to generated API messages
+       msgIdentifier MessageIdentifier // used to retrieve message ID of a message
 
        lastSeqNum uint16 // sequence number of the last sent request
 
-       delayedReply *api.VppReply // reply already taken from ReplyChan, buffered for later delivery
-       replyTimeout time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
+       delayedReply        *vppReply     // reply already taken from ReplyChan, buffered for later delivery
+       replyTimeout        time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
+       receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
 }
 
-func (ch *channel) SendRequest(msg api.Message) api.RequestCtx {
-       ch.lastSeqNum++
-       ch.reqChan <- &api.VppRequest{
-               Message: msg,
-               SeqNum:  ch.lastSeqNum,
+func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel {
+       return &Channel{
+               id:                  id,
+               conn:                conn,
+               msgCodec:            codec,
+               msgIdentifier:       identifier,
+               reqChan:             make(chan *vppRequest, reqSize),
+               replyChan:           make(chan *vppReply, replySize),
+               replyTimeout:        DefaultReplyTimeout,
+               receiveReplyTimeout: ReplyChannelTimeout,
        }
-       return &requestCtxData{ch: ch, seqNum: ch.lastSeqNum}
 }
 
-func (ch *channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
-       ch.lastSeqNum++
-       ch.reqChan <- &api.VppRequest{
-               Message:   msg,
-               Multipart: true,
-               SeqNum:    ch.lastSeqNum,
-       }
-       return &multiRequestCtxData{ch: ch, seqNum: ch.lastSeqNum}
+func (ch *Channel) GetID() uint16 {
+       return ch.id
 }
 
-func (ch *channel) SubscribeNotification(notifChan chan api.Message, msgFactory func() api.Message) (*api.NotifSubscription, error) {
-       subscription := &api.NotifSubscription{
-               NotifChan:  notifChan,
-               MsgFactory: msgFactory,
-       }
-       ch.notifSubsChan <- &api.NotifSubscribeRequest{
-               Subscription: subscription,
-               Subscribe:    true,
-       }
-       return subscription, <-ch.notifSubsReplyChan
+func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
+       req := ch.newRequest(msg, false)
+       ch.reqChan <- req
+       return &requestCtx{ch: ch, seqNum: req.seqNum}
+}
+
+func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
+       req := ch.newRequest(msg, true)
+       ch.reqChan <- req
+       return &multiRequestCtx{ch: ch, seqNum: req.seqNum}
 }
 
-func (ch *channel) UnsubscribeNotification(subscription *api.NotifSubscription) error {
-       ch.notifSubsChan <- &api.NotifSubscribeRequest{
-               Subscription: subscription,
-               Subscribe:    false,
+func (ch *Channel) nextSeqNum() uint16 {
+       ch.lastSeqNum++
+       return ch.lastSeqNum
+}
+
+func (ch *Channel) newRequest(msg api.Message, multi bool) *vppRequest {
+       return &vppRequest{
+               msg:    msg,
+               seqNum: ch.nextSeqNum(),
+               multi:  multi,
        }
-       return <-ch.notifSubsReplyChan
 }
 
-func (ch *channel) CheckMessageCompatibility(messages ...api.Message) error {
-       for _, msg := range messages {
+func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
+       var comperr api.CompatibilityError
+       for _, msg := range msgs {
                _, err := ch.msgIdentifier.GetMessageID(msg)
                if err != nil {
-                       return fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
-                               msg.GetMessageName(), msg.GetCrcString())
+                       if uerr, ok := err.(*adapter.UnknownMsgError); ok {
+                               comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, getMsgID(uerr.MsgName, uerr.MsgCrc))
+                               continue
+                       }
+                       // other errors return immediatelly
+                       return err
                }
+               comperr.CompatibleMessages = append(comperr.CompatibleMessages, getMsgNameWithCrc(msg))
        }
-       return nil
+       if len(comperr.IncompatibleMessages) == 0 {
+               return nil
+       }
+       return &comperr
 }
 
-func (ch *channel) SetReplyTimeout(timeout time.Duration) {
-       ch.replyTimeout = timeout
-}
+func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
+       msgID, err := ch.msgIdentifier.GetMessageID(event)
+       if err != nil {
+               log.WithFields(logrus.Fields{
+                       "msg_name": event.GetMessageName(),
+                       "msg_crc":  event.GetCrcString(),
+               }).Errorf("unable to retrieve message ID: %v", err)
+               return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
+       }
 
-func (ch *channel) GetRequestChannel() chan<- *api.VppRequest {
-       return ch.reqChan
-}
+       sub := &subscriptionCtx{
+               ch:         ch,
+               notifChan:  notifChan,
+               msgID:      msgID,
+               event:      event,
+               msgFactory: getMsgFactory(event),
+       }
 
-func (ch *channel) GetReplyChannel() <-chan *api.VppReply {
-       return ch.replyChan
+       // add the subscription into map
+       ch.conn.subscriptionsLock.Lock()
+       defer ch.conn.subscriptionsLock.Unlock()
+
+       ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
+
+       return sub, nil
 }
 
-func (ch *channel) GetNotificationChannel() chan<- *api.NotifSubscribeRequest {
-       return ch.notifSubsChan
+func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
+       ch.replyTimeout = timeout
 }
 
-func (ch *channel) GetNotificationReplyChannel() <-chan error {
-       return ch.notifSubsReplyChan
+func (ch *Channel) Close() {
+       close(ch.reqChan)
 }
 
-func (ch *channel) GetMessageDecoder() api.MessageDecoder {
-       return ch.msgDecoder
+func (req *requestCtx) ReceiveReply(msg api.Message) error {
+       if req == nil || req.ch == nil {
+               return ErrInvalidRequestCtx
+       }
+
+       lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
+       if err != nil {
+               return err
+       } else if lastReplyReceived {
+               return errors.New("multipart reply recieved while a single reply expected")
+       }
+
+       return nil
 }
 
-func (ch *channel) GetID() uint16 {
-       return ch.id
+func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
+       if req == nil || req.ch == nil {
+               return false, ErrInvalidRequestCtx
+       }
+
+       return req.ch.receiveReplyInternal(msg, req.seqNum)
 }
 
-func (ch *channel) Close() {
-       if ch.reqChan != nil {
-               close(ch.reqChan)
+func (sub *subscriptionCtx) Unsubscribe() error {
+       log.WithFields(logrus.Fields{
+               "msg_name": sub.event.GetMessageName(),
+               "msg_id":   sub.msgID,
+       }).Debug("Removing notification subscription.")
+
+       // remove the subscription from the map
+       sub.ch.conn.subscriptionsLock.Lock()
+       defer sub.ch.conn.subscriptionsLock.Unlock()
+
+       for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
+               if item == sub {
+                       // close notification channel
+                       close(sub.ch.conn.subscriptions[sub.msgID][i].notifChan)
+                       // remove i-th item in the slice
+                       sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
+                       return nil
+               }
        }
+
+       return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
 }
 
 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
-func (ch *channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
-       var ignore bool
+func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
        if msg == nil {
                return false, errors.New("nil message passed in")
        }
 
-       if ch.delayedReply != nil {
+       var ignore bool
+
+       if vppReply := ch.delayedReply; vppReply != nil {
                // try the delayed reply
-               vppReply := ch.delayedReply
                ch.delayedReply = nil
                ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
                if !ignore {
@@ -189,25 +276,32 @@ func (ch *channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (last
                case vppReply := <-ch.replyChan:
                        ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
                        if ignore {
+                               log.WithFields(logrus.Fields{
+                                       "expSeqNum": expSeqNum,
+                                       "channel":   ch.id,
+                               }).Warnf("ignoring received reply: %+v (expecting: %s)", vppReply, msg.GetMessageName())
                                continue
                        }
                        return lastReplyReceived, err
 
                case <-timer.C:
+                       log.WithFields(logrus.Fields{
+                               "expSeqNum": expSeqNum,
+                               "channel":   ch.id,
+                       }).Debugf("timeout (%v) waiting for reply: %s", ch.replyTimeout, msg.GetMessageName())
                        err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
                        return false, err
                }
        }
-       return
 }
 
-func (ch *channel) processReply(reply *api.VppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
+func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
        // check the sequence number
-       cmpSeqNums := compareSeqNumbers(reply.SeqNum, expSeqNum)
+       cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
        if cmpSeqNums == -1 {
                // reply received too late, ignore the message
-               logrus.WithField("sequence-number", reply.SeqNum).Warn(
-                       "Received reply to an already closed binary API request")
+               log.WithField("seqNum", reply.seqNum).
+                       Warn("Received reply to an already closed binary API request")
                ignore = true
                return
        }
@@ -217,11 +311,11 @@ func (ch *channel) processReply(reply *api.VppReply, expSeqNum uint16, msg api.M
                return
        }
 
-       if reply.Error != nil {
-               err = reply.Error
+       if reply.err != nil {
+               err = reply.err
                return
        }
-       if reply.LastReplyReceived {
+       if reply.lastReceived {
                lastReplyReceived = true
                return
        }
@@ -235,42 +329,42 @@ func (ch *channel) processReply(reply *api.VppReply, expSeqNum uint16, msg api.M
                return
        }
 
-       if reply.MessageID != expMsgID {
+       if reply.msgID != expMsgID {
                var msgNameCrc string
-               if nameCrc, err := ch.msgIdentifier.LookupByID(reply.MessageID); err != nil {
+               pkgPath := ch.msgIdentifier.GetMessagePath(msg)
+               if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
                        msgNameCrc = err.Error()
                } else {
-                       msgNameCrc = nameCrc
+                       msgNameCrc = getMsgNameWithCrc(replyMsg)
                }
 
-               err = fmt.Errorf("received invalid message ID (seq-num=%d), expected %d (%s), but got %d (%s) "+
+               err = fmt.Errorf("received unexpected message (seqNum=%d), expected %s (ID %d), but got %s (ID %d) "+
                        "(check if multiple goroutines are not sharing single GoVPP channel)",
-                       reply.SeqNum, expMsgID, msg.GetMessageName(), reply.MessageID, msgNameCrc)
+                       reply.seqNum, msg.GetMessageName(), expMsgID, msgNameCrc, reply.msgID)
                return
        }
 
        // decode the message
-       err = ch.msgDecoder.DecodeMsg(reply.Data, msg)
-       return
-}
-
-// compareSeqNumbers returns -1, 0, 1 if sequence number <seqNum1> precedes, equals to,
-// or succeeds seq. number <seqNum2>.
-// Since sequence numbers cycle in the finite set of size 2^16, the function
-// must assume that the distance between compared sequence numbers is less than
-// (2^16)/2 to determine the order.
-func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
-       // calculate distance from seqNum1 to seqNum2
-       var dist uint16
-       if seqNum1 <= seqNum2 {
-               dist = seqNum2 - seqNum1
-       } else {
-               dist = 0xffff - (seqNum1 - seqNum2 - 1)
+       if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
+               return
        }
-       if dist == 0 {
-               return 0
-       } else if dist <= 0x8000 {
-               return -1
+
+       // check Retval and convert it into VnetAPIError error
+       if strings.HasSuffix(msg.GetMessageName(), "_reply") {
+               // TODO: use categories for messages to avoid checking message name
+               if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
+                       var retval int32
+                       switch f.Kind() {
+                       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                               retval = int32(f.Int())
+                       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                               retval = int32(f.Uint())
+                       default:
+                               logrus.Warnf("invalid kind (%v) for Retval field of message %v", f.Kind(), msg.GetMessageName())
+                       }
+                       err = api.RetvalToVPPApiError(retval)
+               }
        }
-       return 1
+
+       return
 }