Decode message context using the message type only
[govpp.git] / core / request_handler.go
index c1683d3..29685f6 100644 (file)
@@ -18,200 +18,405 @@ import (
        "errors"
        "fmt"
        "sync/atomic"
+       "time"
 
        logger "github.com/sirupsen/logrus"
 
        "git.fd.io/govpp.git/api"
-       "git.fd.io/govpp.git/core/bin_api/vpe"
+)
+
+var ReplyChannelTimeout = time.Millisecond * 100
+
+var (
+       ErrNotConnected = errors.New("not connected to VPP, ignoring the request")
+       ErrProbeTimeout = errors.New("probe reply not received within timeout period")
 )
 
 // watchRequests watches for requests on the request API channel and forwards them as messages to VPP.
-func (c *Connection) watchRequests(ch *api.Channel, chMeta *channelMetadata) {
+func (c *Connection) watchRequests(ch *Channel) {
        for {
                select {
-               case req, ok := <-ch.ReqChan:
+               case req, ok := <-ch.reqChan:
                        // new request on the request channel
                        if !ok {
                                // after closing the request channel, release API channel and return
-                               c.releaseAPIChannel(ch, chMeta)
+                               c.releaseAPIChannel(ch)
                                return
                        }
-                       c.processRequest(ch, chMeta, req)
-
-               case req := <-ch.NotifSubsChan:
-                       // new request on the notification subscribe channel
-                       c.processNotifSubscribeRequest(ch, req)
+                       if err := c.processRequest(ch, req); err != nil {
+                               sendReply(ch, &vppReply{
+                                       seqNum: req.seqNum,
+                                       err:    fmt.Errorf("unable to process request: %w", err),
+                               })
+                       }
                }
        }
 }
 
 // processRequest processes a single request received on the request channel.
-func (c *Connection) processRequest(ch *api.Channel, chMeta *channelMetadata, req *api.VppRequest) error {
+func (c *Connection) sendMessage(context uint32, msg api.Message) error {
        // check whether we are connected to VPP
-       if atomic.LoadUint32(&c.connected) == 0 {
-               error := errors.New("not connected to VPP, ignoring the request")
-               log.Error(error)
-               sendReply(ch, &api.VppReply{Error: error})
-               return error
+       if atomic.LoadUint32(&c.vppConnected) == 0 {
+               return ErrNotConnected
        }
 
+       /*log := log.WithFields(logger.Fields{
+               "context":  context,
+               "msg_name": msg.GetMessageName(),
+               "msg_crc":  msg.GetCrcString(),
+       })*/
+
        // retrieve message ID
-       msgID, err := c.GetMessageID(req.Message)
+       msgID, err := c.GetMessageID(msg)
+       if err != nil {
+               //log.WithError(err).Debugf("unable to retrieve message ID: %#v", msg)
+               return err
+       }
+
+       //log = log.WithField("msg_id", msgID)
+
+       // encode the message
+       data, err := c.codec.EncodeMsg(msg, msgID)
+       if err != nil {
+               log.WithError(err).Debugf("unable to encode message: %#v", msg)
+               return err
+       }
+
+       //log = log.WithField("msg_length", len(data))
+
+       if log.Level >= logger.DebugLevel {
+               log.Debugf("--> SEND: MSG %T %+v", msg, msg)
+       }
+
+       // send message to VPP
+       err = c.vppClient.SendMsg(context, data)
        if err != nil {
-               error := fmt.Errorf("unable to retrieve message ID: %v", err)
+               log.WithError(err).Debugf("unable to send message: %#v", msg)
+               return err
+       }
+
+       return nil
+}
+
+// processRequest processes a single request received on the request channel.
+func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
+       // check whether we are connected to VPP
+       if atomic.LoadUint32(&c.vppConnected) == 0 {
+               err := ErrNotConnected
                log.WithFields(logger.Fields{
-                       "msg_name": req.Message.GetMessageName(),
-                       "msg_crc":  req.Message.GetCrcString(),
-               }).Error(err)
-               sendReply(ch, &api.VppReply{Error: error})
-               return error
+                       "channel":  ch.id,
+                       "seq_num":  req.seqNum,
+                       "msg_name": req.msg.GetMessageName(),
+                       "msg_crc":  req.msg.GetCrcString(),
+                       "error":    err,
+               }).Warnf("Unable to process request")
+               return err
        }
 
-       // encode the message into binary
-       data, err := c.codec.EncodeMsg(req.Message, msgID)
+       // retrieve message ID
+       msgID, err := c.GetMessageID(req.msg)
        if err != nil {
-               error := fmt.Errorf("unable to encode the messge: %v", err)
                log.WithFields(logger.Fields{
-                       "context": chMeta.id,
-                       "msg_id":  msgID,
-               }).Error(error)
-               sendReply(ch, &api.VppReply{Error: error})
-               return error
+                       "channel":  ch.id,
+                       "msg_name": req.msg.GetMessageName(),
+                       "msg_crc":  req.msg.GetCrcString(),
+                       "seq_num":  req.seqNum,
+                       "error":    err,
+               }).Warnf("Unable to retrieve message ID")
+               return err
        }
 
