867278b5c76daa50d238e7a9b1f4e387a9ea8f19
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "encoding/binary"
20         "errors"
21         "fmt"
22         "io"
23         "net"
24         "os"
25         "path/filepath"
26         "strings"
27         "sync"
28         "time"
29
30         "github.com/fsnotify/fsnotify"
31         "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35 )
36
37 const (
38         // DefaultSocketName is default VPP API socket file path.
39         DefaultSocketName = "/run/vpp/api.sock"
40         // DefaultClientName is used for identifying client in socket registration
41         DefaultClientName = "govppsock"
42 )
43
44 var (
45
46         // DefaultConnectTimeout is default timeout for connecting
47         DefaultConnectTimeout = time.Second * 3
48         // DefaultDisconnectTimeout is default timeout for discconnecting
49         DefaultDisconnectTimeout = time.Millisecond * 100
50         // MaxWaitReady defines maximum duration of waiting for socket file
51         MaxWaitReady = time.Second * 10
52 )
53
54 var (
55         debug       = strings.Contains(os.Getenv("DEBUG_GOVPP"), "socketclient")
56         debugMsgIds = strings.Contains(os.Getenv("DEBUG_GOVPP"), "msgtable")
57
58         logger = logrus.New()
59         log    = logger.WithField("logger", "govpp/socketclient")
60 )
61
62 // init initializes global logger
63 func init() {
64         if debug {
65                 logger.Level = logrus.DebugLevel
66                 log.Debug("govpp: debug level enabled for socketclient")
67         }
68 }
69
70 const socketMissing = `
71 ------------------------------------------------------------
72  No socket file found at: %s
73  VPP binary API socket file is missing!
74
75   - is VPP running with socket for binapi enabled?
76   - is the correct socket name configured?
77
78  To enable it add following section to your VPP config:
79    socksvr {
80      default
81    }
82 ------------------------------------------------------------
83 `
84
85 var warnOnce sync.Once
86
87 func (c *socketClient) printMissingSocketMsg() {
88         fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
89 }
90
91 type socketClient struct {
92         sockAddr   string
93         clientName string
94
95         conn   *net.UnixConn
96         reader *bufio.Reader
97         writer *bufio.Writer
98
99         connectTimeout    time.Duration
100         disconnectTimeout time.Duration
101
102         msgCallback  adapter.MsgCallback
103         clientIndex  uint32
104         msgTable     map[string]uint16
105         sockDelMsgId uint16
106         writeMu      sync.Mutex
107
108         headerPool *sync.Pool
109
110         quit chan struct{}
111         wg   sync.WaitGroup
112 }
113
114 func NewVppClient(sockAddr string) *socketClient {
115         if sockAddr == "" {
116                 sockAddr = DefaultSocketName
117         }
118         return &socketClient{
119                 sockAddr:          sockAddr,
120                 clientName:        DefaultClientName,
121                 connectTimeout:    DefaultConnectTimeout,
122                 disconnectTimeout: DefaultDisconnectTimeout,
123                 headerPool: &sync.Pool{New: func() interface{} {
124                         return make([]byte, 16)
125                 }},
126                 msgCallback: func(msgID uint16, data []byte) {
127                         log.Debugf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
128                 },
129         }
130 }
131
132 // SetClientName sets a client name used for identification.
133 func (c *socketClient) SetClientName(name string) {
134         c.clientName = name
135 }
136
137 // SetConnectTimeout sets timeout used during connecting.
138 func (c *socketClient) SetConnectTimeout(t time.Duration) {
139         c.connectTimeout = t
140 }
141
142 // SetDisconnectTimeout sets timeout used during disconnecting.
143 func (c *socketClient) SetDisconnectTimeout(t time.Duration) {
144         c.disconnectTimeout = t
145 }
146
147 func (c *socketClient) SetMsgCallback(cb adapter.MsgCallback) {
148         log.Debug("SetMsgCallback")
149         c.msgCallback = cb
150 }
151
152 const legacySocketName = "/run/vpp-api.sock"
153
154 func (c *socketClient) checkLegacySocket() bool {
155         if c.sockAddr == legacySocketName {
156                 return false
157         }
158         log.Debugf("checking legacy socket: %s", legacySocketName)
159         // check if socket exists
160         if _, err := os.Stat(c.sockAddr); err == nil {
161                 return false // socket exists
162         } else if !os.IsNotExist(err) {
163                 return false // some other error occurred
164         }
165         // check if legacy socket exists
166         if _, err := os.Stat(legacySocketName); err == nil {
167                 // legacy socket exists, update sockAddr
168                 c.sockAddr = legacySocketName
169                 return true
170         }
171         // no socket socket found
172         return false
173 }
174
175 // WaitReady checks socket file existence and waits for it if necessary
176 func (c *socketClient) WaitReady() error {
177         // check if socket already exists
178         if _, err := os.Stat(c.sockAddr); err == nil {
179                 return nil // socket exists, we are ready
180         } else if !os.IsNotExist(err) {
181                 return err // some other error occurred
182         }
183
184         if c.checkLegacySocket() {
185                 return nil
186         }
187
188         // socket does not exist, watch for it
189         watcher, err := fsnotify.NewWatcher()
190         if err != nil {
191                 return err
192         }
193         defer func() {
194                 if err := watcher.Close(); err != nil {
195                         log.Debugf("failed to close file watcher: %v", err)
196                 }
197         }()
198
199         // start directory watcher
200         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
201                 return err
202         }
203
204         timeout := time.NewTimer(MaxWaitReady)
205         for {
206                 select {
207                 case <-timeout.C:
208                         if c.checkLegacySocket() {
209                                 return nil
210                         }
211                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
212
213                 case e := <-watcher.Errors:
214                         return e
215
216                 case ev := <-watcher.Events:
217                         log.Debugf("watcher event: %+v", ev)
218                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
219                                 // socket created, we are ready
220                                 return nil
221                         }
222                 }
223         }
224 }
225
226 func (c *socketClient) Connect() error {
227         c.checkLegacySocket()
228
229         // check if socket exists
230         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
231                 warnOnce.Do(c.printMissingSocketMsg)
232                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
233         } else if err != nil {
234                 return fmt.Errorf("VPP API socket error: %v", err)
235         }
236
237         if err := c.connect(c.sockAddr); err != nil {
238                 return err
239         }
240
241         if err := c.open(); err != nil {
242                 _ = c.disconnect()
243                 return err
244         }
245
246         c.quit = make(chan struct{})
247         c.wg.Add(1)
248         go c.readerLoop()
249
250         return nil
251 }
252
253 func (c *socketClient) Disconnect() error {
254         if c.conn == nil {
255                 return nil
256         }
257         log.Debugf("Disconnecting..")
258
259         close(c.quit)
260
261         if err := c.conn.CloseRead(); err != nil {
262                 log.Debugf("closing readMsg failed: %v", err)
263         }
264
265         // wait for readerLoop to return
266         c.wg.Wait()
267
268         if err := c.close(); err != nil {
269                 log.Debugf("closing failed: %v", err)
270         }
271
272         if err := c.disconnect(); err != nil {
273                 return err
274         }
275
276         return nil
277 }
278
279 const defaultBufferSize = 4096
280
281 func (c *socketClient) connect(sockAddr string) error {
282         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
283
284         log.Debugf("Connecting to: %v", c.sockAddr)
285
286         conn, err := net.DialUnix("unix", nil, addr)
287         if err != nil {
288                 // we try different type of socket for backwards compatbility with VPP<=19.04
289                 if strings.Contains(err.Error(), "wrong type for socket") {
290                         addr.Net = "unixpacket"
291                         log.Debugf("%s, retrying connect with type unixpacket", err)
292                         conn, err = net.DialUnix("unixpacket", nil, addr)
293                 }
294                 if err != nil {
295                         log.Debugf("Connecting to socket %s failed: %s", addr, err)
296                         return err
297                 }
298         }
299
300         c.conn = conn
301         log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
302
303         c.reader = bufio.NewReaderSize(c.conn, defaultBufferSize)
304         c.writer = bufio.NewWriterSize(c.conn, defaultBufferSize)
305
306         return nil
307 }
308
309 func (c *socketClient) disconnect() error {
310         log.Debugf("Closing socket")
311         if err := c.conn.Close(); err != nil {
312                 log.Debugln("Closing socket failed:", err)
313                 return err
314         }
315         return nil
316 }
317
318 const (
319         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
320         createMsgContext = byte(123)
321         deleteMsgContext = byte(124)
322 )
323
324 func (c *socketClient) open() error {
325         var msgCodec = codec.DefaultCodec
326
327         // Request socket client create
328         req := &SockclntCreate{
329                 Name: c.clientName,
330         }
331         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
332         if err != nil {
333                 log.Debugln("Encode  error:", err)
334                 return err
335         }
336         // set non-0 context
337         msg[5] = createMsgContext
338
339         if err := c.writeMsg(msg); err != nil {
340                 log.Debugln("Write error: ", err)
341                 return err
342         }
343         msgReply, err := c.readMsgTimeout(nil, c.connectTimeout)
344         if err != nil {
345                 log.Println("Read error:", err)
346                 return err
347         }
348
349         reply := new(SockclntCreateReply)
350         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
351                 log.Println("Decoding sockclnt_create_reply failed:", err)
352                 return err
353         } else if reply.Response != 0 {
354                 return fmt.Errorf("sockclnt_create_reply: response error (%d)", reply.Response)
355         }
356
357         log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
358                 reply.Response, reply.Index, reply.Count)
359
360         c.clientIndex = reply.Index
361         c.msgTable = make(map[string]uint16, reply.Count)
362         for _, x := range reply.MessageTable {
363                 msgName := strings.Split(x.Name, "\x00")[0]
364                 name := strings.TrimSuffix(msgName, "\x13")
365                 c.msgTable[name] = x.Index
366                 if strings.HasPrefix(name, "sockclnt_delete_") {
367                         c.sockDelMsgId = x.Index
368                 }
369                 if debugMsgIds {
370                         log.Debugf(" - %4d: %q", x.Index, name)
371                 }
372         }
373
374         return nil
375 }
376
377 func (c *socketClient) close() error {
378         var msgCodec = codec.DefaultCodec
379
380         req := &SockclntDelete{
381                 Index: c.clientIndex,
382         }
383         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
384         if err != nil {
385                 log.Debugln("Encode error:", err)
386                 return err
387         }
388         // set non-0 context
389         msg[5] = deleteMsgContext
390
391         log.Debugf("sending socklntDel (%d bytes): % 0X", len(msg), msg)
392
393         if err := c.writeMsg(msg); err != nil {
394                 log.Debugln("Write error: ", err)
395                 return err
396         }
397
398         msgReply, err := c.readMsgTimeout(nil, c.disconnectTimeout)
399         if err != nil {
400                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
401                         // we accept timeout for reply
402                         return nil
403                 }
404                 log.Debugln("Read error:", err)
405                 return err
406         }
407
408         reply := new(SockclntDeleteReply)
409         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
410                 log.Debugln("Decoding sockclnt_delete_reply failed:", err)
411                 return err
412         } else if reply.Response != 0 {
413                 return fmt.Errorf("sockclnt_delete_reply: response error (%d)", reply.Response)
414         }
415
416         return nil
417 }
418
419 func (c *socketClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
420         if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok {
421                 return msgID, nil
422         }
423         return 0, &adapter.UnknownMsgError{
424                 MsgName: msgName,
425                 MsgCrc:  msgCrc,
426         }
427 }
428
429 func (c *socketClient) SendMsg(context uint32, data []byte) error {
430         if len(data) < 10 {
431                 return fmt.Errorf("invalid message data, length must be at least 10 bytes")
432         }
433         setMsgRequestHeader(data, c.clientIndex, context)
434
435         if debug {
436                 log.Debugf("sendMsg (%d) context=%v client=%d: % 02X", len(data), context, c.clientIndex, data)
437         }
438
439         if err := c.writeMsg(data); err != nil {
440                 log.Debugln("writeMsg error: ", err)
441                 return err
442         }
443
444         return nil
445 }
446
447 // setMsgRequestHeader sets client index and context in the message request header
448 //
449 // Message request has following structure:
450 //
451 //    type msgRequestHeader struct {
452 //        MsgID       uint16
453 //        ClientIndex uint32
454 //        Context     uint32
455 //    }
456 //
457 func setMsgRequestHeader(data []byte, clientIndex, context uint32) {
458         // message ID is already set
459         binary.BigEndian.PutUint32(data[2:6], clientIndex)
460         binary.BigEndian.PutUint32(data[6:10], context)
461 }
462
463 func (c *socketClient) writeMsg(msg []byte) error {
464         // we lock to prevent mixing multiple message writes
465         c.writeMu.Lock()
466         defer c.writeMu.Unlock()
467
468         header := c.headerPool.Get().([]byte)
469         err := writeMsgHeader(c.writer, header, len(msg))
470         if err != nil {
471                 return err
472         }
473         c.headerPool.Put(header)
474
475         if err := writeMsgData(c.writer, msg, c.writer.Size()); err != nil {
476                 return err
477         }
478
479         if err := c.writer.Flush(); err != nil {
480                 return err
481         }
482
483         log.Debugf(" -- writeMsg done")
484
485         return nil
486 }
487
488 func writeMsgHeader(w io.Writer, header []byte, dataLen int) error {
489         binary.BigEndian.PutUint32(header[8:12], uint32(dataLen))
490
491         n, err := w.Write(header)
492         if err != nil {
493                 return err
494         }
495         if debug {
496                 log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
497         }
498
499         return nil
500 }
501
502 func writeMsgData(w io.Writer, msg []byte, writerSize int) error {
503         for i := 0; i <= len(msg)/writerSize; i++ {
504                 x := i*writerSize + writerSize
505                 if x > len(msg) {
506                         x = len(msg)
507                 }
508                 if debug {
509                         log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
510                 }
511                 n, err := w.Write(msg[i*writerSize : x])
512                 if err != nil {
513                         return err
514                 }
515                 if debug {
516                         log.Debugf(" - data sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
517                 }
518         }
519         return nil
520 }
521
522 func (c *socketClient) readerLoop() {
523         defer c.wg.Done()
524         defer log.Debugf("reader loop done")
525
526         var buf [8192]byte
527
528         for {
529                 select {
530                 case <-c.quit:
531                         return
532                 default:
533                 }
534
535                 msg, err := c.readMsg(buf[:])
536                 if err != nil {
537                         if isClosedError(err) {
538                                 return
539                         }
540                         log.Debugf("readMsg error: %v", err)
541                         continue
542                 }
543
544                 msgID, context := getMsgReplyHeader(msg)
545                 if debug {
546                         log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), msgID, context)
547                 }
548
549                 c.msgCallback(msgID, msg)
550         }
551 }
552
553 // getMsgReplyHeader gets message ID and context from the message reply header
554 //
555 // Message reply has following structure:
556 //
557 //    type msgReplyHeader struct {
558 //        MsgID       uint16
559 //        Context     uint32
560 //    }
561 //
562 func getMsgReplyHeader(msg []byte) (msgID uint16, context uint32) {
563         msgID = binary.BigEndian.Uint16(msg[0:2])
564         context = binary.BigEndian.Uint32(msg[2:6])
565         return
566 }
567
568 func (c *socketClient) readMsgTimeout(buf []byte, timeout time.Duration) ([]byte, error) {
569         // set read deadline
570         readDeadline := time.Now().Add(timeout)
571         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
572                 return nil, err
573         }
574
575         // read message
576         msgReply, err := c.readMsg(buf)
577         if err != nil {
578                 return nil, err
579         }
580
581         // reset read deadline
582         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
583                 return nil, err
584         }
585
586         return msgReply, nil
587 }
588
589 func (c *socketClient) readMsg(buf []byte) ([]byte, error) {
590         log.Debug("reading msg..")
591
592         header := c.headerPool.Get().([]byte)
593         msgLen, err := readMsgHeader(c.reader, header)
594         if err != nil {
595                 return nil, err
596         }
597         c.headerPool.Put(header)
598
599         msg, err := readMsgData(c.reader, buf, msgLen)
600
601         log.Debugf(" -- readMsg done (buffered: %d)", c.reader.Buffered())
602
603         return msg, nil
604 }
605
606 func readMsgHeader(r io.Reader, header []byte) (int, error) {
607         n, err := io.ReadAtLeast(r, header, 16)
608         if err != nil {
609                 return 0, err
610         }
611         if n == 0 {
612                 log.Debugln("zero bytes header")
613                 return 0, nil
614         } else if n != 16 {
615                 log.Debugf("invalid header (%d bytes): % 0X", n, header[:n])
616                 return 0, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
617         }
618
619         dataLen := binary.BigEndian.Uint32(header[8:12])
620
621         return int(dataLen), nil
622 }
623
624 func readMsgData(r io.Reader, buf []byte, dataLen int) ([]byte, error) {
625         var msg []byte
626         if buf == nil || len(buf) < dataLen {
627                 msg = make([]byte, dataLen)
628         } else {
629                 msg = buf[0:dataLen]
630         }
631
632         n, err := r.Read(msg)
633         if err != nil {
634                 return nil, err
635         }
636         if debug {
637                 log.Debugf(" - read data (%d bytes): % 0X", n, msg[:n])
638         }
639
640         if dataLen > n {
641                 remain := dataLen - n
642                 log.Debugf("continue reading remaining %d bytes", remain)
643                 view := msg[n:]
644
645                 for remain > 0 {
646                         nbytes, err := r.Read(view)
647                         if err != nil {
648                                 return nil, err
649                         } else if nbytes == 0 {
650                                 return nil, fmt.Errorf("zero nbytes")
651                         }
652
653                         remain -= nbytes
654                         log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
655
656                         view = view[nbytes:]
657                 }
658         }
659
660         return msg, nil
661 }
662
663 func isClosedError(err error) bool {
664         if errors.Is(err, io.EOF) {
665                 return true
666         }
667         return strings.HasSuffix(err.Error(), "use of closed network connection")
668 }