X-Git-Url: https://gerrit.fd.io/r/gitweb?a=blobdiff_plain;f=core%2Fstream.go;h=67236f1f2da8c992b9ef2fc1579bdb40236ab0ec;hb=a4112fac7b86fe09650d2bb57969fe46404edd7d;hp=171b2017bcef97ed5ab9e3626267285b4afafe19;hpb=d1f24d37bd447b64e402298bb8eb2479681facf9;p=govpp.git diff --git a/core/stream.go b/core/stream.go index 171b201..67236f1 100644 --- a/core/stream.go +++ b/core/stream.go @@ -19,46 +19,54 @@ import ( "errors" "fmt" "reflect" - "sync/atomic" + "sync" + "time" "git.fd.io/govpp.git/api" ) type Stream struct { - id uint32 conn *Connection ctx context.Context channel *Channel + // available options + requestSize int + replySize int + replyTimeout time.Duration + // per-request context + pkgPath string + sync.Mutex } -func (c *Connection) NewStream(ctx context.Context) (api.Stream, error) { +func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption) (api.Stream, error) { if c == nil { return nil, errors.New("nil connection passed in") } - // TODO: add stream options as variadic parameters for customizing: - // - request/reply channel size - // - reply timeout - // - retries - // - ??? + s := &Stream{ + conn: c, + ctx: ctx, + // default options + requestSize: RequestChanBufSize, + replySize: ReplyChanBufSize, + replyTimeout: DefaultReplyTimeout, + } - // create new channel - chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff) - channel := newChannel(chID, c, c.codec, c, 10, 10) + // parse custom options + for _, option := range options { + option(s) + } - // store API channel within the client - c.channelsLock.Lock() - c.channels[chID] = channel - c.channelsLock.Unlock() + ch, err := c.newChannel(s.requestSize, s.replySize) + if err != nil { + return nil, err + } + s.channel = ch + s.channel.SetReplyTimeout(s.replyTimeout) // Channel.watchRequests are not started here intentionally, because // requests are sent directly by SendMsg. - return &Stream{ - id: uint32(chID), - conn: c, - ctx: ctx, - channel: channel, - }, nil + return s, nil } func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Message) error { @@ -66,18 +74,18 @@ func (c *Connection) Invoke(ctx context.Context, req api.Message, reply api.Mess if err != nil { return err } + defer func() { _ = stream.Close() }() if err := stream.SendMsg(req); err != nil { return err } - msg, err := stream.RecvMsg() + s := stream.(*Stream) + rep, err := s.recvReply() if err != nil { return err } - if msg.GetMessageName() != reply.GetMessageName() || - msg.GetCrcString() != reply.GetCrcString() { - return fmt.Errorf("unexpected reply: %T %+v", msg, msg) + if err := s.channel.msgCodec.DecodeMsg(rep.data, reply); err != nil { + return err } - reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(msg).Elem()) return nil } @@ -102,10 +110,53 @@ func (s *Stream) SendMsg(msg api.Message) error { if err := s.conn.processRequest(s.channel, req); err != nil { return err } + s.Lock() + s.pkgPath = s.conn.GetMessagePath(msg) + s.Unlock() return nil } func (s *Stream) RecvMsg() (api.Message, error) { + reply, err := s.recvReply() + if err != nil { + return nil, err + } + // resolve message type + s.Lock() + path := s.pkgPath + s.Unlock() + msg, err := s.channel.msgIdentifier.LookupByID(path, reply.msgID) + if err != nil { + return nil, err + } + // allocate message instance + msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) + // decode message data + if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil { + return nil, err + } + return msg, nil +} + +func WithRequestSize(size int) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).requestSize = size + } +} + +func WithReplySize(size int) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).replySize = size + } +} + +func WithReplyTimeout(timeout time.Duration) api.StreamOption { + return func(stream api.Stream) { + stream.(*Stream).replyTimeout = timeout + } +} + +func (s *Stream) recvReply() (*vppReply, error) { if s.conn == nil { return nil, errors.New("stream closed") } @@ -120,18 +171,7 @@ func (s *Stream) RecvMsg() (api.Message, error) { // and stream does not use it return nil, reply.err } - // resolve message type - msg, err := s.channel.msgIdentifier.LookupByID(reply.msgID) - if err != nil { - return nil, err - } - // allocate message instance - msg = reflect.New(reflect.TypeOf(msg).Elem()).Interface().(api.Message) - // decode message data - if err := s.channel.msgCodec.DecodeMsg(reply.data, msg); err != nil { - return nil, err - } - return msg, nil + return reply, nil case <-s.ctx.Done(): return nil, s.ctx.Err()