-       if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
+       // encode the message into binary
+       data, err := c.codec.EncodeMsg(req.msg, msgID)
+       if err != nil {
                log.WithFields(logger.Fields{
-                       "context":  chMeta.id,
+                       "channel":  ch.id,
                        "msg_id":   msgID,
-                       "msg_size": len(data),
-               }).Debug("Sending a message to VPP.")
+                       "msg_name": req.msg.GetMessageName(),
+                       "msg_crc":  req.msg.GetCrcString(),
+                       "seq_num":  req.seqNum,
+                       "error":    err,
+               }).Warnf("Unable to encode message: %T %+v", req.msg, req.msg)
+               return err
        }
 
-       // send the message
-       if req.Multipart {
-               // expect multipart response
-               atomic.StoreUint32(&chMeta.multipart, 1)
+       context := packRequestContext(ch.id, req.multi, req.seqNum)
+
+       if log.Level >= logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
+               log.WithFields(logger.Fields{
+                       "channel":  ch.id,
+                       "msg_id":   msgID,
+                       "msg_name": req.msg.GetMessageName(),
+                       "msg_crc":  req.msg.GetCrcString(),
+                       "seq_num":  req.seqNum,
+                       "is_multi": req.multi,
+                       "context":  context,
+                       "data_len": len(data),
+               }).Debugf("--> SEND MSG: %T %+v", req.msg, req.msg)
        }
 
        // send the request to VPP
-       c.vpp.SendMsg(chMeta.id, data)
+       err = c.vppClient.SendMsg(context, data)
+       if err != nil {
+               log.WithFields(logger.Fields{
+                       "channel":  ch.id,
+                       "msg_id":   msgID,
+                       "msg_name": req.msg.GetMessageName(),
+                       "msg_crc":  req.msg.GetCrcString(),
+                       "seq_num":  req.seqNum,
+                       "is_multi": req.multi,
+                       "context":  context,
+                       "data_len": len(data),
+                       "error":    err,
+               }).Warnf("Unable to send message")
+               return err
+       }
 
-       if req.Multipart {
+       if req.multi {
                // send a control ping to determine end of the multipart response
-               ping := &vpe.ControlPing{}
-               pingData, _ := c.codec.EncodeMsg(ping, c.pingReqID)
+               pingData, _ := c.codec.EncodeMsg(c.msgControlPing, c.pingReqID)
 
-               log.WithFields(logger.Fields{
-                       "context":  chMeta.id,
-                       "msg_id":   c.pingReqID,
-                       "msg_size": len(pingData),
-               }).Debug("Sending a control ping to VPP.")
+               if log.Level >= logger.DebugLevel {
+                       log.WithFields(logger.Fields{
+                               "channel":  ch.id,
+                               "msg_id":   c.pingReqID,
+                               "msg_name": c.msgControlPing.GetMessageName(),
+                               "msg_crc":  c.msgControlPing.GetCrcString(),
+                               "seq_num":  req.seqNum,
+                               "context":  context,
+                               "data_len": len(pingData),
+                       }).Debugf(" -> SEND MSG: %T", c.msgControlPing)
+               }
 
-               c.vpp.SendMsg(chMeta.id, pingData)
+               if err := c.vppClient.SendMsg(context, pingData); err != nil {
+                       log.WithFields(logger.Fields{
+                               "context": context,
+                               "seq_num": req.seqNum,
+                               "error":   err,
+                       }).Warnf("unable to send control ping")
+               }
        }
 
        return nil
 }
 
 // msgCallback is called whenever any binary API message comes from VPP.
-func msgCallback(context uint32, msgID uint16, data []byte) {
-       connLock.RLock()
-       defer connLock.RUnlock()
+func (c *Connection) msgCallback(msgID uint16, data []byte) {
+       if c == nil {
+               log.WithField(
+                       "msg_id", msgID,
+               ).Warn("Connection already disconnected, ignoring the message.")
+               return
+       }
+
+       msgType, name, crc, err := c.getMessageDataByID(msgID)
+       if err != nil {
+               log.Warnln(err)
+               return
+       }
 
-       if conn == nil {
-               log.Warn("Already disconnected, ignoring the message.")
+       // decode message context to fix for special cases of messages,
+       // for example:
+       // - replies that don't have context as first field (comes as zero)
+       // - events that don't have context at all (comes as non zero)
+       //
+       context, err := c.codec.DecodeMsgContext(data, msgType)
+       if err != nil {
+               log.WithField("msg_id", msgID).Warnf("Unable to decode message context: %v", err)
                return
        }
 
+       chanID, isMulti, seqNum := unpackRequestContext(context)
+
        if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
                log.WithFields(logger.Fields{
                        "context":  context,
                        "msg_id":   msgID,
                        "msg_size": len(data),
-               }).Debug("Received a message from VPP.")
+                       "channel":  chanID,
+                       "is_multi": isMulti,
+                       "seq_num":  seqNum,
+                       "msg_crc":  crc,
+               }).Debugf("<-- govpp RECEIVE: %s %+v", name)
        }
 
