Update VPP config warnings and change log level in proxy
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "encoding/binary"
20         "errors"
21         "fmt"
22         "io"
23         "net"
24         "os"
25         "path/filepath"
26         "strings"
27         "sync"
28         "time"
29
30         "github.com/fsnotify/fsnotify"
31         "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/binapi/memclnt"
35         "git.fd.io/govpp.git/codec"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = "/run/vpp/api.sock"
41         // DefaultClientName is used for identifying client in socket registration
42         DefaultClientName = "govppsock"
43 )
44
45 var (
46
47         // DefaultConnectTimeout is default timeout for connecting
48         DefaultConnectTimeout = time.Second * 3
49         // DefaultDisconnectTimeout is default timeout for discconnecting
50         DefaultDisconnectTimeout = time.Millisecond * 100
51         // MaxWaitReady defines maximum duration of waiting for socket file
52         MaxWaitReady = time.Second * 10
53 )
54
55 var (
56         debug       = strings.Contains(os.Getenv("DEBUG_GOVPP"), "socketclient")
57         debugMsgIds = strings.Contains(os.Getenv("DEBUG_GOVPP"), "msgtable")
58
59         log logrus.FieldLogger
60 )
61
62 // SetLogger sets global logger.
63 func SetLogger(logger logrus.FieldLogger) {
64         log = logger
65 }
66
67 func init() {
68         logger := logrus.New()
69         if debug {
70                 logger.Level = logrus.DebugLevel
71                 logger.Debug("govpp: debug level enabled for socketclient")
72         }
73         log = logger.WithField("logger", "govpp/socketclient")
74 }
75
76 const socketMissing = `
77 ------------------------------------------------------------
78  No socket file found at: %s
79  VPP binary API socket file is missing!
80
81   - is VPP running with socket for binapi enabled?
82   - is the correct socket name configured?
83
84  To enable it add following section to your VPP config:
85    socksvr {
86      socket-name /run/vpp/api.sock
87    }
88 ------------------------------------------------------------
89 `
90
91 var warnOnce sync.Once
92
93 func (c *socketClient) printMissingSocketMsg() {
94         fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
95 }
96
97 type socketClient struct {
98         sockAddr   string
99         clientName string
100
101         conn   *net.UnixConn
102         reader *bufio.Reader
103         writer *bufio.Writer
104
105         connectTimeout    time.Duration
106         disconnectTimeout time.Duration
107
108         msgCallback  adapter.MsgCallback
109         clientIndex  uint32
110         msgTable     map[string]uint16
111         sockDelMsgId uint16
112         writeMu      sync.Mutex
113
114         headerPool *sync.Pool
115
116         quit chan struct{}
117         wg   sync.WaitGroup
118 }
119
120 func NewVppClient(sockAddr string) *socketClient {
121         if sockAddr == "" {
122                 sockAddr = DefaultSocketName
123         }
124         return &socketClient{
125                 sockAddr:          sockAddr,
126                 clientName:        DefaultClientName,
127                 connectTimeout:    DefaultConnectTimeout,
128                 disconnectTimeout: DefaultDisconnectTimeout,
129                 headerPool: &sync.Pool{New: func() interface{} {
130                         return make([]byte, 16)
131                 }},
132                 msgCallback: func(msgID uint16, data []byte) {
133                         log.Debugf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
134                 },
135         }
136 }
137
138 // SetClientName sets a client name used for identification.
139 func (c *socketClient) SetClientName(name string) {
140         c.clientName = name
141 }
142
143 // SetConnectTimeout sets timeout used during connecting.
144 func (c *socketClient) SetConnectTimeout(t time.Duration) {
145         c.connectTimeout = t
146 }
147
148 // SetDisconnectTimeout sets timeout used during disconnecting.
149 func (c *socketClient) SetDisconnectTimeout(t time.Duration) {
150         c.disconnectTimeout = t
151 }
152
153 func (c *socketClient) SetMsgCallback(cb adapter.MsgCallback) {
154         log.Debug("SetMsgCallback")
155         c.msgCallback = cb
156 }
157
158 const legacySocketName = "/run/vpp-api.sock"
159
160 func (c *socketClient) checkLegacySocket() bool {
161         if c.sockAddr == legacySocketName {
162                 return false
163         }
164         log.Debugf("checking legacy socket: %s", legacySocketName)
165         // check if socket exists
166         if _, err := os.Stat(c.sockAddr); err == nil {
167                 return false // socket exists
168         } else if !os.IsNotExist(err) {
169                 return false // some other error occurred
170         }
171         // check if legacy socket exists
172         if _, err := os.Stat(legacySocketName); err == nil {
173                 // legacy socket exists, update sockAddr
174                 c.sockAddr = legacySocketName
175                 return true
176         }
177         // no socket socket found
178         return false
179 }
180
181 // WaitReady checks socket file existence and waits for it if necessary
182 func (c *socketClient) WaitReady() error {
183         // check if socket already exists
184         if _, err := os.Stat(c.sockAddr); err == nil {
185                 return nil // socket exists, we are ready
186         } else if !os.IsNotExist(err) {
187                 return err // some other error occurred
188         }
189
190         if c.checkLegacySocket() {
191                 return nil
192         }
193
194         // socket does not exist, watch for it
195         watcher, err := fsnotify.NewWatcher()
196         if err != nil {
197                 return err
198         }
199         defer func() {
200                 if err := watcher.Close(); err != nil {
201                         log.Debugf("failed to close file watcher: %v", err)
202                 }
203         }()
204
205         // start directory watcher
206         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
207                 return err
208         }
209
210         timeout := time.NewTimer(MaxWaitReady)
211         for {
212                 select {
213                 case <-timeout.C:
214                         if c.checkLegacySocket() {
215                                 return nil
216                         }
217                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
218
219                 case e := <-watcher.Errors:
220                         return e
221
222                 case ev := <-watcher.Events:
223                         log.Debugf("watcher event: %+v", ev)
224                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
225                                 // socket created, we are ready
226                                 return nil
227                         }
228                 }
229         }
230 }
231
232 func (c *socketClient) Connect() error {
233         c.checkLegacySocket()
234
235         // check if socket exists
236         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
237                 warnOnce.Do(c.printMissingSocketMsg)
238                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
239         } else if err != nil {
240                 return fmt.Errorf("VPP API socket error: %v", err)
241         }
242
243         if err := c.connect(c.sockAddr); err != nil {
244                 return err
245         }
246
247         if err := c.open(); err != nil {
248                 _ = c.disconnect()
249                 return err
250         }
251
252         c.quit = make(chan struct{})
253         c.wg.Add(1)
254         go c.readerLoop()
255
256         return nil
257 }
258
259 func (c *socketClient) Disconnect() error {
260         if c.conn == nil {
261                 return nil
262         }
263         log.Debugf("Disconnecting..")
264
265         close(c.quit)
266
267         if err := c.conn.CloseRead(); err != nil {
268                 log.Debugf("closing readMsg failed: %v", err)
269         }
270
271         // wait for readerLoop to return
272         c.wg.Wait()
273
274         // Don't bother sending a vl_api_sockclnt_delete_t message,
275         // just close the socket.
276
277         if err := c.disconnect(); err != nil {
278                 return err
279         }
280
281         return nil
282 }
283
284 const defaultBufferSize = 4096
285
286 func (c *socketClient) connect(sockAddr string) error {
287         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
288
289         log.Debugf("Connecting to: %v", c.sockAddr)
290
291         conn, err := net.DialUnix("unix", nil, addr)
292         if err != nil {
293                 // we try different type of socket for backwards compatbility with VPP<=19.04
294                 if strings.Contains(err.Error(), "wrong type for socket") {
295                         addr.Net = "unixpacket"
296                         log.Debugf("%s, retrying connect with type unixpacket", err)
297                         conn, err = net.DialUnix("unixpacket", nil, addr)
298                 }
299                 if err != nil {
300                         log.Debugf("Connecting to socket %s failed: %s", addr, err)
301                         return err
302                 }
303         }
304
305         c.conn = conn
306         log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
307
308         c.reader = bufio.NewReaderSize(c.conn, defaultBufferSize)
309         c.writer = bufio.NewWriterSize(c.conn, defaultBufferSize)
310
311         return nil
312 }
313
314 func (c *socketClient) disconnect() error {
315         log.Debugf("Closing socket")
316         if err := c.conn.Close(); err != nil {
317                 log.Debugln("Closing socket failed:", err)
318                 return err
319         }
320         return nil
321 }
322
323 const (
324         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
325         createMsgContext = byte(123)
326         deleteMsgContext = byte(124)
327 )
328
329 func (c *socketClient) open() error {
330         var msgCodec = codec.DefaultCodec
331
332         // Request socket client create
333         req := &memclnt.SockclntCreate{
334                 Name: c.clientName,
335         }
336         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
337         if err != nil {
338                 log.Debugln("Encode  error:", err)
339                 return err
340         }
341         // set non-0 context
342         msg[5] = createMsgContext
343
344         if err := c.writeMsg(msg); err != nil {
345                 log.Debugln("Write error: ", err)
346                 return err
347         }
348         msgReply, err := c.readMsgTimeout(nil, c.connectTimeout)
349         if err != nil {
350                 log.Println("Read error:", err)
351                 return err
352         }
353
354         reply := new(memclnt.SockclntCreateReply)
355         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
356                 log.Println("Decoding sockclnt_create_reply failed:", err)
357                 return err
358         } else if reply.Response != 0 {
359                 return fmt.Errorf("sockclnt_create_reply: response error (%d)", reply.Response)
360         }
361
362         log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
363                 reply.Response, reply.Index, reply.Count)
364
365         c.clientIndex = reply.Index
366         c.msgTable = make(map[string]uint16, reply.Count)
367         for _, x := range reply.MessageTable {
368                 msgName := strings.Split(x.Name, "\x00")[0]
369                 name := strings.TrimSuffix(msgName, "\x13")
370                 c.msgTable[name] = x.Index
371                 if strings.HasPrefix(name, "sockclnt_delete_") {
372                         c.sockDelMsgId = x.Index
373                 }
374                 if debugMsgIds {
375                         log.Debugf(" - %4d: %q", x.Index, name)
376                 }
377         }
378
379         return nil
380 }
381
382 func (c *socketClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
383         if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok {
384                 return msgID, nil
385         }
386         return 0, &adapter.UnknownMsgError{
387                 MsgName: msgName,
388                 MsgCrc:  msgCrc,
389         }
390 }
391
392 func (c *socketClient) SendMsg(context uint32, data []byte) error {
393         if len(data) < 10 {
394                 return fmt.Errorf("invalid message data, length must be at least 10 bytes")
395         }
396         setMsgRequestHeader(data, c.clientIndex, context)
397
398         if debug {
399                 log.Debugf("sendMsg (%d) context=%v client=%d: % 02X", len(data), context, c.clientIndex, data)
400         }
401
402         if err := c.writeMsg(data); err != nil {
403                 log.Debugln("writeMsg error: ", err)
404                 return err
405         }
406
407         return nil
408 }
409
410 // setMsgRequestHeader sets client index and context in the message request header
411 //
412 // Message request has following structure:
413 //
414 //    type msgRequestHeader struct {
415 //        MsgID       uint16
416 //        ClientIndex uint32
417 //        Context     uint32
418 //    }
419 //
420 func setMsgRequestHeader(data []byte, clientIndex, context uint32) {
421         // message ID is already set
422         binary.BigEndian.PutUint32(data[2:6], clientIndex)
423         binary.BigEndian.PutUint32(data[6:10], context)
424 }
425
426 func (c *socketClient) writeMsg(msg []byte) error {
427         // we lock to prevent mixing multiple message writes
428         c.writeMu.Lock()
429         defer c.writeMu.Unlock()
430
431         header := c.headerPool.Get().([]byte)
432         err := writeMsgHeader(c.writer, header, len(msg))
433         if err != nil {
434                 return err
435         }
436         c.headerPool.Put(header)
437
438         if err := writeMsgData(c.writer, msg, c.writer.Size()); err != nil {
439                 return err
440         }
441
442         if err := c.writer.Flush(); err != nil {
443                 return err
444         }
445
446         log.Debugf(" -- writeMsg done")
447
448         return nil
449 }
450
451 func writeMsgHeader(w io.Writer, header []byte, dataLen int) error {
452         binary.BigEndian.PutUint32(header[8:12], uint32(dataLen))
453
454         n, err := w.Write(header)
455         if err != nil {
456                 return err
457         }
458         if debug {
459                 log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
460         }
461
462         return nil
463 }
464
465 func writeMsgData(w io.Writer, msg []byte, writerSize int) error {
466         for i := 0; i <= len(msg)/writerSize; i++ {
467                 x := i*writerSize + writerSize
468                 if x > len(msg) {
469                         x = len(msg)
470                 }
471                 if debug {
472                         log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
473                 }
474                 n, err := w.Write(msg[i*writerSize : x])
475                 if err != nil {
476                         return err
477                 }
478                 if debug {
479                         log.Debugf(" - data sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
480                 }
481         }
482         return nil
483 }
484
485 func (c *socketClient) readerLoop() {
486         defer c.wg.Done()
487         defer log.Debugf("reader loop done")
488
489         var buf [8192]byte
490
491         for {
492                 select {
493                 case <-c.quit:
494                         return
495                 default:
496                 }
497
498                 msg, err := c.readMsg(buf[:])
499                 if err != nil {
500                         if isClosedError(err) {
501                                 return
502                         }
503                         log.Debugf("readMsg error: %v", err)
504                         continue
505                 }
506
507                 msgID, context := getMsgReplyHeader(msg)
508                 if debug {
509                         log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), msgID, context)
510                 }
511
512                 c.msgCallback(msgID, msg)
513         }
514 }
515
516 // getMsgReplyHeader gets message ID and context from the message reply header
517 //
518 // Message reply has following structure:
519 //
520 //    type msgReplyHeader struct {
521 //        MsgID       uint16
522 //        Context     uint32
523 //    }
524 //
525 func getMsgReplyHeader(msg []byte) (msgID uint16, context uint32) {
526         msgID = binary.BigEndian.Uint16(msg[0:2])
527         context = binary.BigEndian.Uint32(msg[2:6])
528         return
529 }
530
531 func (c *socketClient) readMsgTimeout(buf []byte, timeout time.Duration) ([]byte, error) {
532         // set read deadline
533         readDeadline := time.Now().Add(timeout)
534         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
535                 return nil, err
536         }
537
538         // read message
539         msgReply, err := c.readMsg(buf)
540         if err != nil {
541                 return nil, err
542         }
543
544         // reset read deadline
545         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
546                 return nil, err
547         }
548
549         return msgReply, nil
550 }
551
552 func (c *socketClient) readMsg(buf []byte) ([]byte, error) {
553         log.Debug("reading msg..")
554
555         header := c.headerPool.Get().([]byte)
556         msgLen, err := readMsgHeader(c.reader, header)
557         if err != nil {
558                 return nil, err
559         }
560         c.headerPool.Put(header)
561
562         msg, err := readMsgData(c.reader, buf, msgLen)
563
564         log.Debugf(" -- readMsg done (buffered: %d)", c.reader.Buffered())
565
566         return msg, nil
567 }
568
569 func readMsgHeader(r io.Reader, header []byte) (int, error) {
570         n, err := io.ReadAtLeast(r, header, 16)
571         if err != nil {
572                 return 0, err
573         }
574         if n == 0 {
575                 log.Debugln("zero bytes header")
576                 return 0, nil
577         } else if n != 16 {
578                 log.Debugf("invalid header (%d bytes): % 0X", n, header[:n])
579                 return 0, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
580         }
581
582         dataLen := binary.BigEndian.Uint32(header[8:12])
583
584         return int(dataLen), nil
585 }
586
587 func readMsgData(r io.Reader, buf []byte, dataLen int) ([]byte, error) {
588         var msg []byte
589         if buf == nil || len(buf) < dataLen {
590                 msg = make([]byte, dataLen)
591         } else {
592                 msg = buf[0:dataLen]
593         }
594
595         n, err := r.Read(msg)
596         if err != nil {
597                 return nil, err
598         }
599         if debug {
600                 log.Debugf(" - read data (%d bytes): % 0X", n, msg[:n])
601         }
602
603         if dataLen > n {
604                 remain := dataLen - n
605                 log.Debugf("continue reading remaining %d bytes", remain)
606                 view := msg[n:]
607
608                 for remain > 0 {
609                         nbytes, err := r.Read(view)
610                         if err != nil {
611                                 return nil, err
612                         } else if nbytes == 0 {
613                                 return nil, fmt.Errorf("zero nbytes")
614                         }
615
616                         remain -= nbytes
617                         log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
618
619                         view = view[nbytes:]
620                 }
621         }
622
623         return msg, nil
624 }
625
626 func isClosedError(err error) bool {
627         if errors.Is(err, io.EOF) {
628                 return true
629         }
630         return strings.HasSuffix(err.Error(), "use of closed network connection")
631 }