Fix WaitReady for VPP client adapters
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "io"
8         "net"
9         "os"
10         "path/filepath"
11         "strings"
12         "sync"
13         "time"
14
15         "github.com/fsnotify/fsnotify"
16
17         "github.com/lunixbochs/struc"
18         logger "github.com/sirupsen/logrus"
19
20         "git.fd.io/govpp.git/adapter"
21         "git.fd.io/govpp.git/codec"
22         "git.fd.io/govpp.git/examples/bin_api/memclnt"
23 )
24
25 const (
26         // DefaultSocketName is default VPP API socket file name
27         DefaultSocketName = "/run/vpp-api.sock"
28
29         sockCreateMsgId = 15          // hard-coded id for sockclnt_create message
30         govppClientName = "govppsock" // client name used for socket registration
31 )
32
33 var (
34         ConnectTimeout    = time.Second * 3
35         DisconnectTimeout = time.Second
36
37         // MaxWaitReady defines maximum duration before waiting for socket file
38         // times out
39         MaxWaitReady = time.Second * 15
40
41         Debug       = os.Getenv("DEBUG_GOVPP_SOCK") != ""
42         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
43
44         Log = logger.New() // global logger
45 )
46
47 // init initializes global logger, which logs debug level messages to stdout.
48 func init() {
49         Log.Out = os.Stdout
50         if Debug {
51                 Log.Level = logger.DebugLevel
52         }
53 }
54
55 type vppClient struct {
56         sockAddr     string
57         conn         *net.UnixConn
58         reader       *bufio.Reader
59         cb           adapter.MsgCallback
60         clientIndex  uint32
61         msgTable     map[string]uint16
62         sockDelMsgId uint16
63         writeMu      sync.Mutex
64         quit         chan struct{}
65         wg           sync.WaitGroup
66 }
67
68 func NewVppClient(sockAddr string) *vppClient {
69         if sockAddr == "" {
70                 sockAddr = DefaultSocketName
71         }
72         return &vppClient{
73                 sockAddr: sockAddr,
74                 cb:       nilCallback,
75         }
76 }
77
78 func nilCallback(msgID uint16, data []byte) {
79         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
80 }
81
82 // WaitReady checks socket file existence and waits for it if necessary
83 func (c *vppClient) WaitReady() error {
84         // check if file at the path already exists
85         if _, err := os.Stat(c.sockAddr); err == nil {
86                 return nil
87         } else if os.IsExist(err) {
88                 return err
89         }
90
91         // if not, watch for it
92         watcher, err := fsnotify.NewWatcher()
93         if err != nil {
94                 return err
95         }
96         defer func() {
97                 if err := watcher.Close(); err != nil {
98                         Log.Errorf("failed to close file watcher: %v", err)
99                 }
100         }()
101
102         // start watching directory
103         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
104                 return err
105         }
106
107         for {
108                 select {
109                 case <-time.After(MaxWaitReady):
110                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
111                 case e := <-watcher.Errors:
112                         return e
113                 case ev := <-watcher.Events:
114                         Log.Debugf("watcher event: %+v", ev)
115                         if ev.Name == c.sockAddr {
116                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
117                                         // socket was created, we are ready
118                                         return nil
119                                 }
120                         }
121                 }
122         }
123 }
124
125 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
126         Log.Debug("SetMsgCallback")
127         c.cb = cb
128 }
129
130 func (c *vppClient) Connect() error {
131         Log.Debugf("Connecting to: %v", c.sockAddr)
132
133         if err := c.connect(c.sockAddr); err != nil {
134                 return err
135         }
136
137         if err := c.open(); err != nil {
138                 return err
139         }
140
141         c.quit = make(chan struct{})
142         c.wg.Add(1)
143         go c.readerLoop()
144
145         return nil
146 }
147
148 func (c *vppClient) connect(sockAddr string) error {
149         addr, err := net.ResolveUnixAddr("unixpacket", sockAddr)
150         if err != nil {
151                 Log.Debugln("ResolveUnixAddr error:", err)
152                 return err
153         }
154
155         conn, err := net.DialUnix("unixpacket", nil, addr)
156         if err != nil {
157                 Log.Debugln("Dial error:", err)
158                 return err
159         }
160
161         c.conn = conn
162         c.reader = bufio.NewReader(c.conn)
163
164         Log.Debugf("Connected to socket: %v", addr)
165
166         return nil
167 }
168
169 func (c *vppClient) open() error {
170         msgCodec := new(codec.MsgCodec)
171
172         req := &memclnt.SockclntCreate{
173                 Name: []byte(govppClientName),
174         }
175         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
176         if err != nil {
177                 Log.Debugln("Encode error:", err)
178                 return err
179         }
180         // set non-0 context
181         msg[5] = 123
182
183         if err := c.write(msg); err != nil {
184                 Log.Debugln("Write error: ", err)
185                 return err
186         }
187
188         readDeadline := time.Now().Add(ConnectTimeout)
189         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
190                 return err
191         }
192         msgReply, err := c.read()
193         if err != nil {
194                 Log.Println("Read error:", err)
195                 return err
196         }
197         // reset read deadline
198         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
199                 return err
200         }
201
202         //log.Printf("Client got (%d): % 0X", len(msgReply), msgReply)
203
204         reply := new(memclnt.SockclntCreateReply)
205         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
206                 Log.Println("Decode error:", err)
207                 return err
208         }
209
210         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
211                 reply.Response, reply.Index, reply.Count)
212
213         c.clientIndex = reply.Index
214         c.msgTable = make(map[string]uint16, reply.Count)
215         for _, x := range reply.MessageTable {
216                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
217                 c.msgTable[name] = x.Index
218                 if strings.HasPrefix(name, "sockclnt_delete_") {
219                         c.sockDelMsgId = x.Index
220                 }
221                 if DebugMsgIds {
222                         Log.Debugf(" - %4d: %q", x.Index, name)
223                 }
224         }
225
226         return nil
227 }
228
229 func (c *vppClient) Disconnect() error {
230         if c.conn == nil {
231                 return nil
232         }
233         Log.Debugf("Disconnecting..")
234
235         close(c.quit)
236
237         // force readerLoop to timeout
238         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
239                 return err
240         }
241
242         // wait for readerLoop to return
243         c.wg.Wait()
244
245         if err := c.close(); err != nil {
246                 return err
247         }
248
249         if err := c.conn.Close(); err != nil {
250                 Log.Debugln("Close socket conn failed:", err)
251                 return err
252         }
253
254         return nil
255 }
256
257 func (c *vppClient) close() error {
258         msgCodec := new(codec.MsgCodec)
259
260         req := &memclnt.SockclntDelete{
261                 Index: c.clientIndex,
262         }
263         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
264         if err != nil {
265                 Log.Debugln("Encode error:", err)
266                 return err
267         }
268         // set non-0 context
269         msg[5] = 124
270
271         Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg)
272         if err := c.write(msg); err != nil {
273                 Log.Debugln("Write error: ", err)
274                 return err
275         }
276
277         readDeadline := time.Now().Add(DisconnectTimeout)
278         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
279                 return err
280         }
281         msgReply, err := c.read()
282         if err != nil {
283                 Log.Debugln("Read error:", err)
284                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
285                         // we accept timeout for reply
286                         return nil
287                 }
288                 return err
289         }
290         // reset read deadline
291         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
292                 return err
293         }
294
295         reply := new(memclnt.SockclntDeleteReply)
296         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
297                 Log.Debugln("Decode error:", err)
298                 return err
299         }
300
301         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
302
303         return nil
304 }
305
306 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
307         msg := msgName + "_" + msgCrc
308         msgID, ok := c.msgTable[msg]
309         if !ok {
310                 return 0, fmt.Errorf("unknown message: %q", msg)
311         }
312         return msgID, nil
313 }
314
315 type reqHeader struct {
316         //MsgID uint16
317         ClientIndex uint32
318         Context     uint32
319 }
320
321 func (c *vppClient) SendMsg(context uint32, data []byte) error {
322         h := &reqHeader{
323                 ClientIndex: c.clientIndex,
324                 Context:     context,
325         }
326         buf := new(bytes.Buffer)
327         if err := struc.Pack(buf, h); err != nil {
328                 return err
329         }
330         copy(data[2:], buf.Bytes())
331
332         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
333
334         if err := c.write(data); err != nil {
335                 Log.Debugln("write error: ", err)
336                 return err
337         }
338
339         return nil
340 }
341
342 func (c *vppClient) write(msg []byte) error {
343         h := &msgheader{
344                 Data_len: uint32(len(msg)),
345         }
346         buf := new(bytes.Buffer)
347         if err := struc.Pack(buf, h); err != nil {
348                 return err
349         }
350         header := buf.Bytes()
351
352         // we lock to prevent mixing multiple message sends
353         c.writeMu.Lock()
354         defer c.writeMu.Unlock()
355
356         var w io.Writer = c.conn
357
358         if n, err := w.Write(header); err != nil {
359                 return err
360         } else {
361                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
362         }
363         if n, err := w.Write(msg); err != nil {
364                 return err
365         } else {
366                 Log.Debugf(" - msg sent (%d/%d): % 0X", n, len(msg), msg)
367         }
368
369         return nil
370 }
371
372 type msgHeader struct {
373         MsgID   uint16
374         Context uint32
375 }
376
377 func (c *vppClient) readerLoop() {
378         defer c.wg.Done()
379         for {
380                 select {
381                 case <-c.quit:
382                         Log.Debugf("reader quit")
383                         return
384                 default:
385                 }
386
387                 msg, err := c.read()
388                 if err != nil {
389                         if isClosedError(err) {
390                                 return
391                         }
392                         Log.Debugf("READ FAILED: %v", err)
393                         continue
394                 }
395                 h := new(msgHeader)
396                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
397                         Log.Debugf("unpacking header failed: %v", err)
398                         continue
399                 }
400
401                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
402                 c.cb(h.MsgID, msg)
403         }
404 }
405
406 type msgheader struct {
407         Q                 int    `struc:"uint64"`
408         Data_len          uint32 `struc:"uint32"`
409         Gc_mark_timestamp uint32 `struc:"uint32"`
410         //data              [0]uint8
411 }
412
413 func (c *vppClient) read() ([]byte, error) {
414         Log.Debug("reading next msg..")
415
416         header := make([]byte, 16)
417
418         n, err := io.ReadAtLeast(c.reader, header, 16)
419         if err != nil {
420                 return nil, err
421         } else if n == 0 {
422                 Log.Debugln("zero bytes header")
423                 return nil, nil
424         }
425         if n != 16 {
426                 Log.Debug("invalid header data (%d): % 0X", n, header[:n])
427                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
428         }
429         Log.Debugf(" - read header %d bytes: % 0X", n, header)
430
431         h := &msgheader{}
432         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
433                 return nil, err
434         }
435         Log.Debugf(" - decoded header: %+v", h)
436
437         msgLen := int(h.Data_len)
438         msg := make([]byte, msgLen)
439
440         n, err = c.reader.Read(msg)
441         if err != nil {
442                 return nil, err
443         }
444         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
445
446         if msgLen > n {
447                 remain := msgLen - n
448                 Log.Debugf("continue read for another %d bytes", remain)
449                 view := msg[n:]
450
451                 for remain > 0 {
452
453                         nbytes, err := c.reader.Read(view)
454                         if err != nil {
455                                 return nil, err
456                         } else if nbytes == 0 {
457                                 return nil, fmt.Errorf("zero nbytes")
458                         }
459
460                         remain -= nbytes
461                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
462
463                         view = view[nbytes:]
464                 }
465         }
466
467         return msg, nil
468 }
469
470 func isClosedError(err error) bool {
471         if err == io.EOF {
472                 return true
473         }
474         return strings.HasSuffix(err.Error(), "use of closed network connection")
475 }