-       if context == 0 || conn.isNotificationMessage(msgID) {
+       if context == 0 || c.isNotificationMessage(msgID) {
                // process the message as a notification
-               conn.sendNotifications(msgID, data)
+               c.sendNotifications(msgID, data)
                return
        }
 
        // match ch according to the context
-       conn.channelsLock.RLock()
-       ch, ok := conn.channels[context]
-       conn.channelsLock.RUnlock()
-
+       c.channelsLock.RLock()
+       ch, ok := c.channels[chanID]
+       c.channelsLock.RUnlock()
        if !ok {
                log.WithFields(logger.Fields{
-                       "context": context,
+                       "channel": chanID,
                        "msg_id":  msgID,
-               }).Error("Context ID not known, ignoring the message.")
+               }).Error("Channel ID not known, ignoring the message.")
                return
        }
 
-       chMeta := ch.Metadata().(*channelMetadata)
-       lastReplyReceived := false
-       // if this is a control ping reply and multipart request is being processed, treat this as a last part of the reply
-       if msgID == conn.pingReplyID && atomic.CompareAndSwapUint32(&chMeta.multipart, 1, 0) {
-               lastReplyReceived = true
-       }
+       // if this is a control ping reply to a multipart request,
+       // treat this as a last part of the reply
+       lastReplyReceived := isMulti && msgID == c.pingReplyID
 
-       // send the data to the channel
-       sendReply(ch, &api.VppReply{
-               MessageID:         msgID,
-               Data:              data,
-               LastReplyReceived: lastReplyReceived,
+       // send the data to the channel, it needs to be copied,
+       // because it will be freed after this function returns
+       sendReply(ch, &vppReply{
+               msgID:        msgID,
+               seqNum:       seqNum,
+               data:         append([]byte(nil), data...),
+               lastReceived: lastReplyReceived,
        })
+
+       // store actual time of this reply
+       c.lastReplyLock.Lock()
+       c.lastReply = time.Now()
+       c.lastReplyLock.Unlock()
 }
 
 // sendReply sends the reply into the go channel, if it cannot be completed without blocking, otherwise
 // it logs the error and do not send the message.
-func sendReply(ch *api.Channel, reply *api.VppReply) {
+func sendReply(ch *Channel, reply *vppReply) {
+       // first try to avoid creating timer
        select {
-       case ch.ReplyChan <- reply:
-               // reply sent successfully
+       case ch.replyChan <- reply:
+               return // reply sent ok
        default:
-               // unable to write into the channel without blocking
+               // reply channel full
+       }
+       if ch.receiveReplyTimeout == 0 {
                log.WithFields(logger.Fields{
-                       "channel": ch,
-                       "msg_id":  reply.MessageID,
-               }).Warn("Unable to send the reply, reciever end not ready.")
+                       "channel": ch.id,
+                       "msg_id":  reply.msgID,
+                       "seq_num": reply.seqNum,
+                       "err":     reply.err,
+               }).Warn("Reply channel full, dropping reply.")
+               return
+       }
+       select {
+       case ch.replyChan <- reply:
+               return // reply sent ok
+       case <-time.After(ch.receiveReplyTimeout):
+               // receiver still not ready
+               log.WithFields(logger.Fields{
+                       "channel": ch.id,
+                       "msg_id":  reply.msgID,
+                       "seq_num": reply.seqNum,
+                       "err":     reply.err,
+               }).Warnf("Unable to send reply (reciever end not ready in %v).", ch.receiveReplyTimeout)
        }
 }
 
