Refactoring and fixes
[govpp.git] / core / stream.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package core
16
17 import (
18         "context"
19         "errors"
20         "fmt"
21         "reflect"
22         "sync"
23         "time"
24
25         "git.fd.io/govpp.git/api"
26 )
27
28 type Stream struct {
29         conn    *Connection
30         ctx     context.Context
31         channel *Channel
32         // available options
33         requestSize  int
34         replySize    int
35         replyTimeout time.Duration
36         // per-request context
37         pkgPath string
38         sync.Mutex
39 }
40
41 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
42         if c == nil {
43                 return nil, errors.New("nil connection passed in")
44         }
45         s := &Stream{
46                 conn: c,
47                 ctx:  ctx,
48                 // default options
49                 requestSize:  RequestChanBufSize,
50                 replySize:    ReplyChanBufSize,
51                 replyTimeout: DefaultReplyTimeout,
52         }
53
54         // parse custom options
55         for _, option := range options {
56                 option(s)
57         }
58
59         s.channel = c.newChannel(s.requestSize, s.replySize)
60         s.channel.SetReplyTimeout(s.replyTimeout)
61
62         // Channel.watchRequests are not started here intentionally, because
63         // requests are sent directly by SendMsg.
64
65         return s, nil
66 }
67
68 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
69         stream, err := c.NewStream(ctx)
70         if err != nil {
71                 return err
72         }
73         if err := stream.SendMsg(req); err != nil {
74                 return err
75         }
76         s := stream.(*Stream)
77         rep, err := s.recvReply()
78         if err != nil {
79                 return err
80         }
81         if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
82                 return err
83         }
84         return nil
85 }
86
87 func (s *Stream) Context() context.Context {
88         return s.ctx
89 }
90
91 func (s *Stream) Close() error {
92         if s.conn == nil {
93                 return errors.New("stream closed")
94         }
95         s.conn.releaseAPIChannel(s.channel)
96         s.conn = nil
97         return nil
98 }
99
100 func (s *Stream) SendMsg(msg api.Message) error {
101         if s.conn == nil {
102                 return errors.New("stream closed")
103         }
104         req := s.channel.newRequest(msg, false)
105         if err := s.conn.processRequest(s.channel, req); err != nil {
106                 return err
107         }
108         s.Lock()
109         s.pkgPath = s.conn.GetMessagePath(msg)
110         s.Unlock()
111         return nil
112 }
113
114 func (s *Stream) RecvMsg() (api.Message, error) {
115         reply, err := s.recvReply()
116         if err != nil {
117                 return nil, err
118         }
119         // resolve message type
120         s.Lock()
121         path := s.pkgPath
122         s.Unlock()
123         msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
124         if err != nil {
125                 return nil, err
126         }
127         // allocate message instance
128         msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
129         // decode message data
130         if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
131                 return nil, err
132         }
133         return msg, nil
134 }
135
136 func WithRequestSize(size int) api.StreamOption {
137         return func(stream api.Stream) {
138                 stream.(*Stream).requestSize = size
139         }
140 }
141
142 func WithReplySize(size int) api.StreamOption {
143         return func(stream api.Stream) {
144                 stream.(*Stream).replySize = size
145         }
146 }
147
148 func WithReplyTimeout(timeout time.Duration) api.StreamOption {
149         return func(stream api.Stream) {
150                 stream.(*Stream).replyTimeout = timeout
151         }
152 }
153
154 func (s *Stream) recvReply() (*vppReply, error) {
155         if s.conn == nil {
156                 return nil, errors.New("stream closed")
157         }
158         select {
159         case reply, ok := <-s.channel.replyChan:
160                 if !ok {
161                         return nil, fmt.Errorf("reply channel closed")
162                 }
163                 if reply.err != nil {
164                         // this case should actually never happen for stream
165                         // since reply.err is only filled in watchRequests
166                         // and stream does not use it
167                         return nil, reply.err
168                 }
169                 return reply, nil
170
171         case <-s.ctx.Done():
172                 return nil, s.ctx.Err()
173         }
174 }