Reload stats socket when VPP restarts
[govpp.git] / adapter / statsclient / statsclient.go
index a39cbd5..8693410 100644 (file)
@@ -25,6 +25,7 @@ import (
        "time"
 
        "git.fd.io/govpp.git/adapter"
+       "github.com/fsnotify/fsnotify"
        "github.com/ftrvxmtrx/fd"
        logger "github.com/sirupsen/logrus"
 )
@@ -80,6 +81,9 @@ type StatsClient struct {
        headerData  []byte
        isConnected bool
 
+       // to quit socket monitor
+       done chan struct{}
+
        statSegment
 }
 
@@ -92,39 +96,29 @@ func NewStatsClient(sockAddr string) *StatsClient {
                sockAddr: sockAddr,
        }
 }
-// Connect to the VPP stats socket
+
+// Connect to validated VPP stats socket and start monitoring
+// socket file changes
 func (sc *StatsClient) Connect() (err error) {
-       // check if socket exists
-       if _, err := os.Stat(sc.sockAddr); os.IsNotExist(err) {
-               fmt.Fprintf(os.Stderr, socketMissing, sc.sockAddr)
-               return fmt.Errorf("stats socket file %s does not exist", sc.sockAddr)
-       } else if err != nil {
-               return fmt.Errorf("stats socket error: %v", err)
-       }
        if sc.isConnected {
                return fmt.Errorf("already connected")
        }
+       if err := sc.validate(); err != nil {
+               return err
+       }
+       sc.done = make(chan struct{})
+       sc.monitorSocket()
        if sc.statSegment, err = sc.connect(); err != nil {
                return err
        }
-       sc.isConnected = true
        return nil
 }
 
-// Disconnect from the socket and unmap shared memory
+// Disconnect from the socket, unmap shared memory and terminate
+// socket monitor
 func (sc *StatsClient) Disconnect() error {
-       sc.isConnected = false
-       if sc.headerData == nil {
-               return nil
-       }
-       if err := syscall.Munmap(sc.headerData); err != nil {
-               Log.Debugf("unmapping shared memory failed: %v", err)
-               return fmt.Errorf("unmapping shared memory failed: %v", err)
-       }
-       sc.headerData = nil
-
-       Log.Debugf("successfully unmapped shared memory")
-       return nil
+       close(sc.done)
+       return sc.disconnect()
 }
 
 func (sc *StatsClient) ListStats(patterns ...string) ([]string, error) {
@@ -293,7 +287,20 @@ func (sc *StatsClient) UpdateDir(dir *adapter.StatDir) (err error) {
        return nil
 }
 
-func (sc *StatsClient) connect() (statSegment, error) {
+// validate file presence by retrieving its file info
+func (sc *StatsClient) validate() error {
+       if _, err := os.Stat(sc.sockAddr); os.IsNotExist(err) {
+               fmt.Fprintf(os.Stderr, socketMissing, sc.sockAddr)
+               return fmt.Errorf("stats socket file %s does not exist", sc.sockAddr)
+       } else if err != nil {
+               return fmt.Errorf("stats socket error: %v", err)
+       }
+       return nil
+}
+
+// connect to the socket and map it into the memory. According to the
+// header version info, an appropriate segment handler is returned
+func (sc *StatsClient) connect() (ss statSegment, err error) {
        addr := net.UnixAddr{
                Net:  "unixpacket",
                Name: sc.sockAddr,
@@ -343,13 +350,81 @@ func (sc *StatsClient) connect() (statSegment, error) {
        version := getVersion(sc.headerData)
        switch version {
        case 1:
-               return newStatSegmentV1(sc.headerData, size), nil
+               ss = newStatSegmentV1(sc.headerData, size)
        case 2:
-               return newStatSegmentV2(sc.headerData, size), nil
+               ss = newStatSegmentV2(sc.headerData, size)
        default:
                return nil, fmt.Errorf("stat segment version is not supported: %v (min: %v, max: %v)",
                        version, minVersion, maxVersion)
        }
+       sc.isConnected = true
+       return ss, nil
+}
+
+// reconnect disconnects from the socket, re-validates it and
+// connects again
+func (sc *StatsClient) reconnect() (err error) {
+       if err = sc.disconnect(); err != nil {
+               return fmt.Errorf("error disconnecting socket: %v", err)
+       }
+       if err = sc.validate(); err != nil {
+               return fmt.Errorf("error validating socket: %v", err)
+       }
+       if sc.statSegment, err = sc.connect(); err != nil {
+               return fmt.Errorf("error connecting socket: %v", err)
+       }
+       return nil
+}
+
+// disconnect unmaps socket data from the memory and resets the header
+func (sc *StatsClient) disconnect() error {
+       sc.isConnected = false
+       if sc.headerData == nil {
+               return nil
+       }
+       if err := syscall.Munmap(sc.headerData); err != nil {
+               Log.Debugf("unmapping shared memory failed: %v", err)
+               return fmt.Errorf("unmapping shared memory failed: %v", err)
+       }
+       sc.headerData = nil
+
+       Log.Debugf("successfully unmapped shared memory")
+       return nil
+}
+
+func (sc *StatsClient) monitorSocket() {
+       watcher, err := fsnotify.NewWatcher()
+       if err != nil {
+               Log.Errorf("error starting socket monitor: %v", err)
+               return
+       }
+
+       go func() {
+               for {
+                       select {
+                       case event := <-watcher.Events:
+                               if event.Op == fsnotify.Remove {
+                                       if err := sc.reconnect(); err != nil {
+                                               Log.Errorf("error occurred during socket reconnect: %v", err)
+                                       }
+                                       // path must be re-added to the watcher
+                                       if err = watcher.Add(sc.sockAddr); err != nil {
+                                               Log.Errorf("failed to add socket address to the watcher: %v", err)
+                                       }
+                               }
+                       case err := <-watcher.Errors:
+                               Log.Errorf("socket monitor delivered error event: %v", err)
+                       case <-sc.done:
+                               err := watcher.Close()
+                               Log.Debugf("socket monitor closed (error: %v)", err)
+                               return
+                       }
+               }
+       }()
+
+       if err := watcher.Add(sc.sockAddr); err != nil {
+               Log.Errorf("failed to add socket address to the watcher: %v", err)
+       }
 }
 
 // Starts monitoring 'inProgress' field. Returns stats segment