-// GetMessageID returns message identifier of given API message.
-func (c *Connection) GetMessageID(msg api.Message) (uint16, error) {
-       if c == nil {
-               return 0, errors.New("nil connection passed in")
-       }
-       return c.messageNameToID(msg.GetMessageName(), msg.GetCrcString())
+// isNotificationMessage returns true if someone has subscribed to provided message ID.
+func (c *Connection) isNotificationMessage(msgID uint16) bool {
+       c.subscriptionsLock.RLock()
+       defer c.subscriptionsLock.RUnlock()
+
+       _, exists := c.subscriptions[msgID]
+       return exists
 }
 
-// messageNameToID returns message ID of a message identified by its name and CRC.
-func (c *Connection) messageNameToID(msgName string, msgCrc string) (uint16, error) {
-       // try to get the ID from the map
-       c.msgIDsLock.RLock()
-       id, ok := c.msgIDs[msgName+msgCrc]
-       c.msgIDsLock.RUnlock()
-       if ok {
-               return id, nil
+// sendNotifications send a notification message to all subscribers subscribed for that message.
+func (c *Connection) sendNotifications(msgID uint16, data []byte) {
+       c.subscriptionsLock.RLock()
+       defer c.subscriptionsLock.RUnlock()
+
+       matched := false
+
+       // send to notification to each subscriber
+       for _, sub := range c.subscriptions[msgID] {
+               log.WithFields(logger.Fields{
+                       "msg_name": sub.event.GetMessageName(),
+                       "msg_id":   msgID,
+                       "msg_size": len(data),
+               }).Debug("Sending a notification to the subscription channel.")
+
+               event := sub.msgFactory()
+               if err := c.codec.DecodeMsg(data, event); err != nil {
+                       log.WithFields(logger.Fields{
+                               "msg_name": sub.event.GetMessageName(),
+                               "msg_id":   msgID,
+                               "msg_size": len(data),
+                               "error":    err,
+                       }).Warnf("Unable to decode the notification message")
+                       continue
+               }
+
+               // send the message into the go channel of the subscription
+               select {
+               case sub.notifChan <- event:
+                       // message sent successfully
+               default:
+                       // unable to write into the channel without blocking
+                       log.WithFields(logger.Fields{
+                               "msg_name": sub.event.GetMessageName(),
+                               "msg_id":   msgID,
+                               "msg_size": len(data),
+                       }).Warn("Unable to deliver the notification, reciever end not ready.")
+               }
+
+               matched = true
        }
 
-       // get the ID using VPP API
-       id, err := c.vpp.GetMsgID(msgName, msgCrc)
-       if err != nil {
-               error := fmt.Errorf("unable to retrieve message ID: %v", err)
+       if !matched {
                log.WithFields(logger.Fields{
-                       "msg_name": msgName,
-                       "msg_crc":  msgCrc,
-               }).Errorf("unable to retrieve message ID: %v", err)
-               return id, error
+                       "msg_id":   msgID,
+                       "msg_size": len(data),
+               }).Info("No subscription found for the notification message.")
+       }
+}
+
+// +------------------+-------------------+-----------------------+
+// | 15b = channel ID | 1b = is multipart | 16b = sequence number |
+// +------------------+-------------------+-----------------------+
+func packRequestContext(chanID uint16, isMultipart bool, seqNum uint16) uint32 {
+       context := uint32(chanID) << 17
+       if isMultipart {
+               context |= 1 << 16
        }
+       context |= uint32(seqNum)
+       return context
+}
 
-       c.msgIDsLock.Lock()
-       c.msgIDs[msgName+msgCrc] = id
-       c.msgIDsLock.Unlock()
+func unpackRequestContext(context uint32) (chanID uint16, isMulipart bool, seqNum uint16) {
+       chanID = uint16(context >> 17)
+       if ((context >> 16) & 0x1) != 0 {
+               isMulipart = true
+       }
+       seqNum = uint16(context & 0xffff)
+       return
+}
 
-       return id, nil
+// compareSeqNumbers returns -1, 0, 1 if sequence number <seqNum1> precedes, equals to,
+// or succeeds seq. number <seqNum2>.
+// Since sequence numbers cycle in the finite set of size 2^16, the function
+// must assume that the distance between compared sequence numbers is less than
+// (2^16)/2 to determine the order.
+func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
+       // calculate distance from seqNum1 to seqNum2
+       var dist uint16
+       if seqNum1 <= seqNum2 {
+               dist = seqNum2 - seqNum1
+       } else {
+               dist = 0xffff - (seqNum1 - seqNum2 - 1)
+       }
+       if dist == 0 {
+               return 0
+       } else if dist <= 0x8000 {
+               return -1
+       }
+       return 1
+}
+
+// Returns message data based on the message ID not depending on the message path
+func (c *Connection) getMessageDataByID(msgID uint16) (typ api.MessageType, name, crc string, err error) {
+       for _, msgs := range c.msgMapByPath {
+               if msg, ok := msgs[msgID]; ok {
+                       return msg.GetMessageType(), msg.GetMessageName(), msg.GetCrcString(), nil
+               }
+       }
+       return typ, name, crc, fmt.Errorf("unknown message received, ID: %d", msgID)
 }