4173d7e19f2b4b3bfe63a3dcf491797e49673bde
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "encoding/binary"
20         "errors"
21         "fmt"
22         "io"
23         "net"
24         "os"
25         "path/filepath"
26         "strings"
27         "sync"
28         "time"
29
30         "github.com/fsnotify/fsnotify"
31         "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/binapi/memclnt"
35         "git.fd.io/govpp.git/codec"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = "/run/vpp/api.sock"
41         // DefaultClientName is used for identifying client in socket registration
42         DefaultClientName = "govppsock"
43 )
44
45 var (
46
47         // DefaultConnectTimeout is default timeout for connecting
48         DefaultConnectTimeout = time.Second * 3
49         // DefaultDisconnectTimeout is default timeout for discconnecting
50         DefaultDisconnectTimeout = time.Millisecond * 100
51         // MaxWaitReady defines maximum duration of waiting for socket file
52         MaxWaitReady = time.Second * 10
53 )
54
55 var (
56         debug       = strings.Contains(os.Getenv("DEBUG_GOVPP"), "socketclient")
57         debugMsgIds = strings.Contains(os.Getenv("DEBUG_GOVPP"), "msgtable")
58
59         log logrus.FieldLogger
60 )
61
62 // SetLogger sets global logger.
63 func SetLogger(logger logrus.FieldLogger) {
64         log = logger
65 }
66
67 func init() {
68         logger := logrus.New()
69         if debug {
70                 logger.Level = logrus.DebugLevel
71                 logger.Debug("govpp: debug level enabled for socketclient")
72         }
73         log = logger.WithField("logger", "govpp/socketclient")
74 }
75
76 type Client struct {
77         socketPath string
78         clientName string
79
80         conn   *net.UnixConn
81         reader *bufio.Reader
82         writer *bufio.Writer
83
84         connectTimeout    time.Duration
85         disconnectTimeout time.Duration
86
87         msgCallback  adapter.MsgCallback
88         clientIndex  uint32
89         msgTable     map[string]uint16
90         sockDelMsgId uint16
91         writeMu      sync.Mutex
92
93         headerPool *sync.Pool
94
95         quit chan struct{}
96         wg   sync.WaitGroup
97 }
98
99 // NewVppClient returns a new Client using socket.
100 // If socket is empty string DefaultSocketName is used.
101 func NewVppClient(socket string) *Client {
102         if socket == "" {
103                 socket = DefaultSocketName
104         }
105         return &Client{
106                 socketPath:        socket,
107                 clientName:        DefaultClientName,
108                 connectTimeout:    DefaultConnectTimeout,
109                 disconnectTimeout: DefaultDisconnectTimeout,
110                 headerPool: &sync.Pool{New: func() interface{} {
111                         return make([]byte, 16)
112                 }},
113                 msgCallback: func(msgID uint16, data []byte) {
114                         log.Debugf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
115                 },
116         }
117 }
118
119 // SetClientName sets a client name used for identification.
120 func (c *Client) SetClientName(name string) {
121         c.clientName = name
122 }
123
124 // SetConnectTimeout sets timeout used during connecting.
125 func (c *Client) SetConnectTimeout(t time.Duration) {
126         c.connectTimeout = t
127 }
128
129 // SetDisconnectTimeout sets timeout used during disconnecting.
130 func (c *Client) SetDisconnectTimeout(t time.Duration) {
131         c.disconnectTimeout = t
132 }
133
134 func (c *Client) SetMsgCallback(cb adapter.MsgCallback) {
135         log.Debug("SetMsgCallback")
136         c.msgCallback = cb
137 }
138
139 // WaitReady checks socket file existence and waits for it if necessary
140 func (c *Client) WaitReady() error {
141         // check if socket already exists
142         if _, err := os.Stat(c.socketPath); err == nil {
143                 return nil // socket exists, we are ready
144         } else if !os.IsNotExist(err) {
145                 return err // some other error occurred
146         }
147
148         // socket does not exist, watch for it
149         watcher, err := fsnotify.NewWatcher()
150         if err != nil {
151                 return err
152         }
153         defer func() {
154                 if err := watcher.Close(); err != nil {
155                         log.Debugf("failed to close file watcher: %v", err)
156                 }
157         }()
158
159         // start directory watcher
160         if err := watcher.Add(filepath.Dir(c.socketPath)); err != nil {
161                 return err
162         }
163
164         timeout := time.NewTimer(MaxWaitReady)
165         for {
166                 select {
167                 case <-timeout.C:
168                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.socketPath)
169
170                 case e := <-watcher.Errors:
171                         return e
172
173                 case ev := <-watcher.Events:
174                         log.Debugf("watcher event: %+v", ev)
175                         if ev.Name == c.socketPath && (ev.Op&fsnotify.Create) == fsnotify.Create {
176                                 // socket created, we are ready
177                                 return nil
178                         }
179                 }
180         }
181 }
182
183 func (c *Client) Connect() error {
184         // check if socket exists
185         if _, err := os.Stat(c.socketPath); os.IsNotExist(err) {
186                 return fmt.Errorf("VPP API socket file %s does not exist", c.socketPath)
187         } else if err != nil {
188                 return fmt.Errorf("VPP API socket error: %v", err)
189         }
190
191         if err := c.connect(c.socketPath); err != nil {
192                 return err
193         }
194
195         if err := c.open(); err != nil {
196                 _ = c.disconnect()
197                 return err
198         }
199
200         c.quit = make(chan struct{})
201         c.wg.Add(1)
202         go c.readerLoop()
203
204         return nil
205 }
206
207 func (c *Client) Disconnect() error {
208         if c.conn == nil {
209                 return nil
210         }
211         log.Debugf("Disconnecting..")
212
213         close(c.quit)
214
215         if err := c.conn.CloseRead(); err != nil {
216                 log.Debugf("closing readMsg failed: %v", err)
217         }
218
219         // wait for readerLoop to return
220         c.wg.Wait()
221
222         // Don't bother sending a vl_api_sockclnt_delete_t message,
223         // just close the socket.
224         if err := c.disconnect(); err != nil {
225                 return err
226         }
227
228         return nil
229 }
230
231 const defaultBufferSize = 4096
232
233 func (c *Client) connect(sockAddr string) error {
234         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
235
236         log.Debugf("Connecting to: %v", c.socketPath)
237
238         conn, err := net.DialUnix("unix", nil, addr)
239         if err != nil {
240                 // we try different type of socket for backwards compatbility with VPP<=19.04
241                 if strings.Contains(err.Error(), "wrong type for socket") {
242                         addr.Net = "unixpacket"
243                         log.Debugf("%s, retrying connect with type unixpacket", err)
244                         conn, err = net.DialUnix("unixpacket", nil, addr)
245                 }
246                 if err != nil {
247                         log.Debugf("Connecting to socket %s failed: %s", addr, err)
248                         return err
249                 }
250         }
251
252         c.conn = conn
253         log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
254
255         c.reader = bufio.NewReaderSize(c.conn, defaultBufferSize)
256         c.writer = bufio.NewWriterSize(c.conn, defaultBufferSize)
257
258         return nil
259 }
260
261 func (c *Client) disconnect() error {
262         log.Debugf("Closing socket")
263         if err := c.conn.Close(); err != nil {
264                 log.Debugln("Closing socket failed:", err)
265                 return err
266         }
267         return nil
268 }
269
270 const (
271         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
272         createMsgContext = byte(123)
273         deleteMsgContext = byte(124)
274 )
275
276 func (c *Client) open() error {
277         var msgCodec = codec.DefaultCodec
278
279         // Request socket client create
280         req := &memclnt.SockclntCreate{
281                 Name: c.clientName,
282         }
283         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
284         if err != nil {
285                 log.Debugln("Encode  error:", err)
286                 return err
287         }
288         // set non-0 context
289         msg[5] = createMsgContext
290
291         if err := c.writeMsg(msg); err != nil {
292                 log.Debugln("Write error: ", err)
293                 return err
294         }
295         msgReply, err := c.readMsgTimeout(nil, c.connectTimeout)
296         if err != nil {
297                 log.Println("Read error:", err)
298                 return err
299         }
300
301         reply := new(memclnt.SockclntCreateReply)
302         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
303                 log.Println("Decoding sockclnt_create_reply failed:", err)
304                 return err
305         } else if reply.Response != 0 {
306                 return fmt.Errorf("sockclnt_create_reply: response error (%d)", reply.Response)
307         }
308
309         log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
310                 reply.Response, reply.Index, reply.Count)
311
312         c.clientIndex = reply.Index
313         c.msgTable = make(map[string]uint16, reply.Count)
314         for _, x := range reply.MessageTable {
315                 msgName := strings.Split(x.Name, "\x00")[0]
316                 name := strings.TrimSuffix(msgName, "\x13")
317                 c.msgTable[name] = x.Index
318                 if strings.HasPrefix(name, "sockclnt_delete_") {
319                         c.sockDelMsgId = x.Index
320                 }
321                 if debugMsgIds {
322                         log.Debugf(" - %4d: %q", x.Index, name)
323                 }
324         }
325
326         return nil
327 }
328
329 func (c *Client) GetMsgID(msgName string, msgCrc string) (uint16, error) {
330         if msgID, ok := c.msgTable[msgName+"_"+msgCrc]; ok {
331                 return msgID, nil
332         }
333         return 0, &adapter.UnknownMsgError{
334                 MsgName: msgName,
335                 MsgCrc:  msgCrc,
336         }
337 }
338
339 func (c *Client) SendMsg(context uint32, data []byte) error {
340         if len(data) < 10 {
341                 return fmt.Errorf("invalid message data, length must be at least 10 bytes")
342         }
343         setMsgRequestHeader(data, c.clientIndex, context)
344
345         if debug {
346                 log.Debugf("sendMsg (%d) context=%v client=%d: % 02X", len(data), context, c.clientIndex, data)
347         }
348
349         if err := c.writeMsg(data); err != nil {
350                 log.Debugln("writeMsg error: ", err)
351                 return err
352         }
353
354         return nil
355 }
356
357 // setMsgRequestHeader sets client index and context in the message request header
358 //
359 // Message request has following structure:
360 //
361 //    type msgRequestHeader struct {
362 //        MsgID       uint16
363 //        ClientIndex uint32
364 //        Context     uint32
365 //    }
366 //
367 func setMsgRequestHeader(data []byte, clientIndex, context uint32) {
368         // message ID is already set
369         binary.BigEndian.PutUint32(data[2:6], clientIndex)
370         binary.BigEndian.PutUint32(data[6:10], context)
371 }
372
373 func (c *Client) writeMsg(msg []byte) error {
374         // we lock to prevent mixing multiple message writes
375         c.writeMu.Lock()
376         defer c.writeMu.Unlock()
377
378         header := c.headerPool.Get().([]byte)
379         err := writeMsgHeader(c.writer, header, len(msg))
380         if err != nil {
381                 return err
382         }
383         c.headerPool.Put(header)
384
385         if err := writeMsgData(c.writer, msg, c.writer.Size()); err != nil {
386                 return err
387         }
388
389         if err := c.writer.Flush(); err != nil {
390                 return err
391         }
392
393         log.Debugf(" -- writeMsg done")
394
395         return nil
396 }
397
398 func writeMsgHeader(w io.Writer, header []byte, dataLen int) error {
399         binary.BigEndian.PutUint32(header[8:12], uint32(dataLen))
400
401         n, err := w.Write(header)
402         if err != nil {
403                 return err
404         }
405         if debug {
406                 log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
407         }
408
409         return nil
410 }
411
412 func writeMsgData(w io.Writer, msg []byte, writerSize int) error {
413         for i := 0; i <= len(msg)/writerSize; i++ {
414                 x := i*writerSize + writerSize
415                 if x > len(msg) {
416                         x = len(msg)
417                 }
418                 if debug {
419                         log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
420                 }
421                 n, err := w.Write(msg[i*writerSize : x])
422                 if err != nil {
423                         return err
424                 }
425                 if debug {
426                         log.Debugf(" - data sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
427                 }
428         }
429         return nil
430 }
431
432 func (c *Client) readerLoop() {
433         defer c.wg.Done()
434         defer log.Debugf("reader loop done")
435
436         var buf [8192]byte
437
438         for {
439                 select {
440                 case <-c.quit:
441                         return
442                 default:
443                 }
444
445                 msg, err := c.readMsg(buf[:])
446                 if err != nil {
447                         if isClosedError(err) {
448                                 return
449                         }
450                         log.Debugf("readMsg error: %v", err)
451                         continue
452                 }
453
454                 msgID, context := getMsgReplyHeader(msg)
455                 if debug {
456                         log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), msgID, context)
457                 }
458
459                 c.msgCallback(msgID, msg)
460         }
461 }
462
463 // getMsgReplyHeader gets message ID and context from the message reply header
464 //
465 // Message reply has following structure:
466 //
467 //    type msgReplyHeader struct {
468 //        MsgID       uint16
469 //        Context     uint32
470 //    }
471 //
472 func getMsgReplyHeader(msg []byte) (msgID uint16, context uint32) {
473         msgID = binary.BigEndian.Uint16(msg[0:2])
474         context = binary.BigEndian.Uint32(msg[2:6])
475         return
476 }
477
478 func (c *Client) readMsgTimeout(buf []byte, timeout time.Duration) ([]byte, error) {
479         // set read deadline
480         readDeadline := time.Now().Add(timeout)
481         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
482                 return nil, err
483         }
484
485         // read message
486         msgReply, err := c.readMsg(buf)
487         if err != nil {
488                 return nil, err
489         }
490
491         // reset read deadline
492         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
493                 return nil, err
494         }
495
496         return msgReply, nil
497 }
498
499 func (c *Client) readMsg(buf []byte) ([]byte, error) {
500         log.Debug("reading msg..")
501
502         header := c.headerPool.Get().([]byte)
503         msgLen, err := readMsgHeader(c.reader, header)
504         if err != nil {
505                 return nil, err
506         }
507         c.headerPool.Put(header)
508
509         msg, err := readMsgData(c.reader, buf, msgLen)
510
511         log.Debugf(" -- readMsg done (buffered: %d)", c.reader.Buffered())
512
513         return msg, nil
514 }
515
516 func readMsgHeader(r io.Reader, header []byte) (int, error) {
517         n, err := io.ReadAtLeast(r, header, 16)
518         if err != nil {
519                 return 0, err
520         }
521         if n == 0 {
522                 log.Debugln("zero bytes header")
523                 return 0, nil
524         } else if n != 16 {
525                 log.Debugf("invalid header (%d bytes): % 0X", n, header[:n])
526                 return 0, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
527         }
528
529         dataLen := binary.BigEndian.Uint32(header[8:12])
530
531         return int(dataLen), nil
532 }
533
534 func readMsgData(r io.Reader, buf []byte, dataLen int) ([]byte, error) {
535         var msg []byte
536         if buf == nil || len(buf) < dataLen {
537                 msg = make([]byte, dataLen)
538         } else {
539                 msg = buf[0:dataLen]
540         }
541
542         n, err := r.Read(msg)
543         if err != nil {
544                 return nil, err
545         }
546         if debug {
547                 log.Debugf(" - read data (%d bytes): % 0X", n, msg[:n])
548         }
549
550         if dataLen > n {
551                 remain := dataLen - n
552                 log.Debugf("continue reading remaining %d bytes", remain)
553                 view := msg[n:]
554
555                 for remain > 0 {
556                         nbytes, err := r.Read(view)
557                         if err != nil {
558                                 return nil, err
559                         } else if nbytes == 0 {
560                                 return nil, fmt.Errorf("zero nbytes")
561                         }
562
563                         remain -= nbytes
564                         log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
565
566                         view = view[nbytes:]
567                 }
568         }
569
570         return msg, nil
571 }
572
573 func isClosedError(err error) bool {
574         if errors.Is(err, io.EOF) {
575                 return true
576         }
577         return strings.HasSuffix(err.Error(), "use of closed network connection")
578 }