test: Fix test dependancy
[govpp.git] / core / stream.go
index abe9d55..67236f1 100644 (file)
@@ -19,14 +19,13 @@ import (
        "errors"
        "fmt"
        "reflect"
-       "sync/atomic"
+       "sync"
        "time"
 
        "git.fd.io/govpp.git/api"
 )
 
 type Stream struct {
-       id      uint32
        conn    *Connection
        ctx     context.Context
        channel *Channel
@@ -34,6 +33,9 @@ type Stream struct {
        requestSize  int
        replySize    int
        replyTimeout time.Duration
+       // per-request context
+       pkgPath string
+       sync.Mutex
 }
 
 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
@@ -53,15 +55,13 @@ func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption)
        for _, option := range options {
                option(s)
        }
-       // create and store a new channel
-       s.id = atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff
-       s.channel = newChannel(uint16(s.id), c, c.codec, c, s.requestSize, s.replySize)
-       s.channel.SetReplyTimeout(s.replyTimeout)
 
-       // store API channel within the client
-       c.channelsLock.Lock()
-       c.channels[uint16(s.id)] = s.channel
-       c.channelsLock.Unlock()
+       ch, err := c.newChannel(s.requestSize, s.replySize)
+       if err != nil {
+               return nil, err
+       }
+       s.channel = ch
+       s.channel.SetReplyTimeout(s.replyTimeout)
 
        // Channel.watchRequests are not started here intentionally, because
        // requests are sent directly by SendMsg.
@@ -74,6 +74,7 @@ func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Mess
        if err != nil {
                return err
        }
+       defer func() { _ = stream.Close() }()
        if err := stream.SendMsg(req); err != nil {
                return err
        }
@@ -109,6 +110,9 @@ func (s *Stream) SendMsg(msg api.Message) error {
        if err := s.conn.processRequest(s.channel, req); err != nil {
                return err
        }
+       s.Lock()
+       s.pkgPath = s.conn.GetMessagePath(msg)
+       s.Unlock()
        return nil
 }
 
@@ -118,7 +122,10 @@ func (s *Stream) RecvMsg() (api.Message, error) {
                return nil, err
        }
        // resolve message type
-       msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
+       s.Lock()
+       path := s.pkgPath
+       s.Unlock()
+       msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
        if err != nil {
                return nil, err
        }