Refactoring and fixes
[govpp.git] / core / stream.go
index abe9d55..363cc9f 100644 (file)
@@ -19,14 +19,13 @@ import (
        "errors"
        "fmt"
        "reflect"
-       "sync/atomic"
+       "sync"
        "time"
 
        "git.fd.io/govpp.git/api"
 )
 
 type Stream struct {
-       id      uint32
        conn    *Connection
        ctx     context.Context
        channel *Channel
@@ -34,6 +33,9 @@ type Stream struct {
        requestSize  int
        replySize    int
        replyTimeout time.Duration
+       // per-request context
+       pkgPath string
+       sync.Mutex
 }
 
 func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) {
@@ -53,15 +55,9 @@ func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption)
        for _, option := range options {
                option(s)
        }
-       // create and store a new channel
-       s.id = atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff
-       s.channel = newChannel(uint16(s.id), c, c.codec, c, s.requestSize, s.replySize)
-       s.channel.SetReplyTimeout(s.replyTimeout)
 
-       // store API channel within the client
-       c.channelsLock.Lock()
-       c.channels[uint16(s.id)] = s.channel
-       c.channelsLock.Unlock()
+       s.channel = c.newChannel(s.requestSize, s.replySize)
+       s.channel.SetReplyTimeout(s.replyTimeout)
 
        // Channel.watchRequests are not started here intentionally, because
        // requests are sent directly by SendMsg.
@@ -109,6 +105,9 @@ func (s *Stream) SendMsg(msg api.Message) error {
        if err := s.conn.processRequest(s.channel, req); err != nil {
                return err
        }
+       s.Lock()
+       s.pkgPath = s.conn.GetMessagePath(msg)
+       s.Unlock()
        return nil
 }
 
@@ -118,7 +117,10 @@ func (s *Stream) RecvMsg() (api.Message, error) {
                return nil, err
        }
        // resolve message type
-       msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID)
+       s.Lock()
+       path := s.pkgPath
+       s.Unlock()
+       msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID)
        if err != nil {
                return nil, err
        }