f9d972aef89e90a59e011aab002210fdcdefaa59
[govpp.git] / core / request_handler.go
1 // Copyright (c) 2017 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "sync/atomic"
22         "time"
23
24         logger "github.com/sirupsen/logrus"
25
26         "git.fd.io/govpp.git/api"
27 )
28
29 var ReplyChannelTimeout = time.Millisecond * 100
30
31 var (
32         ErrNotConnected = errors.New("not connected to VPP, ignoring the request")
33         ErrProbeTimeout = errors.New("probe reply not received within timeout period")
34 )
35
36 // watchRequests watches for requests on the request API channel and forwards them as messages to VPP.
37 func (c *Connection) watchRequests(ch *Channel) {
38         for {
39                 select {
40                 case req, ok := <-ch.reqChan:
41                         // new request on the request channel
42                         if !ok {
43                                 // after closing the request channel, release API channel and return
44                                 c.releaseAPIChannel(ch)
45                                 return
46                         }
47                         if err := c.processRequest(ch, req); err != nil {
48                                 sendReply(ch, &vppReply{
49                                         seqNum: req.seqNum,
50                                         err:    fmt.Errorf("unable to process request: %w", err),
51                                 })
52                         }
53                 }
54         }
55 }
56
57 // processRequest processes a single request received on the request channel.
58 func (c *Connection) sendMessage(context uint32, msg api.Message) error {
59         // check whether we are connected to VPP
60         if atomic.LoadUint32(&c.vppConnected) == 0 {
61                 return ErrNotConnected
62         }
63
64         /*log := log.WithFields(logger.Fields{
65                 "context":  context,
66                 "msg_name": msg.GetMessageName(),
67                 "msg_crc":  msg.GetCrcString(),
68         })*/
69
70         // retrieve message ID
71         msgID, err := c.GetMessageID(msg)
72         if err != nil {
73                 //log.WithError(err).Debugf("unable to retrieve message ID: %#v", msg)
74                 return err
75         }
76
77         //log = log.WithField("msg_id", msgID)
78
79         // encode the message
80         data, err := c.codec.EncodeMsg(msg, msgID)
81         if err != nil {
82                 log.WithError(err).Debugf("unable to encode message: %#v", msg)
83                 return err
84         }
85
86         //log = log.WithField("msg_length", len(data))
87
88         if log.Level >= logger.DebugLevel {
89                 log.Debugf("--> SEND: MSG %T %+v", msg, msg)
90         }
91
92         // send message to VPP
93         err = c.vppClient.SendMsg(context, data)
94         if err != nil {
95                 log.WithError(err).Debugf("unable to send message: %#v", msg)
96                 return err
97         }
98
99         return nil
100 }
101
102 // processRequest processes a single request received on the request channel.
103 func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
104         // check whether we are connected to VPP
105         if atomic.LoadUint32(&c.vppConnected) == 0 {
106                 err := ErrNotConnected
107                 log.WithFields(logger.Fields{
108                         "channel":  ch.id,
109                         "seq_num":  req.seqNum,
110                         "msg_name": req.msg.GetMessageName(),
111                         "msg_crc":  req.msg.GetCrcString(),
112                         "error":    err,
113                 }).Warnf("Unable to process request")
114                 return err
115         }
116
117         // retrieve message ID
118         msgID, err := c.GetMessageID(req.msg)
119         if err != nil {
120                 log.WithFields(logger.Fields{
121                         "channel":  ch.id,
122                         "msg_name": req.msg.GetMessageName(),
123                         "msg_crc":  req.msg.GetCrcString(),
124                         "seq_num":  req.seqNum,
125                         "error":    err,
126                 }).Warnf("Unable to retrieve message ID")
127                 return err
128         }
129
130         // encode the message into binary
131         data, err := c.codec.EncodeMsg(req.msg, msgID)
132         if err != nil {
133                 log.WithFields(logger.Fields{
134                         "channel":  ch.id,
135                         "msg_id":   msgID,
136                         "msg_name": req.msg.GetMessageName(),
137                         "msg_crc":  req.msg.GetCrcString(),
138                         "seq_num":  req.seqNum,
139                         "error":    err,
140                 }).Warnf("Unable to encode message: %T %+v", req.msg, req.msg)
141                 return err
142         }
143
144         context := packRequestContext(ch.id, req.multi, req.seqNum)
145
146         if log.Level >= logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
147                 log.WithFields(logger.Fields{
148                         "channel":  ch.id,
149                         "msg_id":   msgID,
150                         "msg_name": req.msg.GetMessageName(),
151                         "msg_crc":  req.msg.GetCrcString(),
152                         "seq_num":  req.seqNum,
153                         "is_multi": req.multi,
154                         "context":  context,
155                         "data_len": len(data),
156                 }).Debugf("--> SEND MSG: %T %+v", req.msg, req.msg)
157         }
158
159         // send the request to VPP
160         err = c.vppClient.SendMsg(context, data)
161         if err != nil {
162                 log.WithFields(logger.Fields{
163                         "channel":  ch.id,
164                         "msg_id":   msgID,
165                         "msg_name": req.msg.GetMessageName(),
166                         "msg_crc":  req.msg.GetCrcString(),
167                         "seq_num":  req.seqNum,
168                         "is_multi": req.multi,
169                         "context":  context,
170                         "data_len": len(data),
171                         "error":    err,
172                 }).Warnf("Unable to send message")
173                 return err
174         }
175
176         if req.multi {
177                 // send a control ping to determine end of the multipart response
178                 pingData, _ := c.codec.EncodeMsg(c.msgControlPing, c.pingReqID)
179
180                 if log.Level >= logger.DebugLevel {
181                         log.WithFields(logger.Fields{
182                                 "channel":  ch.id,
183                                 "msg_id":   c.pingReqID,
184                                 "msg_name": c.msgControlPing.GetMessageName(),
185                                 "msg_crc":  c.msgControlPing.GetCrcString(),
186                                 "seq_num":  req.seqNum,
187                                 "context":  context,
188                                 "data_len": len(pingData),
189                         }).Debugf(" -> SEND MSG: %T", c.msgControlPing)
190                 }
191
192                 if err := c.vppClient.SendMsg(context, pingData); err != nil {
193                         log.WithFields(logger.Fields{
194                                 "context": context,
195                                 "seq_num": req.seqNum,
196                                 "error":   err,
197                         }).Warnf("unable to send control ping")
198                 }
199         }
200
201         return nil
202 }
203
204 // msgCallback is called whenever any binary API message comes from VPP.
205 func (c *Connection) msgCallback(msgID uint16, data []byte) {
206         if c == nil {
207                 log.WithField(
208                         "msg_id", msgID,
209                 ).Warn("Connection already disconnected, ignoring the message.")
210                 return
211         }
212
213         msg, err := c.getMessageByID(msgID)
214         if err != nil {
215                 log.Warnln(err)
216                 return
217         }
218
219         // decode message context to fix for special cases of messages,
220         // for example:
221         // - replies that don't have context as first field (comes as zero)
222         // - events that don't have context at all (comes as non zero)
223         //
224         context, err := c.codec.DecodeMsgContext(data, msg)
225         if err != nil {
226                 log.WithField("msg_id", msgID).Warnf("Unable to decode message context: %v", err)
227                 return
228         }
229
230         chanID, isMulti, seqNum := unpackRequestContext(context)
231
232         if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
233                 msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
234
235                 // decode the message
236                 if err = c.codec.DecodeMsg(data, msg); err != nil {
237                         err = fmt.Errorf("decoding message failed: %w", err)
238                         return
239                 }
240
241                 log.WithFields(logger.Fields{
242                         "context":  context,
243                         "msg_id":   msgID,
244                         "msg_size": len(data),
245                         "channel":  chanID,
246                         "is_multi": isMulti,
247                         "seq_num":  seqNum,
248                         "msg_crc":  msg.GetCrcString(),
249                 }).Debugf("<-- govpp RECEIVE: %s %+v", msg.GetMessageName(), msg)
250         }
251
252         if context == 0 || c.isNotificationMessage(msgID) {
253                 // process the message as a notification
254                 c.sendNotifications(msgID, data)
255                 return
256         }
257
258         // match ch according to the context
259         c.channelsLock.RLock()
260         ch, ok := c.channels[chanID]
261         c.channelsLock.RUnlock()
262         if !ok {
263                 log.WithFields(logger.Fields{
264                         "channel": chanID,
265                         "msg_id":  msgID,
266                 }).Error("Channel ID not known, ignoring the message.")
267                 return
268         }
269
270         // if this is a control ping reply to a multipart request,
271         // treat this as a last part of the reply
272         lastReplyReceived := isMulti && msgID == c.pingReplyID
273
274         // send the data to the channel, it needs to be copied,
275         // because it will be freed after this function returns
276         sendReply(ch, &vppReply{
277                 msgID:        msgID,
278                 seqNum:       seqNum,
279                 data:         append([]byte(nil), data...),
280                 lastReceived: lastReplyReceived,
281         })
282
283         // store actual time of this reply
284         c.lastReplyLock.Lock()
285         c.lastReply = time.Now()
286         c.lastReplyLock.Unlock()
287 }
288
289 // sendReply sends the reply into the go channel, if it cannot be completed without blocking, otherwise
290 // it logs the error and do not send the message.
291 func sendReply(ch *Channel, reply *vppReply) {
292         // first try to avoid creating timer
293         select {
294         case ch.replyChan <- reply:
295                 return // reply sent ok
296         default:
297                 // reply channel full
298         }
299         if ch.receiveReplyTimeout == 0 {
300                 log.WithFields(logger.Fields{
301                         "channel": ch.id,
302                         "msg_id":  reply.msgID,
303                         "seq_num": reply.seqNum,
304                         "err":     reply.err,
305                 }).Warn("Reply channel full, dropping reply.")
306                 return
307         }
308         select {
309         case ch.replyChan <- reply:
310                 return // reply sent ok
311         case <-time.After(ch.receiveReplyTimeout):
312                 // receiver still not ready
313                 log.WithFields(logger.Fields{
314                         "channel": ch.id,
315                         "msg_id":  reply.msgID,
316                         "seq_num": reply.seqNum,
317                         "err":     reply.err,
318                 }).Warnf("Unable to send reply (reciever end not ready in %v).", ch.receiveReplyTimeout)
319         }
320 }
321
322 // isNotificationMessage returns true if someone has subscribed to provided message ID.
323 func (c *Connection) isNotificationMessage(msgID uint16) bool {
324         c.subscriptionsLock.RLock()
325         defer c.subscriptionsLock.RUnlock()
326
327         _, exists := c.subscriptions[msgID]
328         return exists
329 }
330
331 // sendNotifications send a notification message to all subscribers subscribed for that message.
332 func (c *Connection) sendNotifications(msgID uint16, data []byte) {
333         c.subscriptionsLock.RLock()
334         defer c.subscriptionsLock.RUnlock()
335
336         matched := false
337
338         // send to notification to each subscriber
339         for _, sub := range c.subscriptions[msgID] {
340                 log.WithFields(logger.Fields{
341                         "msg_name": sub.event.GetMessageName(),
342                         "msg_id":   msgID,
343                         "msg_size": len(data),
344                 }).Debug("Sending a notification to the subscription channel.")
345
346                 event := sub.msgFactory()
347                 if err := c.codec.DecodeMsg(data, event); err != nil {
348                         log.WithFields(logger.Fields{
349                                 "msg_name": sub.event.GetMessageName(),
350                                 "msg_id":   msgID,
351                                 "msg_size": len(data),
352                                 "error":    err,
353                         }).Warnf("Unable to decode the notification message")
354                         continue
355                 }
356
357                 // send the message into the go channel of the subscription
358                 select {
359                 case sub.notifChan <- event:
360                         // message sent successfully
361                 default:
362                         // unable to write into the channel without blocking
363                         log.WithFields(logger.Fields{
364                                 "msg_name": sub.event.GetMessageName(),
365                                 "msg_id":   msgID,
366                                 "msg_size": len(data),
367                         }).Warn("Unable to deliver the notification, reciever end not ready.")
368                 }
369
370                 matched = true
371         }
372
373         if !matched {
374                 log.WithFields(logger.Fields{
375                         "msg_id":   msgID,
376                         "msg_size": len(data),
377                 }).Info("No subscription found for the notification message.")
378         }
379 }
380
381 // +------------------+-------------------+-----------------------+
382 // | 15b = channel ID | 1b = is multipart | 16b = sequence number |
383 // +------------------+-------------------+-----------------------+
384 func packRequestContext(chanID uint16, isMultipart bool, seqNum uint16) uint32 {
385         context := uint32(chanID) << 17
386         if isMultipart {
387                 context |= 1 << 16
388         }
389         context |= uint32(seqNum)
390         return context
391 }
392
393 func unpackRequestContext(context uint32) (chanID uint16, isMulipart bool, seqNum uint16) {
394         chanID = uint16(context >> 17)
395         if ((context >> 16) & 0x1) != 0 {
396                 isMulipart = true
397         }
398         seqNum = uint16(context & 0xffff)
399         return
400 }
401
402 // compareSeqNumbers returns -1, 0, 1 if sequence number <seqNum1> precedes, equals to,
403 // or succeeds seq. number <seqNum2>.
404 // Since sequence numbers cycle in the finite set of size 2^16, the function
405 // must assume that the distance between compared sequence numbers is less than
406 // (2^16)/2 to determine the order.
407 func compareSeqNumbers(seqNum1, seqNum2 uint16) int {
408         // calculate distance from seqNum1 to seqNum2
409         var dist uint16
410         if seqNum1 <= seqNum2 {
411                 dist = seqNum2 - seqNum1
412         } else {
413                 dist = 0xffff - (seqNum1 - seqNum2 - 1)
414         }
415         if dist == 0 {
416                 return 0
417         } else if dist <= 0x8000 {
418                 return -1
419         }
420         return 1
421 }
422
423 // Returns first message from any package where the message ID matches
424 // Note: the msg is further used only for its MessageType which is not
425 // affected by the message's package
426 func (c *Connection) getMessageByID(msgID uint16) (msg api.Message, err error) {
427         var ok bool
428         for _, msgs := range c.msgMapByPath {
429                 if msg, ok = msgs[msgID]; ok {
430                         break
431                 }
432         }
433         if !ok {
434                 return nil, fmt.Errorf("unknown message received, ID: %d", msgID)
435         }
436         return msg, nil
437 }