Print info for users to stderr when socket files are missing
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35         "git.fd.io/govpp.git/examples/binapi/memclnt"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = adapter.DefaultBinapiSocket
41 )
42
43 const socketMissing = `
44 ------------------------------------------------------------
45  VPP binary API socket file %s is missing!
46
47   - is VPP running with socket for binapi enabled?
48   - is the correct socket name configured?
49
50  To enable it add following section to your VPP config:
51    socksvr {
52      default
53    }
54 ------------------------------------------------------------
55 `
56
57 var (
58         // DefaultConnectTimeout is default timeout for connecting
59         DefaultConnectTimeout = time.Second * 3
60         // DefaultDisconnectTimeout is default timeout for discconnecting
61         DefaultDisconnectTimeout = time.Millisecond * 100
62         // MaxWaitReady defines maximum duration before waiting for socket file
63         // times out
64         MaxWaitReady = time.Second * 10
65         // ClientName is used for identifying client in socket registration
66         ClientName = "govppsock"
67 )
68
69 var (
70         // Debug is global variable that determines debug mode
71         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
72         // DebugMsgIds is global variable that determines debug mode for msg ids
73         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
74
75         // Log is global logger
76         Log = logger.New()
77 )
78
79 // init initializes global logger, which logs debug level messages to stdout.
80 func init() {
81         Log.Out = os.Stdout
82         if Debug {
83                 Log.Level = logger.DebugLevel
84                 Log.Debug("govpp/socketclient: enabled debug mode")
85         }
86 }
87
88 type vppClient struct {
89         sockAddr string
90
91         conn   *net.UnixConn
92         reader *bufio.Reader
93         writer *bufio.Writer
94
95         connectTimeout    time.Duration
96         disconnectTimeout time.Duration
97
98         cb           adapter.MsgCallback
99         clientIndex  uint32
100         msgTable     map[string]uint16
101         sockDelMsgId uint16
102         writeMu      sync.Mutex
103
104         quit chan struct{}
105         wg   sync.WaitGroup
106 }
107
108 func NewVppClient(sockAddr string) *vppClient {
109         if sockAddr == "" {
110                 sockAddr = DefaultSocketName
111         }
112         return &vppClient{
113                 sockAddr:          sockAddr,
114                 connectTimeout:    DefaultConnectTimeout,
115                 disconnectTimeout: DefaultDisconnectTimeout,
116                 cb: func(msgID uint16, data []byte) {
117                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
118                 },
119         }
120 }
121
122 // SetConnectTimeout sets timeout used during connecting.
123 func (c *vppClient) SetConnectTimeout(t time.Duration) {
124         c.connectTimeout = t
125 }
126
127 // SetDisconnectTimeout sets timeout used during disconnecting.
128 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
129         c.disconnectTimeout = t
130 }
131
132 // WaitReady checks socket file existence and waits for it if necessary
133 func (c *vppClient) WaitReady() error {
134         // check if socket already exists
135         if _, err := os.Stat(c.sockAddr); err == nil {
136                 return nil // socket exists, we are ready
137         } else if !os.IsNotExist(err) {
138                 return err // some other error occurred
139         }
140
141         // socket does not exist, watch for it
142         watcher, err := fsnotify.NewWatcher()
143         if err != nil {
144                 return err
145         }
146         defer func() {
147                 if err := watcher.Close(); err != nil {
148                         Log.Warnf("failed to close file watcher: %v", err)
149                 }
150         }()
151
152         // start directory watcher
153         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
154                 return err
155         }
156
157         timeout := time.NewTimer(MaxWaitReady)
158         for {
159                 select {
160                 case <-timeout.C:
161                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
162
163                 case e := <-watcher.Errors:
164                         return e
165
166                 case ev := <-watcher.Events:
167                         Log.Debugf("watcher event: %+v", ev)
168                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
169                                 // socket created, we are ready
170                                 return nil
171                         }
172                 }
173         }
174 }
175
176 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
177         Log.Debug("SetMsgCallback")
178         c.cb = cb
179 }
180
181 func (c *vppClient) Connect() error {
182         // check if socket exists
183         if _, err := os.Stat(c.sockAddr); os.IsNotExist(err) {
184                 fmt.Fprintf(os.Stderr, socketMissing, c.sockAddr)
185                 return fmt.Errorf("VPP API socket file %s does not exist", c.sockAddr)
186         } else if err != nil {
187                 return fmt.Errorf("VPP API socket error: %v", err)
188         }
189
190         if err := c.connect(c.sockAddr); err != nil {
191                 return err
192         }
193
194         if err := c.open(); err != nil {
195                 c.disconnect()
196                 return err
197         }
198
199         c.quit = make(chan struct{})
200         c.wg.Add(1)
201         go c.readerLoop()
202
203         return nil
204 }
205
206 func (c *vppClient) Disconnect() error {
207         if c.conn == nil {
208                 return nil
209         }
210         Log.Debugf("Disconnecting..")
211
212         close(c.quit)
213
214         if err := c.conn.CloseRead(); err != nil {
215                 Log.Debugf("closing read failed: %v", err)
216         }
217
218         // wait for readerLoop to return
219         c.wg.Wait()
220
221         if err := c.close(); err != nil {
222                 Log.Debugf("closing failed: %v", err)
223         }
224
225         if err := c.disconnect(); err != nil {
226                 return err
227         }
228
229         return nil
230 }
231
232 func (c *vppClient) connect(sockAddr string) error {
233         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
234
235         Log.Debugf("Connecting to: %v", c.sockAddr)
236
237         conn, err := net.DialUnix("unix", nil, addr)
238         if err != nil {
239                 // we try different type of socket for backwards compatbility with VPP<=19.04
240                 if strings.Contains(err.Error(), "wrong type for socket") {
241                         addr.Net = "unixpacket"
242                         Log.Debugf("%s, retrying connect with type unixpacket", err)
243                         conn, err = net.DialUnix("unixpacket", nil, addr)
244                 }
245                 if err != nil {
246                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
247                         return err
248                 }
249         }
250
251         c.conn = conn
252         Log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
253
254         c.reader = bufio.NewReader(c.conn)
255         c.writer = bufio.NewWriter(c.conn)
256
257         return nil
258 }
259
260 func (c *vppClient) disconnect() error {
261         Log.Debugf("Closing socket")
262         if err := c.conn.Close(); err != nil {
263                 Log.Debugln("Closing socket failed:", err)
264                 return err
265         }
266         return nil
267 }
268
269 const (
270         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
271         createMsgContext = byte(123)
272         deleteMsgContext = byte(124)
273 )
274
275 func (c *vppClient) open() error {
276         msgCodec := new(codec.MsgCodec)
277
278         req := &memclnt.SockclntCreate{
279                 Name: []byte(ClientName),
280         }
281         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
282         if err != nil {
283                 Log.Debugln("Encode error:", err)
284                 return err
285         }
286         // set non-0 context
287         msg[5] = createMsgContext
288
289         if err := c.write(msg); err != nil {
290                 Log.Debugln("Write error: ", err)
291                 return err
292         }
293
294         readDeadline := time.Now().Add(c.connectTimeout)
295         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
296                 return err
297         }
298         msgReply, err := c.read()
299         if err != nil {
300                 Log.Println("Read error:", err)
301                 return err
302         }
303         // reset read deadline
304         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
305                 return err
306         }
307
308         reply := new(memclnt.SockclntCreateReply)
309         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
310                 Log.Println("Decode error:", err)
311                 return err
312         }
313
314         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
315                 reply.Response, reply.Index, reply.Count)
316
317         c.clientIndex = reply.Index
318         c.msgTable = make(map[string]uint16, reply.Count)
319         for _, x := range reply.MessageTable {
320                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
321                 c.msgTable[name] = x.Index
322                 if strings.HasPrefix(name, "sockclnt_delete_") {
323                         c.sockDelMsgId = x.Index
324                 }
325                 if DebugMsgIds {
326                         Log.Debugf(" - %4d: %q", x.Index, name)
327                 }
328         }
329
330         return nil
331 }
332
333 func (c *vppClient) close() error {
334         msgCodec := new(codec.MsgCodec)
335
336         req := &memclnt.SockclntDelete{
337                 Index: c.clientIndex,
338         }
339         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
340         if err != nil {
341                 Log.Debugln("Encode error:", err)
342                 return err
343         }
344         // set non-0 context
345         msg[5] = deleteMsgContext
346
347         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
348         if err := c.write(msg); err != nil {
349                 Log.Debugln("Write error: ", err)
350                 return err
351         }
352
353         readDeadline := time.Now().Add(c.disconnectTimeout)
354         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
355                 return err
356         }
357         msgReply, err := c.read()
358         if err != nil {
359                 Log.Debugln("Read error:", err)
360                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
361                         // we accept timeout for reply
362                         return nil
363                 }
364                 return err
365         }
366         // reset read deadline
367         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
368                 return err
369         }
370
371         reply := new(memclnt.SockclntDeleteReply)
372         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
373                 Log.Debugln("Decode error:", err)
374                 return err
375         }
376
377         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
378
379         return nil
380 }
381
382 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
383         msg := msgName + "_" + msgCrc
384         msgID, ok := c.msgTable[msg]
385         if !ok {
386                 return 0, fmt.Errorf("unknown message: %q", msg)
387         }
388         return msgID, nil
389 }
390
391 type reqHeader struct {
392         // MsgID uint16
393         ClientIndex uint32
394         Context     uint32
395 }
396
397 func (c *vppClient) SendMsg(context uint32, data []byte) error {
398         h := &reqHeader{
399                 ClientIndex: c.clientIndex,
400                 Context:     context,
401         }
402         buf := new(bytes.Buffer)
403         if err := struc.Pack(buf, h); err != nil {
404                 return err
405         }
406         copy(data[2:], buf.Bytes())
407
408         Log.Debugf("sendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
409
410         if err := c.write(data); err != nil {
411                 Log.Debugln("write error: ", err)
412                 return err
413         }
414
415         return nil
416 }
417
418 func (c *vppClient) write(msg []byte) error {
419         h := &msgheader{
420                 DataLen: uint32(len(msg)),
421         }
422         buf := new(bytes.Buffer)
423         if err := struc.Pack(buf, h); err != nil {
424                 return err
425         }
426         header := buf.Bytes()
427
428         // we lock to prevent mixing multiple message sends
429         c.writeMu.Lock()
430         defer c.writeMu.Unlock()
431
432         if n, err := c.writer.Write(header); err != nil {
433                 return err
434         } else {
435                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
436         }
437
438         writerSize := c.writer.Size()
439         for i := 0; i <= len(msg)/writerSize; i++ {
440                 x := i*writerSize + writerSize
441                 if x > len(msg) {
442                         x = len(msg)
443                 }
444                 Log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
445                 if n, err := c.writer.Write(msg[i*writerSize : x]); err != nil {
446                         return err
447                 } else {
448                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
449                 }
450         }
451         if err := c.writer.Flush(); err != nil {
452                 return err
453         }
454
455         Log.Debugf(" -- write done")
456
457         return nil
458 }
459
460 type msgHeader struct {
461         MsgID   uint16
462         Context uint32
463 }
464
465 func (c *vppClient) readerLoop() {
466         defer c.wg.Done()
467         defer Log.Debugf("reader quit")
468
469         for {
470                 select {
471                 case <-c.quit:
472                         return
473                 default:
474                 }
475
476                 msg, err := c.read()
477                 if err != nil {
478                         if isClosedError(err) {
479                                 return
480                         }
481                         Log.Debugf("read failed: %v", err)
482                         continue
483                 }
484
485                 h := new(msgHeader)
486                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
487                         Log.Debugf("unpacking header failed: %v", err)
488                         continue
489                 }
490
491                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
492                 c.cb(h.MsgID, msg)
493         }
494 }
495
496 type msgheader struct {
497         Q               int    `struc:"uint64"`
498         DataLen         uint32 `struc:"uint32"`
499         GcMarkTimestamp uint32 `struc:"uint32"`
500 }
501
502 func (c *vppClient) read() ([]byte, error) {
503         Log.Debug(" reading next msg..")
504
505         header := make([]byte, 16)
506
507         n, err := io.ReadAtLeast(c.reader, header, 16)
508         if err != nil {
509                 return nil, err
510         }
511         if n == 0 {
512                 Log.Debugln("zero bytes header")
513                 return nil, nil
514         } else if n != 16 {
515                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
516                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
517         }
518         Log.Debugf(" read header %d bytes: % 0X", n, header)
519
520         h := &msgheader{}
521         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
522                 return nil, err
523         }
524         Log.Debugf(" - decoded header: %+v", h)
525
526         msgLen := int(h.DataLen)
527         msg := make([]byte, msgLen)
528
529         n, err = c.reader.Read(msg)
530         if err != nil {
531                 return nil, err
532         }
533         Log.Debugf(" - read msg %d bytes (%d buffered) % 0X", n, c.reader.Buffered(), msg[:n])
534
535         if msgLen > n {
536                 remain := msgLen - n
537                 Log.Debugf("continue read for another %d bytes", remain)
538                 view := msg[n:]
539
540                 for remain > 0 {
541                         nbytes, err := c.reader.Read(view)
542                         if err != nil {
543                                 return nil, err
544                         } else if nbytes == 0 {
545                                 return nil, fmt.Errorf("zero nbytes")
546                         }
547
548                         remain -= nbytes
549                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
550
551                         view = view[nbytes:]
552                 }
553         }
554
555         Log.Debugf(" -- read done (buffered: %d)", c.reader.Buffered())
556
557         return msg, nil
558 }
559
560 func isClosedError(err error) bool {
561         if err == io.EOF {
562                 return true
563         }
564         return strings.HasSuffix(err.Error(), "use of closed network connection")
565 }