Split outgoing packet data by 4096 bytes
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "io"
8         "net"
9         "os"
10         "path/filepath"
11         "strings"
12         "sync"
13         "time"
14
15         "github.com/fsnotify/fsnotify"
16         "github.com/lunixbochs/struc"
17         logger "github.com/sirupsen/logrus"
18
19         "git.fd.io/govpp.git/adapter"
20         "git.fd.io/govpp.git/codec"
21         "git.fd.io/govpp.git/examples/bin_api/memclnt"
22 )
23
24 const (
25         // DefaultSocketName is default VPP API socket file name
26         DefaultSocketName = "/run/vpp-api.sock"
27 )
28
29 var (
30         // DefaultConnectTimeout is default timeout for connecting
31         DefaultConnectTimeout = time.Second * 3
32         // DefaultDisconnectTimeout is default timeout for discconnecting
33         DefaultDisconnectTimeout = time.Second
34         // MaxWaitReady defines maximum duration before waiting for socket file
35         // times out
36         MaxWaitReady = time.Second * 15
37         // ClientName is used for identifying client in socket registration
38         ClientName = "govppsock"
39 )
40
41 var (
42         // Debug is global variable that determines debug mode
43         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
44         // DebugMsgIds is global variable that determines debug mode for msg ids
45         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
46
47         Log = logger.New() // global logger
48 )
49
50 // init initializes global logger, which logs debug level messages to stdout.
51 func init() {
52         Log.Out = os.Stdout
53         if Debug {
54                 Log.Level = logger.DebugLevel
55         }
56 }
57
58 type vppClient struct {
59         sockAddr string
60         conn     *net.UnixConn
61         reader   *bufio.Reader
62         writer   *bufio.Writer
63
64         connectTimeout    time.Duration
65         disconnectTimeout time.Duration
66
67         cb           adapter.MsgCallback
68         clientIndex  uint32
69         msgTable     map[string]uint16
70         sockDelMsgId uint16
71         writeMu      sync.Mutex
72
73         quit chan struct{}
74         wg   sync.WaitGroup
75 }
76
77 func NewVppClient(sockAddr string) *vppClient {
78         if sockAddr == "" {
79                 sockAddr = DefaultSocketName
80         }
81         return &vppClient{
82                 sockAddr:          sockAddr,
83                 connectTimeout:    DefaultConnectTimeout,
84                 disconnectTimeout: DefaultDisconnectTimeout,
85                 cb: func(msgID uint16, data []byte) {
86                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
87                 },
88         }
89 }
90
91 // SetConnectTimeout sets timeout used during connecting.
92 func (c *vppClient) SetConnectTimeout(t time.Duration) {
93         c.connectTimeout = t
94 }
95
96 // SetDisconnectTimeout sets timeout used during disconnecting.
97 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
98         c.disconnectTimeout = t
99 }
100
101 // WaitReady checks socket file existence and waits for it if necessary
102 func (c *vppClient) WaitReady() error {
103         // check if file at the path already exists
104         if _, err := os.Stat(c.sockAddr); err == nil {
105                 return nil
106         } else if os.IsExist(err) {
107                 return err
108         }
109
110         // if not, watch for it
111         watcher, err := fsnotify.NewWatcher()
112         if err != nil {
113                 return err
114         }
115         defer func() {
116                 if err := watcher.Close(); err != nil {
117                         Log.Errorf("failed to close file watcher: %v", err)
118                 }
119         }()
120
121         // start watching directory
122         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
123                 return err
124         }
125
126         for {
127                 select {
128                 case <-time.After(MaxWaitReady):
129                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
130                 case e := <-watcher.Errors:
131                         return e
132                 case ev := <-watcher.Events:
133                         Log.Debugf("watcher event: %+v", ev)
134                         if ev.Name == c.sockAddr {
135                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
136                                         // socket was created, we are ready
137                                         return nil
138                                 }
139                         }
140                 }
141         }
142 }
143
144 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
145         Log.Debug("SetMsgCallback")
146         c.cb = cb
147 }
148
149 func (c *vppClient) Connect() error {
150         Log.Debugf("Connecting to: %v", c.sockAddr)
151
152         if err := c.connect(c.sockAddr); err != nil {
153                 return err
154         }
155
156         if err := c.open(); err != nil {
157                 return err
158         }
159
160         c.quit = make(chan struct{})
161         c.wg.Add(1)
162         go c.readerLoop()
163
164         return nil
165 }
166
167 func (c *vppClient) connect(sockAddr string) error {
168         addr, err := net.ResolveUnixAddr("unixpacket", sockAddr)
169         if err != nil {
170                 Log.Debugln("ResolveUnixAddr error:", err)
171                 return err
172         }
173
174         conn, err := net.DialUnix("unixpacket", nil, addr)
175         if err != nil {
176                 Log.Debugln("Dial error:", err)
177                 return err
178         }
179
180         c.conn = conn
181         c.reader = bufio.NewReader(c.conn)
182         c.writer = bufio.NewWriter(c.conn)
183
184         Log.Debugf("Connected to socket: %v", addr)
185
186         return nil
187 }
188
189 const (
190         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
191         createMsgContext = byte(123)
192         deleteMsgContext = byte(124)
193 )
194
195 func (c *vppClient) open() error {
196         msgCodec := new(codec.MsgCodec)
197
198         req := &memclnt.SockclntCreate{
199                 Name: []byte(ClientName),
200         }
201         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
202         if err != nil {
203                 Log.Debugln("Encode error:", err)
204                 return err
205         }
206         // set non-0 context
207         msg[5] = createMsgContext
208
209         if err := c.write(msg); err != nil {
210                 Log.Debugln("Write error: ", err)
211                 return err
212         }
213
214         readDeadline := time.Now().Add(c.connectTimeout)
215         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
216                 return err
217         }
218         msgReply, err := c.read()
219         if err != nil {
220                 Log.Println("Read error:", err)
221                 return err
222         }
223         // reset read deadline
224         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
225                 return err
226         }
227
228         reply := new(memclnt.SockclntCreateReply)
229         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
230                 Log.Println("Decode error:", err)
231                 return err
232         }
233
234         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
235                 reply.Response, reply.Index, reply.Count)
236
237         c.clientIndex = reply.Index
238         c.msgTable = make(map[string]uint16, reply.Count)
239         for _, x := range reply.MessageTable {
240                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
241                 c.msgTable[name] = x.Index
242                 if strings.HasPrefix(name, "sockclnt_delete_") {
243                         c.sockDelMsgId = x.Index
244                 }
245                 if DebugMsgIds {
246                         Log.Debugf(" - %4d: %q", x.Index, name)
247                 }
248         }
249
250         return nil
251 }
252
253 func (c *vppClient) Disconnect() error {
254         if c.conn == nil {
255                 return nil
256         }
257         Log.Debugf("Disconnecting..")
258
259         close(c.quit)
260
261         // force readerLoop to timeout
262         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
263                 return err
264         }
265
266         // wait for readerLoop to return
267         c.wg.Wait()
268
269         if err := c.close(); err != nil {
270                 return err
271         }
272
273         if err := c.conn.Close(); err != nil {
274                 Log.Debugln("Close socket conn failed:", err)
275                 return err
276         }
277
278         return nil
279 }
280
281 func (c *vppClient) close() error {
282         msgCodec := new(codec.MsgCodec)
283
284         req := &memclnt.SockclntDelete{
285                 Index: c.clientIndex,
286         }
287         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
288         if err != nil {
289                 Log.Debugln("Encode error:", err)
290                 return err
291         }
292         // set non-0 context
293         msg[5] = deleteMsgContext
294
295         Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg)
296         if err := c.write(msg); err != nil {
297                 Log.Debugln("Write error: ", err)
298                 return err
299         }
300
301         readDeadline := time.Now().Add(c.disconnectTimeout)
302         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
303                 return err
304         }
305         msgReply, err := c.read()
306         if err != nil {
307                 Log.Debugln("Read error:", err)
308                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
309                         // we accept timeout for reply
310                         return nil
311                 }
312                 return err
313         }
314         // reset read deadline
315         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
316                 return err
317         }
318
319         reply := new(memclnt.SockclntDeleteReply)
320         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
321                 Log.Debugln("Decode error:", err)
322                 return err
323         }
324
325         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
326
327         return nil
328 }
329
330 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
331         msg := msgName + "_" + msgCrc
332         msgID, ok := c.msgTable[msg]
333         if !ok {
334                 return 0, fmt.Errorf("unknown message: %q", msg)
335         }
336         return msgID, nil
337 }
338
339 type reqHeader struct {
340         // MsgID uint16
341         ClientIndex uint32
342         Context     uint32
343 }
344
345 func (c *vppClient) SendMsg(context uint32, data []byte) error {
346         h := &reqHeader{
347                 ClientIndex: c.clientIndex,
348                 Context:     context,
349         }
350         buf := new(bytes.Buffer)
351         if err := struc.Pack(buf, h); err != nil {
352                 return err
353         }
354         copy(data[2:], buf.Bytes())
355
356         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
357
358         if err := c.write(data); err != nil {
359                 Log.Debugln("write error: ", err)
360                 return err
361         }
362
363         return nil
364 }
365
366 func (c *vppClient) write(msg []byte) error {
367         h := &msgheader{
368                 DataLen: uint32(len(msg)),
369         }
370         buf := new(bytes.Buffer)
371         if err := struc.Pack(buf, h); err != nil {
372                 return err
373         }
374         header := buf.Bytes()
375
376         // we lock to prevent mixing multiple message sends
377         c.writeMu.Lock()
378         defer c.writeMu.Unlock()
379
380         if n, err := c.writer.Write(header); err != nil {
381                 return err
382         } else {
383                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
384         }
385
386         if err := c.writer.Flush(); err != nil {
387                 return err
388         }
389
390         for i := 0; i <= len(msg)/c.writer.Size(); i++ {
391                 x := i*c.writer.Size() + c.writer.Size()
392                 if x > len(msg) {
393                         x = len(msg)
394                 }
395                 Log.Debugf("x=%v i=%v len=%v mod=%v\n", x, i, len(msg), len(msg)/c.writer.Size())
396                 if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
397                         return err
398                 } else {
399                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
400                 }
401                 if err := c.writer.Flush(); err != nil {
402                         return err
403                 }
404
405         }
406
407         return nil
408 }
409
410 type msgHeader struct {
411         MsgID   uint16
412         Context uint32
413 }
414
415 func (c *vppClient) readerLoop() {
416         defer c.wg.Done()
417         for {
418                 select {
419                 case <-c.quit:
420                         Log.Debugf("reader quit")
421                         return
422                 default:
423                 }
424
425                 msg, err := c.read()
426                 if err != nil {
427                         if isClosedError(err) {
428                                 return
429                         }
430                         Log.Debugf("READ FAILED: %v", err)
431                         continue
432                 }
433                 h := new(msgHeader)
434                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
435                         Log.Debugf("unpacking header failed: %v", err)
436                         continue
437                 }
438
439                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
440                 c.cb(h.MsgID, msg)
441         }
442 }
443
444 type msgheader struct {
445         Q               int    `struc:"uint64"`
446         DataLen         uint32 `struc:"uint32"`
447         GcMarkTimestamp uint32 `struc:"uint32"`
448 }
449
450 func (c *vppClient) read() ([]byte, error) {
451         Log.Debug("reading next msg..")
452
453         header := make([]byte, 16)
454
455         n, err := io.ReadAtLeast(c.reader, header, 16)
456         if err != nil {
457                 return nil, err
458         } else if n == 0 {
459                 Log.Debugln("zero bytes header")
460                 return nil, nil
461         }
462         if n != 16 {
463                 Log.Debug("invalid header data (%d): % 0X", n, header[:n])
464                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
465         }
466         Log.Debugf(" - read header %d bytes: % 0X", n, header)
467
468         h := &msgheader{}
469         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
470                 return nil, err
471         }
472         Log.Debugf(" - decoded header: %+v", h)
473
474         msgLen := int(h.DataLen)
475         msg := make([]byte, msgLen)
476
477         n, err = c.reader.Read(msg)
478         if err != nil {
479                 return nil, err
480         }
481         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
482
483         if msgLen > n {
484                 remain := msgLen - n
485                 Log.Debugf("continue read for another %d bytes", remain)
486                 view := msg[n:]
487
488                 for remain > 0 {
489
490                         nbytes, err := c.reader.Read(view)
491                         if err != nil {
492                                 return nil, err
493                         } else if nbytes == 0 {
494                                 return nil, fmt.Errorf("zero nbytes")
495                         }
496
497                         remain -= nbytes
498                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
499
500                         view = view[nbytes:]
501                 }
502         }
503
504         return msg, nil
505 }
506
507 func isClosedError(err error) bool {
508         if err == io.EOF {
509                 return true
510         }
511         return strings.HasSuffix(err.Error(), "use of closed network connection")
512 }