Fix socketclient for VPP 19.08
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35         "git.fd.io/govpp.git/examples/binapi/memclnt"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = adapter.DefaultBinapiSocket
41 )
42
43 var (
44         // DefaultConnectTimeout is default timeout for connecting
45         DefaultConnectTimeout = time.Second * 3
46         // DefaultDisconnectTimeout is default timeout for discconnecting
47         DefaultDisconnectTimeout = time.Millisecond * 100
48         // MaxWaitReady defines maximum duration before waiting for socket file
49         // times out
50         MaxWaitReady = time.Second * 10
51         // ClientName is used for identifying client in socket registration
52         ClientName = "govppsock"
53 )
54
55 var (
56         // Debug is global variable that determines debug mode
57         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
58         // DebugMsgIds is global variable that determines debug mode for msg ids
59         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
60
61         // Log is global logger
62         Log = logger.New()
63 )
64
65 // init initializes global logger, which logs debug level messages to stdout.
66 func init() {
67         Log.Out = os.Stdout
68         if Debug {
69                 Log.Level = logger.DebugLevel
70                 Log.Debug("govpp/socketclient: enabled debug mode")
71         }
72 }
73
74 type vppClient struct {
75         sockAddr string
76
77         conn   *net.UnixConn
78         reader *bufio.Reader
79         writer *bufio.Writer
80
81         connectTimeout    time.Duration
82         disconnectTimeout time.Duration
83
84         cb           adapter.MsgCallback
85         clientIndex  uint32
86         msgTable     map[string]uint16
87         sockDelMsgId uint16
88         writeMu      sync.Mutex
89
90         quit chan struct{}
91         wg   sync.WaitGroup
92 }
93
94 func NewVppClient(sockAddr string) *vppClient {
95         if sockAddr == "" {
96                 sockAddr = DefaultSocketName
97         }
98         return &vppClient{
99                 sockAddr:          sockAddr,
100                 connectTimeout:    DefaultConnectTimeout,
101                 disconnectTimeout: DefaultDisconnectTimeout,
102                 cb: func(msgID uint16, data []byte) {
103                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
104                 },
105         }
106 }
107
108 // SetConnectTimeout sets timeout used during connecting.
109 func (c *vppClient) SetConnectTimeout(t time.Duration) {
110         c.connectTimeout = t
111 }
112
113 // SetDisconnectTimeout sets timeout used during disconnecting.
114 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
115         c.disconnectTimeout = t
116 }
117
118 // WaitReady checks socket file existence and waits for it if necessary
119 func (c *vppClient) WaitReady() error {
120         // check if socket already exists
121         if _, err := os.Stat(c.sockAddr); err == nil {
122                 return nil // socket exists, we are ready
123         } else if !os.IsNotExist(err) {
124                 return err // some other error occurred
125         }
126
127         // socket does not exist, watch for it
128         watcher, err := fsnotify.NewWatcher()
129         if err != nil {
130                 return err
131         }
132         defer func() {
133                 if err := watcher.Close(); err != nil {
134                         Log.Warnf("failed to close file watcher: %v", err)
135                 }
136         }()
137
138         // start directory watcher
139         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
140                 return err
141         }
142
143         timeout := time.NewTimer(MaxWaitReady)
144         for {
145                 select {
146                 case <-timeout.C:
147                         return fmt.Errorf("timeout waiting (%s) for socket file: %s", MaxWaitReady, c.sockAddr)
148
149                 case e := <-watcher.Errors:
150                         return e
151
152                 case ev := <-watcher.Events:
153                         Log.Debugf("watcher event: %+v", ev)
154                         if ev.Name == c.sockAddr && (ev.Op&fsnotify.Create) == fsnotify.Create {
155                                 // socket created, we are ready
156                                 return nil
157                         }
158                 }
159         }
160 }
161
162 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
163         Log.Debug("SetMsgCallback")
164         c.cb = cb
165 }
166
167 func (c *vppClient) Connect() error {
168         Log.Debugf("Connecting to: %v", c.sockAddr)
169
170         if err := c.connect(c.sockAddr); err != nil {
171                 return err
172         }
173
174         if err := c.open(); err != nil {
175                 c.disconnect()
176                 return err
177         }
178
179         c.quit = make(chan struct{})
180         c.wg.Add(1)
181         go c.readerLoop()
182
183         return nil
184 }
185
186 func (c *vppClient) Disconnect() error {
187         if c.conn == nil {
188                 return nil
189         }
190         Log.Debugf("Disconnecting..")
191
192         close(c.quit)
193
194         if err := c.conn.CloseRead(); err != nil {
195                 Log.Debugf("closing read failed: %v", err)
196         }
197
198         // wait for readerLoop to return
199         c.wg.Wait()
200
201         if err := c.close(); err != nil {
202                 Log.Debugf("closing failed: %v", err)
203         }
204
205         if err := c.disconnect(); err != nil {
206                 return err
207         }
208
209         return nil
210 }
211
212 func (c *vppClient) connect(sockAddr string) error {
213         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
214
215         conn, err := net.DialUnix("unix", nil, addr)
216         if err != nil {
217                 // we try different type of socket for backwards compatbility with VPP<=19.04
218                 if strings.Contains(err.Error(), "wrong type for socket") {
219                         addr.Net = "unixpacket"
220                         Log.Warnf("%s, retrying connect with type unixpacket", err)
221                         conn, err = net.DialUnix("unixpacket", nil, addr)
222                 }
223                 if err != nil {
224                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
225                         return err
226                 }
227         }
228
229         c.conn = conn
230         Log.Debugf("Connected to socket (local addr: %v)", c.conn.LocalAddr().(*net.UnixAddr))
231
232         c.reader = bufio.NewReader(c.conn)
233         c.writer = bufio.NewWriter(c.conn)
234
235         return nil
236 }
237
238 func (c *vppClient) disconnect() error {
239         Log.Debugf("Closing socket")
240         if err := c.conn.Close(); err != nil {
241                 Log.Debugln("Closing socket failed:", err)
242                 return err
243         }
244         return nil
245 }
246
247 const (
248         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
249         createMsgContext = byte(123)
250         deleteMsgContext = byte(124)
251 )
252
253 func (c *vppClient) open() error {
254         msgCodec := new(codec.MsgCodec)
255
256         req := &memclnt.SockclntCreate{
257                 Name: []byte(ClientName),
258         }
259         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
260         if err != nil {
261                 Log.Debugln("Encode error:", err)
262                 return err
263         }
264         // set non-0 context
265         msg[5] = createMsgContext
266
267         if err := c.write(msg); err != nil {
268                 Log.Debugln("Write error: ", err)
269                 return err
270         }
271
272         readDeadline := time.Now().Add(c.connectTimeout)
273         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
274                 return err
275         }
276         msgReply, err := c.read()
277         if err != nil {
278                 Log.Println("Read error:", err)
279                 return err
280         }
281         // reset read deadline
282         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
283                 return err
284         }
285
286         reply := new(memclnt.SockclntCreateReply)
287         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
288                 Log.Println("Decode error:", err)
289                 return err
290         }
291
292         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
293                 reply.Response, reply.Index, reply.Count)
294
295         c.clientIndex = reply.Index
296         c.msgTable = make(map[string]uint16, reply.Count)
297         for _, x := range reply.MessageTable {
298                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
299                 c.msgTable[name] = x.Index
300                 if strings.HasPrefix(name, "sockclnt_delete_") {
301                         c.sockDelMsgId = x.Index
302                 }
303                 if DebugMsgIds {
304                         Log.Debugf(" - %4d: %q", x.Index, name)
305                 }
306         }
307
308         return nil
309 }
310
311 func (c *vppClient) close() error {
312         msgCodec := new(codec.MsgCodec)
313
314         req := &memclnt.SockclntDelete{
315                 Index: c.clientIndex,
316         }
317         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
318         if err != nil {
319                 Log.Debugln("Encode error:", err)
320                 return err
321         }
322         // set non-0 context
323         msg[5] = deleteMsgContext
324
325         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
326         if err := c.write(msg); err != nil {
327                 Log.Debugln("Write error: ", err)
328                 return err
329         }
330
331         readDeadline := time.Now().Add(c.disconnectTimeout)
332         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
333                 return err
334         }
335         msgReply, err := c.read()
336         if err != nil {
337                 Log.Debugln("Read error:", err)
338                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
339                         // we accept timeout for reply
340                         return nil
341                 }
342                 return err
343         }
344         // reset read deadline
345         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
346                 return err
347         }
348
349         reply := new(memclnt.SockclntDeleteReply)
350         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
351                 Log.Debugln("Decode error:", err)
352                 return err
353         }
354
355         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
356
357         return nil
358 }
359
360 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
361         msg := msgName + "_" + msgCrc
362         msgID, ok := c.msgTable[msg]
363         if !ok {
364                 return 0, fmt.Errorf("unknown message: %q", msg)
365         }
366         return msgID, nil
367 }
368
369 type reqHeader struct {
370         // MsgID uint16
371         ClientIndex uint32
372         Context     uint32
373 }
374
375 func (c *vppClient) SendMsg(context uint32, data []byte) error {
376         h := &reqHeader{
377                 ClientIndex: c.clientIndex,
378                 Context:     context,
379         }
380         buf := new(bytes.Buffer)
381         if err := struc.Pack(buf, h); err != nil {
382                 return err
383         }
384         copy(data[2:], buf.Bytes())
385
386         Log.Debugf("sendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
387
388         if err := c.write(data); err != nil {
389                 Log.Debugln("write error: ", err)
390                 return err
391         }
392
393         return nil
394 }
395
396 func (c *vppClient) write(msg []byte) error {
397         h := &msgheader{
398                 DataLen: uint32(len(msg)),
399         }
400         buf := new(bytes.Buffer)
401         if err := struc.Pack(buf, h); err != nil {
402                 return err
403         }
404         header := buf.Bytes()
405
406         // we lock to prevent mixing multiple message sends
407         c.writeMu.Lock()
408         defer c.writeMu.Unlock()
409
410         if n, err := c.writer.Write(header); err != nil {
411                 return err
412         } else {
413                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
414         }
415
416         writerSize := c.writer.Size()
417         for i := 0; i <= len(msg)/writerSize; i++ {
418                 x := i*writerSize + writerSize
419                 if x > len(msg) {
420                         x = len(msg)
421                 }
422                 Log.Debugf(" - x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/writerSize)
423                 if n, err := c.writer.Write(msg[i*writerSize : x]); err != nil {
424                         return err
425                 } else {
426                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
427                 }
428         }
429         if err := c.writer.Flush(); err != nil {
430                 return err
431         }
432
433         Log.Debugf(" -- write done")
434
435         return nil
436 }
437
438 type msgHeader struct {
439         MsgID   uint16
440         Context uint32
441 }
442
443 func (c *vppClient) readerLoop() {
444         defer c.wg.Done()
445         defer Log.Debugf("reader quit")
446
447         for {
448                 select {
449                 case <-c.quit:
450                         return
451                 default:
452                 }
453
454                 msg, err := c.read()
455                 if err != nil {
456                         if isClosedError(err) {
457                                 return
458                         }
459                         Log.Debugf("read failed: %v", err)
460                         continue
461                 }
462
463                 h := new(msgHeader)
464                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
465                         Log.Debugf("unpacking header failed: %v", err)
466                         continue
467                 }
468
469                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
470                 c.cb(h.MsgID, msg)
471         }
472 }
473
474 type msgheader struct {
475         Q               int    `struc:"uint64"`
476         DataLen         uint32 `struc:"uint32"`
477         GcMarkTimestamp uint32 `struc:"uint32"`
478 }
479
480 func (c *vppClient) read() ([]byte, error) {
481         Log.Debug(" reading next msg..")
482
483         header := make([]byte, 16)
484
485         n, err := io.ReadAtLeast(c.reader, header, 16)
486         if err != nil {
487                 return nil, err
488         }
489         if n == 0 {
490                 Log.Debugln("zero bytes header")
491                 return nil, nil
492         } else if n != 16 {
493                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
494                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
495         }
496         Log.Debugf(" read header %d bytes: % 0X", n, header)
497
498         h := &msgheader{}
499         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
500                 return nil, err
501         }
502         Log.Debugf(" - decoded header: %+v", h)
503
504         msgLen := int(h.DataLen)
505         msg := make([]byte, msgLen)
506
507         n, err = c.reader.Read(msg)
508         if err != nil {
509                 return nil, err
510         }
511         Log.Debugf(" - read msg %d bytes (%d buffered) % 0X", n, c.reader.Buffered(), msg[:n])
512
513         if msgLen > n {
514                 remain := msgLen - n
515                 Log.Debugf("continue read for another %d bytes", remain)
516                 view := msg[n:]
517
518                 for remain > 0 {
519                         nbytes, err := c.reader.Read(view)
520                         if err != nil {
521                                 return nil, err
522                         } else if nbytes == 0 {
523                                 return nil, fmt.Errorf("zero nbytes")
524                         }
525
526                         remain -= nbytes
527                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
528
529                         view = view[nbytes:]
530                 }
531         }
532
533         Log.Debugf(" -- read done (buffered: %d)", c.reader.Buffered())
534
535         return msg, nil
536 }
537
538 func isClosedError(err error) bool {
539         if err == io.EOF {
540                 return true
541         }
542         return strings.HasSuffix(err.Error(), "use of closed network connection")
543 }