Fix: generate (un)marshall for memory client messages
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "encoding/binary"
20         "errors"
21         "fmt"
22         "git.fd.io/govpp.git/adapter/socketclient/binapi/memclnt"
23         "io"
24         "net"
25         "os"
26         "path/filepath"
27         "strings"
28         "sync"
29         "time"
30
31         "github.com/fsnotify/fsnotify"
32         "github.com/sirupsen/logrus"
33
34         "git.fd.io/govpp.git/adapter"
35         "git.fd.io/govpp.git/codec"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = "/run/vpp/api.sock"
41         // DefaultClientName is used for identifying client in socket registration
42         DefaultClientName = "govppsock"
43 )
44
45 var (
46
47         // DefaultConnectTimeout is default timeout for connecting
48         DefaultConnectTimeout = time.Second * 3
49         // DefaultDisconnectTimeout is default timeout for discconnecting
50         DefaultDisconnectTimeout = time.Millisecond * 100
51         // MaxWaitReady defines maximum duration of waiting for socket file
52         MaxWaitReady = time.Second * 10
53 )
54
55 var (
56         debug       = strings.Contains(os.Getenv("DEBUG_GOVPP"), "socketclient")
57         debugMsgIds = strings.Contains(os.Getenv("DEBUG_GOVPP"), "msgtable")
58
59         logger = logrus.New()
60         log    = logger.WithField("logger", "govpp/socketclient")
61 )
62
63 // init initializes global logger
64 func init() {
65         if debug {
66                 logger.Level = logrus.DebugLevel
67                 log.Debug("govpp: debug level enabled for socketclient")
68         }
69 }
70
71 const socketMissing = `
72 ------------------------------------------------------------
73  No socket file found at: %s
74  VPP binary API socket file is missing!
75
76   - is VPP running with socket for binapi enabled?
77   - is the correct socket name configured?
78
79  To enable it add following section to your VPP config:
80    socksvr {
81      default
82    }
83 ------------------------------------------------------------
84 `
85
86 var warnOnce sync.Once
87
88 func (c *socketClient) printMissingSocketMsg() {
89         fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
90 }
91
92 type socketClient struct {
93         sockAddr   string
94         clientName string
95
96         conn   *net.UnixConn
97         reader *bufio.Reader
98         writer *bufio.Writer
99
100         connectTimeout    time.Duration
101         disconnectTimeout time.Duration
102
103         msgCallback  adapter.MsgCallback
104         clientIndex  uint32
105         msgTable     map[string]uint16
106         sockDelMsgId uint16
107         writeMu      sync.Mutex
108
109         headerPool *sync.Pool
110
111         quit chan struct{}
112         wg   sync.WaitGroup
113 }
114
115 func NewVppClient(sockAddr string) *socketClient {
116         if sockAddr == "" {
117                 sockAddr = DefaultSocketName
118         }
119         return &socketClient{
120                 sockAddr:          sockAddr,
121                 clientName:        DefaultClientName,
122                 connectTimeout:    DefaultConnectTimeout,
123                 disconnectTimeout: DefaultDisconnectTimeout,
124                 headerPool: &sync.Pool{New: func() interface{} {
125                         return make([]byte, 16)
126                 }},
127                 msgCallback: func(msgID uint16, data []byte) {
128                         log.Debugf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
129                 },
130         }
131 }
132
133 // SetClientName sets a client name used for identification.
134 func (c *socketClient) SetClientName(name string) {
135         c.clientName = name
136 }
137
138 // SetConnectTimeout sets timeout used during connecting.
139 func (c *socketClient) SetConnectTimeout(t time.Duration) {
140         c.connectTimeout = t
141 }
142
143 // SetDisconnectTimeout sets timeout used during disconnecting.
144 func (c *socketClient) SetDisconnectTimeout(t time.Duration) {
145         c.disconnectTimeout = t
146 }
147
148 func (c *socketClient) SetMsgCallback(cb adapter.MsgCallback) {
149         log.Debug("SetMsgCallback")
150         c.msgCallback = cb
151 }
152
153 const legacySocketName = "/run/vpp-api.sock"
154
155 func (c *socketClient) checkLegacySocket() bool {
156         if c.sockAddr == legacySocketName {
157                 return false
158         }
159         log.Debugf("checking legacy socket: %s", legacySocketName)
160         // check if socket exists
161         if _, err := os.Stat(c.sockAddr); err == nil {
162                 return false // socket exists
163         } else if !os.IsNotExist(err) {
164                 return false // some other error occurred
165         }
166         // check if legacy socket exists
167         if _, err := os.Stat(legacySocketName); err == nil {
168                 // legacy socket exists, update sockAddr
169                 c.sockAddr = legacySocketName
170                 return true
171         }
172         // no socket socket found
173         return false
174 }
175
176 // WaitReady checks socket file existence and waits for it if necessary
177 func (c *socketClient) WaitReady() error {
178         // check if socket already exists
179         if _, err := os.Stat(c.sockAddr); err == nil {
180                 return nil // socket exists, we are ready
181         } else if !os.IsNotExist(err) {
182                 return err // some other error occurred
183         }
184
185         if c.checkLegacySocket() {
186                 return nil
187         }
188
189         // socket does not exist, watch for it
190         watcher, err := fsnotify.NewWatcher()
191         if err != nil {
192                 return err
193         }
194         defer func() {
195                 if err := watcher.Close(); err != nil {
196                         log.Debugf("failed to close file watcher: %v", err)
197                 }
198         }()
199
200         // start directory watcher
201         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
202                 return err
203         }
204
205         timeout := time.NewTimer(MaxWaitReady)
206         for {
207                 select {
208                 case <-timeout.C:
209                         if c.checkLegacySocket() {
210                                 return nil
211                         }
212                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
213
214                 case e := <-watcher.Errors:
215                         return e
216
217                 case ev := <-watcher.Events:
218                         log.Debugf("watcher event: %+v", ev)
219                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
220                                 // socket created, we are ready
221                                 return nil
222                         }
223                 }
224         }
225 }
226
227 func (c *socketClient) Connect() error {
228         c.checkLegacySocket()
229
230         // check if socket exists
231         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
232                 warnOnce.Do(c.printMissingSocketMsg)
233                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
234         } else if err != nil {
235                 return fmt.Errorf("VPP API socket error: %v", err)
236         }
237
238         if err := c.connect(c.sockAddr); err != nil {
239                 return err
240         }
241
242         if err := c.open(); err != nil {
243                 _ = c.disconnect()
244                 return err
245         }
246
247         c.quit = make(chan struct{})
248         c.wg.Add(1)
249         go c.readerLoop()
250
251         return nil
252 }
253
254 func (c *socketClient) Disconnect() error {
255         if c.conn == nil {
256                 return nil
257         }
258         log.Debugf("Disconnecting..")
259
260         close(c.quit)
261
262         if err := c.conn.CloseRead(); err != nil {
263                 log.Debugf("closing readMsg failed: %v", err)
264         }
265
266         // wait for readerLoop to return
267         c.wg.Wait()
268
269         if err := c.close(); err != nil {
270                 log.Debugf("closing failed: %v", err)
271         }
272
273         if err := c.disconnect(); err != nil {
274                 return err
275         }
276
277         return nil
278 }
279
280 const defaultBufferSize = 4096
281
282 func (c *socketClient) connect(sockAddr string) error {
283         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
284
285         log.Debugf("Connecting to: %v", c.sockAddr)
286
287         conn, err := net.DialUnix("unix", nil, addr)
288         if err != nil {
289                 // we try different type of socket for backwards compatbility with VPP<=19.04
290                 if strings.Contains(err.Error(), "wrong type for socket") {
291                         addr.Net = "unixpacket"
292                         log.Debugf("%s, retrying connect with type unixpacket", err)
293                         conn, err = net.DialUnix("unixpacket", nil, addr)
294                 }
295                 if err != nil {
296                         log.Debugf("Connecting to socket %s failed: %s", addr, err)
297                         return err
298                 }
299         }
300
301         c.conn = conn
302         log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
303
304         c.reader = bufio.NewReaderSize(c.conn, defaultBufferSize)
305         c.writer = bufio.NewWriterSize(c.conn, defaultBufferSize)
306
307         return nil
308 }
309
310 func (c *socketClient) disconnect() error {
311         log.Debugf("Closing socket")
312         if err := c.conn.Close(); err != nil {
313                 log.Debugln("Closing socket failed:", err)
314                 return err
315         }
316         return nil
317 }
318
319 const (
320         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
321         createMsgContext = byte(123)
322         deleteMsgContext = byte(124)
323 )
324
325 func (c *socketClient) open() error {
326         var msgCodec = codec.DefaultCodec
327
328         // Request socket client create
329         req := &memclnt.SockclntCreate{
330                 Name: c.clientName,
331         }
332         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
333         if err != nil {
334                 log.Debugln("Encode  error:", err)
335                 return err
336         }
337         // set non-0 context
338         msg[5] = createMsgContext
339
340         if err := c.writeMsg(msg); err != nil {
341                 log.Debugln("Write error: ", err)
342                 return err
343         }
344         msgReply, err := c.readMsgTimeout(nil, c.connectTimeout)
345         if err != nil {
346                 log.Println("Read error:", err)
347                 return err
348         }
349
350         reply := new(memclnt.SockclntCreateReply)
351         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
352                 log.Println("Decoding sockclnt_create_reply failed:", err)
353                 return err
354         } else if reply.Response != 0 {
355                 return fmt.Errorf("sockclnt_create_reply: response error (%d)", reply.Response)
356         }
357
358         log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
359                 reply.Response, reply.Index, reply.Count)
360
361         c.clientIndex = reply.Index
362         c.msgTable = make(map[string]uint16, reply.Count)
363         for _, x := range reply.MessageTable {
364                 msgName := strings.Split(x.Name, "\x00")[0]
365                 name := strings.TrimSuffix(msgName, "\x13")
366                 c.msgTable[name] = x.Index
367                 if strings.HasPrefix(name, "sockclnt_delete_") {
368                         c.sockDelMsgId = x.Index
369                 }
370                 if debugMsgIds {
371                         log.Debugf(" - %4d: %q", x.Index, name)
372                 }
373         }
374
375         return nil
376 }
377
378 func (c *socketClient) close() error {
379         var msgCodec = codec.DefaultCodec
380
381         req := &memclnt.SockclntDelete{
382                 Index: c.clientIndex,
383         }
384         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
385         if err != nil {
386                 log.Debugln("Encode error:", err)
387                 return err
388         }
389         // set non-0 context
390         msg[5] = deleteMsgContext
391
392         log.Debugf("sending socklntDel (%d bytes): % 0X", len(msg), msg)
393
394         if err := c.writeMsg(msg); err != nil {
395                 log.Debugln("Write error: ", err)
396                 return err
397         }
398
399         msgReply, err := c.readMsgTimeout(nil, c.disconnectTimeout)
400         if err != nil {
401                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
402                         // we accept timeout for reply
403                         return nil
404                 }
405                 log.Debugln("Read error:", err)
406                 return err
407         }
408
409         reply := new(memclnt.SockclntDeleteReply)
410         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
411                 log.Debugln("Decoding sockclnt_delete_reply failed:", err)
412                 return err
413         } else if reply.Response != 0 {
414                 return fmt.Errorf("sockclnt_delete_reply: response error (%d)", reply.Response)
415         }
416
417         return nil
418 }
419
420 func (c *socketClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
421         if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok {
422                 return msgID, nil
423         }
424         return 0, &adapter.UnknownMsgError{
425                 MsgName: msgName,
426                 MsgCrc:  msgCrc,
427         }
428 }
429
430 func (c *socketClient) SendMsg(context uint32, data []byte) error {
431         if len(data) < 10 {
432                 return fmt.Errorf("invalid message data, length must be at least 10 bytes")
433         }
434         setMsgRequestHeader(data, c.clientIndex, context)
435
436         if debug {
437                 log.Debugf("sendMsg (%d) context=%v client=%d: % 02X", len(data), context, c.clientIndex, data)
438         }
439
440         if err := c.writeMsg(data); err != nil {
441                 log.Debugln("writeMsg error: ", err)
442                 return err
443         }
444
445         return nil
446 }
447
448 // setMsgRequestHeader sets client index and context in the message request header
449 //
450 // Message request has following structure:
451 //
452 //    type msgRequestHeader struct {
453 //        MsgID       uint16
454 //        ClientIndex uint32
455 //        Context     uint32
456 //    }
457 //
458 func setMsgRequestHeader(data []byte, clientIndex, context uint32) {
459         // message ID is already set
460         binary.BigEndian.PutUint32(data[2:6], clientIndex)
461         binary.BigEndian.PutUint32(data[6:10], context)
462 }
463
464 func (c *socketClient) writeMsg(msg []byte) error {
465         // we lock to prevent mixing multiple message writes
466         c.writeMu.Lock()
467         defer c.writeMu.Unlock()
468
469         header := c.headerPool.Get().([]byte)
470         err := writeMsgHeader(c.writer, header, len(msg))
471         if err != nil {
472                 return err
473         }
474         c.headerPool.Put(header)
475
476         if err := writeMsgData(c.writer, msg, c.writer.Size()); err != nil {
477                 return err
478         }
479
480         if err := c.writer.Flush(); err != nil {
481                 return err
482         }
483
484         log.Debugf(" -- writeMsg done")
485
486         return nil
487 }
488
489 func writeMsgHeader(w io.Writer, header []byte, dataLen int) error {
490         binary.BigEndian.PutUint32(header[8:12], uint32(dataLen))
491
492         n, err := w.Write(header)
493         if err != nil {
494                 return err
495         }
496         if debug {
497                 log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
498         }
499
500         return nil
501 }
502
503 func writeMsgData(w io.Writer, msg []byte, writerSize int) error {
504         for i := 0; i <= len(msg)/writerSize; i++ {
505                 x := i*writerSize + writerSize
506                 if x > len(msg) {
507                         x = len(msg)
508                 }
509                 if debug {
510                         log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
511                 }
512                 n, err := w.Write(msg[i*writerSize : x])
513                 if err != nil {
514                         return err
515                 }
516                 if debug {
517                         log.Debugf(" - data sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
518                 }
519         }
520         return nil
521 }
522
523 func (c *socketClient) readerLoop() {
524         defer c.wg.Done()
525         defer log.Debugf("reader loop done")
526
527         var buf [8192]byte
528
529         for {
530                 select {
531                 case <-c.quit:
532                         return
533                 default:
534                 }
535
536                 msg, err := c.readMsg(buf[:])
537                 if err != nil {
538                         if isClosedError(err) {
539                                 return
540                         }
541                         log.Debugf("readMsg error: %v", err)
542                         continue
543                 }
544
545                 msgID, context := getMsgReplyHeader(msg)
546                 if debug {
547                         log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), msgID, context)
548                 }
549
550                 c.msgCallback(msgID, msg)
551         }
552 }
553
554 // getMsgReplyHeader gets message ID and context from the message reply header
555 //
556 // Message reply has following structure:
557 //
558 //    type msgReplyHeader struct {
559 //        MsgID       uint16
560 //        Context     uint32
561 //    }
562 //
563 func getMsgReplyHeader(msg []byte) (msgID uint16, context uint32) {
564         msgID = binary.BigEndian.Uint16(msg[0:2])
565         context = binary.BigEndian.Uint32(msg[2:6])
566         return
567 }
568
569 func (c *socketClient) readMsgTimeout(buf []byte, timeout time.Duration) ([]byte, error) {
570         // set read deadline
571         readDeadline := time.Now().Add(timeout)
572         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
573                 return nil, err
574         }
575
576         // read message
577         msgReply, err := c.readMsg(buf)
578         if err != nil {
579                 return nil, err
580         }
581
582         // reset read deadline
583         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
584                 return nil, err
585         }
586
587         return msgReply, nil
588 }
589
590 func (c *socketClient) readMsg(buf []byte) ([]byte, error) {
591         log.Debug("reading msg..")
592
593         header := c.headerPool.Get().([]byte)
594         msgLen, err := readMsgHeader(c.reader, header)
595         if err != nil {
596                 return nil, err
597         }
598         c.headerPool.Put(header)
599
600         msg, err := readMsgData(c.reader, buf, msgLen)
601
602         log.Debugf(" -- readMsg done (buffered: %d)", c.reader.Buffered())
603
604         return msg, nil
605 }
606
607 func readMsgHeader(r io.Reader, header []byte) (int, error) {
608         n, err := io.ReadAtLeast(r, header, 16)
609         if err != nil {
610                 return 0, err
611         }
612         if n == 0 {
613                 log.Debugln("zero bytes header")
614                 return 0, nil
615         } else if n != 16 {
616                 log.Debugf("invalid header (%d bytes): % 0X", n, header[:n])
617                 return 0, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
618         }
619
620         dataLen := binary.BigEndian.Uint32(header[8:12])
621
622         return int(dataLen), nil
623 }
624
625 func readMsgData(r io.Reader, buf []byte, dataLen int) ([]byte, error) {
626         var msg []byte
627         if buf == nil || len(buf) < dataLen {
628                 msg = make([]byte, dataLen)
629         } else {
630                 msg = buf[0:dataLen]
631         }
632
633         n, err := r.Read(msg)
634         if err != nil {
635                 return nil, err
636         }
637         if debug {
638                 log.Debugf(" - read data (%d bytes): % 0X", n, msg[:n])
639         }
640
641         if dataLen > n {
642                 remain := dataLen - n
643                 log.Debugf("continue reading remaining %d bytes", remain)
644                 view := msg[n:]
645
646                 for remain > 0 {
647                         nbytes, err := r.Read(view)
648                         if err != nil {
649                                 return nil, err
650                         } else if nbytes == 0 {
651                                 return nil, fmt.Errorf("zero nbytes")
652                         }
653
654                         remain -= nbytes
655                         log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
656
657                         view = view[nbytes:]
658                 }
659         }
660
661         return msg, nil
662 }
663
664 func isClosedError(err error) bool {
665         if errors.Is(err, io.EOF) {
666                 return true
667         }
668         return strings.HasSuffix(err.Error(), "use of closed network connection")
669 }