Add socketclient implementation
[govpp.git] / core / channel.go
1 // Copyright (c) 2018 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package core
16
17 import (
18         "errors"
19         "fmt"
20         "reflect"
21         "strings"
22         "time"
23
24         "git.fd.io/govpp.git/api"
25         "github.com/sirupsen/logrus"
26 )
27
28 var (
29         ErrInvalidRequestCtx = errors.New("invalid request context")
30 )
31
32 // MessageCodec provides functionality for decoding binary data to generated API messages.
33 type MessageCodec interface {
34         // EncodeMsg encodes message into binary data.
35         EncodeMsg(msg api.Message, msgID uint16) ([]byte, error)
36         // DecodeMsg decodes binary-encoded data of a message into provided Message structure.
37         DecodeMsg(data []byte, msg api.Message) error
38 }
39
40 // MessageIdentifier provides identification of generated API messages.
41 type MessageIdentifier interface {
42         // GetMessageID returns message identifier of given API message.
43         GetMessageID(msg api.Message) (uint16, error)
44         // LookupByID looks up message name and crc by ID
45         LookupByID(msgID uint16) (api.Message, error)
46 }
47
48 // vppRequest is a request that will be sent to VPP.
49 type vppRequest struct {
50         seqNum uint16      // sequence number
51         msg    api.Message // binary API message to be send to VPP
52         multi  bool        // true if multipart response is expected
53 }
54
55 // vppReply is a reply received from VPP.
56 type vppReply struct {
57         seqNum       uint16 // sequence number
58         msgID        uint16 // ID of the message
59         data         []byte // encoded data with the message
60         lastReceived bool   // for multi request, true if the last reply has been already received
61         err          error  // in case of error, data is nil and this member contains error
62 }
63
64 // requestCtx is a context for request with single reply
65 type requestCtx struct {
66         ch     *Channel
67         seqNum uint16
68 }
69
70 // multiRequestCtx is a context for request with multiple responses
71 type multiRequestCtx struct {
72         ch     *Channel
73         seqNum uint16
74 }
75
76 // subscriptionCtx is a context of subscription for delivery of specific notification messages.
77 type subscriptionCtx struct {
78         ch         *Channel
79         notifChan  chan api.Message   // channel where notification messages will be delivered to
80         msgID      uint16             // message ID for the subscribed event message
81         event      api.Message        // event message that this subscription is for
82         msgFactory func() api.Message // function that returns a new instance of the specific message that is expected as a notification
83 }
84
85 // channel is the main communication interface with govpp core. It contains four Go channels, one for sending the requests
86 // to VPP, one for receiving the replies from it and the same set for notifications. The user can access the Go channels
87 // via methods provided by Channel interface in this package. Do not use the same channel from multiple goroutines
88 // concurrently, otherwise the responses could mix! Use multiple channels instead.
89 type Channel struct {
90         id   uint16
91         conn *Connection
92
93         reqChan   chan *vppRequest // channel for sending the requests to VPP
94         replyChan chan *vppReply   // channel where VPP replies are delivered to
95
96         msgCodec      MessageCodec      // used to decode binary data to generated API messages
97         msgIdentifier MessageIdentifier // used to retrieve message ID of a message
98
99         lastSeqNum uint16 // sequence number of the last sent request
100
101         delayedReply *vppReply     // reply already taken from ReplyChan, buffered for later delivery
102         replyTimeout time.Duration // maximum time that the API waits for a reply from VPP before returning an error, can be set with SetReplyTimeout
103 }
104
105 func newChannel(id uint16, conn *Connection, codec MessageCodec, identifier MessageIdentifier, reqSize, replySize int) *Channel {
106         return &Channel{
107                 id:            id,
108                 conn:          conn,
109                 msgCodec:      codec,
110                 msgIdentifier: identifier,
111                 reqChan:       make(chan *vppRequest, reqSize),
112                 replyChan:     make(chan *vppReply, replySize),
113                 replyTimeout:  DefaultReplyTimeout,
114         }
115 }
116
117 func (ch *Channel) GetID() uint16 {
118         return ch.id
119 }
120
121 func (ch *Channel) nextSeqNum() uint16 {
122         ch.lastSeqNum++
123         return ch.lastSeqNum
124 }
125
126 func (ch *Channel) SendRequest(msg api.Message) api.RequestCtx {
127         seqNum := ch.nextSeqNum()
128         ch.reqChan <- &vppRequest{
129                 msg:    msg,
130                 seqNum: seqNum,
131         }
132         return &requestCtx{ch: ch, seqNum: seqNum}
133 }
134
135 func (ch *Channel) SendMultiRequest(msg api.Message) api.MultiRequestCtx {
136         seqNum := ch.nextSeqNum()
137         ch.reqChan <- &vppRequest{
138                 msg:    msg,
139                 seqNum: seqNum,
140                 multi:  true,
141         }
142         return &multiRequestCtx{ch: ch, seqNum: seqNum}
143 }
144
145 func (ch *Channel) CheckCompatiblity(msgs ...api.Message) error {
146         for _, msg := range msgs {
147                 _, err := ch.msgIdentifier.GetMessageID(msg)
148                 if err != nil {
149                         return err
150                 }
151         }
152         return nil
153 }
154
155 func (ch *Channel) SubscribeNotification(notifChan chan api.Message, event api.Message) (api.SubscriptionCtx, error) {
156         msgID, err := ch.msgIdentifier.GetMessageID(event)
157         if err != nil {
158                 log.WithFields(logrus.Fields{
159                         "msg_name": event.GetMessageName(),
160                         "msg_crc":  event.GetCrcString(),
161                 }).Errorf("unable to retrieve message ID: %v", err)
162                 return nil, fmt.Errorf("unable to retrieve event message ID: %v", err)
163         }
164
165         sub := &subscriptionCtx{
166                 ch:         ch,
167                 notifChan:  notifChan,
168                 msgID:      msgID,
169                 event:      event,
170                 msgFactory: getMsgFactory(event),
171         }
172
173         // add the subscription into map
174         ch.conn.subscriptionsLock.Lock()
175         defer ch.conn.subscriptionsLock.Unlock()
176
177         ch.conn.subscriptions[msgID] = append(ch.conn.subscriptions[msgID], sub)
178
179         return sub, nil
180 }
181
182 func (ch *Channel) SetReplyTimeout(timeout time.Duration) {
183         ch.replyTimeout = timeout
184 }
185
186 func (ch *Channel) Close() {
187         if ch.reqChan != nil {
188                 close(ch.reqChan)
189                 ch.reqChan = nil
190         }
191 }
192
193 func (req *requestCtx) ReceiveReply(msg api.Message) error {
194         if req == nil || req.ch == nil {
195                 return ErrInvalidRequestCtx
196         }
197
198         lastReplyReceived, err := req.ch.receiveReplyInternal(msg, req.seqNum)
199         if err != nil {
200                 return err
201         } else if lastReplyReceived {
202                 return errors.New("multipart reply recieved while a single reply expected")
203         }
204
205         return nil
206 }
207
208 func (req *multiRequestCtx) ReceiveReply(msg api.Message) (lastReplyReceived bool, err error) {
209         if req == nil || req.ch == nil {
210                 return false, ErrInvalidRequestCtx
211         }
212
213         return req.ch.receiveReplyInternal(msg, req.seqNum)
214 }
215
216 func (sub *subscriptionCtx) Unsubscribe() error {
217         log.WithFields(logrus.Fields{
218                 "msg_name": sub.event.GetMessageName(),
219                 "msg_id":   sub.msgID,
220         }).Debug("Removing notification subscription.")
221
222         // remove the subscription from the map
223         sub.ch.conn.subscriptionsLock.Lock()
224         defer sub.ch.conn.subscriptionsLock.Unlock()
225
226         for i, item := range sub.ch.conn.subscriptions[sub.msgID] {
227                 if item == sub {
228                         // remove i-th item in the slice
229                         sub.ch.conn.subscriptions[sub.msgID] = append(sub.ch.conn.subscriptions[sub.msgID][:i], sub.ch.conn.subscriptions[sub.msgID][i+1:]...)
230                         return nil
231                 }
232         }
233
234         return fmt.Errorf("subscription for %q not found", sub.event.GetMessageName())
235 }
236
237 // receiveReplyInternal receives a reply from the reply channel into the provided msg structure.
238 func (ch *Channel) receiveReplyInternal(msg api.Message, expSeqNum uint16) (lastReplyReceived bool, err error) {
239         if msg == nil {
240                 return false, errors.New("nil message passed in")
241         }
242
243         var ignore bool
244
245         if vppReply := ch.delayedReply; vppReply != nil {
246                 // try the delayed reply
247                 ch.delayedReply = nil
248                 ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
249                 if !ignore {
250                         return lastReplyReceived, err
251                 }
252         }
253
254         timer := time.NewTimer(ch.replyTimeout)
255         for {
256                 select {
257                 // blocks until a reply comes to ReplyChan or until timeout expires
258                 case vppReply := <-ch.replyChan:
259                         ignore, lastReplyReceived, err = ch.processReply(vppReply, expSeqNum, msg)
260                         if ignore {
261                                 logrus.Warnf("ignoring reply: %+v", vppReply)
262                                 continue
263                         }
264                         return lastReplyReceived, err
265
266                 case <-timer.C:
267                         err = fmt.Errorf("no reply received within the timeout period %s", ch.replyTimeout)
268                         return false, err
269                 }
270         }
271         return
272 }
273
274 func (ch *Channel) processReply(reply *vppReply, expSeqNum uint16, msg api.Message) (ignore bool, lastReplyReceived bool, err error) {
275         // check the sequence number
276         cmpSeqNums := compareSeqNumbers(reply.seqNum, expSeqNum)
277         if cmpSeqNums == -1 {
278                 // reply received too late, ignore the message
279                 logrus.WithField("seqNum", reply.seqNum).
280                         Warn("Received reply to an already closed binary API request")
281                 ignore = true
282                 return
283         }
284         if cmpSeqNums == 1 {
285                 ch.delayedReply = reply
286                 err = fmt.Errorf("missing binary API reply with sequence number: %d", expSeqNum)
287                 return
288         }
289
290         if reply.err != nil {
291                 err = reply.err
292                 return
293         }
294         if reply.lastReceived {
295                 lastReplyReceived = true
296                 return
297         }
298
299         // message checks
300         var expMsgID uint16
301         expMsgID, err = ch.msgIdentifier.GetMessageID(msg)
302         if err != nil {
303                 err = fmt.Errorf("message %s with CRC %s is not compatible with the VPP we are connected to",
304                         msg.GetMessageName(), msg.GetCrcString())
305                 return
306         }
307
308         if reply.msgID != expMsgID {
309                 var msgNameCrc string
310                 if replyMsg, err := ch.msgIdentifier.LookupByID(reply.msgID); err != nil {
311                         msgNameCrc = err.Error()
312                 } else {
313                         msgNameCrc = getMsgNameWithCrc(replyMsg)
314                 }
315
316                 err = fmt.Errorf("received invalid message ID (seqNum=%d), expected %d (%s), but got %d (%s) "+
317                         "(check if multiple goroutines are not sharing single GoVPP channel)",
318                         reply.seqNum, expMsgID, msg.GetMessageName(), reply.msgID, msgNameCrc)
319                 return
320         }
321
322         // decode the message
323         if err = ch.msgCodec.DecodeMsg(reply.data, msg); err != nil {
324                 return
325         }
326
327         // check Retval and convert it into VnetAPIError error
328         if strings.HasSuffix(msg.GetMessageName(), "_reply") {
329                 // TODO: use categories for messages to avoid checking message name
330                 if f := reflect.Indirect(reflect.ValueOf(msg)).FieldByName("Retval"); f.IsValid() {
331                         retval := int32(f.Int())
332                         err = api.RetvalToVPPApiError(retval)
333                 }
334         }
335
336         return
337 }