Add various generator improvements
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "io"
8         "net"
9         "os"
10         "path/filepath"
11         "strings"
12         "sync"
13         "time"
14
15         "github.com/fsnotify/fsnotify"
16         "github.com/lunixbochs/struc"
17         logger "github.com/sirupsen/logrus"
18
19         "git.fd.io/govpp.git/adapter"
20         "git.fd.io/govpp.git/codec"
21         "git.fd.io/govpp.git/examples/binapi/memclnt"
22 )
23
24 const (
25         // DefaultSocketName is default VPP API socket file name
26         DefaultSocketName = "/run/vpp-api.sock"
27 )
28
29 var (
30         // DefaultConnectTimeout is default timeout for connecting
31         DefaultConnectTimeout = time.Second * 3
32         // DefaultDisconnectTimeout is default timeout for discconnecting
33         DefaultDisconnectTimeout = time.Second
34         // MaxWaitReady defines maximum duration before waiting for socket file
35         // times out
36         MaxWaitReady = time.Second * 15
37         // ClientName is used for identifying client in socket registration
38         ClientName = "govppsock"
39 )
40
41 var (
42         // Debug is global variable that determines debug mode
43         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
44         // DebugMsgIds is global variable that determines debug mode for msg ids
45         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
46
47         Log = logger.New() // global logger
48 )
49
50 // init initializes global logger, which logs debug level messages to stdout.
51 func init() {
52         Log.Out = os.Stdout
53         if Debug {
54                 Log.Level = logger.DebugLevel
55         }
56 }
57
58 type vppClient struct {
59         sockAddr string
60         conn     *net.UnixConn
61         reader   *bufio.Reader
62         writer   *bufio.Writer
63
64         connectTimeout    time.Duration
65         disconnectTimeout time.Duration
66
67         cb           adapter.MsgCallback
68         clientIndex  uint32
69         msgTable     map[string]uint16
70         sockDelMsgId uint16
71         writeMu      sync.Mutex
72
73         quit chan struct{}
74         wg   sync.WaitGroup
75 }
76
77 func NewVppClient(sockAddr string) *vppClient {
78         if sockAddr == "" {
79                 sockAddr = DefaultSocketName
80         }
81         return &vppClient{
82                 sockAddr:          sockAddr,
83                 connectTimeout:    DefaultConnectTimeout,
84                 disconnectTimeout: DefaultDisconnectTimeout,
85                 cb: func(msgID uint16, data []byte) {
86                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
87                 },
88         }
89 }
90
91 // SetConnectTimeout sets timeout used during connecting.
92 func (c *vppClient) SetConnectTimeout(t time.Duration) {
93         c.connectTimeout = t
94 }
95
96 // SetDisconnectTimeout sets timeout used during disconnecting.
97 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
98         c.disconnectTimeout = t
99 }
100
101 // WaitReady checks socket file existence and waits for it if necessary
102 func (c *vppClient) WaitReady() error {
103         // check if file at the path already exists
104         if _, err := os.Stat(c.sockAddr); err == nil {
105                 return nil
106         } else if os.IsExist(err) {
107                 return err
108         }
109
110         // if not, watch for it
111         watcher, err := fsnotify.NewWatcher()
112         if err != nil {
113                 return err
114         }
115         defer func() {
116                 if err := watcher.Close(); err != nil {
117                         Log.Errorf("failed to close file watcher: %v", err)
118                 }
119         }()
120
121         // start watching directory
122         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
123                 return err
124         }
125
126         for {
127                 select {
128                 case <-time.After(MaxWaitReady):
129                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
130                 case e := <-watcher.Errors:
131                         return e
132                 case ev := <-watcher.Events:
133                         Log.Debugf("watcher event: %+v", ev)
134                         if ev.Name == c.sockAddr {
135                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
136                                         // socket was created, we are ready
137                                         return nil
138                                 }
139                         }
140                 }
141         }
142 }
143
144 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
145         Log.Debug("SetMsgCallback")
146         c.cb = cb
147 }
148
149 func (c *vppClient) Connect() error {
150         Log.Debugf("Connecting to: %v", c.sockAddr)
151
152         if err := c.connect(c.sockAddr); err != nil {
153                 return err
154         }
155
156         if err := c.open(); err != nil {
157                 return err
158         }
159
160         c.quit = make(chan struct{})
161         c.wg.Add(1)
162         go c.readerLoop()
163
164         return nil
165 }
166
167 func (c *vppClient) connect(sockAddr string) error {
168         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
169
170         conn, err := net.DialUnix("unix", nil, addr)
171         if err != nil {
172                 // we try different type of socket for backwards compatbility with VPP<=19.04
173                 if strings.Contains(err.Error(), "wrong type for socket") {
174                         addr.Net = "unixpacket"
175                         Log.Debugf("%s, retrying connect with type unixpacket", err)
176                         conn, err = net.DialUnix("unixpacket", nil, addr)
177                 }
178                 if err != nil {
179                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
180                         return err
181                 }
182         }
183
184         c.conn = conn
185         c.reader = bufio.NewReader(c.conn)
186         c.writer = bufio.NewWriter(c.conn)
187
188         Log.Debugf("Connected to socket: %v", addr)
189
190         return nil
191 }
192
193 const (
194         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
195         createMsgContext = byte(123)
196         deleteMsgContext = byte(124)
197 )
198
199 func (c *vppClient) open() error {
200         msgCodec := new(codec.MsgCodec)
201
202         req := &memclnt.SockclntCreate{
203                 Name: []byte(ClientName),
204         }
205         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
206         if err != nil {
207                 Log.Debugln("Encode error:", err)
208                 return err
209         }
210         // set non-0 context
211         msg[5] = createMsgContext
212
213         if err := c.write(msg); err != nil {
214                 Log.Debugln("Write error: ", err)
215                 return err
216         }
217
218         readDeadline := time.Now().Add(c.connectTimeout)
219         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
220                 return err
221         }
222         msgReply, err := c.read()
223         if err != nil {
224                 Log.Println("Read error:", err)
225                 return err
226         }
227         // reset read deadline
228         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
229                 return err
230         }
231
232         reply := new(memclnt.SockclntCreateReply)
233         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
234                 Log.Println("Decode error:", err)
235                 return err
236         }
237
238         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
239                 reply.Response, reply.Index, reply.Count)
240
241         c.clientIndex = reply.Index
242         c.msgTable = make(map[string]uint16, reply.Count)
243         for _, x := range reply.MessageTable {
244                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
245                 c.msgTable[name] = x.Index
246                 if strings.HasPrefix(name, "sockclnt_delete_") {
247                         c.sockDelMsgId = x.Index
248                 }
249                 if DebugMsgIds {
250                         Log.Debugf(" - %4d: %q", x.Index, name)
251                 }
252         }
253
254         return nil
255 }
256
257 func (c *vppClient) Disconnect() error {
258         if c.conn == nil {
259                 return nil
260         }
261         Log.Debugf("Disconnecting..")
262
263         close(c.quit)
264
265         // force readerLoop to timeout
266         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
267                 return err
268         }
269
270         // wait for readerLoop to return
271         c.wg.Wait()
272
273         if err := c.close(); err != nil {
274                 return err
275         }
276
277         if err := c.conn.Close(); err != nil {
278                 Log.Debugln("Closing socket failed:", err)
279                 return err
280         }
281
282         return nil
283 }
284
285 func (c *vppClient) close() error {
286         msgCodec := new(codec.MsgCodec)
287
288         req := &memclnt.SockclntDelete{
289                 Index: c.clientIndex,
290         }
291         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
292         if err != nil {
293                 Log.Debugln("Encode error:", err)
294                 return err
295         }
296         // set non-0 context
297         msg[5] = deleteMsgContext
298
299         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
300         if err := c.write(msg); err != nil {
301                 Log.Debugln("Write error: ", err)
302                 return err
303         }
304
305         readDeadline := time.Now().Add(c.disconnectTimeout)
306         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
307                 return err
308         }
309         msgReply, err := c.read()
310         if err != nil {
311                 Log.Debugln("Read error:", err)
312                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
313                         // we accept timeout for reply
314                         return nil
315                 }
316                 return err
317         }
318         // reset read deadline
319         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
320                 return err
321         }
322
323         reply := new(memclnt.SockclntDeleteReply)
324         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
325                 Log.Debugln("Decode error:", err)
326                 return err
327         }
328
329         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
330
331         return nil
332 }
333
334 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
335         msg := msgName + "_" + msgCrc
336         msgID, ok := c.msgTable[msg]
337         if !ok {
338                 return 0, fmt.Errorf("unknown message: %q", msg)
339         }
340         return msgID, nil
341 }
342
343 type reqHeader struct {
344         // MsgID uint16
345         ClientIndex uint32
346         Context     uint32
347 }
348
349 func (c *vppClient) SendMsg(context uint32, data []byte) error {
350         h := &reqHeader{
351                 ClientIndex: c.clientIndex,
352                 Context:     context,
353         }
354         buf := new(bytes.Buffer)
355         if err := struc.Pack(buf, h); err != nil {
356                 return err
357         }
358         copy(data[2:], buf.Bytes())
359
360         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
361
362         if err := c.write(data); err != nil {
363                 Log.Debugln("write error: ", err)
364                 return err
365         }
366
367         return nil
368 }
369
370 func (c *vppClient) write(msg []byte) error {
371         h := &msgheader{
372                 DataLen: uint32(len(msg)),
373         }
374         buf := new(bytes.Buffer)
375         if err := struc.Pack(buf, h); err != nil {
376                 return err
377         }
378         header := buf.Bytes()
379
380         // we lock to prevent mixing multiple message sends
381         c.writeMu.Lock()
382         defer c.writeMu.Unlock()
383
384         if n, err := c.writer.Write(header); err != nil {
385                 return err
386         } else {
387                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
388         }
389
390         if err := c.writer.Flush(); err != nil {
391                 return err
392         }
393
394         for i := 0; i <= len(msg)/c.writer.Size(); i++ {
395                 x := i*c.writer.Size() + c.writer.Size()
396                 if x > len(msg) {
397                         x = len(msg)
398                 }
399                 Log.Debugf("x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/c.writer.Size())
400                 if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
401                         return err
402                 } else {
403                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
404                 }
405                 if err := c.writer.Flush(); err != nil {
406                         return err
407                 }
408
409         }
410
411         return nil
412 }
413
414 type msgHeader struct {
415         MsgID   uint16
416         Context uint32
417 }
418
419 func (c *vppClient) readerLoop() {
420         defer c.wg.Done()
421         defer Log.Debugf("reader quit")
422         for {
423                 select {
424                 case <-c.quit:
425                         return
426                 default:
427                 }
428
429                 msg, err := c.read()
430                 if err != nil {
431                         if isClosedError(err) {
432                                 return
433                         }
434                         Log.Debugf("read failed: %v", err)
435                         continue
436                 }
437                 h := new(msgHeader)
438                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
439                         Log.Debugf("unpacking header failed: %v", err)
440                         continue
441                 }
442
443                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
444                 c.cb(h.MsgID, msg)
445         }
446 }
447
448 type msgheader struct {
449         Q               int    `struc:"uint64"`
450         DataLen         uint32 `struc:"uint32"`
451         GcMarkTimestamp uint32 `struc:"uint32"`
452 }
453
454 func (c *vppClient) read() ([]byte, error) {
455         Log.Debug("reading next msg..")
456
457         header := make([]byte, 16)
458
459         n, err := io.ReadAtLeast(c.reader, header, 16)
460         if err != nil {
461                 return nil, err
462         } else if n == 0 {
463                 Log.Debugln("zero bytes header")
464                 return nil, nil
465         }
466         if n != 16 {
467                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
468                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
469         }
470         Log.Debugf(" - read header %d bytes: % 0X", n, header)
471
472         h := &msgheader{}
473         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
474                 return nil, err
475         }
476         Log.Debugf(" - decoded header: %+v", h)
477
478         msgLen := int(h.DataLen)
479         msg := make([]byte, msgLen)
480
481         n, err = c.reader.Read(msg)
482         if err != nil {
483                 return nil, err
484         }
485         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
486
487         if msgLen > n {
488                 remain := msgLen - n
489                 Log.Debugf("continue read for another %d bytes", remain)
490                 view := msg[n:]
491
492                 for remain > 0 {
493                         nbytes, err := c.reader.Read(view)
494                         if err != nil {
495                                 return nil, err
496                         } else if nbytes == 0 {
497                                 return nil, fmt.Errorf("zero nbytes")
498                         }
499
500                         remain -= nbytes
501                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
502
503                         view = view[nbytes:]
504                 }
505         }
506
507         return msg, nil
508 }
509
510 func isClosedError(err error) bool {
511         if err == io.EOF {
512                 return true
513         }
514         return strings.HasSuffix(err.Error(), "use of closed network connection")
515 }