Try using different type of unix socket connection
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "io"
8         "net"
9         "os"
10         "path/filepath"
11         "strings"
12         "sync"
13         "time"
14
15         "github.com/fsnotify/fsnotify"
16         "github.com/lunixbochs/struc"
17         logger "github.com/sirupsen/logrus"
18
19         "git.fd.io/govpp.git/adapter"
20         "git.fd.io/govpp.git/codec"
21         "git.fd.io/govpp.git/examples/bin_api/memclnt"
22 )
23
24 const (
25         // DefaultSocketName is default VPP API socket file name
26         DefaultSocketName = "/run/vpp-api.sock"
27 )
28
29 var (
30         // DefaultConnectTimeout is default timeout for connecting
31         DefaultConnectTimeout = time.Second * 3
32         // DefaultDisconnectTimeout is default timeout for discconnecting
33         DefaultDisconnectTimeout = time.Second
34         // MaxWaitReady defines maximum duration before waiting for socket file
35         // times out
36         MaxWaitReady = time.Second * 15
37         // ClientName is used for identifying client in socket registration
38         ClientName = "govppsock"
39 )
40
41 var (
42         // Debug is global variable that determines debug mode
43         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
44         // DebugMsgIds is global variable that determines debug mode for msg ids
45         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
46
47         Log = logger.New() // global logger
48 )
49
50 // init initializes global logger, which logs debug level messages to stdout.
51 func init() {
52         Log.Out = os.Stdout
53         if Debug {
54                 Log.Level = logger.DebugLevel
55         }
56 }
57
58 type vppClient struct {
59         sockAddr string
60         conn     *net.UnixConn
61         reader   *bufio.Reader
62         writer   *bufio.Writer
63
64         connectTimeout    time.Duration
65         disconnectTimeout time.Duration
66
67         cb           adapter.MsgCallback
68         clientIndex  uint32
69         msgTable     map[string]uint16
70         sockDelMsgId uint16
71         writeMu      sync.Mutex
72
73         quit chan struct{}
74         wg   sync.WaitGroup
75 }
76
77 func NewVppClient(sockAddr string) *vppClient {
78         if sockAddr == "" {
79                 sockAddr = DefaultSocketName
80         }
81         return &vppClient{
82                 sockAddr:          sockAddr,
83                 connectTimeout:    DefaultConnectTimeout,
84                 disconnectTimeout: DefaultDisconnectTimeout,
85                 cb: func(msgID uint16, data []byte) {
86                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
87                 },
88         }
89 }
90
91 // SetConnectTimeout sets timeout used during connecting.
92 func (c *vppClient) SetConnectTimeout(t time.Duration) {
93         c.connectTimeout = t
94 }
95
96 // SetDisconnectTimeout sets timeout used during disconnecting.
97 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
98         c.disconnectTimeout = t
99 }
100
101 // WaitReady checks socket file existence and waits for it if necessary
102 func (c *vppClient) WaitReady() error {
103         // check if file at the path already exists
104         if _, err := os.Stat(c.sockAddr); err == nil {
105                 return nil
106         } else if os.IsExist(err) {
107                 return err
108         }
109
110         // if not, watch for it
111         watcher, err := fsnotify.NewWatcher()
112         if err != nil {
113                 return err
114         }
115         defer func() {
116                 if err := watcher.Close(); err != nil {
117                         Log.Errorf("failed to close file watcher: %v", err)
118                 }
119         }()
120
121         // start watching directory
122         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
123                 return err
124         }
125
126         for {
127                 select {
128                 case <-time.After(MaxWaitReady):
129                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
130                 case e := <-watcher.Errors:
131                         return e
132                 case ev := <-watcher.Events:
133                         Log.Debugf("watcher event: %+v", ev)
134                         if ev.Name == c.sockAddr {
135                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
136                                         // socket was created, we are ready
137                                         return nil
138                                 }
139                         }
140                 }
141         }
142 }
143
144 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
145         Log.Debug("SetMsgCallback")
146         c.cb = cb
147 }
148
149 func (c *vppClient) Connect() error {
150         Log.Debugf("Connecting to: %v", c.sockAddr)
151
152         if err := c.connect(c.sockAddr); err != nil {
153                 return err
154         }
155
156         if err := c.open(); err != nil {
157                 return err
158         }
159
160         c.quit = make(chan struct{})
161         c.wg.Add(1)
162         go c.readerLoop()
163
164         return nil
165 }
166
167 func (c *vppClient) connect(sockAddr string) error {
168         addr, err := net.ResolveUnixAddr("unixpacket", sockAddr)
169         if err != nil {
170                 Log.Debugln("ResolveUnixAddr error:", err)
171                 return err
172         }
173
174         conn, err := net.DialUnix("unix", nil, addr)
175         if err != nil {
176                 if strings.Contains(err.Error(), "wrong type for socket") {
177                         conn, err = net.DialUnix("unixpacket", nil, addr)
178                 }
179                 if err != nil {
180                         Log.Debugln("Dial error:", err)
181                         return err
182                 }
183         }
184
185         c.conn = conn
186         c.reader = bufio.NewReader(c.conn)
187         c.writer = bufio.NewWriter(c.conn)
188
189         Log.Debugf("Connected to socket: %v", addr)
190
191         return nil
192 }
193
194 const (
195         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
196         createMsgContext = byte(123)
197         deleteMsgContext = byte(124)
198 )
199
200 func (c *vppClient) open() error {
201         msgCodec := new(codec.MsgCodec)
202
203         req := &memclnt.SockclntCreate{
204                 Name: []byte(ClientName),
205         }
206         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
207         if err != nil {
208                 Log.Debugln("Encode error:", err)
209                 return err
210         }
211         // set non-0 context
212         msg[5] = createMsgContext
213
214         if err := c.write(msg); err != nil {
215                 Log.Debugln("Write error: ", err)
216                 return err
217         }
218
219         readDeadline := time.Now().Add(c.connectTimeout)
220         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
221                 return err
222         }
223         msgReply, err := c.read()
224         if err != nil {
225                 Log.Println("Read error:", err)
226                 return err
227         }
228         // reset read deadline
229         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
230                 return err
231         }
232
233         reply := new(memclnt.SockclntCreateReply)
234         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
235                 Log.Println("Decode error:", err)
236                 return err
237         }
238
239         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
240                 reply.Response, reply.Index, reply.Count)
241
242         c.clientIndex = reply.Index
243         c.msgTable = make(map[string]uint16, reply.Count)
244         for _, x := range reply.MessageTable {
245                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
246                 c.msgTable[name] = x.Index
247                 if strings.HasPrefix(name, "sockclnt_delete_") {
248                         c.sockDelMsgId = x.Index
249                 }
250                 if DebugMsgIds {
251                         Log.Debugf(" - %4d: %q", x.Index, name)
252                 }
253         }
254
255         return nil
256 }
257
258 func (c *vppClient) Disconnect() error {
259         if c.conn == nil {
260                 return nil
261         }
262         Log.Debugf("Disconnecting..")
263
264         close(c.quit)
265
266         // force readerLoop to timeout
267         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
268                 return err
269         }
270
271         // wait for readerLoop to return
272         c.wg.Wait()
273
274         if err := c.close(); err != nil {
275                 return err
276         }
277
278         if err := c.conn.Close(); err != nil {
279                 Log.Debugln("Close socket conn failed:", err)
280                 return err
281         }
282
283         return nil
284 }
285
286 func (c *vppClient) close() error {
287         msgCodec := new(codec.MsgCodec)
288
289         req := &memclnt.SockclntDelete{
290                 Index: c.clientIndex,
291         }
292         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
293         if err != nil {
294                 Log.Debugln("Encode error:", err)
295                 return err
296         }
297         // set non-0 context
298         msg[5] = deleteMsgContext
299
300         Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg)
301         if err := c.write(msg); err != nil {
302                 Log.Debugln("Write error: ", err)
303                 return err
304         }
305
306         readDeadline := time.Now().Add(c.disconnectTimeout)
307         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
308                 return err
309         }
310         msgReply, err := c.read()
311         if err != nil {
312                 Log.Debugln("Read error:", err)
313                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
314                         // we accept timeout for reply
315                         return nil
316                 }
317                 return err
318         }
319         // reset read deadline
320         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
321                 return err
322         }
323
324         reply := new(memclnt.SockclntDeleteReply)
325         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
326                 Log.Debugln("Decode error:", err)
327                 return err
328         }
329
330         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
331
332         return nil
333 }
334
335 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
336         msg := msgName + "_" + msgCrc
337         msgID, ok := c.msgTable[msg]
338         if !ok {
339                 return 0, fmt.Errorf("unknown message: %q", msg)
340         }
341         return msgID, nil
342 }
343
344 type reqHeader struct {
345         // MsgID uint16
346         ClientIndex uint32
347         Context     uint32
348 }
349
350 func (c *vppClient) SendMsg(context uint32, data []byte) error {
351         h := &reqHeader{
352                 ClientIndex: c.clientIndex,
353                 Context:     context,
354         }
355         buf := new(bytes.Buffer)
356         if err := struc.Pack(buf, h); err != nil {
357                 return err
358         }
359         copy(data[2:], buf.Bytes())
360
361         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
362
363         if err := c.write(data); err != nil {
364                 Log.Debugln("write error: ", err)
365                 return err
366         }
367
368         return nil
369 }
370
371 func (c *vppClient) write(msg []byte) error {
372         h := &msgheader{
373                 DataLen: uint32(len(msg)),
374         }
375         buf := new(bytes.Buffer)
376         if err := struc.Pack(buf, h); err != nil {
377                 return err
378         }
379         header := buf.Bytes()
380
381         // we lock to prevent mixing multiple message sends
382         c.writeMu.Lock()
383         defer c.writeMu.Unlock()
384
385         if n, err := c.writer.Write(header); err != nil {
386                 return err
387         } else {
388                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
389         }
390
391         if err := c.writer.Flush(); err != nil {
392                 return err
393         }
394
395         for i := 0; i <= len(msg)/c.writer.Size(); i++ {
396                 x := i*c.writer.Size() + c.writer.Size()
397                 if x > len(msg) {
398                         x = len(msg)
399                 }
400                 Log.Debugf("x=%v i=%v len=%v mod=%v\n", x, i, len(msg), len(msg)/c.writer.Size())
401                 if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
402                         return err
403                 } else {
404                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
405                 }
406                 if err := c.writer.Flush(); err != nil {
407                         return err
408                 }
409
410         }
411
412         return nil
413 }
414
415 type msgHeader struct {
416         MsgID   uint16
417         Context uint32
418 }
419
420 func (c *vppClient) readerLoop() {
421         defer c.wg.Done()
422         for {
423                 select {
424                 case <-c.quit:
425                         Log.Debugf("reader quit")
426                         return
427                 default:
428                 }
429
430                 msg, err := c.read()
431                 if err != nil {
432                         if isClosedError(err) {
433                                 return
434                         }
435                         Log.Debugf("READ FAILED: %v", err)
436                         continue
437                 }
438                 h := new(msgHeader)
439                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
440                         Log.Debugf("unpacking header failed: %v", err)
441                         continue
442                 }
443
444                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
445                 c.cb(h.MsgID, msg)
446         }
447 }
448
449 type msgheader struct {
450         Q               int    `struc:"uint64"`
451         DataLen         uint32 `struc:"uint32"`
452         GcMarkTimestamp uint32 `struc:"uint32"`
453 }
454
455 func (c *vppClient) read() ([]byte, error) {
456         Log.Debug("reading next msg..")
457
458         header := make([]byte, 16)
459
460         n, err := io.ReadAtLeast(c.reader, header, 16)
461         if err != nil {
462                 return nil, err
463         } else if n == 0 {
464                 Log.Debugln("zero bytes header")
465                 return nil, nil
466         }
467         if n != 16 {
468                 Log.Debug("invalid header data (%d): % 0X", n, header[:n])
469                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
470         }
471         Log.Debugf(" - read header %d bytes: % 0X", n, header)
472
473         h := &msgheader{}
474         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
475                 return nil, err
476         }
477         Log.Debugf(" - decoded header: %+v", h)
478
479         msgLen := int(h.DataLen)
480         msg := make([]byte, msgLen)
481
482         n, err = c.reader.Read(msg)
483         if err != nil {
484                 return nil, err
485         }
486         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
487
488         if msgLen > n {
489                 remain := msgLen - n
490                 Log.Debugf("continue read for another %d bytes", remain)
491                 view := msg[n:]
492
493                 for remain > 0 {
494
495                         nbytes, err := c.reader.Read(view)
496                         if err != nil {
497                                 return nil, err
498                         } else if nbytes == 0 {
499                                 return nil, fmt.Errorf("zero nbytes")
500                         }
501
502                         remain -= nbytes
503                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
504
505                         view = view[nbytes:]
506                 }
507         }
508
509         return msg, nil
510 }
511
512 func isClosedError(err error) bool {
513         if err == io.EOF {
514                 return true
515         }
516         return strings.HasSuffix(err.Error(), "use of closed network connection")
517 }