Fix: generate (un)marshall for memory client messages
[govpp.git] / adapter / socketclient / socketclient.go
index daca005..b2c5d47 100644 (file)
@@ -16,8 +16,10 @@ package socketclient
 
 import (
        "bufio"
-       "bytes"
+       "encoding/binary"
+       "errors"
        "fmt"
+       "git.fd.io/govpp.git/adapter/socketclient/binapi/memclnt"
        "io"
        "net"
        "os"
@@ -27,147 +29,218 @@ import (
        "time"
 
        "github.com/fsnotify/fsnotify"
-       "github.com/lunixbochs/struc"
-       logger "github.com/sirupsen/logrus"
+       "github.com/sirupsen/logrus"
 
        "git.fd.io/govpp.git/adapter"
        "git.fd.io/govpp.git/codec"
-       "git.fd.io/govpp.git/examples/binapi/memclnt"
 )
 
 const (
-       // DefaultSocketName is default VPP API socket file name
-       DefaultSocketName = "/run/vpp-api.sock"
+       // DefaultSocketName is default VPP API socket file path.
+       DefaultSocketName = "/run/vpp/api.sock"
+       // DefaultClientName is used for identifying client in socket registration
+       DefaultClientName = "govppsock"
 )
 
 var (
+
        // DefaultConnectTimeout is default timeout for connecting
        DefaultConnectTimeout = time.Second * 3
        // DefaultDisconnectTimeout is default timeout for discconnecting
        DefaultDisconnectTimeout = time.Millisecond * 100
-       // MaxWaitReady defines maximum duration before waiting for socket file
-       // times out
-       MaxWaitReady = time.Second * 15
-       // ClientName is used for identifying client in socket registration
-       ClientName = "govppsock"
+       // MaxWaitReady defines maximum duration of waiting for socket file
+       MaxWaitReady = time.Second * 10
 )
 
 var (
-       // Debug is global variable that determines debug mode
-       Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
-       // DebugMsgIds is global variable that determines debug mode for msg ids
-       DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
+       debug       = strings.Contains(os.Getenv("DEBUG_GOVPP"), "socketclient")
+       debugMsgIds = strings.Contains(os.Getenv("DEBUG_GOVPP"), "msgtable")
 
-       Log = logger.New() // global logger
+       logger = logrus.New()
+       log    = logger.WithField("logger", "govpp/socketclient")
 )
 
-// init initializes global logger, which logs debug level messages to stdout.
+// init initializes global logger
 func init() {
-       Log.Out = os.Stdout
-       if Debug {
-               Log.Level = logger.DebugLevel
+       if debug {
+               logger.Level = logrus.DebugLevel
+               log.Debug("govpp: debug level enabled for socketclient")
        }
 }
 
-type vppClient struct {
-       sockAddr string
-       conn     *net.UnixConn
-       reader   *bufio.Reader
-       writer   *bufio.Writer
+const socketMissing = `
+------------------------------------------------------------
+ No socket file found at: %s
+ VPP binary API socket file is missing!
+
+  - is VPP running with socket for binapi enabled?
+  - is the correct socket name configured?
+
+ To enable it add following section to your VPP config:
+   socksvr {
+     default
+   }
+------------------------------------------------------------
+`
+
+var warnOnce sync.Once
+
+func (c *socketClient) printMissingSocketMsg() {
+       fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
+}
+
+type socketClient struct {
+       sockAddr   string
+       clientName string
+
+       conn   *net.UnixConn
+       reader *bufio.Reader
+       writer *bufio.Writer
 
        connectTimeout    time.Duration
        disconnectTimeout time.Duration
 
-       cb           adapter.MsgCallback
+       msgCallback  adapter.MsgCallback
        clientIndex  uint32
        msgTable     map[string]uint16
        sockDelMsgId uint16
        writeMu      sync.Mutex
 
+       headerPool *sync.Pool
+
        quit chan struct{}
        wg   sync.WaitGroup
 }
 
-func NewVppClient(sockAddr string) *vppClient {
+func NewVppClient(sockAddr string) *socketClient {
        if sockAddr == "" {
                sockAddr = DefaultSocketName
        }
-       return &vppClient{
+       return &socketClient{
                sockAddr:          sockAddr,
+               clientName:        DefaultClientName,
                connectTimeout:    DefaultConnectTimeout,
                disconnectTimeout: DefaultDisconnectTimeout,
-               cb: func(msgID uint16, data []byte) {
-                       Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
+               headerPool: &sync.Pool{New: func() interface{} {
+                       return make([]byte, 16)
+               }},
+               msgCallback: func(msgID uint16, data []byte) {
+                       log.Debugf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
                },
        }
 }
 
+// SetClientName sets a client name used for identification.
+func (c *socketClient) SetClientName(name string) {
+       c.clientName = name
+}
+
 // SetConnectTimeout sets timeout used during connecting.
-func (c *vppClient) SetConnectTimeout(t time.Duration) {
+func (c *socketClient) SetConnectTimeout(t time.Duration) {
        c.connectTimeout = t
 }
 
 // SetDisconnectTimeout sets timeout used during disconnecting.
-func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
+func (c *socketClient) SetDisconnectTimeout(t time.Duration) {
        c.disconnectTimeout = t
 }
 
+func (c *socketClient) SetMsgCallback(cb adapter.MsgCallback) {
+       log.Debug("SetMsgCallback")
+       c.msgCallback = cb
+}
+
+const legacySocketName = "/run/vpp-api.sock"
+
+func (c *socketClient) checkLegacySocket() bool {
+       if c.sockAddr == legacySocketName {
+               return false
+       }
+       log.Debugf("checking legacy socket: %s", legacySocketName)
+       // check if socket exists
+       if _, err := os.Stat(c.sockAddr); err == nil {
+               return false // socket exists
+       } else if !os.IsNotExist(err) {
+               return false // some other error occurred
+       }
+       // check if legacy socket exists
+       if _, err := os.Stat(legacySocketName); err == nil {
+               // legacy socket exists, update sockAddr
+               c.sockAddr = legacySocketName
+               return true
+       }
+       // no socket socket found
+       return false
+}
+
 // WaitReady checks socket file existence and waits for it if necessary
-func (c *vppClient) WaitReady() error {
-       // check if file at the path already exists
+func (c *socketClient) WaitReady() error {
+       // check if socket already exists
        if _, err := os.Stat(c.sockAddr); err == nil {
+               return nil // socket exists, we are ready
+       } else if !os.IsNotExist(err) {
+               return err // some other error occurred
+       }
+
+       if c.checkLegacySocket() {
                return nil
-       } else if os.IsExist(err) {
-               return err
        }
 
-       // if not, watch for it
+       // socket does not exist, watch for it
        watcher, err := fsnotify.NewWatcher()
        if err != nil {
                return err
        }
        defer func() {
                if err := watcher.Close(); err != nil {
-                       Log.Errorf("failed to close file watcher: %v", err)
+                       log.Debugf("failed to close file watcher: %v", err)
                }
        }()
 
-       // start watching directory
+       // start directory watcher
        if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
                return err
        }
 
+       timeout := time.NewTimer(MaxWaitReady)
        for {
                select {
-               case <-time.After(MaxWaitReady):
-                       return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
+               case <-timeout.C:
+                       if c.checkLegacySocket() {
+                               return nil
+                       }
+                       return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
+
                case e := <-watcher.Errors:
                        return e
+
                case ev := <-watcher.Events:
-                       Log.Debugf("watcher event: %+v", ev)
-                       if ev.Name == c.sockAddr {
-                               if (ev.Op & fsnotify.Create) == fsnotify.Create {
-                                       // socket was created, we are ready
-                                       return nil
-                               }
+                       log.Debugf("watcher event: %+v", ev)
+                       if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
+                               // socket created, we are ready
+                               return nil
                        }
                }
        }
 }
 
-func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
-       Log.Debug("SetMsgCallback")
-       c.cb = cb
-}
+func (c *socketClient) Connect() error {
+       c.checkLegacySocket()
 
-func (c *vppClient) Connect() error {
-       Log.Debugf("Connecting to: %v", c.sockAddr)
+       // check if socket exists
+       if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
+               warnOnce.Do(c.printMissingSocketMsg)
+               return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
+       } else if err != nil {
+               return fmt.Errorf("VPP API socket error: %v", err)
+       }
 
        if err := c.connect(c.sockAddr); err != nil {
                return err
        }
 
        if err := c.open(); err != nil {
+               _ = c.disconnect()
                return err
        }
 
@@ -178,261 +251,281 @@ func (c *vppClient) Connect() error {
        return nil
 }
 
-func (c *vppClient) connect(sockAddr string) error {
+func (c *socketClient) Disconnect() error {
+       if c.conn == nil {
+               return nil
+       }
+       log.Debugf("Disconnecting..")
+
+       close(c.quit)
+
+       if err := c.conn.CloseRead(); err != nil {
+               log.Debugf("closing readMsg failed: %v", err)
+       }
+
+       // wait for readerLoop to return
+       c.wg.Wait()
+
+       if err := c.close(); err != nil {
+               log.Debugf("closing failed: %v", err)
+       }
+
+       if err := c.disconnect(); err != nil {
+               return err
+       }
+
+       return nil
+}
+
+const defaultBufferSize = 4096
+
+func (c *socketClient) connect(sockAddr string) error {
        addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
 
+       log.Debugf("Connecting to: %v", c.sockAddr)
+
        conn, err := net.DialUnix("unix", nil, addr)
        if err != nil {
                // we try different type of socket for backwards compatbility with VPP<=19.04
                if strings.Contains(err.Error(), "wrong type for socket") {
                        addr.Net = "unixpacket"
-                       Log.Debugf("%s, retrying connect with type unixpacket", err)
+                       log.Debugf("%s, retrying connect with type unixpacket", err)
                        conn, err = net.DialUnix("unixpacket", nil, addr)
                }
                if err != nil {
-                       Log.Debugf("Connecting to socket %s failed: %s", addr, err)
+                       log.Debugf("Connecting to socket %s failed: %s", addr, err)
                        return err
                }
        }
 
        c.conn = conn
-       c.reader = bufio.NewReader(c.conn)
-       c.writer = bufio.NewWriter(c.conn)
+       log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
 
-       Log.Debugf("Connected to socket: %v", addr)
+       c.reader = bufio.NewReaderSize(c.conn, defaultBufferSize)
+       c.writer = bufio.NewWriterSize(c.conn, defaultBufferSize)
 
        return nil
 }
 
+func (c *socketClient) disconnect() error {
+       log.Debugf("Closing socket")
+       if err := c.conn.Close(); err != nil {
+               log.Debugln("Closing socket failed:", err)
+               return err
+       }
+       return nil
+}
+
 const (
        sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
        createMsgContext = byte(123)
        deleteMsgContext = byte(124)
 )
 
-func (c *vppClient) open() error {
-       msgCodec := new(codec.MsgCodec)
+func (c *socketClient) open() error {
+       var msgCodec = codec.DefaultCodec
 
+       // Request socket client create
        req := &memclnt.SockclntCreate{
-               Name: []byte(ClientName),
+               Name: c.clientName,
        }
        msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
        if err != nil {
-               Log.Debugln("Encode error:", err)
+               log.Debugln("Encode  error:", err)
                return err
        }
        // set non-0 context
        msg[5] = createMsgContext
 
-       if err := c.write(msg); err != nil {
-               Log.Debugln("Write error: ", err)
-               return err
-       }
-
-       readDeadline := time.Now().Add(c.connectTimeout)
-       if err := c.conn.SetReadDeadline(readDeadline); err != nil {
+       if err := c.writeMsg(msg); err != nil {
+               log.Debugln("Write error: ", err)
                return err
        }
-       msgReply, err := c.read()
+       msgReply, err := c.readMsgTimeout(nil, c.connectTimeout)
        if err != nil {
-               Log.Println("Read error:", err)
-               return err
-       }
-       // reset read deadline
-       if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
+               log.Println("Read error:", err)
                return err
        }
 
        reply := new(memclnt.SockclntCreateReply)
        if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
-               Log.Println("Decode error:", err)
+               log.Println("Decoding sockclnt_create_reply failed:", err)
                return err
+       } else if reply.Response != 0 {
+               return fmt.Errorf("sockclnt_create_reply: response error (%d)", reply.Response)
        }
 
-       Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
+       log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
                reply.Response, reply.Index, reply.Count)
 
        c.clientIndex = reply.Index
        c.msgTable = make(map[string]uint16, reply.Count)
        for _, x := range reply.MessageTable {
-               name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
+               msgName := strings.Split(x.Name, "\x00")[0]
+               name := strings.TrimSuffix(msgName, "\x13")
                c.msgTable[name] = x.Index
                if strings.HasPrefix(name, "sockclnt_delete_") {
                        c.sockDelMsgId = x.Index
                }
-               if DebugMsgIds {
-                       Log.Debugf(" - %4d: %q", x.Index, name)
+               if debugMsgIds {
+                       log.Debugf(" - %4d: %q", x.Index, name)
                }
        }
 
        return nil
 }
 
-func (c *vppClient) Disconnect() error {
-       if c.conn == nil {
-               return nil
-       }
-       Log.Debugf("Disconnecting..")
-
-       close(c.quit)
-
-       // force readerLoop to timeout
-       if err := c.conn.SetReadDeadline(time.Now()); err != nil {
-               return err
-       }
-
-       // wait for readerLoop to return
-       c.wg.Wait()
-
-       if err := c.close(); err != nil {
-               return err
-       }
-
-       if err := c.conn.Close(); err != nil {
-               Log.Debugln("Closing socket failed:", err)
-               return err
-       }
-
-       return nil
-}
-
-func (c *vppClient) close() error {
-       msgCodec := new(codec.MsgCodec)
+func (c *socketClient) close() error {
+       var msgCodec = codec.DefaultCodec
 
        req := &memclnt.SockclntDelete{
                Index: c.clientIndex,
        }
        msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
        if err != nil {
-               Log.Debugln("Encode error:", err)
+               log.Debugln("Encode error:", err)
                return err
        }
        // set non-0 context
        msg[5] = deleteMsgContext
 
-       Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
-       if err := c.write(msg); err != nil {
-               Log.Debugln("Write error: ", err)
-               return err
-       }
+       log.Debugf("sending socklntDel (%d bytes): % 0X", len(msg), msg)
 
-       readDeadline := time.Now().Add(c.disconnectTimeout)
-       if err := c.conn.SetReadDeadline(readDeadline); err != nil {
+       if err := c.writeMsg(msg); err != nil {
+               log.Debugln("Write error: ", err)
                return err
        }
-       msgReply, err := c.read()
+
+       msgReply, err := c.readMsgTimeout(nil, c.disconnectTimeout)
        if err != nil {
-               Log.Debugln("Read error:", err)
                if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
                        // we accept timeout for reply
                        return nil
                }
-               return err
-       }
-       // reset read deadline
-       if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
+               log.Debugln("Read error:", err)
                return err
        }
 
        reply := new(memclnt.SockclntDeleteReply)
        if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
-               Log.Debugln("Decode error:", err)
+               log.Debugln("Decoding sockclnt_delete_reply failed:", err)
                return err
+       } else if reply.Response != 0 {
+               return fmt.Errorf("sockclnt_delete_reply: response error (%d)", reply.Response)
        }
 
-       Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
-
        return nil
 }
 
-func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
-       msg := msgName + "_" + msgCrc
-       msgID, ok := c.msgTable[msg]
-       if !ok {
-               return 0, fmt.Errorf("unknown message: %q", msg)
+func (c *socketClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
+       if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok {
+               return msgID, nil
+       }
+       return 0, &adapter.UnknownMsgError{
+               MsgName: msgName,
+               MsgCrc:  msgCrc,
        }
-       return msgID, nil
-}
-
-type reqHeader struct {
-       // MsgID uint16
-       ClientIndex uint32
-       Context     uint32
 }
 
-func (c *vppClient) SendMsg(context uint32, data []byte) error {
-       h := &reqHeader{
-               ClientIndex: c.clientIndex,
-               Context:     context,
-       }
-       buf := new(bytes.Buffer)
-       if err := struc.Pack(buf, h); err != nil {
-               return err
+func (c *socketClient) SendMsg(context uint32, data []byte) error {
+       if len(data) < 10 {
+               return fmt.Errorf("invalid message data, length must be at least 10 bytes")
        }
-       copy(data[2:], buf.Bytes())
+       setMsgRequestHeader(data, c.clientIndex, context)
 
-       Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
+       if debug {
+               log.Debugf("sendMsg (%d) context=%v client=%d: % 02X", len(data), context, c.clientIndex, data)
+       }
 
-       if err := c.write(data); err != nil {
-               Log.Debugln("write error: ", err)
+       if err := c.writeMsg(data); err != nil {
+               log.Debugln("writeMsg error: ", err)
                return err
        }
 
        return nil
 }
 
-func (c *vppClient) write(msg []byte) error {
-       h := &msgheader{
-               DataLen: uint32(len(msg)),
-       }
-       buf := new(bytes.Buffer)
-       if err := struc.Pack(buf, h); err != nil {
-               return err
-       }
-       header := buf.Bytes()
+// setMsgRequestHeader sets client index and context in the message request header
+//
+// Message request has following structure:
+//
+//    type msgRequestHeader struct {
+//        MsgID       uint16
+//        ClientIndex uint32
+//        Context     uint32
+//    }
+//
+func setMsgRequestHeader(data []byte, clientIndex, context uint32) {
+       // message ID is already set
+       binary.BigEndian.PutUint32(data[2:6], clientIndex)
+       binary.BigEndian.PutUint32(data[6:10], context)
+}
 
-       // we lock to prevent mixing multiple message sends
+func (c *socketClient) writeMsg(msg []byte) error {
+       // we lock to prevent mixing multiple message writes
        c.writeMu.Lock()
        defer c.writeMu.Unlock()
 
-       if n, err := c.writer.Write(header); err != nil {
+       header := c.headerPool.Get().([]byte)
+       err := writeMsgHeader(c.writer, header, len(msg))
+       if err != nil {
+               return err
+       }
+       c.headerPool.Put(header)
+
+       if err := writeMsgData(c.writer, msg, c.writer.Size()); err != nil {
                return err
-       } else {
-               Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
        }
 
        if err := c.writer.Flush(); err != nil {
                return err
        }
 
-       for i := 0; i <= len(msg)/c.writer.Size(); i++ {
-               x := i*c.writer.Size() + c.writer.Size()
+       log.Debugf(" -- writeMsg done")
+
+       return nil
+}
+
+func writeMsgHeader(w io.Writer, header []byte, dataLen int) error {
+       binary.BigEndian.PutUint32(header[8:12], uint32(dataLen))
+
+       n, err := w.Write(header)
+       if err != nil {
+               return err
+       }
+       if debug {
+               log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
+       }
+
+       return nil
+}
+
+func writeMsgData(w io.Writer, msg []byte, writerSize int) error {
+       for i := 0; i <= len(msg)/writerSize; i++ {
+               x := i*writerSize + writerSize
                if x > len(msg) {
                        x = len(msg)
                }
-               Log.Debugf("x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/c.writer.Size())
-               if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
-                       return err
-               } else {
-                       Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
+               if debug {
+                       log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
                }
-               if err := c.writer.Flush(); err != nil {
+               n, err := w.Write(msg[i*writerSize : x])
+               if err != nil {
                        return err
                }
-
+               if debug {
+                       log.Debugf(" - data sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
+               }
        }
-
        return nil
 }
 
-type msgHeader struct {
-       MsgID   uint16
-       Context uint32
-}
-
-func (c *vppClient) readerLoop() {
+func (c *socketClient) readerLoop() {
        defer c.wg.Done()
-       defer Log.Debugf("reader quit")
+       defer log.Debugf("reader loop done")
+
+       var buf [8192]byte
+
        for {
                select {
                case <-c.quit:
@@ -440,71 +533,118 @@ func (c *vppClient) readerLoop() {
                default:
                }
 
-               msg, err := c.read()
+               msg, err := c.readMsg(buf[:])
                if err != nil {
                        if isClosedError(err) {
                                return
                        }
-                       Log.Debugf("read failed: %v", err)
+                       log.Debugf("readMsg error: %v", err)
                        continue
                }
-               h := new(msgHeader)
-               if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
-                       Log.Debugf("unpacking header failed: %v", err)
-                       continue
+
+               msgID, context := getMsgReplyHeader(msg)
+               if debug {
+                       log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), msgID, context)
                }
 
-               Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
-               c.cb(h.MsgID, msg)
+               c.msgCallback(msgID, msg)
        }
 }
 
-type msgheader struct {
-       Q               int    `struc:"uint64"`
-       DataLen         uint32 `struc:"uint32"`
-       GcMarkTimestamp uint32 `struc:"uint32"`
+// getMsgReplyHeader gets message ID and context from the message reply header
+//
+// Message reply has following structure:
+//
+//    type msgReplyHeader struct {
+//        MsgID       uint16
+//        Context     uint32
+//    }
+//
+func getMsgReplyHeader(msg []byte) (msgID uint16, context uint32) {
+       msgID = binary.BigEndian.Uint16(msg[0:2])
+       context = binary.BigEndian.Uint32(msg[2:6])
+       return
 }
 
-func (c *vppClient) read() ([]byte, error) {
-       Log.Debug("reading next msg..")
-
-       header := make([]byte, 16)
+func (c *socketClient) readMsgTimeout(buf []byte, timeout time.Duration) ([]byte, error) {
+       // set read deadline
+       readDeadline := time.Now().Add(timeout)
+       if err := c.conn.SetReadDeadline(readDeadline); err != nil {
+               return nil, err
+       }
 
-       n, err := io.ReadAtLeast(c.reader, header, 16)
+       // read message
+       msgReply, err := c.readMsg(buf)
        if err != nil {
                return nil, err
-       } else if n == 0 {
-               Log.Debugln("zero bytes header")
-               return nil, nil
        }
-       if n != 16 {
-               Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
-               return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
+
+       // reset read deadline
+       if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
+               return nil, err
        }
-       Log.Debugf(" - read header %d bytes: % 0X", n, header)
 
-       h := &msgheader{}
-       if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
+       return msgReply, nil
+}
+
+func (c *socketClient) readMsg(buf []byte) ([]byte, error) {
+       log.Debug("reading msg..")
+
+       header := c.headerPool.Get().([]byte)
+       msgLen, err := readMsgHeader(c.reader, header)
+       if err != nil {
                return nil, err
        }
-       Log.Debugf(" - decoded header: %+v", h)
+       c.headerPool.Put(header)
 
-       msgLen := int(h.DataLen)
-       msg := make([]byte, msgLen)
+       msg, err := readMsgData(c.reader, buf, msgLen)
 
-       n, err = c.reader.Read(msg)
+       log.Debugf(" -- readMsg done (buffered: %d)", c.reader.Buffered())
+
+       return msg, nil
+}
+
+func readMsgHeader(r io.Reader, header []byte) (int, error) {
+       n, err := io.ReadAtLeast(r, header, 16)
+       if err != nil {
+               return 0, err
+       }
+       if n == 0 {
+               log.Debugln("zero bytes header")
+               return 0, nil
+       } else if n != 16 {
+               log.Debugf("invalid header (%d bytes): % 0X", n, header[:n])
+               return 0, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
+       }
+
+       dataLen := binary.BigEndian.Uint32(header[8:12])
+
+       return int(dataLen), nil
+}
+
+func readMsgData(r io.Reader, buf []byte, dataLen int) ([]byte, error) {
+       var msg []byte
+       if buf == nil || len(buf) < dataLen {
+               msg = make([]byte, dataLen)
+       } else {
+               msg = buf[0:dataLen]
+       }
+
+       n, err := r.Read(msg)
        if err != nil {
                return nil, err
        }
-       Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
+       if debug {
+               log.Debugf(" - read data (%d bytes): % 0X", n, msg[:n])
+       }
 
-       if msgLen > n {
-               remain := msgLen - n
-               Log.Debugf("continue read for another %d bytes", remain)
+       if dataLen > n {
+               remain := dataLen - n
+               log.Debugf("continue reading remaining %d bytes", remain)
                view := msg[n:]
 
                for remain > 0 {
-                       nbytes, err := c.reader.Read(view)
+                       nbytes, err := r.Read(view)
                        if err != nil {
                                return nil, err
                        } else if nbytes == 0 {
@@ -512,7 +652,7 @@ func (c *vppClient) read() ([]byte, error) {
                        }
 
                        remain -= nbytes
-                       Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
+                       log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
 
                        view = view[nbytes:]
                }
@@ -522,7 +662,7 @@ func (c *vppClient) read() ([]byte, error) {
 }
 
 func isClosedError(err error) bool {
-       if err == io.EOF {
+       if errors.Is(err, io.EOF) {
                return true
        }
        return strings.HasSuffix(err.Error(), "use of closed network connection")