043d253ed13094db5d327fcf83930a2ed0a3e1a5
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35 )
36
37 const (
38         // DefaultSocketName is default VPP API socket file path.
39         DefaultSocketName = adapter.DefaultBinapiSocket
40         legacySocketName  = "/run/vpp-api.sock"
41 )
42
43 var (
44         // DefaultConnectTimeout is default timeout for connecting
45         DefaultConnectTimeout = time.Second * 3
46         // DefaultDisconnectTimeout is default timeout for discconnecting
47         DefaultDisconnectTimeout = time.Millisecond * 100
48         // MaxWaitReady defines maximum duration before waiting for socket file
49         // times out
50         MaxWaitReady = time.Second * 10
51         // ClientName is used for identifying client in socket registration
52         ClientName = "govppsock"
53 )
54
55 var (
56         // Debug is global variable that determines debug mode
57         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
58         // DebugMsgIds is global variable that determines debug mode for msg ids
59         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
60
61         // Log is global logger
62         Log = logger.New()
63 )
64
65 // init initializes global logger, which logs debug level messages to stdout.
66 func init() {
67         Log.Out = os.Stdout
68         if Debug {
69                 Log.Level = logger.DebugLevel
70                 Log.Debug("govpp/socketclient: enabled debug mode")
71         }
72 }
73
74 const socketMissing = `
75 ------------------------------------------------------------
76  No socket file found at: %s
77  VPP binary API socket file is missing!
78
79   - is VPP running with socket for binapi enabled?
80   - is the correct socket name configured?
81
82  To enable it add following section to your VPP config:
83    socksvr {
84      default
85    }
86 ------------------------------------------------------------
87 `
88
89 var warnOnce sync.Once
90
91 func (c *vppClient) printMissingSocketMsg() {
92         fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
93 }
94
95 type vppClient struct {
96         sockAddr string
97
98         conn   *net.UnixConn
99         reader *bufio.Reader
100         writer *bufio.Writer
101
102         connectTimeout    time.Duration
103         disconnectTimeout time.Duration
104
105         cb           adapter.MsgCallback
106         clientIndex  uint32
107         msgTable     map[string]uint16
108         sockDelMsgId uint16
109         writeMu      sync.Mutex
110
111         quit chan struct{}
112         wg   sync.WaitGroup
113 }
114
115 func NewVppClient(sockAddr string) *vppClient {
116         if sockAddr == "" {
117                 sockAddr = DefaultSocketName
118         }
119         return &vppClient{
120                 sockAddr:          sockAddr,
121                 connectTimeout:    DefaultConnectTimeout,
122                 disconnectTimeout: DefaultDisconnectTimeout,
123                 cb: func(msgID uint16, data []byte) {
124                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
125                 },
126         }
127 }
128
129 // SetConnectTimeout sets timeout used during connecting.
130 func (c *vppClient) SetConnectTimeout(t time.Duration) {
131         c.connectTimeout = t
132 }
133
134 // SetDisconnectTimeout sets timeout used during disconnecting.
135 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
136         c.disconnectTimeout = t
137 }
138
139 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
140         Log.Debug("SetMsgCallback")
141         c.cb = cb
142 }
143
144 func (c *vppClient) checkLegacySocket() bool {
145         if c.sockAddr == legacySocketName {
146                 return false
147         }
148         Log.Debugf("checking legacy socket: %s", legacySocketName)
149         // check if socket exists
150         if _, err := os.Stat(c.sockAddr); err == nil {
151                 return false // socket exists
152         } else if !os.IsNotExist(err) {
153                 return false // some other error occurred
154         }
155         // check if legacy socket exists
156         if _, err := os.Stat(legacySocketName); err == nil {
157                 // legacy socket exists, update sockAddr
158                 c.sockAddr = legacySocketName
159                 return true
160         }
161         // no socket socket found
162         return false
163 }
164
165 // WaitReady checks socket file existence and waits for it if necessary
166 func (c *vppClient) WaitReady() error {
167         // check if socket already exists
168         if _, err := os.Stat(c.sockAddr); err == nil {
169                 return nil // socket exists, we are ready
170         } else if !os.IsNotExist(err) {
171                 return err // some other error occurred
172         }
173
174         if c.checkLegacySocket() {
175                 return nil
176         }
177
178         // socket does not exist, watch for it
179         watcher, err := fsnotify.NewWatcher()
180         if err != nil {
181                 return err
182         }
183         defer func() {
184                 if err := watcher.Close(); err != nil {
185                         Log.Warnf("failed to close file watcher: %v", err)
186                 }
187         }()
188
189         // start directory watcher
190         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
191                 return err
192         }
193
194         timeout := time.NewTimer(MaxWaitReady)
195         for {
196                 select {
197                 case <-timeout.C:
198                         if c.checkLegacySocket() {
199                                 return nil
200                         }
201                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
202
203                 case e := <-watcher.Errors:
204                         return e
205
206                 case ev := <-watcher.Events:
207                         Log.Debugf("watcher event: %+v", ev)
208                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
209                                 // socket created, we are ready
210                                 return nil
211                         }
212                 }
213         }
214 }
215
216 func (c *vppClient) Connect() error {
217         c.checkLegacySocket()
218
219         // check if socket exists
220         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
221                 warnOnce.Do(c.printMissingSocketMsg)
222                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
223         } else if err != nil {
224                 return fmt.Errorf("VPP API socket error: %v", err)
225         }
226
227         if err := c.connect(c.sockAddr); err != nil {
228                 return err
229         }
230
231         if err := c.open(); err != nil {
232                 c.disconnect()
233                 return err
234         }
235
236         c.quit = make(chan struct{})
237         c.wg.Add(1)
238         go c.readerLoop()
239
240         return nil
241 }
242
243 func (c *vppClient) Disconnect() error {
244         if c.conn == nil {
245                 return nil
246         }
247         Log.Debugf("Disconnecting..")
248
249         close(c.quit)
250
251         if err := c.conn.CloseRead(); err != nil {
252                 Log.Debugf("closing read failed: %v", err)
253         }
254
255         // wait for readerLoop to return
256         c.wg.Wait()
257
258         if err := c.close(); err != nil {
259                 Log.Debugf("closing failed: %v", err)
260         }
261
262         if err := c.disconnect(); err != nil {
263                 return err
264         }
265
266         return nil
267 }
268
269 func (c *vppClient) connect(sockAddr string) error {
270         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
271
272         Log.Debugf("Connecting to: %v", c.sockAddr)
273
274         conn, err := net.DialUnix("unix", nil, addr)
275         if err != nil {
276                 // we try different type of socket for backwards compatbility with VPP<=19.04
277                 if strings.Contains(err.Error(), "wrong type for socket") {
278                         addr.Net = "unixpacket"
279                         Log.Debugf("%s, retrying connect with type unixpacket", err)
280                         conn, err = net.DialUnix("unixpacket", nil, addr)
281                 }
282                 if err != nil {
283                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
284                         return err
285                 }
286         }
287
288         c.conn = conn
289         Log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
290
291         c.reader = bufio.NewReader(c.conn)
292         c.writer = bufio.NewWriter(c.conn)
293
294         return nil
295 }
296
297 func (c *vppClient) disconnect() error {
298         Log.Debugf("Closing socket")
299         if err := c.conn.Close(); err != nil {
300                 Log.Debugln("Closing socket failed:", err)
301                 return err
302         }
303         return nil
304 }
305
306 const (
307         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
308         createMsgContext = byte(123)
309         deleteMsgContext = byte(124)
310 )
311
312 func (c *vppClient) open() error {
313         msgCodec := new(codec.MsgCodec)
314
315         req := &SockclntCreate{Name: ClientName}
316         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
317         if err != nil {
318                 Log.Debugln("Encode error:", err)
319                 return err
320         }
321         // set non-0 context
322         msg[5] = createMsgContext
323
324         if err := c.write(msg); err != nil {
325                 Log.Debugln("Write error: ", err)
326                 return err
327         }
328
329         readDeadline := time.Now().Add(c.connectTimeout)
330         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
331                 return err
332         }
333         msgReply, err := c.read()
334         if err != nil {
335                 Log.Println("Read error:", err)
336                 return err
337         }
338         // reset read deadline
339         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
340                 return err
341         }
342
343         reply := new(SockclntCreateReply)
344         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
345                 Log.Println("Decode error:", err)
346                 return err
347         }
348
349         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
350                 reply.Response, reply.Index, reply.Count)
351
352         c.clientIndex = reply.Index
353         c.msgTable = make(map[string]uint16, reply.Count)
354         for _, x := range reply.MessageTable {
355                 msgName := strings.Split(x.Name, "\x00")[0]
356                 name := strings.TrimSuffix(msgName, "\x13")
357                 c.msgTable[name] = x.Index
358                 if strings.HasPrefix(name, "sockclnt_delete_") {
359                         c.sockDelMsgId = x.Index
360                 }
361                 if DebugMsgIds {
362                         Log.Debugf(" - %4d: %q", x.Index, name)
363                 }
364         }
365
366         return nil
367 }
368
369 func (c *vppClient) close() error {
370         msgCodec := new(codec.MsgCodec)
371
372         req := &SockclntDelete{
373                 Index: c.clientIndex,
374         }
375         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
376         if err != nil {
377                 Log.Debugln("Encode error:", err)
378                 return err
379         }
380         // set non-0 context
381         msg[5] = deleteMsgContext
382
383         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
384         if err := c.write(msg); err != nil {
385                 Log.Debugln("Write error: ", err)
386                 return err
387         }
388
389         readDeadline := time.Now().Add(c.disconnectTimeout)
390         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
391                 return err
392         }
393         msgReply, err := c.read()
394         if err != nil {
395                 Log.Debugln("Read error:", err)
396                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
397                         // we accept timeout for reply
398                         return nil
399                 }
400                 return err
401         }
402         // reset read deadline
403         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
404                 return err
405         }
406
407         reply := new(SockclntDeleteReply)
408         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
409                 Log.Debugln("Decode error:", err)
410                 return err
411         }
412
413         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
414
415         return nil
416 }
417
418 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
419         msg := msgName + "_" + msgCrc
420         msgID, ok := c.msgTable[msg]
421         if !ok {
422                 return 0, fmt.Errorf("unknown message: %q", msg)
423         }
424         return msgID, nil
425 }
426
427 type reqHeader struct {
428         // MsgID uint16
429         ClientIndex uint32
430         Context     uint32
431 }
432
433 func (c *vppClient) SendMsg(context uint32, data []byte) error {
434         h := &reqHeader{
435                 ClientIndex: c.clientIndex,
436                 Context:     context,
437         }
438         buf := new(bytes.Buffer)
439         if err := struc.Pack(buf, h); err != nil {
440                 return err
441         }
442         copy(data[2:], buf.Bytes())
443
444         Log.Debugf("sendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
445
446         if err := c.write(data); err != nil {
447                 Log.Debugln("write error: ", err)
448                 return err
449         }
450
451         return nil
452 }
453
454 func (c *vppClient) write(msg []byte) error {
455         h := &msgheader{
456                 DataLen: uint32(len(msg)),
457         }
458         buf := new(bytes.Buffer)
459         if err := struc.Pack(buf, h); err != nil {
460                 return err
461         }
462         header := buf.Bytes()
463
464         // we lock to prevent mixing multiple message sends
465         c.writeMu.Lock()
466         defer c.writeMu.Unlock()
467
468         if n, err := c.writer.Write(header); err != nil {
469                 return err
470         } else {
471                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
472         }
473
474         writerSize := c.writer.Size()
475         for i := 0; i <= len(msg)/writerSize; i++ {
476                 x := i*writerSize + writerSize
477                 if x > len(msg) {
478                         x = len(msg)
479                 }
480                 Log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
481                 if n, err := c.writer.Write(msg[i*writerSize : x]); err != nil {
482                         return err
483                 } else {
484                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
485                 }
486         }
487         if err := c.writer.Flush(); err != nil {
488                 return err
489         }
490
491         Log.Debugf(" -- write done")
492
493         return nil
494 }
495
496 type msgHeader struct {
497         MsgID   uint16
498         Context uint32
499 }
500
501 func (c *vppClient) readerLoop() {
502         defer c.wg.Done()
503         defer Log.Debugf("reader quit")
504
505         for {
506                 select {
507                 case <-c.quit:
508                         return
509                 default:
510                 }
511
512                 msg, err := c.read()
513                 if err != nil {
514                         if isClosedError(err) {
515                                 return
516                         }
517                         Log.Debugf("read failed: %v", err)
518                         continue
519                 }
520
521                 h := new(msgHeader)
522                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
523                         Log.Debugf("unpacking header failed: %v", err)
524                         continue
525                 }
526
527                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
528                 c.cb(h.MsgID, msg)
529         }
530 }
531
532 type msgheader struct {
533         Q               int    `struc:"uint64"`
534         DataLen         uint32 `struc:"uint32"`
535         GcMarkTimestamp uint32 `struc:"uint32"`
536 }
537
538 func (c *vppClient) read() ([]byte, error) {
539         Log.Debug(" reading next msg..")
540
541         header := make([]byte, 16)
542
543         n, err := io.ReadAtLeast(c.reader, header, 16)
544         if err != nil {
545                 return nil, err
546         }
547         if n == 0 {
548                 Log.Debugln("zero bytes header")
549                 return nil, nil
550         } else if n != 16 {
551                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
552                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
553         }
554         Log.Debugf(" read header %d bytes: % 0X", n, header)
555
556         h := &msgheader{}
557         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
558                 return nil, err
559         }
560         Log.Debugf(" - decoded header: %+v", h)
561
562         msgLen := int(h.DataLen)
563         msg := make([]byte, msgLen)
564
565         n, err = c.reader.Read(msg)
566         if err != nil {
567                 return nil, err
568         }
569         Log.Debugf(" - read msg %d bytes (%d buffered) % 0X", n, c.reader.Buffered(), msg[:n])
570
571         if msgLen > n {
572                 remain := msgLen - n
573                 Log.Debugf("continue read for another %d bytes", remain)
574                 view := msg[n:]
575
576                 for remain > 0 {
577                         nbytes, err := c.reader.Read(view)
578                         if err != nil {
579                                 return nil, err
580                         } else if nbytes == 0 {
581                                 return nil, fmt.Errorf("zero nbytes")
582                         }
583
584                         remain -= nbytes
585                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
586
587                         view = view[nbytes:]
588                 }
589         }
590
591         Log.Debugf(" -- read done (buffered: %d)", c.reader.Buffered())
592
593         return msg, nil
594 }
595
596 func isClosedError(err error) bool {
597         if err == io.EOF {
598                 return true
599         }
600         return strings.HasSuffix(err.Error(), "use of closed network connection")
601 }