Change module name to go.fd.io/govpp
[govpp.git] / core / stream.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package core
16
17 import (
18         "context"
19         "errors"
20         "fmt"
21         "reflect"
22         "sync"
23         "time"
24
25         "go.fd.io/govpp/api"
26 )
27
28 type Stream struct {
29         conn    *Connection
30         ctx     context.Context
31         channel *Channel
32         // available options
33         requestSize  int
34         replySize    int
35         replyTimeout time.Duration
36         // per-request context
37         pkgPath string
38         sync.Mutex
39 }
40
41 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
42         if c == nil {
43                 return nil, errors.New("nil connection passed in")
44         }
45         s := &Stream{
46                 conn: c,
47                 ctx:  ctx,
48                 // default options
49                 requestSize:  RequestChanBufSize,
50                 replySize:    ReplyChanBufSize,
51                 replyTimeout: DefaultReplyTimeout,
52         }
53
54         // parse custom options
55         for _, option := range options {
56                 option(s)
57         }
58
59         ch, err := c.newChannel(s.requestSize, s.replySize)
60         if err != nil {
61                 return nil, err
62         }
63         s.channel = ch
64         s.channel.SetReplyTimeout(s.replyTimeout)
65
66         // Channel.watchRequests are not started here intentionally, because
67         // requests are sent directly by SendMsg.
68
69         return s, nil
70 }
71
72 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
73         stream, err := c.NewStream(ctx)
74         if err != nil {
75                 return err
76         }
77         defer func() { _ = stream.Close() }()
78         if err := stream.SendMsg(req); err != nil {
79                 return err
80         }
81         s := stream.(*Stream)
82         rep, err := s.recvReply()
83         if err != nil {
84                 return err
85         }
86         if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
87                 return err
88         }
89         return nil
90 }
91
92 func (s *Stream) Context() context.Context {
93         return s.ctx
94 }
95
96 func (s *Stream) Close() error {
97         if s.conn == nil {
98                 return errors.New("stream closed")
99         }
100         s.conn.releaseAPIChannel(s.channel)
101         s.conn = nil
102         return nil
103 }
104
105 func (s *Stream) SendMsg(msg api.Message) error {
106         if s.conn == nil {
107                 return errors.New("stream closed")
108         }
109         req := s.channel.newRequest(msg, false)
110         if err := s.conn.processRequest(s.channel, req); err != nil {
111                 return err
112         }
113         s.Lock()
114         s.pkgPath = s.conn.GetMessagePath(msg)
115         s.Unlock()
116         return nil
117 }
118
119 func (s *Stream) RecvMsg() (api.Message, error) {
120         reply, err := s.recvReply()
121         if err != nil {
122                 return nil, err
123         }
124         // resolve message type
125         s.Lock()
126         path := s.pkgPath
127         s.Unlock()
128         msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
129         if err != nil {
130                 return nil, err
131         }
132         // allocate message instance
133         msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
134         // decode message data
135         if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
136                 return nil, err
137         }
138         return msg, nil
139 }
140
141 func WithRequestSize(size int) api.StreamOption {
142         return func(stream api.Stream) {
143                 stream.(*Stream).requestSize = size
144         }
145 }
146
147 func WithReplySize(size int) api.StreamOption {
148         return func(stream api.Stream) {
149                 stream.(*Stream).replySize = size
150         }
151 }
152
153 func WithReplyTimeout(timeout time.Duration) api.StreamOption {
154         return func(stream api.Stream) {
155                 stream.(*Stream).replyTimeout = timeout
156         }
157 }
158
159 func (s *Stream) recvReply() (*vppReply, error) {
160         if s.conn == nil {
161                 return nil, errors.New("stream closed")
162         }
163         select {
164         case reply, ok := <-s.channel.replyChan:
165                 if !ok {
166                         return nil, fmt.Errorf("reply channel closed")
167                 }
168                 if reply.err != nil {
169                         // this case should actually never happen for stream
170                         // since reply.err is only filled in watchRequests
171                         // and stream does not use it
172                         return nil, reply.err
173                 }
174                 return reply, nil
175
176         case <-s.ctx.Done():
177                 return nil, s.ctx.Err()
178         }
179 }