1086c3639dbaf887f2a49bfdf8abc6cda671846d
[govpp.git] / core / channel.go
1 // Copyright (c) 2018 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "strings"
22         "sync/atomic"
23         "time"
24
25         "github.com/sirupsen/logrus"
26
27         "git.fd.io/govpp.git/adapter"
28         "git.fd.io/govpp.git/api"
29 )
30
31 var (
32         ErrInvalidRequestCtx = errors.New("invalid request context")
33 )
34
35 // MessageCodec provides functionality for decoding binary data to generated API messages.
36 type MessageCodec interface {
37         // EncodeMsg encodes message into binary data.
38         EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
39         // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
40         DecodeMsg(data []byte, msg api.Message) error
41         // DecodeMsgContext decodes context from message data and type.
42         DecodeMsgContext(data []byte, msgType api.MessageType) (context uint32, err error)
43 }
44
45 // MessageIdentifier provides identification of generated API messages.
46 type MessageIdentifier interface {
47         // GetMessageID returns message identifier of given API message.
48         GetMessageID(msg api.Message) (uint16, error)
49         // GetMessagePath returns path for the given message
50         GetMessagePath(msg api.Message) string
51         // LookupByID looks up message name and crc by ID
52         LookupByID(path string, msgID uint16) (api.Message, error)
53 }
54
55 // vppRequest is a request that will be sent to VPP.
56 type vppRequest struct {
57         seqNum uint16      // sequence number
58         msg    api.Message // binary API message to be send to VPP
59         multi  bool        // true if multipart response is expected
60 }
61
62 // vppReply is a reply received from VPP.
63 type vppReply struct {
64         seqNum       uint16 // sequence number
65         msgID        uint16 // ID of the message
66         data         []byte // encoded data with the message
67         lastReceived bool   // for multi request, true if the last reply has been already received
68         err          error  // in case of error, data is nil and this member contains error
69 }
70
71 // requestCtx is a context for request with single reply
72 type requestCtx struct {
73         ch     *Channel
74         seqNum uint16
75 }
76
77 // multiRequestCtx is a context for request with multiple responses
78 type multiRequestCtx struct {
79         ch     *Channel
80         seqNum uint16
81 }
82
83 // subscriptionCtx is a context of subscription for delivery of specific notification messages.
84 type subscriptionCtx struct {
85         ch         *Channel
86         notifChan  chan api.Message   // channel where notification messages will be delivered to
87         msgID      uint16             // message ID for the subscribed event message
88         event      api.Message        // event message that this subscription is for
89         msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
90 }
91
92 // Channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
93 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
94 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
95 // concurrently, otherwise the responses could mix! Use multiple channels instead.
96 type Channel struct {
97         id   uint16
98         conn *Connection
99
100         reqChan   chan *vppRequest // channel for sending the requests to VPP
101         replyChan chan *vppReply   // channel where VPP replies are delivered to
102
103         msgCodec      MessageCodec      // used to decode binary data to generated API messages
104         msgIdentifier MessageIdentifier // used to retrieve message ID of a message
105
106         lastSeqNum uint16 // sequence number of the last sent request
107
108         delayedReply        *vppReply     // reply already taken from ReplyChan, buffered for later delivery
109         replyTimeout        time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
110         receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
111 }
112
113 func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) *Channel {
114         // create new channel
115         chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff)
116         channel := &Channel{
117                 id:                  chID,
118                 conn:                c,
119                 msgCodec:            c.codec,
120                 msgIdentifier:       c,
121                 reqChan:             make(chan *vppRequest, reqChanBufSize),
122                 replyChan:           make(chan *vppReply, replyChanBufSize),
123                 replyTimeout:        DefaultReplyTimeout,
124                 receiveReplyTimeout: ReplyChannelTimeout,
125         }
126
127         // store API channel within the client
128         c.channelsLock.Lock()
129         c.channels[chID] = channel
130         c.channelsLock.Unlock()
131
132         return channel
133 }
134
135 func (ch *Channel) GetID() uint16 {
136         return ch.id
137 }
138
139 func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
140         req := ch.newRequest(msg, false)
141         ch.reqChan <- req
142         return &requestCtx{ch: ch, seqNum: req.seqNum}
143 }
144
145 func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
146         req := ch.newRequest(msg, true)
147         ch.reqChan <- req
148         return &multiRequestCtx{ch: ch, seqNum: req.seqNum}
149 }
150
151 func (ch *Channel) nextSeqNum() uint16 {
152         ch.lastSeqNum++
153         return ch.lastSeqNum
154 }
155
156 func (ch *Channel) newRequest(msg api.Message, multi bool) *vppRequest {
157         return &vppRequest{
158                 msg:    msg,
159                 seqNum: ch.nextSeqNum(),
160                 multi:  multi,
161         }
162 }
163
164 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
165         var comperr api.CompatibilityError
166         for _, msg := range msgs {
167                 _, err := ch.msgIdentifier.GetMessageID(msg)
168                 if err != nil {
169                         if uerr, ok := err.(*adapter.UnknownMsgError); ok {
170                                 comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, getMsgID(uerr.MsgName, uerr.MsgCrc))
171                                 continue
172                         }
173                         // other errors return immediatelly
174                         return err
175                 }
176                 comperr.CompatibleMessages = append(comperr.CompatibleMessages, getMsgNameWithCrc(msg))
177         }
178         if len(comperr.IncompatibleMessages) == 0 {
179                 return nil
180         }
181         return &comperr
182 }
183
184 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
185         msgID, err := ch.msgIdentifier.GetMessageID(event)
186         if err != nil {
187                 log.WithFields(logrus.Fields{
188                         "msg_name": event.GetMessageName(),
189                         "msg_crc":  event.GetCrcString(),
190                 }).Errorf("unable to retrieve message ID: %v", err)
191                 return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
192         }
193
194         sub := &subscriptionCtx{
195                 ch:         ch,
196                 notifChan:  notifChan,
197                 msgID:      msgID,
198                 event:      event,
199                 msgFactory: getMsgFactory(event),
200         }
201
202         // add the subscription into map
203         ch.conn.subscriptionsLock.Lock()
204         defer ch.conn.subscriptionsLock.Unlock()
205
206         ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
207
208         return sub, nil
209 }
210
211 func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
212         ch.replyTimeout = timeout
213 }
214
215 func (ch *Channel) Close() {
216         close(ch.reqChan)
217 }
218
219 func (req *requestCtx) ReceiveReply(msg api.Message) error {
220         if req == nil || req.ch == nil {
221                 return ErrInvalidRequestCtx
222         }
223
224         lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
225         if err != nil {
226                 return err
227         } else if lastReplyReceived {
228                 return errors.New("multipart reply recieved while a single reply expected")
229         }
230
231         return nil
232 }
233
234 func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
235         if req == nil || req.ch == nil {
236                 return false, ErrInvalidRequestCtx
237         }
238
239         return req.ch.receiveReplyInternal(msg, req.seqNum)
240 }
241
242 func (sub *subscriptionCtx) Unsubscribe() error {
243         log.WithFields(logrus.Fields{
244                 "msg_name": sub.event.GetMessageName(),
245                 "msg_id":   sub.msgID,
246         }).Debug("Removing notification subscription.")
247
248         // remove the subscription from the map
249         sub.ch.conn.subscriptionsLock.Lock()
250         defer sub.ch.conn.subscriptionsLock.Unlock()
251
252         for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
253                 if item == sub {
254                         // close notification channel
255                         close(sub.ch.conn.subscriptions[sub.msgID][i].notifChan)
256                         // remove i-th item in the slice
257                         sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
258                         return nil
259                 }
260         }
261
262         return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
263 }
264
265 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
266 func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
267         if msg == nil {
268                 return false, errors.New("nil message passed in")
269         }
270
271         var ignore bool
272
273         if vppReply := ch.delayedReply; vppReply != nil {
274                 // try the delayed reply
275                 ch.delayedReply = nil
276                 ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
277                 if !ignore {
278                         return lastReplyReceived, err
279                 }
280         }
281
282         timer := time.NewTimer(ch.replyTimeout)
283         for {
284                 select {
285                 // blocks until a reply comes to ReplyChan or until timeout expires
286                 case vppReply := <-ch.replyChan:
287                         ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
288                         if ignore {
289                                 log.WithFields(logrus.Fields{
290                                         "expSeqNum": expSeqNum,
291                                         "channel":   ch.id,
292                                 }).Warnf("ignoring received reply: %+v (expecting: %s)", vppReply, msg.GetMessageName())
293                                 continue
294                         }
295                         return lastReplyReceived, err
296
297                 case <-timer.C:
298                         log.WithFields(logrus.Fields{
299                                 "expSeqNum": expSeqNum,
300                                 "channel":   ch.id,
301                         }).Debugf("timeout (%v) waiting for reply: %s", ch.replyTimeout, msg.GetMessageName())
302                         err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
303                         return false, err
304                 }
305         }
306 }
307
308 func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
309         // check the sequence number
310         cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
311         if cmpSeqNums == -1 {
312                 // reply received too late, ignore the message
313                 log.WithField("seqNum", reply.seqNum).
314                         Warn("Received reply to an already closed binary API request")
315                 ignore = true
316                 return
317         }
318         if cmpSeqNums == 1 {
319                 ch.delayedReply = reply
320                 err = fmt.Errorf("missing binary API reply with sequence number: %d", expSeqNum)
321                 return
322         }
323
324         if reply.err != nil {
325                 err = reply.err
326                 return
327         }
328         if reply.lastReceived {
329                 lastReplyReceived = true
330                 return
331         }
332
333         // message checks
334         var expMsgID uint16
335         expMsgID, err = ch.msgIdentifier.GetMessageID(msg)
336         if err != nil {
337                 err = fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
338                         msg.GetMessageName(), msg.GetCrcString())
339                 return
340         }
341
342         if reply.msgID != expMsgID {
343                 var msgNameCrc string
344                 pkgPath := ch.msgIdentifier.GetMessagePath(msg)
345                 if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
346                         msgNameCrc = err.Error()
347                 } else {
348                         msgNameCrc = getMsgNameWithCrc(replyMsg)
349                 }
350
351                 err = fmt.Errorf("received unexpected message (seqNum=%d), expected %s (ID %d), but got %s (ID %d) "+
352                         "(check if multiple goroutines are not sharing single GoVPP channel)",
353                         reply.seqNum, msg.GetMessageName(), expMsgID, msgNameCrc, reply.msgID)
354                 return
355         }
356
357         // decode the message
358         if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
359                 return
360         }
361
362         // check Retval and convert it into VnetAPIError error
363         if strings.HasSuffix(msg.GetMessageName(), "_reply") {
364                 // TODO: use categories for messages to avoid checking message name
365                 if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
366                         var retval int32
367                         switch f.Kind() {
368                         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
369                                 retval = int32(f.Int())
370                         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
371                                 retval = int32(f.Uint())
372                         default:
373                                 logrus.Warnf("invalid kind (%v) for Retval field of message %v", f.Kind(), msg.GetMessageName())
374                         }
375                         err = api.RetvalToVPPApiError(retval)
376                 }
377         }
378
379         return
380 }