112c14e28baf6dbba38f9c76af2ea77c9bbd2bbe
[govpp.git] / core / channel.go
1 // Copyright (c) 2018 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "strings"
22         "time"
23
24         "github.com/sirupsen/logrus"
25
26         "git.fd.io/govpp.git/adapter"
27         "git.fd.io/govpp.git/api"
28 )
29
30 var (
31         ErrInvalidRequestCtx = errors.New("invalid request context")
32 )
33
34 // MessageCodec provides functionality for decoding binary data to generated API messages.
35 type MessageCodec interface {
36         // EncodeMsg encodes message into binary data.
37         EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
38         // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
39         DecodeMsg(data []byte, msg api.Message) error
40         // DecodeMsgContext decodes context from message data and type.
41         DecodeMsgContext(data []byte, msgType api.MessageType) (context uint32, err error)
42 }
43
44 // MessageIdentifier provides identification of generated API messages.
45 type MessageIdentifier interface {
46         // GetMessageID returns message identifier of given API message.
47         GetMessageID(msg api.Message) (uint16, error)
48         // GetMessagePath returns path for the given message
49         GetMessagePath(msg api.Message) string
50         // LookupByID looks up message name and crc by ID
51         LookupByID(path string, msgID uint16) (api.Message, error)
52 }
53
54 // vppRequest is a request that will be sent to VPP.
55 type vppRequest struct {
56         seqNum uint16      // sequence number
57         msg    api.Message // binary API message to be send to VPP
58         multi  bool        // true if multipart response is expected
59 }
60
61 // vppReply is a reply received from VPP.
62 type vppReply struct {
63         seqNum       uint16 // sequence number
64         msgID        uint16 // ID of the message
65         data         []byte // encoded data with the message
66         lastReceived bool   // for multi request, true if the last reply has been already received
67         err          error  // in case of error, data is nil and this member contains error
68 }
69
70 // requestCtx is a context for request with single reply
71 type requestCtx struct {
72         ch     *Channel
73         seqNum uint16
74 }
75
76 // multiRequestCtx is a context for request with multiple responses
77 type multiRequestCtx struct {
78         ch     *Channel
79         seqNum uint16
80 }
81
82 // subscriptionCtx is a context of subscription for delivery of specific notification messages.
83 type subscriptionCtx struct {
84         ch         *Channel
85         notifChan  chan api.Message   // channel where notification messages will be delivered to
86         msgID      uint16             // message ID for the subscribed event message
87         event      api.Message        // event message that this subscription is for
88         msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
89 }
90
91 // Channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
92 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
93 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
94 // concurrently, otherwise the responses could mix! Use multiple channels instead.
95 type Channel struct {
96         id   uint16
97         conn *Connection
98
99         reqChan   chan *vppRequest // channel for sending the requests to VPP
100         replyChan chan *vppReply   // channel where VPP replies are delivered to
101
102         msgCodec      MessageCodec      // used to decode binary data to generated API messages
103         msgIdentifier MessageIdentifier // used to retrieve message ID of a message
104
105         lastSeqNum uint16 // sequence number of the last sent request
106
107         delayedReply        *vppReply     // reply already taken from ReplyChan, buffered for later delivery
108         replyTimeout        time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
109         receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
110 }
111
112 func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) (*Channel, error) {
113         // create new channel
114         channel := &Channel{
115                 conn:                c,
116                 msgCodec:            c.codec,
117                 msgIdentifier:       c,
118                 reqChan:             make(chan *vppRequest, reqChanBufSize),
119                 replyChan:           make(chan *vppReply, replyChanBufSize),
120                 replyTimeout:        DefaultReplyTimeout,
121                 receiveReplyTimeout: ReplyChannelTimeout,
122         }
123
124         // store API channel within the client
125         c.channelsLock.Lock()
126         if len(c.channels) >= 0x7fff {
127                 return nil, errors.New("all channel IDs are used")
128         }
129         for {
130                 c.nextChannelID++
131                 chID := c.nextChannelID & 0x7fff
132                 _, ok := c.channels[chID]
133                 if !ok {
134                         channel.id = chID
135                         c.channels[chID] = channel
136                         break
137                 }
138         }
139         c.channelsLock.Unlock()
140
141         return channel, nil
142 }
143
144 func (ch *Channel) GetID() uint16 {
145         return ch.id
146 }
147
148 func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
149         req := ch.newRequest(msg, false)
150         ch.reqChan <- req
151         return &requestCtx{ch: ch, seqNum: req.seqNum}
152 }
153
154 func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
155         req := ch.newRequest(msg, true)
156         ch.reqChan <- req
157         return &multiRequestCtx{ch: ch, seqNum: req.seqNum}
158 }
159
160 func (ch *Channel) nextSeqNum() uint16 {
161         ch.lastSeqNum++
162         return ch.lastSeqNum
163 }
164
165 func (ch *Channel) newRequest(msg api.Message, multi bool) *vppRequest {
166         return &vppRequest{
167                 msg:    msg,
168                 seqNum: ch.nextSeqNum(),
169                 multi:  multi,
170         }
171 }
172
173 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
174         var comperr api.CompatibilityError
175         for _, msg := range msgs {
176                 _, err := ch.msgIdentifier.GetMessageID(msg)
177                 if err != nil {
178                         if uerr, ok := err.(*adapter.UnknownMsgError); ok {
179                                 comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, getMsgID(uerr.MsgName, uerr.MsgCrc))
180                                 continue
181                         }
182                         // other errors return immediatelly
183                         return err
184                 }
185                 comperr.CompatibleMessages = append(comperr.CompatibleMessages, getMsgNameWithCrc(msg))
186         }
187         if len(comperr.IncompatibleMessages) == 0 {
188                 return nil
189         }
190         return &comperr
191 }
192
193 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
194         msgID, err := ch.msgIdentifier.GetMessageID(event)
195         if err != nil {
196                 log.WithFields(logrus.Fields{
197                         "msg_name": event.GetMessageName(),
198                         "msg_crc":  event.GetCrcString(),
199                 }).Errorf("unable to retrieve message ID: %v", err)
200                 return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
201         }
202
203         sub := &subscriptionCtx{
204                 ch:         ch,
205                 notifChan:  notifChan,
206                 msgID:      msgID,
207                 event:      event,
208                 msgFactory: getMsgFactory(event),
209         }
210
211         // add the subscription into map
212         ch.conn.subscriptionsLock.Lock()
213         defer ch.conn.subscriptionsLock.Unlock()
214
215         ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
216
217         return sub, nil
218 }
219
220 func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
221         ch.replyTimeout = timeout
222 }
223
224 func (ch *Channel) Close() {
225         close(ch.reqChan)
226 }
227
228 func (req *requestCtx) ReceiveReply(msg api.Message) error {
229         if req == nil || req.ch == nil {
230                 return ErrInvalidRequestCtx
231         }
232
233         lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
234         if err != nil {
235                 return err
236         } else if lastReplyReceived {
237                 return errors.New("multipart reply recieved while a single reply expected")
238         }
239
240         return nil
241 }
242
243 func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
244         if req == nil || req.ch == nil {
245                 return false, ErrInvalidRequestCtx
246         }
247
248         return req.ch.receiveReplyInternal(msg, req.seqNum)
249 }
250
251 func (sub *subscriptionCtx) Unsubscribe() error {
252         log.WithFields(logrus.Fields{
253                 "msg_name": sub.event.GetMessageName(),
254                 "msg_id":   sub.msgID,
255         }).Debug("Removing notification subscription.")
256
257         // remove the subscription from the map
258         sub.ch.conn.subscriptionsLock.Lock()
259         defer sub.ch.conn.subscriptionsLock.Unlock()
260
261         for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
262                 if item == sub {
263                         // close notification channel
264                         close(sub.ch.conn.subscriptions[sub.msgID][i].notifChan)
265                         // remove i-th item in the slice
266                         sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
267                         return nil
268                 }
269         }
270
271         return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
272 }
273
274 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
275 func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
276         if msg == nil {
277                 return false, errors.New("nil message passed in")
278         }
279
280         var ignore bool
281
282         if vppReply := ch.delayedReply; vppReply != nil {
283                 // try the delayed reply
284                 ch.delayedReply = nil
285                 ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
286                 if !ignore {
287                         return lastReplyReceived, err
288                 }
289         }
290
291         timer := time.NewTimer(ch.replyTimeout)
292         for {
293                 select {
294                 // blocks until a reply comes to ReplyChan or until timeout expires
295                 case vppReply := <-ch.replyChan:
296                         ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
297                         if ignore {
298                                 log.WithFields(logrus.Fields{
299                                         "expSeqNum": expSeqNum,
300                                         "channel":   ch.id,
301                                 }).Warnf("ignoring received reply: %+v (expecting: %s)", vppReply, msg.GetMessageName())
302                                 continue
303                         }
304                         return lastReplyReceived, err
305
306                 case <-timer.C:
307                         log.WithFields(logrus.Fields{
308                                 "expSeqNum": expSeqNum,
309                                 "channel":   ch.id,
310                         }).Debugf("timeout (%v) waiting for reply: %s", ch.replyTimeout, msg.GetMessageName())
311                         err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
312                         return false, err
313                 }
314         }
315 }
316
317 func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
318         // check the sequence number
319         cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
320         if cmpSeqNums == -1 {
321                 // reply received too late, ignore the message
322                 log.WithField("seqNum", reply.seqNum).
323                         Warn("Received reply to an already closed binary API request")
324                 ignore = true
325                 return
326         }
327         if cmpSeqNums == 1 {
328                 ch.delayedReply = reply
329                 err = fmt.Errorf("missing binary API reply with sequence number: %d", expSeqNum)
330                 return
331         }
332
333         if reply.err != nil {
334                 err = reply.err
335                 return
336         }
337         if reply.lastReceived {
338                 lastReplyReceived = true
339                 return
340         }
341
342         // message checks
343         var expMsgID uint16
344         expMsgID, err = ch.msgIdentifier.GetMessageID(msg)
345         if err != nil {
346                 err = fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
347                         msg.GetMessageName(), msg.GetCrcString())
348                 return
349         }
350
351         if reply.msgID != expMsgID {
352                 var msgNameCrc string
353                 pkgPath := ch.msgIdentifier.GetMessagePath(msg)
354                 if replyMsg, err := ch.msgIdentifier.LookupByID(pkgPath, reply.msgID); err != nil {
355                         msgNameCrc = err.Error()
356                 } else {
357                         msgNameCrc = getMsgNameWithCrc(replyMsg)
358                 }
359
360                 err = fmt.Errorf("received unexpected message (seqNum=%d), expected %s (ID %d), but got %s (ID %d) "+
361                         "(check if multiple goroutines are not sharing single GoVPP channel)",
362                         reply.seqNum, msg.GetMessageName(), expMsgID, msgNameCrc, reply.msgID)
363                 return
364         }
365
366         // decode the message
367         if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
368                 return
369         }
370
371         // check Retval and convert it into VnetAPIError error
372         if strings.HasSuffix(msg.GetMessageName(), "_reply") {
373                 // TODO: use categories for messages to avoid checking message name
374                 if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
375                         var retval int32
376                         switch f.Kind() {
377                         case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
378                                 retval = int32(f.Int())
379                         case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
380                                 retval = int32(f.Uint())
381                         default:
382                                 logrus.Warnf("invalid kind (%v) for Retval field of message %v", f.Kind(), msg.GetMessageName())
383                         }
384                         err = api.RetvalToVPPApiError(retval)
385                 }
386         }
387
388         return
389 }