connection: prevent channel ID overlap 50/35150/2
authorLukas Vogel <vogel@anapaya.net>
Mon, 31 Jan 2022 10:33:37 +0000 (11:33 +0100)
committerLukas Vogel <vogel@anapaya.net>
Mon, 31 Jan 2022 10:53:39 +0000 (11:53 +0100)
When creating a new channel and the channel ID wraps around, make sure
to not re-use a channel ID that is still in use. Re-using the channel
ID usually means that the connection health check will stop working and
other things might break as well.

Also rename maxChannelID to nextChannelID and use a lock to guard access
instead of using an atomic. The lock does anyway need to be acquired
because to put the entry in the map.

This commit was inspired by the following PR on Github:
https://github.com/FDio/govpp/pull/14.

Change-Id: I8c1a4ca63a53d07a6482b6047a3005065168c0b4
Signed-off-by: Lukas Vogel <vogel@anapaya.net>
core/channel.go
core/connection.go
core/stream.go

index 1086c36..112c14e 100644 (file)
@@ -19,7 +19,6 @@ import (
        "fmt"
        "reflect"
        "strings"
-       "sync/atomic"
        "time"
 
        "github.com/sirupsen/logrus"
@@ -110,11 +109,9 @@ type Channel struct {
        receiveReplyTimeout time.Duration // maximum time that we wait for receiver to consume reply
 }
 
-func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) *Channel {
+func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) (*Channel, error) {
        // create new channel
-       chID := uint16(atomic.AddUint32(&c.maxChannelID, 1) & 0x7fff)
        channel := &Channel{
-               id:                  chID,
                conn:                c,
                msgCodec:            c.codec,
                msgIdentifier:       c,
@@ -126,10 +123,22 @@ func (c *Connection) newChannel(reqChanBufSize, replyChanBufSize int) *Channel {
 
        // store API channel within the client
        c.channelsLock.Lock()
-       c.channels[chID] = channel
+       if len(c.channels) >= 0x7fff {
+               return nil, errors.New("all channel IDs are used")
+       }
+       for {
+               c.nextChannelID++
+               chID := c.nextChannelID & 0x7fff
+               _, ok := c.channels[chID]
+               if !ok {
+                       channel.id = chID
+                       c.channels[chID] = channel
+                       break
+               }
+       }
        c.channelsLock.Unlock()
 
-       return channel
+       return channel, nil
 }
 
 func (ch *Channel) GetID() uint16 {
index 442eb51..1bfcae5 100644 (file)
@@ -109,9 +109,9 @@ type Connection struct {
        msgIDs       map[string]uint16                 // map of message IDs indexed by message name + CRC
        msgMapByPath map[string]map[uint16]api.Message // map of messages indexed by message ID which are indexed by path
 
-       maxChannelID uint32              // maximum used channel ID (the real limit is 2^15, 32-bit is used for atomic operations)
-       channelsLock sync.RWMutex        // lock for the channels map
-       channels     map[uint16]*Channel // map of all API channels indexed by the channel ID
+       channelsLock  sync.RWMutex        // lock for the channels map and the channel ID
+       nextChannelID uint16              // next potential channel ID (the real limit is 2^15)
+       channels      map[uint16]*Channel // map of all API channels indexed by the channel ID
 
        subscriptionsLock sync.RWMutex                  // lock for the subscriptions map
        subscriptions     map[uint16][]*subscriptionCtx // map od all notification subscriptions indexed by message ID
@@ -248,7 +248,10 @@ func (c *Connection) newAPIChannel(reqChanBufSize, replyChanBufSize int) (*Chann
                return nil, errors.New("nil connection passed in")
        }
 
-       channel := c.newChannel(reqChanBufSize, replyChanBufSize)
+       channel, err := c.newChannel(reqChanBufSize, replyChanBufSize)
+       if err != nil {
+               return nil, err
+       }
 
        // start watching on the request channel
        go c.watchRequests(channel)
index 2f639b0..67236f1 100644 (file)
@@ -56,7 +56,11 @@ func (c *Connection) NewStream(ctx context.Context, options ...api.StreamOption)
                option(s)
        }
 
-       s.channel = c.newChannel(s.requestSize, s.replySize)
+       ch, err := c.newChannel(s.requestSize, s.replySize)
+       if err != nil {
+               return nil, err
+       }
+       s.channel = ch
        s.channel.SetReplyTimeout(s.replyTimeout)
 
        // Channel.watchRequests are not started here intentionally, because