test: Fix test dependancy
[govpp.git] / core / stream.go
index 171b201..67236f1 100644 (file)
@@ -19,46 +19,54 @@ import (
        "errors"
        "fmt"
        "reflect"
-       "sync/atomic"
+       "sync"
+       "time"
 
        "git.fd.io/govpp.git/api"
 )
 
 type Stream struct {
-       id      uint32
        conn    *Connection
        ctx     context.Context
        channel *Channel
+       // available options
+       requestSize  int
+       replySize    int
+       replyTimeout time.Duration
+       // per-request context
+       pkgPath string
+       sync.Mutex
 }
 
-func (c *Connection) NewStream(ctx context.Context) (api.Stream, error) {
+func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
        if c == nil {
                return nil, errors.New("nil connection passed in")
        }
-       // TODO: add stream options as variadic parameters for customizing:
-       // - request/reply channel size
-       // - reply timeout
-       // - retries
-       // - ???
+       s := &Stream{
+               conn: c,
+               ctx:  ctx,
+               // default options
+               requestSize:  RequestChanBufSize,
+               replySize:    ReplyChanBufSize,
+               replyTimeout: DefaultReplyTimeout,
+       }
 
-       // create new channel
-       chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff)
-       channel := newChannel(chID, c, c.codec, c, 10, 10)
+       // parse custom options
+       for _, option := range options {
+               option(s)
+       }
 
-       // store API channel within the client
-       c.channelsLock.Lock()
-       c.channels[chID] = channel
-       c.channelsLock.Unlock()
+       ch, err := c.newChannel(s.requestSize, s.replySize)
+       if err != nil {
+               return nil, err
+       }
+       s.channel = ch
+       s.channel.SetReplyTimeout(s.replyTimeout)
 
        // Channel.watchRequests are not started here intentionally, because
        // requests are sent directly by SendMsg.
 
-       return &Stream{
-               id:      uint32(chID),
-               conn:    c,
-               ctx:     ctx,
-               channel: channel,
-       }, nil
+       return s, nil
 }
 
 func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error {
@@ -66,18 +74,18 @@ func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Mess
        if err != nil {
                return err
        }
+       defer func() { _ = stream.Close() }()
        if err := stream.SendMsg(req); err != nil {
                return err
        }
-       msg, err := stream.RecvMsg()
+       s := stream.(*Stream)
+       rep, err := s.recvReply()
        if err != nil {
                return err
        }
-       if msg.GetMessageName() != reply.GetMessageName() ||
-               msg.GetCrcString() != reply.GetCrcString() {
-               return fmt.Errorf("unexpected reply: %T %+v", msg, msg)
+       if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil {
+               return err
        }
-       reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(msg).Elem())
        return nil
 }
 
@@ -102,10 +110,53 @@ func (s *Stream) SendMsg(msg api.Message) error {
        if err := s.conn.processRequest(s.channel, req); err != nil {
                return err
        }
+       s.Lock()
+       s.pkgPath = s.conn.GetMessagePath(msg)
+       s.Unlock()
        return nil
 }
 
 func (s *Stream) RecvMsg() (api.Message, error) {
+       reply, err := s.recvReply()
+       if err != nil {
+               return nil, err
+       }
+       // resolve message type
+       s.Lock()
+       path := s.pkgPath
+       s.Unlock()
+       msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
+       if err != nil {
+               return nil, err
+       }
+       // allocate message instance
+       msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
+       // decode message data
+       if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
+               return nil, err
+       }
+       return msg, nil
+}
+
+func WithRequestSize(size int) api.StreamOption {
+       return func(stream api.Stream) {
+               stream.(*Stream).requestSize = size
+       }
+}
+
+func WithReplySize(size int) api.StreamOption {
+       return func(stream api.Stream) {
+               stream.(*Stream).replySize = size
+       }
+}
+
+func WithReplyTimeout(timeout time.Duration) api.StreamOption {
+       return func(stream api.Stream) {
+               stream.(*Stream).replyTimeout = timeout
+       }
+}
+
+func (s *Stream) recvReply() (*vppReply, error) {
        if s.conn == nil {
                return nil, errors.New("stream closed")
        }
@@ -120,18 +171,7 @@ func (s *Stream) RecvMsg() (api.Message, error) {
                        // and stream does not use it
                        return nil, reply.err
                }
-               // resolve message type
-               msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
-               if err != nil {
-                       return nil, err
-               }
-               // allocate message instance
-               msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message)
-               // decode message data
-               if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil {
-                       return nil, err
-               }
-               return msg, nil
+               return reply, nil
 
        case <-s.ctx.Done():
                return nil, s.ctx.Err()