Add statsclient - pure Go implementation for stats API
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35         "git.fd.io/govpp.git/examples/binapi/memclnt"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file name
40         DefaultSocketName = "/run/vpp-api.sock"
41 )
42
43 var (
44         // DefaultConnectTimeout is default timeout for connecting
45         DefaultConnectTimeout = time.Second * 3
46         // DefaultDisconnectTimeout is default timeout for discconnecting
47         DefaultDisconnectTimeout = time.Second
48         // MaxWaitReady defines maximum duration before waiting for socket file
49         // times out
50         MaxWaitReady = time.Second * 15
51         // ClientName is used for identifying client in socket registration
52         ClientName = "govppsock"
53 )
54
55 var (
56         // Debug is global variable that determines debug mode
57         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
58         // DebugMsgIds is global variable that determines debug mode for msg ids
59         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
60
61         Log = logger.New() // global logger
62 )
63
64 // init initializes global logger, which logs debug level messages to stdout.
65 func init() {
66         Log.Out = os.Stdout
67         if Debug {
68                 Log.Level = logger.DebugLevel
69         }
70 }
71
72 type vppClient struct {
73         sockAddr string
74         conn     *net.UnixConn
75         reader   *bufio.Reader
76         writer   *bufio.Writer
77
78         connectTimeout    time.Duration
79         disconnectTimeout time.Duration
80
81         cb           adapter.MsgCallback
82         clientIndex  uint32
83         msgTable     map[string]uint16
84         sockDelMsgId uint16
85         writeMu      sync.Mutex
86
87         quit chan struct{}
88         wg   sync.WaitGroup
89 }
90
91 func NewVppClient(sockAddr string) *vppClient {
92         if sockAddr == "" {
93                 sockAddr = DefaultSocketName
94         }
95         return &vppClient{
96                 sockAddr:          sockAddr,
97                 connectTimeout:    DefaultConnectTimeout,
98                 disconnectTimeout: DefaultDisconnectTimeout,
99                 cb: func(msgID uint16, data []byte) {
100                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
101                 },
102         }
103 }
104
105 // SetConnectTimeout sets timeout used during connecting.
106 func (c *vppClient) SetConnectTimeout(t time.Duration) {
107         c.connectTimeout = t
108 }
109
110 // SetDisconnectTimeout sets timeout used during disconnecting.
111 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
112         c.disconnectTimeout = t
113 }
114
115 // WaitReady checks socket file existence and waits for it if necessary
116 func (c *vppClient) WaitReady() error {
117         // check if file at the path already exists
118         if _, err := os.Stat(c.sockAddr); err == nil {
119                 return nil
120         } else if os.IsExist(err) {
121                 return err
122         }
123
124         // if not, watch for it
125         watcher, err := fsnotify.NewWatcher()
126         if err != nil {
127                 return err
128         }
129         defer func() {
130                 if err := watcher.Close(); err != nil {
131                         Log.Errorf("failed to close file watcher: %v", err)
132                 }
133         }()
134
135         // start watching directory
136         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
137                 return err
138         }
139
140         for {
141                 select {
142                 case <-time.After(MaxWaitReady):
143                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
144                 case e := <-watcher.Errors:
145                         return e
146                 case ev := <-watcher.Events:
147                         Log.Debugf("watcher event: %+v", ev)
148                         if ev.Name == c.sockAddr {
149                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
150                                         // socket was created, we are ready
151                                         return nil
152                                 }
153                         }
154                 }
155         }
156 }
157
158 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
159         Log.Debug("SetMsgCallback")
160         c.cb = cb
161 }
162
163 func (c *vppClient) Connect() error {
164         Log.Debugf("Connecting to: %v", c.sockAddr)
165
166         if err := c.connect(c.sockAddr); err != nil {
167                 return err
168         }
169
170         if err := c.open(); err != nil {
171                 return err
172         }
173
174         c.quit = make(chan struct{})
175         c.wg.Add(1)
176         go c.readerLoop()
177
178         return nil
179 }
180
181 func (c *vppClient) connect(sockAddr string) error {
182         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
183
184         conn, err := net.DialUnix("unix", nil, addr)
185         if err != nil {
186                 // we try different type of socket for backwards compatbility with VPP<=19.04
187                 if strings.Contains(err.Error(), "wrong type for socket") {
188                         addr.Net = "unixpacket"
189                         Log.Debugf("%s, retrying connect with type unixpacket", err)
190                         conn, err = net.DialUnix("unixpacket", nil, addr)
191                 }
192                 if err != nil {
193                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
194                         return err
195                 }
196         }
197
198         c.conn = conn
199         c.reader = bufio.NewReader(c.conn)
200         c.writer = bufio.NewWriter(c.conn)
201
202         Log.Debugf("Connected to socket: %v", addr)
203
204         return nil
205 }
206
207 const (
208         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
209         createMsgContext = byte(123)
210         deleteMsgContext = byte(124)
211 )
212
213 func (c *vppClient) open() error {
214         msgCodec := new(codec.MsgCodec)
215
216         req := &memclnt.SockclntCreate{
217                 Name: []byte(ClientName),
218         }
219         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
220         if err != nil {
221                 Log.Debugln("Encode error:", err)
222                 return err
223         }
224         // set non-0 context
225         msg[5] = createMsgContext
226
227         if err := c.write(msg); err != nil {
228                 Log.Debugln("Write error: ", err)
229                 return err
230         }
231
232         readDeadline := time.Now().Add(c.connectTimeout)
233         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
234                 return err
235         }
236         msgReply, err := c.read()
237         if err != nil {
238                 Log.Println("Read error:", err)
239                 return err
240         }
241         // reset read deadline
242         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
243                 return err
244         }
245
246         reply := new(memclnt.SockclntCreateReply)
247         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
248                 Log.Println("Decode error:", err)
249                 return err
250         }
251
252         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
253                 reply.Response, reply.Index, reply.Count)
254
255         c.clientIndex = reply.Index
256         c.msgTable = make(map[string]uint16, reply.Count)
257         for _, x := range reply.MessageTable {
258                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
259                 c.msgTable[name] = x.Index
260                 if strings.HasPrefix(name, "sockclnt_delete_") {
261                         c.sockDelMsgId = x.Index
262                 }
263                 if DebugMsgIds {
264                         Log.Debugf(" - %4d: %q", x.Index, name)
265                 }
266         }
267
268         return nil
269 }
270
271 func (c *vppClient) Disconnect() error {
272         if c.conn == nil {
273                 return nil
274         }
275         Log.Debugf("Disconnecting..")
276
277         close(c.quit)
278
279         // force readerLoop to timeout
280         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
281                 return err
282         }
283
284         // wait for readerLoop to return
285         c.wg.Wait()
286
287         if err := c.close(); err != nil {
288                 return err
289         }
290
291         if err := c.conn.Close(); err != nil {
292                 Log.Debugln("Closing socket failed:", err)
293                 return err
294         }
295
296         return nil
297 }
298
299 func (c *vppClient) close() error {
300         msgCodec := new(codec.MsgCodec)
301
302         req := &memclnt.SockclntDelete{
303                 Index: c.clientIndex,
304         }
305         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
306         if err != nil {
307                 Log.Debugln("Encode error:", err)
308                 return err
309         }
310         // set non-0 context
311         msg[5] = deleteMsgContext
312
313         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
314         if err := c.write(msg); err != nil {
315                 Log.Debugln("Write error: ", err)
316                 return err
317         }
318
319         readDeadline := time.Now().Add(c.disconnectTimeout)
320         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
321                 return err
322         }
323         msgReply, err := c.read()
324         if err != nil {
325                 Log.Debugln("Read error:", err)
326                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
327                         // we accept timeout for reply
328                         return nil
329                 }
330                 return err
331         }
332         // reset read deadline
333         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
334                 return err
335         }
336
337         reply := new(memclnt.SockclntDeleteReply)
338         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
339                 Log.Debugln("Decode error:", err)
340                 return err
341         }
342
343         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
344
345         return nil
346 }
347
348 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
349         msg := msgName + "_" + msgCrc
350         msgID, ok := c.msgTable[msg]
351         if !ok {
352                 return 0, fmt.Errorf("unknown message: %q", msg)
353         }
354         return msgID, nil
355 }
356
357 type reqHeader struct {
358         // MsgID uint16
359         ClientIndex uint32
360         Context     uint32
361 }
362
363 func (c *vppClient) SendMsg(context uint32, data []byte) error {
364         h := &reqHeader{
365                 ClientIndex: c.clientIndex,
366                 Context:     context,
367         }
368         buf := new(bytes.Buffer)
369         if err := struc.Pack(buf, h); err != nil {
370                 return err
371         }
372         copy(data[2:], buf.Bytes())
373
374         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
375
376         if err := c.write(data); err != nil {
377                 Log.Debugln("write error: ", err)
378                 return err
379         }
380
381         return nil
382 }
383
384 func (c *vppClient) write(msg []byte) error {
385         h := &msgheader{
386                 DataLen: uint32(len(msg)),
387         }
388         buf := new(bytes.Buffer)
389         if err := struc.Pack(buf, h); err != nil {
390                 return err
391         }
392         header := buf.Bytes()
393
394         // we lock to prevent mixing multiple message sends
395         c.writeMu.Lock()
396         defer c.writeMu.Unlock()
397
398         if n, err := c.writer.Write(header); err != nil {
399                 return err
400         } else {
401                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
402         }
403
404         if err := c.writer.Flush(); err != nil {
405                 return err
406         }
407
408         for i := 0; i <= len(msg)/c.writer.Size(); i++ {
409                 x := i*c.writer.Size() + c.writer.Size()
410                 if x > len(msg) {
411                         x = len(msg)
412                 }
413                 Log.Debugf("x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/c.writer.Size())
414                 if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
415                         return err
416                 } else {
417                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
418                 }
419                 if err := c.writer.Flush(); err != nil {
420                         return err
421                 }
422
423         }
424
425         return nil
426 }
427
428 type msgHeader struct {
429         MsgID   uint16
430         Context uint32
431 }
432
433 func (c *vppClient) readerLoop() {
434         defer c.wg.Done()
435         defer Log.Debugf("reader quit")
436         for {
437                 select {
438                 case <-c.quit:
439                         return
440                 default:
441                 }
442
443                 msg, err := c.read()
444                 if err != nil {
445                         if isClosedError(err) {
446                                 return
447                         }
448                         Log.Debugf("read failed: %v", err)
449                         continue
450                 }
451                 h := new(msgHeader)
452                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
453                         Log.Debugf("unpacking header failed: %v", err)
454                         continue
455                 }
456
457                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
458                 c.cb(h.MsgID, msg)
459         }
460 }
461
462 type msgheader struct {
463         Q               int    `struc:"uint64"`
464         DataLen         uint32 `struc:"uint32"`
465         GcMarkTimestamp uint32 `struc:"uint32"`
466 }
467
468 func (c *vppClient) read() ([]byte, error) {
469         Log.Debug("reading next msg..")
470
471         header := make([]byte, 16)
472
473         n, err := io.ReadAtLeast(c.reader, header, 16)
474         if err != nil {
475                 return nil, err
476         } else if n == 0 {
477                 Log.Debugln("zero bytes header")
478                 return nil, nil
479         }
480         if n != 16 {
481                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
482                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
483         }
484         Log.Debugf(" - read header %d bytes: % 0X", n, header)
485
486         h := &msgheader{}
487         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
488                 return nil, err
489         }
490         Log.Debugf(" - decoded header: %+v", h)
491
492         msgLen := int(h.DataLen)
493         msg := make([]byte, msgLen)
494
495         n, err = c.reader.Read(msg)
496         if err != nil {
497                 return nil, err
498         }
499         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
500
501         if msgLen > n {
502                 remain := msgLen - n
503                 Log.Debugf("continue read for another %d bytes", remain)
504                 view := msg[n:]
505
506                 for remain > 0 {
507                         nbytes, err := c.reader.Read(view)
508                         if err != nil {
509                                 return nil, err
510                         } else if nbytes == 0 {
511                                 return nil, fmt.Errorf("zero nbytes")
512                         }
513
514                         remain -= nbytes
515                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
516
517                         view = view[nbytes:]
518                 }
519         }
520
521         return msg, nil
522 }
523
524 func isClosedError(err error) bool {
525         if err == io.EOF {
526                 return true
527         }
528         return strings.HasSuffix(err.Error(), "use of closed network connection")
529 }