Fixed incorrect message error in the stream API
[govpp.git] / core / stream.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package core
16
17 import (
18         "context"
19         "errors"
20         "fmt"
21         "reflect"
22         "sync"
23         "sync/atomic"
24         "time"
25
26         "git.fd.io/govpp.git/api"
27 )
28
29 type Stream struct {
30         id      uint32
31         conn    *Connection
32         ctx     context.Context
33         channel *Channel
34         // available options
35         requestSize  int
36         replySize    int
37         replyTimeout time.Duration
38         // per-request context
39         pkgPath string
40         sync.Mutex
41 }
42
43 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
44         if c == nil {
45                 return nil, errors.New("nil connection passed in")
46         }
47         s := &Stream{
48                 conn: c,
49                 ctx:  ctx,
50                 // default options
51                 requestSize:  RequestChanBufSize,
52                 replySize:    ReplyChanBufSize,
53                 replyTimeout: DefaultReplyTimeout,
54         }
55
56         // parse custom options
57         for _, option := range options {
58                 option(s)
59         }
60         // create and store a new channel
61         s.id = atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff
62         s.channel = newChannel(uint16(s.id), c, c.codec, c, s.requestSize, s.replySize)
63         s.channel.SetReplyTimeout(s.replyTimeout)
64
65         // store API channel within the client
66         c.channelsLock.Lock()
67         c.channels[uint16(s.id)] = s.channel
68         c.channelsLock.Unlock()
69
70         // Channel.watchRequests are not started here intentionally, because
71         // requests are sent directly by SendMsg.
72
73         return s, nil
74 }
75
76 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
77         stream, err := c.NewStream(ctx)
78         if err != nil {
79                 return err
80         }
81         if err := stream.SendMsg(req); err != nil {
82                 return err
83         }
84         s := stream.(*Stream)
85         rep, err := s.recvReply()
86         if err != nil {
87                 return err
88         }
89         if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
90                 return err
91         }
92         return nil
93 }
94
95 func (s *Stream) Context() context.Context {
96         return s.ctx
97 }
98
99 func (s *Stream) Close() error {
100         if s.conn == nil {
101                 return errors.New("stream closed")
102         }
103         s.conn.releaseAPIChannel(s.channel)
104         s.conn = nil
105         return nil
106 }
107
108 func (s *Stream) SendMsg(msg api.Message) error {
109         if s.conn == nil {
110                 return errors.New("stream closed")
111         }
112         req := s.channel.newRequest(msg, false)
113         if err := s.conn.processRequest(s.channel, req); err != nil {
114                 return err
115         }
116         s.Lock()
117         s.pkgPath = s.conn.GetMessagePath(msg)
118         s.Unlock()
119         return nil
120 }
121
122 func (s *Stream) RecvMsg() (api.Message, error) {
123         reply, err := s.recvReply()
124         if err != nil {
125                 return nil, err
126         }
127         // resolve message type
128         s.Lock()
129         path := s.pkgPath
130         s.Unlock()
131         msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
132         if err != nil {
133                 return nil, err
134         }
135         // allocate message instance
136         msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
137         // decode message data
138         if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
139                 return nil, err
140         }
141         return msg, nil
142 }
143
144 func WithRequestSize(size int) api.StreamOption {
145         return func(stream api.Stream) {
146                 stream.(*Stream).requestSize = size
147         }
148 }
149
150 func WithReplySize(size int) api.StreamOption {
151         return func(stream api.Stream) {
152                 stream.(*Stream).replySize = size
153         }
154 }
155
156 func WithReplyTimeout(timeout time.Duration) api.StreamOption {
157         return func(stream api.Stream) {
158                 stream.(*Stream).replyTimeout = timeout
159         }
160 }
161
162 func (s *Stream) recvReply() (*vppReply, error) {
163         if s.conn == nil {
164                 return nil, errors.New("stream closed")
165         }
166         select {
167         case reply, ok := <-s.channel.replyChan:
168                 if !ok {
169                         return nil, fmt.Errorf("reply channel closed")
170                 }
171                 if reply.err != nil {
172                         // this case should actually never happen for stream
173                         // since reply.err is only filled in watchRequests
174                         // and stream does not use it
175                         return nil, reply.err
176                 }
177                 return reply, nil
178
179         case <-s.ctx.Done():
180                 return nil, s.ctx.Err()
181         }
182 }