connection: prevent channel ID overlap
[govpp.git] / core / channel.go
index 5b513e3..112c14e 100644 (file)
@@ -23,6 +23,7 @@ import (
 
        "github.com/sirupsen/logrus"
 
+       "git.fd.io/govpp.git/adapter"
        "git.fd.io/govpp.git/api"
 )
 
@@ -36,14 +37,18 @@ type MessageCodec interface {
        EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
        // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
        DecodeMsg(data []byte, msg api.Message) error
+       // DecodeMsgContext decodes context from message data and type.
+       DecodeMsgContext(data []byte, msgType api.MessageType) (context uint32, err error)
 }
 
 // MessageIdentifier provides identification of generated API messages.
 type MessageIdentifier interface {
        // GetMessageID returns message identifier of given API message.
        GetMessageID(msg api.Message) (uint16, error)
+       // GetMessagePath returns path for the given message
+       GetMessagePath(msg api.Message) string
        // LookupByID looks up message name and crc by ID
-       LookupByID(msgID uint16) (api.Message, error)
+       LookupByID(path string, msgID uint16) (api.Message, error)
 }
 
 // vppRequest is a request that will be sent to VPP.
@@ -83,7 +88,7 @@ type subscriptionCtx struct {
        msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
 }
 
-// channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
+// Channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
 // concurrently, otherwise the responses could mix! Use multiple channels instead.
@@ -99,59 +104,90 @@ type Channel struct {
 
        lastSeqNum uint16 // sequence number of the last sent request
 
-       delayedReply *vppReply     // reply already taken from ReplyChan, buffered for later delivery
-       replyTimeout time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
+       delayedReply        *vppReply     // reply already taken from ReplyChan, buffered for later delivery
+       replyTimeout        time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
+       receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
 }
 
-func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel {
-       return &Channel{
-               id:            id,
-               conn:          conn,
-               msgCodec:      codec,
-               msgIdentifier: identifier,
-               reqChan:       make(chan *vppRequest, reqSize),
-               replyChan:     make(chan *vppReply, replySize),
-               replyTimeout:  DefaultReplyTimeout,
+func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) (*Channel, error) {
+       // create new channel
+       channel := &Channel{
+               conn:                c,
+               msgCodec:            c.codec,
+               msgIdentifier:       c,
+               reqChan:             make(chan *vppRequest, reqChanBufSize),
+               replyChan:           make(chan *vppReply, replyChanBufSize),
+               replyTimeout:        DefaultReplyTimeout,
+               receiveReplyTimeout: ReplyChannelTimeout,
        }
+
+       // store API channel within the client
+       c.channelsLock.Lock()
+       if len(c.channels) >= 0x7fff {
+               return nil, errors.New("all channel IDs are used")
+       }
+       for {
+               c.nextChannelID++
+               chID := c.nextChannelID & 0x7fff
+               _, ok := c.channels[chID]
+               if !ok {
+                       channel.id = chID
+                       c.channels[chID] = channel
+                       break
+               }
+       }
+       c.channelsLock.Unlock()
+
+       return channel, nil
 }
 
 func (ch *Channel) GetID() uint16 {
        return ch.id
 }
 
+func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
+       req := ch.newRequest(msg, false)
+       ch.reqChan <- req
+       return &requestCtx{ch: ch, seqNum: req.seqNum}
+}
+
+func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
+       req := ch.newRequest(msg, true)
+       ch.reqChan <- req
+       return &multiRequestCtx{ch: ch, seqNum: req.seqNum}
+}
+
 func (ch *Channel) nextSeqNum() uint16 {
        ch.lastSeqNum++
        return ch.lastSeqNum
 }
 
-func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
-       seqNum := ch.nextSeqNum()
-       ch.reqChan <- &vppRequest{
-               msg:    msg,
-               seqNum: seqNum,
-       }
-       return &requestCtx{ch: ch, seqNum: seqNum}
-}
-
-func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
-       seqNum := ch.nextSeqNum()
-       ch.reqChan <- &vppRequest{
+func (ch *Channel) newRequest(msg api.Message, multi bool) *vppRequest {
+       return &vppRequest{
                msg:    msg,
-               seqNum: seqNum,
-               multi:  true,
+               seqNum: ch.nextSeqNum(),
+               multi:  multi,
        }
-       return &multiRequestCtx{ch: ch, seqNum: seqNum}
 }
 
 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
+       var comperr api.CompatibilityError
        for _, msg := range msgs {
-               // TODO: collect all incompatible messages and return summarized error
                _, err := ch.msgIdentifier.GetMessageID(msg)
                if err != nil {
+                       if uerr, ok := err.(*adapter.UnknownMsgError); ok {
+                               comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, getMsgID(uerr.MsgName, uerr.MsgCrc))
+                               continue
+                       }
+                       // other errors return immediatelly
                        return err
                }
+               comperr.CompatibleMessages = append(comperr.CompatibleMessages, getMsgNameWithCrc(msg))
        }
-       return nil
+       if len(comperr.IncompatibleMessages) == 0 {
+               return nil
+       }
+       return &comperr
 }
 
 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
@@ -224,6 +260,8 @@ func (sub *subscriptionCtx) Unsubscribe() error {
 
        for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
                if item == sub {
+                       // close notification channel
+                       close(sub.ch.conn.subscriptions[sub.msgID][i].notifChan)
                        // remove i-th item in the slice
                        sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
                        return nil
@@ -312,15 +350,16 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa
 
        if reply.msgID != expMsgID {
                var msgNameCrc string
-               if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
+               pkgPath := ch.msgIdentifier.GetMessagePath(msg)
+               if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
                        msgNameCrc = err.Error()
                } else {
                        msgNameCrc = getMsgNameWithCrc(replyMsg)
                }
 
-               err = fmt.Errorf("received invalid message ID (seqNum=%d), expected %d (%s), but got %d (%s) "+
+               err = fmt.Errorf("received unexpected message (seqNum=%d), expected %s (ID %d), but got %s (ID %d) "+
                        "(check if multiple goroutines are not sharing single GoVPP channel)",
-                       reply.seqNum, expMsgID, msg.GetMessageName(), reply.msgID, msgNameCrc)
+                       reply.seqNum, msg.GetMessageName(), expMsgID, msgNameCrc, reply.msgID)
                return
        }
 
@@ -333,7 +372,15 @@ func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Messa
        if strings.HasSuffix(msg.GetMessageName(), "_reply") {
                // TODO: use categories for messages to avoid checking message name
                if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
-                       retval := int32(f.Int())
+                       var retval int32
+                       switch f.Kind() {
+                       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                               retval = int32(f.Int())
+                       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                               retval = int32(f.Uint())
+                       default:
+                               logrus.Warnf("invalid kind (%v) for Retval field of message %v", f.Kind(), msg.GetMessageName())
+                       }
                        err = api.RetvalToVPPApiError(retval)
                }
        }