abe9d55ff824f3c3ed47ae014d6a897b6ab210e9
[govpp.git] / core / stream.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package core
16
17 import (
18         "context"
19         "errors"
20         "fmt"
21         "reflect"
22         "sync/atomic"
23         "time"
24
25         "git.fd.io/govpp.git/api"
26 )
27
28 type Stream struct {
29         id      uint32
30         conn    *Connection
31         ctx     context.Context
32         channel *Channel
33         // available options
34         requestSize  int
35         replySize    int
36         replyTimeout time.Duration
37 }
38
39 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
40         if c == nil {
41                 return nil, errors.New("nil connection passed in")
42         }
43         s := &Stream{
44                 conn: c,
45                 ctx:  ctx,
46                 // default options
47                 requestSize:  RequestChanBufSize,
48                 replySize:    ReplyChanBufSize,
49                 replyTimeout: DefaultReplyTimeout,
50         }
51
52         // parse custom options
53         for _, option := range options {
54                 option(s)
55         }
56         // create and store a new channel
57         s.id = atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff
58         s.channel = newChannel(uint16(s.id), c, c.codec, c, s.requestSize, s.replySize)
59         s.channel.SetReplyTimeout(s.replyTimeout)
60
61         // store API channel within the client
62         c.channelsLock.Lock()
63         c.channels[uint16(s.id)] = s.channel
64         c.channelsLock.Unlock()
65
66         // Channel.watchRequests are not started here intentionally, because
67         // requests are sent directly by SendMsg.
68
69         return s, nil
70 }
71
72 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
73         stream, err := c.NewStream(ctx)
74         if err != nil {
75                 return err
76         }
77         if err := stream.SendMsg(req); err != nil {
78                 return err
79         }
80         s := stream.(*Stream)
81         rep, err := s.recvReply()
82         if err != nil {
83                 return err
84         }
85         if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
86                 return err
87         }
88         return nil
89 }
90
91 func (s *Stream) Context() context.Context {
92         return s.ctx
93 }
94
95 func (s *Stream) Close() error {
96         if s.conn == nil {
97                 return errors.New("stream closed")
98         }
99         s.conn.releaseAPIChannel(s.channel)
100         s.conn = nil
101         return nil
102 }
103
104 func (s *Stream) SendMsg(msg api.Message) error {
105         if s.conn == nil {
106                 return errors.New("stream closed")
107         }
108         req := s.channel.newRequest(msg, false)
109         if err := s.conn.processRequest(s.channel, req); err != nil {
110                 return err
111         }
112         return nil
113 }
114
115 func (s *Stream) RecvMsg() (api.Message, error) {
116         reply, err := s.recvReply()
117         if err != nil {
118                 return nil, err
119         }
120         // resolve message type
121         msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
122         if err != nil {
123                 return nil, err
124         }
125         // allocate message instance
126         msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
127         // decode message data
128         if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
129                 return nil, err
130         }
131         return msg, nil
132 }
133
134 func WithRequestSize(size int) api.StreamOption {
135         return func(stream api.Stream) {
136                 stream.(*Stream).requestSize = size
137         }
138 }
139
140 func WithReplySize(size int) api.StreamOption {
141         return func(stream api.Stream) {
142                 stream.(*Stream).replySize = size
143         }
144 }
145
146 func WithReplyTimeout(timeout time.Duration) api.StreamOption {
147         return func(stream api.Stream) {
148                 stream.(*Stream).replyTimeout = timeout
149         }
150 }
151
152 func (s *Stream) recvReply() (*vppReply, error) {
153         if s.conn == nil {
154                 return nil, errors.New("stream closed")
155         }
156         select {
157         case reply, ok := <-s.channel.replyChan:
158                 if !ok {
159                         return nil, fmt.Errorf("reply channel closed")
160                 }
161                 if reply.err != nil {
162                         // this case should actually never happen for stream
163                         // since reply.err is only filled in watchRequests
164                         // and stream does not use it
165                         return nil, reply.err
166                 }
167                 return reply, nil
168
169         case <-s.ctx.Done():
170                 return nil, s.ctx.Err()
171         }
172 }