Decode message context using the message type only
[govpp.git] / core / request_handler.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "sync/atomic"
21         "time"
22
23         logger "github.com/sirupsen/logrus"
24
25         "git.fd.io/govpp.git/api"
26 )
27
28 var ReplyChannelTimeout = time.Millisecond * 100
29
30 var (
31         ErrNotConnected = errors.New("not connected to VPP, ignoring the request")
32         ErrProbeTimeout = errors.New("probe reply not received within timeout period")
33 )
34
35 // watchRequests watches for requests on the request API channel and forwards them as messages to VPP.
36 func (c *Connection) watchRequests(ch *Channel) {
37         for {
38                 select {
39                 case req, ok := <-ch.reqChan:
40                         // new request on the request channel
41                         if !ok {
42                                 // after closing the request channel, release API channel and return
43                                 c.releaseAPIChannel(ch)
44                                 return
45                         }
46                         if err := c.processRequest(ch, req); err != nil {
47                                 sendReply(ch, &vppReply{
48                                         seqNum: req.seqNum,
49                                         err:    fmt.Errorf("unable to process request: %w", err),
50                                 })
51                         }
52                 }
53         }
54 }
55
56 // processRequest processes a single request received on the request channel.
57 func (c *Connection) sendMessage(context uint32, msg api.Message) error {
58         // check whether we are connected to VPP
59         if atomic.LoadUint32(&c.vppConnected) == 0 {
60                 return ErrNotConnected
61         }
62
63         /*log := log.WithFields(logger.Fields{
64                 "context":  context,
65                 "msg_name": msg.GetMessageName(),
66                 "msg_crc":  msg.GetCrcString(),
67         })*/
68
69         // retrieve message ID
70         msgID, err := c.GetMessageID(msg)
71         if err != nil {
72                 //log.WithError(err).Debugf("unable to retrieve message ID: %#v", msg)
73                 return err
74         }
75
76         //log = log.WithField("msg_id", msgID)
77
78         // encode the message
79         data, err := c.codec.EncodeMsg(msg, msgID)
80         if err != nil {
81                 log.WithError(err).Debugf("unable to encode message: %#v", msg)
82                 return err
83         }
84
85         //log = log.WithField("msg_length", len(data))
86
87         if log.Level >= logger.DebugLevel {
88                 log.Debugf("--> SEND: MSG %T %+v", msg, msg)
89         }
90
91         // send message to VPP
92         err = c.vppClient.SendMsg(context, data)
93         if err != nil {
94                 log.WithError(err).Debugf("unable to send message: %#v", msg)
95                 return err
96         }
97
98         return nil
99 }
100
101 // processRequest processes a single request received on the request channel.
102 func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
103         // check whether we are connected to VPP
104         if atomic.LoadUint32(&c.vppConnected) == 0 {
105                 err := ErrNotConnected
106                 log.WithFields(logger.Fields{
107                         "channel":  ch.id,
108                         "seq_num":  req.seqNum,
109                         "msg_name": req.msg.GetMessageName(),
110                         "msg_crc":  req.msg.GetCrcString(),
111                         "error":    err,
112                 }).Warnf("Unable to process request")
113                 return err
114         }
115
116         // retrieve message ID
117         msgID, err := c.GetMessageID(req.msg)
118         if err != nil {
119                 log.WithFields(logger.Fields{
120                         "channel":  ch.id,
121                         "msg_name": req.msg.GetMessageName(),
122                         "msg_crc":  req.msg.GetCrcString(),
123                         "seq_num":  req.seqNum,
124                         "error":    err,
125                 }).Warnf("Unable to retrieve message ID")
126                 return err
127         }
128
129         // encode the message into binary
130         data, err := c.codec.EncodeMsg(req.msg, msgID)
131         if err != nil {
132                 log.WithFields(logger.Fields{
133                         "channel":  ch.id,
134                         "msg_id":   msgID,
135                         "msg_name": req.msg.GetMessageName(),
136                         "msg_crc":  req.msg.GetCrcString(),
137                         "seq_num":  req.seqNum,
138                         "error":    err,
139                 }).Warnf("Unable to encode message: %T %+v", req.msg, req.msg)
140                 return err
141         }
142
143         context := packRequestContext(ch.id, req.multi, req.seqNum)
144
145         if log.Level >= logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
146                 log.WithFields(logger.Fields{
147                         "channel":  ch.id,
148                         "msg_id":   msgID,
149                         "msg_name": req.msg.GetMessageName(),
150                         "msg_crc":  req.msg.GetCrcString(),
151                         "seq_num":  req.seqNum,
152                         "is_multi": req.multi,
153                         "context":  context,
154                         "data_len": len(data),
155                 }).Debugf("--> SEND MSG: %T %+v", req.msg, req.msg)
156         }
157
158         // send the request to VPP
159         err = c.vppClient.SendMsg(context, data)
160         if err != nil {
161                 log.WithFields(logger.Fields{
162                         "channel":  ch.id,
163                         "msg_id":   msgID,
164                         "msg_name": req.msg.GetMessageName(),
165                         "msg_crc":  req.msg.GetCrcString(),
166                         "seq_num":  req.seqNum,
167                         "is_multi": req.multi,
168                         "context":  context,
169                         "data_len": len(data),
170                         "error":    err,
171                 }).Warnf("Unable to send message")
172                 return err
173         }
174
175         if req.multi {
176                 // send a control ping to determine end of the multipart response
177                 pingData, _ := c.codec.EncodeMsg(c.msgControlPing, c.pingReqID)
178
179                 if log.Level >= logger.DebugLevel {
180                         log.WithFields(logger.Fields{
181                                 "channel":  ch.id,
182                                 "msg_id":   c.pingReqID,
183                                 "msg_name": c.msgControlPing.GetMessageName(),
184                                 "msg_crc":  c.msgControlPing.GetCrcString(),
185                                 "seq_num":  req.seqNum,
186                                 "context":  context,
187                                 "data_len": len(pingData),
188                         }).Debugf(" -> SEND MSG: %T", c.msgControlPing)
189                 }
190
191                 if err := c.vppClient.SendMsg(context, pingData); err != nil {
192                         log.WithFields(logger.Fields{
193                                 "context": context,
194                                 "seq_num": req.seqNum,
195                                 "error":   err,
196                         }).Warnf("unable to send control ping")
197                 }
198         }
199
200         return nil
201 }
202
203 // msgCallback is called whenever any binary API message comes from VPP.
204 func (c *Connection) msgCallback(msgID uint16, data []byte) {
205         if c == nil {
206                 log.WithField(
207                         "msg_id", msgID,
208                 ).Warn("Connection already disconnected, ignoring the message.")
209                 return
210         }
211
212         msgType, name, crc, err := c.getMessageDataByID(msgID)
213         if err != nil {
214                 log.Warnln(err)
215                 return
216         }
217
218         // decode message context to fix for special cases of messages,
219         // for example:
220         // - replies that don't have context as first field (comes as zero)
221         // - events that don't have context at all (comes as non zero)
222         //
223         context, err := c.codec.DecodeMsgContext(data, msgType)
224         if err != nil {
225                 log.WithField("msg_id", msgID).Warnf("Unable to decode message context: %v", err)
226                 return
227         }
228
229         chanID, isMulti, seqNum := unpackRequestContext(context)
230
231         if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
232                 log.WithFields(logger.Fields{
233                         "context":  context,
234                         "msg_id":   msgID,
235                         "msg_size": len(data),
236                         "channel":  chanID,
237                         "is_multi": isMulti,
238                         "seq_num":  seqNum,
239                         "msg_crc":  crc,
240                 }).Debugf("<-- govpp RECEIVE: %s %+v", name)
241         }
242
243         if context == 0 || c.isNotificationMessage(msgID) {
244                 // process the message as a notification
245                 c.sendNotifications(msgID, data)
246                 return
247         }
248
249         // match ch according to the context
250         c.channelsLock.RLock()
251         ch, ok := c.channels[chanID]
252         c.channelsLock.RUnlock()
253         if !ok {
254                 log.WithFields(logger.Fields{
255                         "channel": chanID,
256                         "msg_id":  msgID,
257                 }).Error("Channel ID not known, ignoring the message.")
258                 return
259         }
260
261         // if this is a control ping reply to a multipart request,
262         // treat this as a last part of the reply
263         lastReplyReceived := isMulti && msgID == c.pingReplyID
264
265         // send the data to the channel, it needs to be copied,
266         // because it will be freed after this function returns
267         sendReply(ch, &vppReply{
268                 msgID:        msgID,
269                 seqNum:       seqNum,
270                 data:         append([]byte(nil), data...),
271                 lastReceived: lastReplyReceived,
272         })
273
274         // store actual time of this reply
275         c.lastReplyLock.Lock()
276         c.lastReply = time.Now()
277         c.lastReplyLock.Unlock()
278 }
279
280 // sendReply sends the reply into the go channel, if it cannot be completed without blocking, otherwise
281 // it logs the error and do not send the message.
282 func sendReply(ch *Channel, reply *vppReply) {
283         // first try to avoid creating timer
284         select {
285         case ch.replyChan <- reply:
286                 return // reply sent ok
287         default:
288                 // reply channel full
289         }
290         if ch.receiveReplyTimeout == 0 {
291                 log.WithFields(logger.Fields{
292                         "channel": ch.id,
293                         "msg_id":  reply.msgID,
294                         "seq_num": reply.seqNum,
295                         "err":     reply.err,
296                 }).Warn("Reply channel full, dropping reply.")
297                 return
298         }
299         select {
300         case ch.replyChan <- reply:
301                 return // reply sent ok
302         case <-time.After(ch.receiveReplyTimeout):
303                 // receiver still not ready
304                 log.WithFields(logger.Fields{
305                         "channel": ch.id,
306                         "msg_id":  reply.msgID,
307                         "seq_num": reply.seqNum,
308                         "err":     reply.err,
309                 }).Warnf("Unable to send reply (reciever end not ready in %v).", ch.receiveReplyTimeout)
310         }
311 }
312
313 // isNotificationMessage returns true if someone has subscribed to provided message ID.
314 func (c *Connection) isNotificationMessage(msgID uint16) bool {
315         c.subscriptionsLock.RLock()
316         defer c.subscriptionsLock.RUnlock()
317
318         _, exists := c.subscriptions[msgID]
319         return exists
320 }
321
322 // sendNotifications send a notification message to all subscribers subscribed for that message.
323 func (c *Connection) sendNotifications(msgID uint16, data []byte) {
324         c.subscriptionsLock.RLock()
325         defer c.subscriptionsLock.RUnlock()
326
327         matched := false
328
329         // send to notification to each subscriber
330         for _, sub := range c.subscriptions[msgID] {
331                 log.WithFields(logger.Fields{
332                         "msg_name": sub.event.GetMessageName(),
333                         "msg_id":   msgID,
334                         "msg_size": len(data),
335                 }).Debug("Sending a notification to the subscription channel.")
336
337                 event := sub.msgFactory()
338                 if err := c.codec.DecodeMsg(data, event); err != nil {
339                         log.WithFields(logger.Fields{
340                                 "msg_name": sub.event.GetMessageName(),
341                                 "msg_id":   msgID,
342                                 "msg_size": len(data),
343                                 "error":    err,
344                         }).Warnf("Unable to decode the notification message")
345                         continue
346                 }
347
348                 // send the message into the go channel of the subscription
349                 select {
350                 case sub.notifChan <- event:
351                         // message sent successfully
352                 default:
353                         // unable to write into the channel without blocking
354                         log.WithFields(logger.Fields{
355                                 "msg_name": sub.event.GetMessageName(),
356                                 "msg_id":   msgID,
357                                 "msg_size": len(data),
358                         }).Warn("Unable to deliver the notification, reciever end not ready.")
359                 }
360
361                 matched = true
362         }
363
364         if !matched {
365                 log.WithFields(logger.Fields{
366                         "msg_id":   msgID,
367                         "msg_size": len(data),
368                 }).Info("No subscription found for the notification message.")
369         }
370 }
371
372 // +------------------+-------------------+-----------------------+
373 // | 15b = channel ID | 1b = is multipart | 16b = sequence number |
374 // +------------------+-------------------+-----------------------+
375 func packRequestContext(chanID uint16, isMultipart bool, seqNum uint16) uint32 {
376         context := uint32(chanID) << 17
377         if isMultipart {
378                 context |= 1 << 16
379         }
380         context |= uint32(seqNum)
381         return context
382 }
383
384 func unpackRequestContext(context uint32) (chanID uint16, isMulipart bool, seqNum uint16) {
385         chanID = uint16(context >> 17)
386         if ((context >> 16) & 0x1) != 0 {
387                 isMulipart = true
388         }
389         seqNum = uint16(context & 0xffff)
390         return
391 }
392
393 // compareSeqNumbers returns -1, 0, 1 if sequence number <seqNum1> precedes, equals to,
394 // or succeeds seq. number <seqNum2>.
395 // Since sequence numbers cycle in the finite set of size 2^16, the function
396 // must assume that the distance between compared sequence numbers is less than
397 // (2^16)/2 to determine the order.
398 func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
399         // calculate distance from seqNum1 to seqNum2
400         var dist uint16
401         if seqNum1 <= seqNum2 {
402                 dist = seqNum2 - seqNum1
403         } else {
404                 dist = 0xffff - (seqNum1 - seqNum2 - 1)
405         }
406         if dist == 0 {
407                 return 0
408         } else if dist <= 0x8000 {
409                 return -1
410         }
411         return 1
412 }
413
414 // Returns message data based on the message ID not depending on the message path
415 func (c *Connection) getMessageDataByID(msgID uint16) (typ api.MessageType, name, crc string, err error) {
416         for _, msgs := range c.msgMapByPath {
417                 if msg, ok := msgs[msgID]; ok {
418                         return msg.GetMessageType(), msg.GetMessageName(), msg.GetCrcString(), nil
419                 }
420         }
421         return typ, name, crc, fmt.Errorf("unknown message received, ID: %d", msgID)
422 }