Fix binapigen decoding and minor improvements
[govpp.git] / core / channel.go
1 // Copyright (c) 2018 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "strings"
22         "time"
23
24         "github.com/sirupsen/logrus"
25
26         "git.fd.io/govpp.git/adapter"
27         "git.fd.io/govpp.git/api"
28 )
29
30 var (
31         ErrInvalidRequestCtx = errors.New("invalid request context")
32 )
33
34 // MessageCodec provides functionality for decoding binary data to generated API messages.
35 type MessageCodec interface {
36         // EncodeMsg encodes message into binary data.
37         EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
38         // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
39         DecodeMsg(data []byte, msg api.Message) error
40         // DecodeMsgContext decodes context from message data.
41         DecodeMsgContext(data []byte, msg api.Message) (context uint32, err error)
42 }
43
44 // MessageIdentifier provides identification of generated API messages.
45 type MessageIdentifier interface {
46         // GetMessageID returns message identifier of given API message.
47         GetMessageID(msg api.Message) (uint16, error)
48         // LookupByID looks up message name and crc by ID
49         LookupByID(msgID uint16) (api.Message, error)
50 }
51
52 // vppRequest is a request that will be sent to VPP.
53 type vppRequest struct {
54         seqNum uint16      // sequence number
55         msg    api.Message // binary API message to be send to VPP
56         multi  bool        // true if multipart response is expected
57 }
58
59 // vppReply is a reply received from VPP.
60 type vppReply struct {
61         seqNum       uint16 // sequence number
62         msgID        uint16 // ID of the message
63         data         []byte // encoded data with the message
64         lastReceived bool   // for multi request, true if the last reply has been already received
65         err          error  // in case of error, data is nil and this member contains error
66 }
67
68 // requestCtx is a context for request with single reply
69 type requestCtx struct {
70         ch     *Channel
71         seqNum uint16
72 }
73
74 // multiRequestCtx is a context for request with multiple responses
75 type multiRequestCtx struct {
76         ch     *Channel
77         seqNum uint16
78 }
79
80 // subscriptionCtx is a context of subscription for delivery of specific notification messages.
81 type subscriptionCtx struct {
82         ch         *Channel
83         notifChan  chan api.Message   // channel where notification messages will be delivered to
84         msgID      uint16             // message ID for the subscribed event message
85         event      api.Message        // event message that this subscription is for
86         msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
87 }
88
89 // Channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
90 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
91 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
92 // concurrently, otherwise the responses could mix! Use multiple channels instead.
93 type Channel struct {
94         id   uint16
95         conn *Connection
96
97         reqChan   chan *vppRequest // channel for sending the requests to VPP
98         replyChan chan *vppReply   // channel where VPP replies are delivered to
99
100         msgCodec      MessageCodec      // used to decode binary data to generated API messages
101         msgIdentifier MessageIdentifier // used to retrieve message ID of a message
102
103         lastSeqNum uint16 // sequence number of the last sent request
104
105         delayedReply        *vppReply     // reply already taken from ReplyChan, buffered for later delivery
106         replyTimeout        time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
107         receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
108 }
109
110 func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel {
111         return &Channel{
112                 id:                  id,
113                 conn:                conn,
114                 msgCodec:            codec,
115                 msgIdentifier:       identifier,
116                 reqChan:             make(chan *vppRequest, reqSize),
117                 replyChan:           make(chan *vppReply, replySize),
118                 replyTimeout:        DefaultReplyTimeout,
119                 receiveReplyTimeout: ReplyChannelTimeout,
120         }
121 }
122
123 func (ch *Channel) GetID() uint16 {
124         return ch.id
125 }
126
127 func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
128         req := ch.newRequest(msg, false)
129         ch.reqChan <- req
130         return &requestCtx{ch: ch, seqNum: req.seqNum}
131 }
132
133 func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
134         req := ch.newRequest(msg, true)
135         ch.reqChan <- req
136         return &multiRequestCtx{ch: ch, seqNum: req.seqNum}
137 }
138
139 func (ch *Channel) nextSeqNum() uint16 {
140         ch.lastSeqNum++
141         return ch.lastSeqNum
142 }
143
144 func (ch *Channel) newRequest(msg api.Message, multi bool) *vppRequest {
145         return &vppRequest{
146                 msg:    msg,
147                 seqNum: ch.nextSeqNum(),
148                 multi:  multi,
149         }
150 }
151
152 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
153         var comperr api.CompatibilityError
154         for _, msg := range msgs {
155                 _, err := ch.msgIdentifier.GetMessageID(msg)
156                 if err != nil {
157                         if uerr, ok := err.(*adapter.UnknownMsgError); ok {
158                                 comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, getMsgID(uerr.MsgName, uerr.MsgCrc))
159                                 continue
160                         }
161                         // other errors return immediatelly
162                         return err
163                 }
164                 comperr.CompatibleMessages = append(comperr.CompatibleMessages, getMsgNameWithCrc(msg))
165         }
166         if len(comperr.IncompatibleMessages) == 0 {
167                 return nil
168         }
169         return &comperr
170 }
171
172 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
173         msgID, err := ch.msgIdentifier.GetMessageID(event)
174         if err != nil {
175                 log.WithFields(logrus.Fields{
176                         "msg_name": event.GetMessageName(),
177                         "msg_crc":  event.GetCrcString(),
178                 }).Errorf("unable to retrieve message ID: %v", err)
179                 return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
180         }
181
182         sub := &subscriptionCtx{
183                 ch:         ch,
184                 notifChan:  notifChan,
185                 msgID:      msgID,
186                 event:      event,
187                 msgFactory: getMsgFactory(event),
188         }
189
190         // add the subscription into map
191         ch.conn.subscriptionsLock.Lock()
192         defer ch.conn.subscriptionsLock.Unlock()
193
194         ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
195
196         return sub, nil
197 }
198
199 func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
200         ch.replyTimeout = timeout
201 }
202
203 func (ch *Channel) Close() {
204         close(ch.reqChan)
205 }
206
207 func (req *requestCtx) ReceiveReply(msg api.Message) error {
208         if req == nil || req.ch == nil {
209                 return ErrInvalidRequestCtx
210         }
211
212         lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
213         if err != nil {
214                 return err
215         } else if lastReplyReceived {
216                 return errors.New("multipart reply recieved while a single reply expected")
217         }
218
219         return nil
220 }
221
222 func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
223         if req == nil || req.ch == nil {
224                 return false, ErrInvalidRequestCtx
225         }
226
227         return req.ch.receiveReplyInternal(msg, req.seqNum)
228 }
229
230 func (sub *subscriptionCtx) Unsubscribe() error {
231         log.WithFields(logrus.Fields{
232                 "msg_name": sub.event.GetMessageName(),
233                 "msg_id":   sub.msgID,
234         }).Debug("Removing notification subscription.")
235
236         // remove the subscription from the map
237         sub.ch.conn.subscriptionsLock.Lock()
238         defer sub.ch.conn.subscriptionsLock.Unlock()
239
240         for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
241                 if item == sub {
242                         // close notification channel
243                         close(sub.ch.conn.subscriptions[sub.msgID][i].notifChan)
244                         // remove i-th item in the slice
245                         sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
246                         return nil
247                 }
248         }
249
250         return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
251 }
252
253 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
254 func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
255         if msg == nil {
256                 return false, errors.New("nil message passed in")
257         }
258
259         var ignore bool
260
261         if vppReply := ch.delayedReply; vppReply != nil {
262                 // try the delayed reply
263                 ch.delayedReply = nil
264                 ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
265                 if !ignore {
266                         return lastReplyReceived, err
267                 }
268         }
269
270         timer := time.NewTimer(ch.replyTimeout)
271         for {
272                 select {
273                 // blocks until a reply comes to ReplyChan or until timeout expires
274                 case vppReply := <-ch.replyChan:
275                         ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
276                         if ignore {
277                                 log.WithFields(logrus.Fields{
278                                         "expSeqNum": expSeqNum,
279                                         "channel":   ch.id,
280                                 }).Warnf("ignoring received reply: %+v (expecting: %s)", vppReply, msg.GetMessageName())
281                                 continue
282                         }
283                         return lastReplyReceived, err
284
285                 case <-timer.C:
286                         log.WithFields(logrus.Fields{
287                                 "expSeqNum": expSeqNum,
288                                 "channel":   ch.id,
289                         }).Debugf("timeout (%v) waiting for reply: %s", ch.replyTimeout, msg.GetMessageName())
290                         err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
291                         return false, err
292                 }
293         }
294 }
295
296 func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
297         // check the sequence number
298         cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
299         if cmpSeqNums == -1 {
300                 // reply received too late, ignore the message
301                 log.WithField("seqNum", reply.seqNum).
302                         Warn("Received reply to an already closed binary API request")
303                 ignore = true
304                 return
305         }
306         if cmpSeqNums == 1 {
307                 ch.delayedReply = reply
308                 err = fmt.Errorf("missing binary API reply with sequence number: %d", expSeqNum)
309                 return
310         }
311
312         if reply.err != nil {
313                 err = reply.err
314                 return
315         }
316         if reply.lastReceived {
317                 lastReplyReceived = true
318                 return
319         }
320
321         // message checks
322         var expMsgID uint16
323         expMsgID, err = ch.msgIdentifier.GetMessageID(msg)
324         if err != nil {
325                 err = fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
326                         msg.GetMessageName(), msg.GetCrcString())
327                 return
328         }
329
330         if reply.msgID != expMsgID {
331                 var msgNameCrc string
332                 if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
333                         msgNameCrc = err.Error()
334                 } else {
335                         msgNameCrc = getMsgNameWithCrc(replyMsg)
336                 }
337
338                 err = fmt.Errorf("received unexpected message (seqNum=%d), expected %s (ID %d), but got %s (ID %d) "+
339                         "(check if multiple goroutines are not sharing single GoVPP channel)",
340                         reply.seqNum, msg.GetMessageName(), expMsgID, msgNameCrc, reply.msgID)
341                 return
342         }
343
344         // decode the message
345         if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
346                 return
347         }
348
349         // check Retval and convert it into VnetAPIError error
350         if strings.HasSuffix(msg.GetMessageName(), "_reply") {
351                 // TODO: use categories for messages to avoid checking message name
352                 if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
353                         retval := int32(f.Int())
354                         err = api.RetvalToVPPApiError(retval)
355                 }
356         }
357
358         return
359 }