Improve compatibility checking
[govpp.git] / core / channel.go
1 // Copyright (c) 2018 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "strings"
22         "time"
23
24         "github.com/sirupsen/logrus"
25
26         "git.fd.io/govpp.git/adapter"
27         "git.fd.io/govpp.git/api"
28 )
29
30 var (
31         ErrInvalidRequestCtx = errors.New("invalid request context")
32 )
33
34 // MessageCodec provides functionality for decoding binary data to generated API messages.
35 type MessageCodec interface {
36         // EncodeMsg encodes message into binary data.
37         EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
38         // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
39         DecodeMsg(data []byte, msg api.Message) error
40 }
41
42 // MessageIdentifier provides identification of generated API messages.
43 type MessageIdentifier interface {
44         // GetMessageID returns message identifier of given API message.
45         GetMessageID(msg api.Message) (uint16, error)
46         // LookupByID looks up message name and crc by ID
47         LookupByID(msgID uint16) (api.Message, error)
48 }
49
50 // vppRequest is a request that will be sent to VPP.
51 type vppRequest struct {
52         seqNum uint16      // sequence number
53         msg    api.Message // binary API message to be send to VPP
54         multi  bool        // true if multipart response is expected
55 }
56
57 // vppReply is a reply received from VPP.
58 type vppReply struct {
59         seqNum       uint16 // sequence number
60         msgID        uint16 // ID of the message
61         data         []byte // encoded data with the message
62         lastReceived bool   // for multi request, true if the last reply has been already received
63         err          error  // in case of error, data is nil and this member contains error
64 }
65
66 // requestCtx is a context for request with single reply
67 type requestCtx struct {
68         ch     *Channel
69         seqNum uint16
70 }
71
72 // multiRequestCtx is a context for request with multiple responses
73 type multiRequestCtx struct {
74         ch     *Channel
75         seqNum uint16
76 }
77
78 // subscriptionCtx is a context of subscription for delivery of specific notification messages.
79 type subscriptionCtx struct {
80         ch         *Channel
81         notifChan  chan api.Message   // channel where notification messages will be delivered to
82         msgID      uint16             // message ID for the subscribed event message
83         event      api.Message        // event message that this subscription is for
84         msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
85 }
86
87 // channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
88 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
89 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
90 // concurrently, otherwise the responses could mix! Use multiple channels instead.
91 type Channel struct {
92         id   uint16
93         conn *Connection
94
95         reqChan   chan *vppRequest // channel for sending the requests to VPP
96         replyChan chan *vppReply   // channel where VPP replies are delivered to
97
98         msgCodec      MessageCodec      // used to decode binary data to generated API messages
99         msgIdentifier MessageIdentifier // used to retrieve message ID of a message
100
101         lastSeqNum uint16 // sequence number of the last sent request
102
103         delayedReply *vppReply     // reply already taken from ReplyChan, buffered for later delivery
104         replyTimeout time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
105 }
106
107 func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel {
108         return &Channel{
109                 id:            id,
110                 conn:          conn,
111                 msgCodec:      codec,
112                 msgIdentifier: identifier,
113                 reqChan:       make(chan *vppRequest, reqSize),
114                 replyChan:     make(chan *vppReply, replySize),
115                 replyTimeout:  DefaultReplyTimeout,
116         }
117 }
118
119 func (ch *Channel) GetID() uint16 {
120         return ch.id
121 }
122
123 func (ch *Channel) nextSeqNum() uint16 {
124         ch.lastSeqNum++
125         return ch.lastSeqNum
126 }
127
128 func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
129         seqNum := ch.nextSeqNum()
130         ch.reqChan <- &vppRequest{
131                 msg:    msg,
132                 seqNum: seqNum,
133         }
134         return &requestCtx{ch: ch, seqNum: seqNum}
135 }
136
137 func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
138         seqNum := ch.nextSeqNum()
139         ch.reqChan <- &vppRequest{
140                 msg:    msg,
141                 seqNum: seqNum,
142                 multi:  true,
143         }
144         return &multiRequestCtx{ch: ch, seqNum: seqNum}
145 }
146
147 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
148         var comperr api.CompatibilityError
149         for _, msg := range msgs {
150                 _, err := ch.msgIdentifier.GetMessageID(msg)
151                 if err != nil {
152                         if uerr, ok := err.(*adapter.UnknownMsgError); ok {
153                                 m := fmt.Sprintf("%s_%s", uerr.MsgName, uerr.MsgCrc)
154                                 comperr.IncompatibleMessages = append(comperr.IncompatibleMessages, m)
155                                 continue
156                         }
157                         // other errors return immediatelly
158                         return err
159                 }
160         }
161         if len(comperr.IncompatibleMessages) == 0 {
162                 return nil
163         }
164         return &comperr
165 }
166
167 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
168         msgID, err := ch.msgIdentifier.GetMessageID(event)
169         if err != nil {
170                 log.WithFields(logrus.Fields{
171                         "msg_name": event.GetMessageName(),
172                         "msg_crc":  event.GetCrcString(),
173                 }).Errorf("unable to retrieve message ID: %v", err)
174                 return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
175         }
176
177         sub := &subscriptionCtx{
178                 ch:         ch,
179                 notifChan:  notifChan,
180                 msgID:      msgID,
181                 event:      event,
182                 msgFactory: getMsgFactory(event),
183         }
184
185         // add the subscription into map
186         ch.conn.subscriptionsLock.Lock()
187         defer ch.conn.subscriptionsLock.Unlock()
188
189         ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
190
191         return sub, nil
192 }
193
194 func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
195         ch.replyTimeout = timeout
196 }
197
198 func (ch *Channel) Close() {
199         close(ch.reqChan)
200 }
201
202 func (req *requestCtx) ReceiveReply(msg api.Message) error {
203         if req == nil || req.ch == nil {
204                 return ErrInvalidRequestCtx
205         }
206
207         lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
208         if err != nil {
209                 return err
210         } else if lastReplyReceived {
211                 return errors.New("multipart reply recieved while a single reply expected")
212         }
213
214         return nil
215 }
216
217 func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
218         if req == nil || req.ch == nil {
219                 return false, ErrInvalidRequestCtx
220         }
221
222         return req.ch.receiveReplyInternal(msg, req.seqNum)
223 }
224
225 func (sub *subscriptionCtx) Unsubscribe() error {
226         log.WithFields(logrus.Fields{
227                 "msg_name": sub.event.GetMessageName(),
228                 "msg_id":   sub.msgID,
229         }).Debug("Removing notification subscription.")
230
231         // remove the subscription from the map
232         sub.ch.conn.subscriptionsLock.Lock()
233         defer sub.ch.conn.subscriptionsLock.Unlock()
234
235         for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
236                 if item == sub {
237                         // remove i-th item in the slice
238                         sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
239                         return nil
240                 }
241         }
242
243         return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
244 }
245
246 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
247 func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
248         if msg == nil {
249                 return false, errors.New("nil message passed in")
250         }
251
252         var ignore bool
253
254         if vppReply := ch.delayedReply; vppReply != nil {
255                 // try the delayed reply
256                 ch.delayedReply = nil
257                 ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
258                 if !ignore {
259                         return lastReplyReceived, err
260                 }
261         }
262
263         timer := time.NewTimer(ch.replyTimeout)
264         for {
265                 select {
266                 // blocks until a reply comes to ReplyChan or until timeout expires
267                 case vppReply := <-ch.replyChan:
268                         ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
269                         if ignore {
270                                 log.WithFields(logrus.Fields{
271                                         "expSeqNum": expSeqNum,
272                                         "channel":   ch.id,
273                                 }).Warnf("ignoring received reply: %+v (expecting: %s)", vppReply, msg.GetMessageName())
274                                 continue
275                         }
276                         return lastReplyReceived, err
277
278                 case <-timer.C:
279                         log.WithFields(logrus.Fields{
280                                 "expSeqNum": expSeqNum,
281                                 "channel":   ch.id,
282                         }).Debugf("timeout (%v) waiting for reply: %s", ch.replyTimeout, msg.GetMessageName())
283                         err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
284                         return false, err
285                 }
286         }
287 }
288
289 func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
290         // check the sequence number
291         cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
292         if cmpSeqNums == -1 {
293                 // reply received too late, ignore the message
294                 log.WithField("seqNum", reply.seqNum).
295                         Warn("Received reply to an already closed binary API request")
296                 ignore = true
297                 return
298         }
299         if cmpSeqNums == 1 {
300                 ch.delayedReply = reply
301                 err = fmt.Errorf("missing binary API reply with sequence number: %d", expSeqNum)
302                 return
303         }
304
305         if reply.err != nil {
306                 err = reply.err
307                 return
308         }
309         if reply.lastReceived {
310                 lastReplyReceived = true
311                 return
312         }
313
314         // message checks
315         var expMsgID uint16
316         expMsgID, err = ch.msgIdentifier.GetMessageID(msg)
317         if err != nil {
318                 err = fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
319                         msg.GetMessageName(), msg.GetCrcString())
320                 return
321         }
322
323         if reply.msgID != expMsgID {
324                 var msgNameCrc string
325                 if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
326                         msgNameCrc = err.Error()
327                 } else {
328                         msgNameCrc = getMsgNameWithCrc(replyMsg)
329                 }
330
331                 err = fmt.Errorf("received invalid message ID (seqNum=%d), expected %d (%s), but got %d (%s) "+
332                         "(check if multiple goroutines are not sharing single GoVPP channel)",
333                         reply.seqNum, expMsgID, msg.GetMessageName(), reply.msgID, msgNameCrc)
334                 return
335         }
336
337         // decode the message
338         if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
339                 return
340         }
341
342         // check Retval and convert it into VnetAPIError error
343         if strings.HasSuffix(msg.GetMessageName(), "_reply") {
344                 // TODO: use categories for messages to avoid checking message name
345                 if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
346                         retval := int32(f.Int())
347                         err = api.RetvalToVPPApiError(retval)
348                 }
349         }
350
351         return
352 }