2f639b01e8912d14a1aebf10d4ebd41fcf714b16
[govpp.git] / core / stream.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package core
16
17 import (
18         "context"
19         "errors"
20         "fmt"
21         "reflect"
22         "sync"
23         "time"
24
25         "git.fd.io/govpp.git/api"
26 )
27
28 type Stream struct {
29         conn    *Connection
30         ctx     context.Context
31         channel *Channel
32         // available options
33         requestSize  int
34         replySize    int
35         replyTimeout time.Duration
36         // per-request context
37         pkgPath string
38         sync.Mutex
39 }
40
41 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
42         if c == nil {
43                 return nil, errors.New("nil connection passed in")
44         }
45         s := &Stream{
46                 conn: c,
47                 ctx:  ctx,
48                 // default options
49                 requestSize:  RequestChanBufSize,
50                 replySize:    ReplyChanBufSize,
51                 replyTimeout: DefaultReplyTimeout,
52         }
53
54         // parse custom options
55         for _, option := range options {
56                 option(s)
57         }
58
59         s.channel = c.newChannel(s.requestSize, s.replySize)
60         s.channel.SetReplyTimeout(s.replyTimeout)
61
62         // Channel.watchRequests are not started here intentionally, because
63         // requests are sent directly by SendMsg.
64
65         return s, nil
66 }
67
68 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
69         stream, err := c.NewStream(ctx)
70         if err != nil {
71                 return err
72         }
73         defer func() { _ = stream.Close() }()
74         if err := stream.SendMsg(req); err != nil {
75                 return err
76         }
77         s := stream.(*Stream)
78         rep, err := s.recvReply()
79         if err != nil {
80                 return err
81         }
82         if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
83                 return err
84         }
85         return nil
86 }
87
88 func (s *Stream) Context() context.Context {
89         return s.ctx
90 }
91
92 func (s *Stream) Close() error {
93         if s.conn == nil {
94                 return errors.New("stream closed")
95         }
96         s.conn.releaseAPIChannel(s.channel)
97         s.conn = nil
98         return nil
99 }
100
101 func (s *Stream) SendMsg(msg api.Message) error {
102         if s.conn == nil {
103                 return errors.New("stream closed")
104         }
105         req := s.channel.newRequest(msg, false)
106         if err := s.conn.processRequest(s.channel, req); err != nil {
107                 return err
108         }
109         s.Lock()
110         s.pkgPath = s.conn.GetMessagePath(msg)
111         s.Unlock()
112         return nil
113 }
114
115 func (s *Stream) RecvMsg() (api.Message, error) {
116         reply, err := s.recvReply()
117         if err != nil {
118                 return nil, err
119         }
120         // resolve message type
121         s.Lock()
122         path := s.pkgPath
123         s.Unlock()
124         msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
125         if err != nil {
126                 return nil, err
127         }
128         // allocate message instance
129         msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
130         // decode message data
131         if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
132                 return nil, err
133         }
134         return msg, nil
135 }
136
137 func WithRequestSize(size int) api.StreamOption {
138         return func(stream api.Stream) {
139                 stream.(*Stream).requestSize = size
140         }
141 }
142
143 func WithReplySize(size int) api.StreamOption {
144         return func(stream api.Stream) {
145                 stream.(*Stream).replySize = size
146         }
147 }
148
149 func WithReplyTimeout(timeout time.Duration) api.StreamOption {
150         return func(stream api.Stream) {
151                 stream.(*Stream).replyTimeout = timeout
152         }
153 }
154
155 func (s *Stream) recvReply() (*vppReply, error) {
156         if s.conn == nil {
157                 return nil, errors.New("stream closed")
158         }
159         select {
160         case reply, ok := <-s.channel.replyChan:
161                 if !ok {
162                         return nil, fmt.Errorf("reply channel closed")
163                 }
164                 if reply.err != nil {
165                         // this case should actually never happen for stream
166                         // since reply.err is only filled in watchRequests
167                         // and stream does not use it
168                         return nil, reply.err
169                 }
170                 return reply, nil
171
172         case <-s.ctx.Done():
173                 return nil, s.ctx.Err()
174         }
175 }