Fix error counters for VPP 19.04
[govpp.git] / adapter / socketclient / socketclient.go
1 // Copyright (c) 2019 Cisco and/or its affiliates.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at:
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 package socketclient
16
17 import (
18         "bufio"
19         "bytes"
20         "fmt"
21         "io"
22         "net"
23         "os"
24         "path/filepath"
25         "strings"
26         "sync"
27         "time"
28
29         "github.com/fsnotify/fsnotify"
30         "github.com/lunixbochs/struc"
31         logger "github.com/sirupsen/logrus"
32
33         "git.fd.io/govpp.git/adapter"
34         "git.fd.io/govpp.git/codec"
35         "git.fd.io/govpp.git/examples/binapi/memclnt"
36 )
37
38 const (
39         // DefaultSocketName is default VPP API socket file path.
40         DefaultSocketName = adapter.DefaultBinapiSocket
41 )
42
43 var (
44         // DefaultConnectTimeout is default timeout for connecting
45         DefaultConnectTimeout = time.Second * 3
46         // DefaultDisconnectTimeout is default timeout for discconnecting
47         DefaultDisconnectTimeout = time.Millisecond * 100
48         // MaxWaitReady defines maximum duration before waiting for socket file
49         // times out
50         MaxWaitReady = time.Second * 15
51         // ClientName is used for identifying client in socket registration
52         ClientName = "govppsock"
53 )
54
55 var (
56         // Debug is global variable that determines debug mode
57         Debug = os.Getenv("DEBUG_GOVPP_SOCK") != ""
58         // DebugMsgIds is global variable that determines debug mode for msg ids
59         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
60
61         // Log is global logger
62         Log = logger.New()
63 )
64
65 // init initializes global logger, which logs debug level messages to stdout.
66 func init() {
67         Log.Out = os.Stdout
68         if Debug {
69                 Log.Level = logger.DebugLevel
70                 Log.Debug("govpp/socketclient: enabled debug mode")
71         }
72 }
73
74 type vppClient struct {
75         sockAddr string
76         conn     *net.UnixConn
77         reader   *bufio.Reader
78         writer   *bufio.Writer
79
80         connectTimeout    time.Duration
81         disconnectTimeout time.Duration
82
83         cb           adapter.MsgCallback
84         clientIndex  uint32
85         msgTable     map[string]uint16
86         sockDelMsgId uint16
87         writeMu      sync.Mutex
88
89         quit chan struct{}
90         wg   sync.WaitGroup
91 }
92
93 func NewVppClient(sockAddr string) *vppClient {
94         if sockAddr == "" {
95                 sockAddr = DefaultSocketName
96         }
97         return &vppClient{
98                 sockAddr:          sockAddr,
99                 connectTimeout:    DefaultConnectTimeout,
100                 disconnectTimeout: DefaultDisconnectTimeout,
101                 cb: func(msgID uint16, data []byte) {
102                         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
103                 },
104         }
105 }
106
107 // SetConnectTimeout sets timeout used during connecting.
108 func (c *vppClient) SetConnectTimeout(t time.Duration) {
109         c.connectTimeout = t
110 }
111
112 // SetDisconnectTimeout sets timeout used during disconnecting.
113 func (c *vppClient) SetDisconnectTimeout(t time.Duration) {
114         c.disconnectTimeout = t
115 }
116
117 // WaitReady checks socket file existence and waits for it if necessary
118 func (c *vppClient) WaitReady() error {
119         // check if file at the path already exists
120         if _, err := os.Stat(c.sockAddr); err == nil {
121                 return nil
122         } else if os.IsExist(err) {
123                 return err
124         }
125
126         // if not, watch for it
127         watcher, err := fsnotify.NewWatcher()
128         if err != nil {
129                 return err
130         }
131         defer func() {
132                 if err := watcher.Close(); err != nil {
133                         Log.Errorf("failed to close file watcher: %v", err)
134                 }
135         }()
136
137         // start watching directory
138         if err := watcher.Add(filepath.Dir(c.sockAddr)); err != nil {
139                 return err
140         }
141
142         for {
143                 select {
144                 case <-time.After(MaxWaitReady):
145                         return fmt.Errorf("waiting for socket file timed out (%s)", MaxWaitReady)
146                 case e := <-watcher.Errors:
147                         return e
148                 case ev := <-watcher.Events:
149                         Log.Debugf("watcher event: %+v", ev)
150                         if ev.Name == c.sockAddr {
151                                 if (ev.Op & fsnotify.Create) == fsnotify.Create {
152                                         // socket was created, we are ready
153                                         return nil
154                                 }
155                         }
156                 }
157         }
158 }
159
160 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
161         Log.Debug("SetMsgCallback")
162         c.cb = cb
163 }
164
165 func (c *vppClient) Connect() error {
166         Log.Debugf("Connecting to: %v", c.sockAddr)
167
168         if err := c.connect(c.sockAddr); err != nil {
169                 return err
170         }
171
172         if err := c.open(); err != nil {
173                 return err
174         }
175
176         c.quit = make(chan struct{})
177         c.wg.Add(1)
178         go c.readerLoop()
179
180         return nil
181 }
182
183 func (c *vppClient) connect(sockAddr string) error {
184         addr := &net.UnixAddr{Name: sockAddr, Net: "unix"}
185
186         conn, err := net.DialUnix("unix", nil, addr)
187         if err != nil {
188                 // we try different type of socket for backwards compatbility with VPP<=19.04
189                 if strings.Contains(err.Error(), "wrong type for socket") {
190                         addr.Net = "unixpacket"
191                         Log.Debugf("%s, retrying connect with type unixpacket", err)
192                         conn, err = net.DialUnix("unixpacket", nil, addr)
193                 }
194                 if err != nil {
195                         Log.Debugf("Connecting to socket %s failed: %s", addr, err)
196                         return err
197                 }
198         }
199
200         c.conn = conn
201         c.reader = bufio.NewReader(c.conn)
202         c.writer = bufio.NewWriter(c.conn)
203
204         Log.Debugf("Connected to socket: %v", addr)
205
206         return nil
207 }
208
209 const (
210         sockCreateMsgId  = 15 // hard-coded sockclnt_create message ID
211         createMsgContext = byte(123)
212         deleteMsgContext = byte(124)
213 )
214
215 func (c *vppClient) open() error {
216         msgCodec := new(codec.MsgCodec)
217
218         req := &memclnt.SockclntCreate{
219                 Name: []byte(ClientName),
220         }
221         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
222         if err != nil {
223                 Log.Debugln("Encode error:", err)
224                 return err
225         }
226         // set non-0 context
227         msg[5] = createMsgContext
228
229         if err := c.write(msg); err != nil {
230                 Log.Debugln("Write error: ", err)
231                 return err
232         }
233
234         readDeadline := time.Now().Add(c.connectTimeout)
235         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
236                 return err
237         }
238         msgReply, err := c.read()
239         if err != nil {
240                 Log.Println("Read error:", err)
241                 return err
242         }
243         // reset read deadline
244         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
245                 return err
246         }
247
248         reply := new(memclnt.SockclntCreateReply)
249         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
250                 Log.Println("Decode error:", err)
251                 return err
252         }
253
254         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
255                 reply.Response, reply.Index, reply.Count)
256
257         c.clientIndex = reply.Index
258         c.msgTable = make(map[string]uint16, reply.Count)
259         for _, x := range reply.MessageTable {
260                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
261                 c.msgTable[name] = x.Index
262                 if strings.HasPrefix(name, "sockclnt_delete_") {
263                         c.sockDelMsgId = x.Index
264                 }
265                 if DebugMsgIds {
266                         Log.Debugf(" - %4d: %q", x.Index, name)
267                 }
268         }
269
270         return nil
271 }
272
273 func (c *vppClient) Disconnect() error {
274         if c.conn == nil {
275                 return nil
276         }
277         Log.Debugf("Disconnecting..")
278
279         close(c.quit)
280
281         // force readerLoop to timeout
282         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
283                 return err
284         }
285
286         // wait for readerLoop to return
287         c.wg.Wait()
288
289         if err := c.close(); err != nil {
290                 return err
291         }
292
293         if err := c.conn.Close(); err != nil {
294                 Log.Debugln("Closing socket failed:", err)
295                 return err
296         }
297
298         return nil
299 }
300
301 func (c *vppClient) close() error {
302         msgCodec := new(codec.MsgCodec)
303
304         req := &memclnt.SockclntDelete{
305                 Index: c.clientIndex,
306         }
307         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
308         if err != nil {
309                 Log.Debugln("Encode error:", err)
310                 return err
311         }
312         // set non-0 context
313         msg[5] = deleteMsgContext
314
315         Log.Debugf("sending socklntDel (%d byes): % 0X", len(msg), msg)
316         if err := c.write(msg); err != nil {
317                 Log.Debugln("Write error: ", err)
318                 return err
319         }
320
321         readDeadline := time.Now().Add(c.disconnectTimeout)
322         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
323                 return err
324         }
325         msgReply, err := c.read()
326         if err != nil {
327                 Log.Debugln("Read error:", err)
328                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
329                         // we accept timeout for reply
330                         return nil
331                 }
332                 return err
333         }
334         // reset read deadline
335         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
336                 return err
337         }
338
339         reply := new(memclnt.SockclntDeleteReply)
340         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
341                 Log.Debugln("Decode error:", err)
342                 return err
343         }
344
345         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
346
347         return nil
348 }
349
350 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
351         msg := msgName + "_" + msgCrc
352         msgID, ok := c.msgTable[msg]
353         if !ok {
354                 return 0, fmt.Errorf("unknown message: %q", msg)
355         }
356         return msgID, nil
357 }
358
359 type reqHeader struct {
360         // MsgID uint16
361         ClientIndex uint32
362         Context     uint32
363 }
364
365 func (c *vppClient) SendMsg(context uint32, data []byte) error {
366         h := &reqHeader{
367                 ClientIndex: c.clientIndex,
368                 Context:     context,
369         }
370         buf := new(bytes.Buffer)
371         if err := struc.Pack(buf, h); err != nil {
372                 return err
373         }
374         copy(data[2:], buf.Bytes())
375
376         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
377
378         if err := c.write(data); err != nil {
379                 Log.Debugln("write error: ", err)
380                 return err
381         }
382
383         return nil
384 }
385
386 func (c *vppClient) write(msg []byte) error {
387         h := &msgheader{
388                 DataLen: uint32(len(msg)),
389         }
390         buf := new(bytes.Buffer)
391         if err := struc.Pack(buf, h); err != nil {
392                 return err
393         }
394         header := buf.Bytes()
395
396         // we lock to prevent mixing multiple message sends
397         c.writeMu.Lock()
398         defer c.writeMu.Unlock()
399
400         if n, err := c.writer.Write(header); err != nil {
401                 return err
402         } else {
403                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
404         }
405
406         if err := c.writer.Flush(); err != nil {
407                 return err
408         }
409
410         for i := 0; i <= len(msg)/c.writer.Size(); i++ {
411                 x := i*c.writer.Size() + c.writer.Size()
412                 if x > len(msg) {
413                         x = len(msg)
414                 }
415                 Log.Debugf("x=%v i=%v len=%v mod=%v", x, i, len(msg), len(msg)/c.writer.Size())
416                 if n, err := c.writer.Write(msg[i*c.writer.Size() : x]); err != nil {
417                         return err
418                 } else {
419                         Log.Debugf(" - msg sent x=%d (%d/%d): % 0X", x, n, len(msg), msg)
420                 }
421                 if err := c.writer.Flush(); err != nil {
422                         return err
423                 }
424
425         }
426
427         return nil
428 }
429
430 type msgHeader struct {
431         MsgID   uint16
432         Context uint32
433 }
434
435 func (c *vppClient) readerLoop() {
436         defer c.wg.Done()
437         defer Log.Debugf("reader quit")
438         for {
439                 select {
440                 case <-c.quit:
441                         return
442                 default:
443                 }
444
445                 msg, err := c.read()
446                 if err != nil {
447                         if isClosedError(err) {
448                                 return
449                         }
450                         Log.Debugf("read failed: %v", err)
451                         continue
452                 }
453                 h := new(msgHeader)
454                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
455                         Log.Debugf("unpacking header failed: %v", err)
456                         continue
457                 }
458
459                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
460                 c.cb(h.MsgID, msg)
461         }
462 }
463
464 type msgheader struct {
465         Q               int    `struc:"uint64"`
466         DataLen         uint32 `struc:"uint32"`
467         GcMarkTimestamp uint32 `struc:"uint32"`
468 }
469
470 func (c *vppClient) read() ([]byte, error) {
471         Log.Debug("reading next msg..")
472
473         header := make([]byte, 16)
474
475         n, err := io.ReadAtLeast(c.reader, header, 16)
476         if err != nil {
477                 return nil, err
478         } else if n == 0 {
479                 Log.Debugln("zero bytes header")
480                 return nil, nil
481         }
482         if n != 16 {
483                 Log.Debugf("invalid header data (%d): % 0X", n, header[:n])
484                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
485         }
486         Log.Debugf(" - read header %d bytes: % 0X", n, header)
487
488         h := &msgheader{}
489         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
490                 return nil, err
491         }
492         Log.Debugf(" - decoded header: %+v", h)
493
494         msgLen := int(h.DataLen)
495         msg := make([]byte, msgLen)
496
497         n, err = c.reader.Read(msg)
498         if err != nil {
499                 return nil, err
500         }
501         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
502
503         if msgLen > n {
504                 remain := msgLen - n
505                 Log.Debugf("continue read for another %d bytes", remain)
506                 view := msg[n:]
507
508                 for remain > 0 {
509                         nbytes, err := c.reader.Read(view)
510                         if err != nil {
511                                 return nil, err
512                         } else if nbytes == 0 {
513                                 return nil, fmt.Errorf("zero nbytes")
514                         }
515
516                         remain -= nbytes
517                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
518
519                         view = view[nbytes:]
520                 }
521         }
522
523         return msg, nil
524 }
525
526 func isClosedError(err error) bool {
527         if err == io.EOF {
528                 return true
529         }
530         return strings.HasSuffix(err.Error(), "use of closed network connection")
531 }