Use new default binapi socket with fallback to legacy
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35         "git.fd.io/govpp.git/examples/binapi/memclnt"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = adapter.DefaultBinapiSocket
41         legacySocketName  = "/run/vpp-api.sock"
42 )
43
44 var (
45         // DefaultConnectTimeout is default timeout for connecting
46         DefaultConnectTimeout = time.Second * 3
47         // DefaultDisconnectTimeout is default timeout for discconnecting
48         DefaultDisconnectTimeout = time.Millisecond * 100
49         // MaxWaitReady defines maximum duration before waiting for socket file
50         // times out
51         MaxWaitReady = time.Second * 10
52         // ClientName is used for identifying client in socket registration
53         ClientName = "govppsock"
54 )
55
56 var (
57         // Debug is global variable that determines debug mode
58         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
59         // DebugMsgIds is global variable that determines debug mode for msg ids
60         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
61
62         // Log is global logger
63         Log = logger.New()
64 )
65
66 // init initializes global logger, which logs debug level messages to stdout.
67 func init() {
68         Log.Out = os.Stdout
69         if Debug {
70                 Log.Level = logger.DebugLevel
71                 Log.Debug("govpp/socketclient: enabled debug mode")
72         }
73 }
74
75 const socketMissing = `
76 ------------------------------------------------------------
77  No socket file found at: %s
78  VPP binary API socket file is missing!
79
80   - is VPP running with socket for binapi enabled?
81   - is the correct socket name configured?
82
83  To enable it add following section to your VPP config:
84    socksvr {
85      default
86    }
87 ------------------------------------------------------------
88 `
89
90 var warnOnce sync.Once
91
92 func (c *vppClient) printMissingSocketMsg() {
93         fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
94 }
95
96 type vppClient struct {
97         sockAddr string
98
99         conn   *net.UnixConn
100         reader *bufio.Reader
101         writer *bufio.Writer
102
103         connectTimeout    time.Duration
104         disconnectTimeout time.Duration
105
106         cb           adapter.MsgCallback
107         clientIndex  uint32
108         msgTable     map[string]uint16
109         sockDelMsgId uint16
110         writeMu      sync.Mutex
111
112         quit chan struct{}
113         wg   sync.WaitGroup
114 }
115
116 func NewVppClient(sockAddr string) *vppClient {
117         if sockAddr == "" {
118                 sockAddr = DefaultSocketName
119         }
120         return &vppClient{
121                 sockAddr:          sockAddr,
122                 connectTimeout:    DefaultConnectTimeout,
123                 disconnectTimeout: DefaultDisconnectTimeout,
124                 cb: func(msgID uint16, data []byte) {
125                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
126                 },
127         }
128 }
129
130 // SetConnectTimeout sets timeout used during connecting.
131 func (c *vppClient) SetConnectTimeout(t time.Duration) {
132         c.connectTimeout = t
133 }
134
135 // SetDisconnectTimeout sets timeout used during disconnecting.
136 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
137         c.disconnectTimeout = t
138 }
139
140 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
141         Log.Debug("SetMsgCallback")
142         c.cb = cb
143 }
144
145 func (c *vppClient) checkLegacySocket() bool {
146         if c.sockAddr == legacySocketName {
147                 return false
148         }
149         Log.Debugf("checking legacy socket: %s", legacySocketName)
150         // check if socket exists
151         if _, err := os.Stat(c.sockAddr); err == nil {
152                 return false // socket exists
153         } else if !os.IsNotExist(err) {
154                 return false // some other error occurred
155         }
156         // check if legacy socket exists
157         if _, err := os.Stat(legacySocketName); err == nil {
158                 // legacy socket exists, update sockAddr
159                 c.sockAddr = legacySocketName
160                 return true
161         }
162         // no socket socket found
163         return false
164 }
165
166 // WaitReady checks socket file existence and waits for it if necessary
167 func (c *vppClient) WaitReady() error {
168         // check if socket already exists
169         if _, err := os.Stat(c.sockAddr); err == nil {
170                 return nil // socket exists, we are ready
171         } else if !os.IsNotExist(err) {
172                 return err // some other error occurred
173         }
174
175         if c.checkLegacySocket() {
176                 return nil
177         }
178
179         // socket does not exist, watch for it
180         watcher, err := fsnotify.NewWatcher()
181         if err != nil {
182                 return err
183         }
184         defer func() {
185                 if err := watcher.Close(); err != nil {
186                         Log.Warnf("failed to close file watcher: %v", err)
187                 }
188         }()
189
190         // start directory watcher
191         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
192                 return err
193         }
194
195         timeout := time.NewTimer(MaxWaitReady)
196         for {
197                 select {
198                 case <-timeout.C:
199                         if c.checkLegacySocket() {
200                                 return nil
201                         }
202                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
203
204                 case e := <-watcher.Errors:
205                         return e
206
207                 case ev := <-watcher.Events:
208                         Log.Debugf("watcher event: %+v", ev)
209                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
210                                 // socket created, we are ready
211                                 return nil
212                         }
213                 }
214         }
215 }
216
217 func (c *vppClient) Connect() error {
218         c.checkLegacySocket()
219
220         // check if socket exists
221         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
222                 warnOnce.Do(c.printMissingSocketMsg)
223                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
224         } else if err != nil {
225                 return fmt.Errorf("VPP API socket error: %v", err)
226         }
227
228         if err := c.connect(c.sockAddr); err != nil {
229                 return err
230         }
231
232         if err := c.open(); err != nil {
233                 c.disconnect()
234                 return err
235         }
236
237         c.quit = make(chan struct{})
238         c.wg.Add(1)
239         go c.readerLoop()
240
241         return nil
242 }
243
244 func (c *vppClient) Disconnect() error {
245         if c.conn == nil {
246                 return nil
247         }
248         Log.Debugf("Disconnecting..")
249
250         close(c.quit)
251
252         if err := c.conn.CloseRead(); err != nil {
253                 Log.Debugf("closing read failed: %v", err)
254         }
255
256         // wait for readerLoop to return
257         c.wg.Wait()
258
259         if err := c.close(); err != nil {
260                 Log.Debugf("closing failed: %v", err)
261         }
262
263         if err := c.disconnect(); err != nil {
264                 return err
265         }
266
267         return nil
268 }
269
270 func (c *vppClient) connect(sockAddr string) error {
271         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
272
273         Log.Debugf("Connecting to: %v", c.sockAddr)
274
275         conn, err := net.DialUnix("unix", nil, addr)
276         if err != nil {
277                 // we try different type of socket for backwards compatbility with VPP<=19.04
278                 if strings.Contains(err.Error(), "wrong type for socket") {
279                         addr.Net = "unixpacket"
280                         Log.Debugf("%s, retrying connect with type unixpacket", err)
281                         conn, err = net.DialUnix("unixpacket", nil, addr)
282                 }
283                 if err != nil {
284                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
285                         return err
286                 }
287         }
288
289         c.conn = conn
290         Log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
291
292         c.reader = bufio.NewReader(c.conn)
293         c.writer = bufio.NewWriter(c.conn)
294
295         return nil
296 }
297
298 func (c *vppClient) disconnect() error {
299         Log.Debugf("Closing socket")
300         if err := c.conn.Close(); err != nil {
301                 Log.Debugln("Closing socket failed:", err)
302                 return err
303         }
304         return nil
305 }
306
307 const (
308         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
309         createMsgContext = byte(123)
310         deleteMsgContext = byte(124)
311 )
312
313 func (c *vppClient) open() error {
314         msgCodec := new(codec.MsgCodec)
315
316         req := &memclnt.SockclntCreate{
317                 Name: []byte(ClientName),
318         }
319         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
320         if err != nil {
321                 Log.Debugln("Encode error:", err)
322                 return err
323         }
324         // set non-0 context
325         msg[5] = createMsgContext
326
327         if err := c.write(msg); err != nil {
328                 Log.Debugln("Write error: ", err)
329                 return err
330         }
331
332         readDeadline := time.Now().Add(c.connectTimeout)
333         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
334                 return err
335         }
336         msgReply, err := c.read()
337         if err != nil {
338                 Log.Println("Read error:", err)
339                 return err
340         }
341         // reset read deadline
342         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
343                 return err
344         }
345
346         reply := new(memclnt.SockclntCreateReply)
347         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
348                 Log.Println("Decode error:", err)
349                 return err
350         }
351
352         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
353                 reply.Response, reply.Index, reply.Count)
354
355         c.clientIndex = reply.Index
356         c.msgTable = make(map[string]uint16, reply.Count)
357         for _, x := range reply.MessageTable {
358                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
359                 c.msgTable[name] = x.Index
360                 if strings.HasPrefix(name, "sockclnt_delete_") {
361                         c.sockDelMsgId = x.Index
362                 }
363                 if DebugMsgIds {
364                         Log.Debugf(" - %4d: %q", x.Index, name)
365                 }
366         }
367
368         return nil
369 }
370
371 func (c *vppClient) close() error {
372         msgCodec := new(codec.MsgCodec)
373
374         req := &memclnt.SockclntDelete{
375                 Index: c.clientIndex,
376         }
377         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
378         if err != nil {
379                 Log.Debugln("Encode error:", err)
380                 return err
381         }
382         // set non-0 context
383         msg[5] = deleteMsgContext
384
385         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
386         if err := c.write(msg); err != nil {
387                 Log.Debugln("Write error: ", err)
388                 return err
389         }
390
391         readDeadline := time.Now().Add(c.disconnectTimeout)
392         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
393                 return err
394         }
395         msgReply, err := c.read()
396         if err != nil {
397                 Log.Debugln("Read error:", err)
398                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
399                         // we accept timeout for reply
400                         return nil
401                 }
402                 return err
403         }
404         // reset read deadline
405         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
406                 return err
407         }
408
409         reply := new(memclnt.SockclntDeleteReply)
410         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
411                 Log.Debugln("Decode error:", err)
412                 return err
413         }
414
415         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
416
417         return nil
418 }
419
420 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
421         msg := msgName + "_" + msgCrc
422         msgID, ok := c.msgTable[msg]
423         if !ok {
424                 return 0, fmt.Errorf("unknown message: %q", msg)
425         }
426         return msgID, nil
427 }
428
429 type reqHeader struct {
430         // MsgID uint16
431         ClientIndex uint32
432         Context     uint32
433 }
434
435 func (c *vppClient) SendMsg(context uint32, data []byte) error {
436         h := &reqHeader{
437                 ClientIndex: c.clientIndex,
438                 Context:     context,
439         }
440         buf := new(bytes.Buffer)
441         if err := struc.Pack(buf, h); err != nil {
442                 return err
443         }
444         copy(data[2:], buf.Bytes())
445
446         Log.Debugf("sendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
447
448         if err := c.write(data); err != nil {
449                 Log.Debugln("write error: ", err)
450                 return err
451         }
452
453         return nil
454 }
455
456 func (c *vppClient) write(msg []byte) error {
457         h := &msgheader{
458                 DataLen: uint32(len(msg)),
459         }
460         buf := new(bytes.Buffer)
461         if err := struc.Pack(buf, h); err != nil {
462                 return err
463         }
464         header := buf.Bytes()
465
466         // we lock to prevent mixing multiple message sends
467         c.writeMu.Lock()
468         defer c.writeMu.Unlock()
469
470         if n, err := c.writer.Write(header); err != nil {
471                 return err
472         } else {
473                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
474         }
475
476         writerSize := c.writer.Size()
477         for i := 0; i <= len(msg)/writerSize; i++ {
478                 x := i*writerSize + writerSize
479                 if x > len(msg) {
480                         x = len(msg)
481                 }
482                 Log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
483                 if n, err := c.writer.Write(msg[i*writerSize : x]); err != nil {
484                         return err
485                 } else {
486                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
487                 }
488         }
489         if err := c.writer.Flush(); err != nil {
490                 return err
491         }
492
493         Log.Debugf(" -- write done")
494
495         return nil
496 }
497
498 type msgHeader struct {
499         MsgID   uint16
500         Context uint32
501 }
502
503 func (c *vppClient) readerLoop() {
504         defer c.wg.Done()
505         defer Log.Debugf("reader quit")
506
507         for {
508                 select {
509                 case <-c.quit:
510                         return
511                 default:
512                 }
513
514                 msg, err := c.read()
515                 if err != nil {
516                         if isClosedError(err) {
517                                 return
518                         }
519                         Log.Debugf("read failed: %v", err)
520                         continue
521                 }
522
523                 h := new(msgHeader)
524                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
525                         Log.Debugf("unpacking header failed: %v", err)
526                         continue
527                 }
528
529                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
530                 c.cb(h.MsgID, msg)
531         }
532 }
533
534 type msgheader struct {
535         Q               int    `struc:"uint64"`
536         DataLen         uint32 `struc:"uint32"`
537         GcMarkTimestamp uint32 `struc:"uint32"`
538 }
539
540 func (c *vppClient) read() ([]byte, error) {
541         Log.Debug(" reading next msg..")
542
543         header := make([]byte, 16)
544
545         n, err := io.ReadAtLeast(c.reader, header, 16)
546         if err != nil {
547                 return nil, err
548         }
549         if n == 0 {
550                 Log.Debugln("zero bytes header")
551                 return nil, nil
552         } else if n != 16 {
553                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
554                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
555         }
556         Log.Debugf(" read header %d bytes: % 0X", n, header)
557
558         h := &msgheader{}
559         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
560                 return nil, err
561         }
562         Log.Debugf(" - decoded header: %+v", h)
563
564         msgLen := int(h.DataLen)
565         msg := make([]byte, msgLen)
566
567         n, err = c.reader.Read(msg)
568         if err != nil {
569                 return nil, err
570         }
571         Log.Debugf(" - read msg %d bytes (%d buffered) % 0X", n, c.reader.Buffered(), msg[:n])
572
573         if msgLen > n {
574                 remain := msgLen - n
575                 Log.Debugf("continue read for another %d bytes", remain)
576                 view := msg[n:]
577
578                 for remain > 0 {
579                         nbytes, err := c.reader.Read(view)
580                         if err != nil {
581                                 return nil, err
582                         } else if nbytes == 0 {
583                                 return nil, fmt.Errorf("zero nbytes")
584                         }
585
586                         remain -= nbytes
587                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
588
589                         view = view[nbytes:]
590                 }
591         }
592
593         Log.Debugf(" -- read done (buffered: %d)", c.reader.Buffered())
594
595         return msg, nil
596 }
597
598 func isClosedError(err error) bool {
599         if err == io.EOF {
600                 return true
601         }
602         return strings.HasSuffix(err.Error(), "use of closed network connection")
603 }