socketclient: wait for socket to be created
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "github.com/fsnotify/fsnotify"
8         "io"
9         "net"
10         "os"
11         "path/filepath"
12         "strings"
13         "sync"
14         "time"
15
16         "github.com/lunixbochs/struc"
17         logger "github.com/sirupsen/logrus"
18
19         "git.fd.io/govpp.git/adapter"
20         "git.fd.io/govpp.git/codec"
21         "git.fd.io/govpp.git/examples/bin_api/memclnt"
22 )
23
24 const (
25         // DefaultSocketName is default VPP API socket file name
26         DefaultSocketName = "/run/vpp-api.sock"
27
28         sockCreateMsgId = 15          // hard-coded id for sockclnt_create message
29         govppClientName = "govppsock" // client name used for socket registration
30 )
31
32 var (
33         ConnectTimeout    = time.Second * 3
34         DisconnectTimeout = time.Second
35
36         Debug       = os.Getenv("DEBUG_GOVPP_SOCK") != ""
37         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
38
39         Log = logger.New() // global logger
40 )
41
42 // init initializes global logger, which logs debug level messages to stdout.
43 func init() {
44         Log.Out = os.Stdout
45         if Debug {
46                 Log.Level = logger.DebugLevel
47         }
48 }
49
50 type vppClient struct {
51         sockAddr     string
52         conn         *net.UnixConn
53         reader       *bufio.Reader
54         cb           adapter.MsgCallback
55         clientIndex  uint32
56         msgTable     map[string]uint16
57         sockDelMsgId uint16
58         writeMu      sync.Mutex
59         quit         chan struct{}
60         wg           sync.WaitGroup
61 }
62
63 func NewVppClient(sockAddr string) *vppClient {
64         if sockAddr == "" {
65                 sockAddr = DefaultSocketName
66         }
67         return &vppClient{
68                 sockAddr: sockAddr,
69                 cb:       nilCallback,
70         }
71 }
72
73 func nilCallback(msgID uint16, data []byte) {
74         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
75 }
76
77 // WaitReady checks socket file existence and waits for it if necessary
78 func (c *vppClient) WaitReady() error {
79         // verify file existence
80         if _, err := os.Stat(c.sockAddr); err == nil {
81                 return nil
82         } else if os.IsExist(err) {
83                 return err
84         }
85
86         // if not, watch for it
87         watcher, err := fsnotify.NewWatcher()
88         if err != nil {
89                 return err
90         }
91         defer func() {
92                 if err := watcher.Close(); err != nil {
93                         Log.Errorf("failed to close file watcher: %v", err)
94                 }
95         }()
96         path := filepath.Dir(c.sockAddr)
97         if err := watcher.Add(path); err != nil {
98                 return err
99         }
100
101         for {
102                 ev := <-watcher.Events
103                 if ev.Name == path {
104                         if (ev.Op & fsnotify.Create) == fsnotify.Create {
105                                 // socket ready
106                                 return nil
107                         }
108                 }
109         }
110 }
111
112 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
113         Log.Debug("SetMsgCallback")
114         c.cb = cb
115 }
116
117 func (c *vppClient) Connect() error {
118         Log.Debugf("Connecting to: %v", c.sockAddr)
119
120         if err := c.connect(c.sockAddr); err != nil {
121                 return err
122         }
123
124         if err := c.open(); err != nil {
125                 return err
126         }
127
128         c.quit = make(chan struct{})
129         c.wg.Add(1)
130         go c.readerLoop()
131
132         return nil
133 }
134
135 func (c *vppClient) connect(sockAddr string) error {
136         addr, err := net.ResolveUnixAddr("unixpacket", sockAddr)
137         if err != nil {
138                 Log.Debugln("ResolveUnixAddr error:", err)
139                 return err
140         }
141
142         conn, err := net.DialUnix("unixpacket", nil, addr)
143         if err != nil {
144                 Log.Debugln("Dial error:", err)
145                 return err
146         }
147
148         c.conn = conn
149         c.reader = bufio.NewReader(c.conn)
150
151         Log.Debugf("Connected to socket: %v", addr)
152
153         return nil
154 }
155
156 func (c *vppClient) open() error {
157         msgCodec := new(codec.MsgCodec)
158
159         req := &memclnt.SockclntCreate{
160                 Name: []byte(govppClientName),
161         }
162         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
163         if err != nil {
164                 Log.Debugln("Encode error:", err)
165                 return err
166         }
167         // set non-0 context
168         msg[5] = 123
169
170         if err := c.write(msg); err != nil {
171                 Log.Debugln("Write error: ", err)
172                 return err
173         }
174
175         readDeadline := time.Now().Add(ConnectTimeout)
176         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
177                 return err
178         }
179         msgReply, err := c.read()
180         if err != nil {
181                 Log.Println("Read error:", err)
182                 return err
183         }
184         // reset read deadline
185         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
186                 return err
187         }
188
189         //log.Printf("Client got (%d): % 0X", len(msgReply), msgReply)
190
191         reply := new(memclnt.SockclntCreateReply)
192         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
193                 Log.Println("Decode error:", err)
194                 return err
195         }
196
197         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
198                 reply.Response, reply.Index, reply.Count)
199
200         c.clientIndex = reply.Index
201         c.msgTable = make(map[string]uint16, reply.Count)
202         for _, x := range reply.MessageTable {
203                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
204                 c.msgTable[name] = x.Index
205                 if strings.HasPrefix(name, "sockclnt_delete_") {
206                         c.sockDelMsgId = x.Index
207                 }
208                 if DebugMsgIds {
209                         Log.Debugf(" - %4d: %q", x.Index, name)
210                 }
211         }
212
213         return nil
214 }
215
216 func (c *vppClient) Disconnect() error {
217         if c.conn == nil {
218                 return nil
219         }
220         Log.Debugf("Disconnecting..")
221
222         close(c.quit)
223
224         // force readerLoop to timeout
225         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
226                 return err
227         }
228
229         // wait for readerLoop to return
230         c.wg.Wait()
231
232         if err := c.close(); err != nil {
233                 return err
234         }
235
236         if err := c.conn.Close(); err != nil {
237                 Log.Debugln("Close socket conn failed:", err)
238                 return err
239         }
240
241         return nil
242 }
243
244 func (c *vppClient) close() error {
245         msgCodec := new(codec.MsgCodec)
246
247         req := &memclnt.SockclntDelete{
248                 Index: c.clientIndex,
249         }
250         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
251         if err != nil {
252                 Log.Debugln("Encode error:", err)
253                 return err
254         }
255         // set non-0 context
256         msg[5] = 124
257
258         Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg)
259         if err := c.write(msg); err != nil {
260                 Log.Debugln("Write error: ", err)
261                 return err
262         }
263
264         readDeadline := time.Now().Add(DisconnectTimeout)
265         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
266                 return err
267         }
268         msgReply, err := c.read()
269         if err != nil {
270                 Log.Debugln("Read error:", err)
271                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
272                         // we accept timeout for reply
273                         return nil
274                 }
275                 return err
276         }
277         // reset read deadline
278         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
279                 return err
280         }
281
282         reply := new(memclnt.SockclntDeleteReply)
283         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
284                 Log.Debugln("Decode error:", err)
285                 return err
286         }
287
288         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
289
290         return nil
291 }
292
293 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
294         msg := msgName + "_" + msgCrc
295         msgID, ok := c.msgTable[msg]
296         if !ok {
297                 return 0, fmt.Errorf("unknown message: %q", msg)
298         }
299         return msgID, nil
300 }
301
302 type reqHeader struct {
303         //MsgID uint16
304         ClientIndex uint32
305         Context     uint32
306 }
307
308 func (c *vppClient) SendMsg(context uint32, data []byte) error {
309         h := &reqHeader{
310                 ClientIndex: c.clientIndex,
311                 Context:     context,
312         }
313         buf := new(bytes.Buffer)
314         if err := struc.Pack(buf, h); err != nil {
315                 return err
316         }
317         copy(data[2:], buf.Bytes())
318
319         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
320
321         if err := c.write(data); err != nil {
322                 Log.Debugln("write error: ", err)
323                 return err
324         }
325
326         return nil
327 }
328
329 func (c *vppClient) write(msg []byte) error {
330         h := &msgheader{
331                 Data_len: uint32(len(msg)),
332         }
333         buf := new(bytes.Buffer)
334         if err := struc.Pack(buf, h); err != nil {
335                 return err
336         }
337         header := buf.Bytes()
338
339         // we lock to prevent mixing multiple message sends
340         c.writeMu.Lock()
341         defer c.writeMu.Unlock()
342
343         var w io.Writer = c.conn
344
345         if n, err := w.Write(header); err != nil {
346                 return err
347         } else {
348                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
349         }
350         if n, err := w.Write(msg); err != nil {
351                 return err
352         } else {
353                 Log.Debugf(" - msg sent (%d/%d): % 0X", n, len(msg), msg)
354         }
355
356         return nil
357 }
358
359 type msgHeader struct {
360         MsgID   uint16
361         Context uint32
362 }
363
364 func (c *vppClient) readerLoop() {
365         defer c.wg.Done()
366         for {
367                 select {
368                 case <-c.quit:
369                         Log.Debugf("reader quit")
370                         return
371                 default:
372                 }
373
374                 msg, err := c.read()
375                 if err != nil {
376                         if isClosedError(err) {
377                                 return
378                         }
379                         Log.Debugf("READ FAILED: %v", err)
380                         continue
381                 }
382                 h := new(msgHeader)
383                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
384                         Log.Debugf("unpacking header failed: %v", err)
385                         continue
386                 }
387
388                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
389                 c.cb(h.MsgID, msg)
390         }
391 }
392
393 type msgheader struct {
394         Q                 int    `struc:"uint64"`
395         Data_len          uint32 `struc:"uint32"`
396         Gc_mark_timestamp uint32 `struc:"uint32"`
397         //data              [0]uint8
398 }
399
400 func (c *vppClient) read() ([]byte, error) {
401         Log.Debug("reading next msg..")
402
403         header := make([]byte, 16)
404
405         n, err := io.ReadAtLeast(c.reader, header, 16)
406         if err != nil {
407                 return nil, err
408         } else if n == 0 {
409                 Log.Debugln("zero bytes header")
410                 return nil, nil
411         }
412         if n != 16 {
413                 Log.Debug("invalid header data (%d): % 0X", n, header[:n])
414                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
415         }
416         Log.Debugf(" - read header %d bytes: % 0X", n, header)
417
418         h := &msgheader{}
419         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
420                 return nil, err
421         }
422         Log.Debugf(" - decoded header: %+v", h)
423
424         msgLen := int(h.Data_len)
425         msg := make([]byte, msgLen)
426
427         n, err = c.reader.Read(msg)
428         if err != nil {
429                 return nil, err
430         }
431         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
432
433         if msgLen > n {
434                 remain := msgLen - n
435                 Log.Debugf("continue read for another %d bytes", remain)
436                 view := msg[n:]
437
438                 for remain > 0 {
439
440                         nbytes, err := c.reader.Read(view)
441                         if err != nil {
442                                 return nil, err
443                         } else if nbytes == 0 {
444                                 return nil, fmt.Errorf("zero nbytes")
445                         }
446
447                         remain -= nbytes
448                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
449
450                         view = view[nbytes:]
451                 }
452         }
453
454         return msg, nil
455 }
456
457 func isClosedError(err error) bool {
458         if err == io.EOF {
459                 return true
460         }
461         return strings.HasSuffix(err.Error(), "use of closed network connection")
462 }