Fixup build with golang 1.12
[govpp.git] / adapter / socketclient / socketclient.go
1 package socketclient
2
3 import (
4         "bufio"
5         "bytes"
6         "fmt"
7         "io"
8         "net"
9         "os"
10         "strings"
11         "sync"
12         "time"
13
14         "github.com/lunixbochs/struc"
15         logger "github.com/sirupsen/logrus"
16
17         "git.fd.io/govpp.git/adapter"
18         "git.fd.io/govpp.git/codec"
19         "git.fd.io/govpp.git/examples/bin_api/memclnt"
20 )
21
22 const (
23         // DefaultSocketName is default VPP API socket file name
24         DefaultSocketName = "/run/vpp-api.sock"
25
26         sockCreateMsgId = 15          // hard-coded id for sockclnt_create message
27         govppClientName = "govppsock" // client name used for socket registration
28 )
29
30 var (
31         ConnectTimeout    = time.Second * 3
32         DisconnectTimeout = time.Second
33
34         Debug       = os.Getenv("DEBUG_GOVPP_SOCK") != ""
35         DebugMsgIds = os.Getenv("DEBUG_GOVPP_SOCKMSG") != ""
36
37         Log = logger.New() // global logger
38 )
39
40 // init initializes global logger, which logs debug level messages to stdout.
41 func init() {
42         Log.Out = os.Stdout
43         if Debug {
44                 Log.Level = logger.DebugLevel
45         }
46 }
47
48 type vppClient struct {
49         sockAddr     string
50         conn         *net.UnixConn
51         reader       *bufio.Reader
52         cb           adapter.MsgCallback
53         clientIndex  uint32
54         msgTable     map[string]uint16
55         sockDelMsgId uint16
56         writeMu      sync.Mutex
57         quit         chan struct{}
58         wg           sync.WaitGroup
59 }
60
61 func NewVppClient(sockAddr string) *vppClient {
62         if sockAddr == "" {
63                 sockAddr = DefaultSocketName
64         }
65         return &vppClient{
66                 sockAddr: sockAddr,
67                 cb:       nilCallback,
68         }
69 }
70
71 func nilCallback(msgID uint16, data []byte) {
72         Log.Warnf("no callback set, dropping message: ID=%v len=%d", msgID, len(data))
73 }
74
75 func (*vppClient) WaitReady() error {
76         // TODO: add watcher for socket file?
77         return nil
78 }
79
80 func (c *vppClient) SetMsgCallback(cb adapter.MsgCallback) {
81         Log.Debug("SetMsgCallback")
82         c.cb = cb
83 }
84
85 func (c *vppClient) Connect() error {
86         Log.Debugf("Connecting to: %v", c.sockAddr)
87
88         if err := c.connect(c.sockAddr); err != nil {
89                 return err
90         }
91
92         if err := c.open(); err != nil {
93                 return err
94         }
95
96         c.quit = make(chan struct{})
97         c.wg.Add(1)
98         go c.readerLoop()
99
100         return nil
101 }
102
103 func (c *vppClient) connect(sockAddr string) error {
104         addr, err := net.ResolveUnixAddr("unixpacket", sockAddr)
105         if err != nil {
106                 Log.Debugln("ResolveUnixAddr error:", err)
107                 return err
108         }
109
110         conn, err := net.DialUnix("unixpacket", nil, addr)
111         if err != nil {
112                 Log.Debugln("Dial error:", err)
113                 return err
114         }
115
116         c.conn = conn
117         c.reader = bufio.NewReader(c.conn)
118
119         Log.Debugf("Connected to socket: %v", addr)
120
121         return nil
122 }
123
124 func (c *vppClient) open() error {
125         msgCodec := new(codec.MsgCodec)
126
127         req := &memclnt.SockclntCreate{
128                 Name: []byte(govppClientName),
129         }
130         msg, err := msgCodec.EncodeMsg(req, sockCreateMsgId)
131         if err != nil {
132                 Log.Debugln("Encode error:", err)
133                 return err
134         }
135         // set non-0 context
136         msg[5] = 123
137
138         if err := c.write(msg); err != nil {
139                 Log.Debugln("Write error: ", err)
140                 return err
141         }
142
143         readDeadline := time.Now().Add(ConnectTimeout)
144         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
145                 return err
146         }
147         msgReply, err := c.read()
148         if err != nil {
149                 Log.Println("Read error:", err)
150                 return err
151         }
152         // reset read deadline
153         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
154                 return err
155         }
156
157         //log.Printf("Client got (%d): % 0X", len(msgReply), msgReply)
158
159         reply := new(memclnt.SockclntCreateReply)
160         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
161                 Log.Println("Decode error:", err)
162                 return err
163         }
164
165         Log.Debugf("SockclntCreateReply: Response=%v Index=%v Count=%v",
166                 reply.Response, reply.Index, reply.Count)
167
168         c.clientIndex = reply.Index
169         c.msgTable = make(map[string]uint16, reply.Count)
170         for _, x := range reply.MessageTable {
171                 name := string(bytes.TrimSuffix(bytes.Split(x.Name, []byte{0x00})[0], []byte{0x13}))
172                 c.msgTable[name] = x.Index
173                 if strings.HasPrefix(name, "sockclnt_delete_") {
174                         c.sockDelMsgId = x.Index
175                 }
176                 if DebugMsgIds {
177                         Log.Debugf(" - %4d: %q", x.Index, name)
178                 }
179         }
180
181         return nil
182 }
183
184 func (c *vppClient) Disconnect() error {
185         if c.conn == nil {
186                 return nil
187         }
188         Log.Debugf("Disconnecting..")
189
190         close(c.quit)
191
192         // force readerLoop to timeout
193         if err := c.conn.SetReadDeadline(time.Now()); err != nil {
194                 return err
195         }
196
197         // wait for readerLoop to return
198         c.wg.Wait()
199
200         if err := c.close(); err != nil {
201                 return err
202         }
203
204         if err := c.conn.Close(); err != nil {
205                 Log.Debugln("Close socket conn failed:", err)
206                 return err
207         }
208
209         return nil
210 }
211
212 func (c *vppClient) close() error {
213         msgCodec := new(codec.MsgCodec)
214
215         req := &memclnt.SockclntDelete{
216                 Index: c.clientIndex,
217         }
218         msg, err := msgCodec.EncodeMsg(req, c.sockDelMsgId)
219         if err != nil {
220                 Log.Debugln("Encode error:", err)
221                 return err
222         }
223         // set non-0 context
224         msg[5] = 124
225
226         Log.Debugf("sending socklntDel (%d byes): % 0X\n", len(msg), msg)
227         if err := c.write(msg); err != nil {
228                 Log.Debugln("Write error: ", err)
229                 return err
230         }
231
232         readDeadline := time.Now().Add(DisconnectTimeout)
233         if err := c.conn.SetReadDeadline(readDeadline); err != nil {
234                 return err
235         }
236         msgReply, err := c.read()
237         if err != nil {
238                 Log.Debugln("Read error:", err)
239                 if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
240                         // we accept timeout for reply
241                         return nil
242                 }
243                 return err
244         }
245         // reset read deadline
246         if err := c.conn.SetReadDeadline(time.Time{}); err != nil {
247                 return err
248         }
249
250         reply := new(memclnt.SockclntDeleteReply)
251         if err := msgCodec.DecodeMsg(msgReply, reply); err != nil {
252                 Log.Debugln("Decode error:", err)
253                 return err
254         }
255
256         Log.Debugf("SockclntDeleteReply: Response=%v", reply.Response)
257
258         return nil
259 }
260
261 func (c *vppClient) GetMsgID(msgName string, msgCrc string) (uint16, error) {
262         msg := msgName + "_" + msgCrc
263         msgID, ok := c.msgTable[msg]
264         if !ok {
265                 return 0, fmt.Errorf("unknown message: %q", msg)
266         }
267         return msgID, nil
268 }
269
270 type reqHeader struct {
271         //MsgID uint16
272         ClientIndex uint32
273         Context     uint32
274 }
275
276 func (c *vppClient) SendMsg(context uint32, data []byte) error {
277         h := &reqHeader{
278                 ClientIndex: c.clientIndex,
279                 Context:     context,
280         }
281         buf := new(bytes.Buffer)
282         if err := struc.Pack(buf, h); err != nil {
283                 return err
284         }
285         copy(data[2:], buf.Bytes())
286
287         Log.Debugf("SendMsg (%d) context=%v client=%d: data: % 02X", len(data), context, c.clientIndex, data)
288
289         if err := c.write(data); err != nil {
290                 Log.Debugln("write error: ", err)
291                 return err
292         }
293
294         return nil
295 }
296
297 func (c *vppClient) write(msg []byte) error {
298         h := &msgheader{
299                 Data_len: uint32(len(msg)),
300         }
301         buf := new(bytes.Buffer)
302         if err := struc.Pack(buf, h); err != nil {
303                 return err
304         }
305         header := buf.Bytes()
306
307         // we lock to prevent mixing multiple message sends
308         c.writeMu.Lock()
309         defer c.writeMu.Unlock()
310
311         var w io.Writer = c.conn
312
313         if n, err := w.Write(header); err != nil {
314                 return err
315         } else {
316                 Log.Debugf(" - header sent (%d/%d): % 0X", n, len(header), header)
317         }
318         if n, err := w.Write(msg); err != nil {
319                 return err
320         } else {
321                 Log.Debugf(" - msg sent (%d/%d): % 0X", n, len(msg), msg)
322         }
323
324         return nil
325 }
326
327 type msgHeader struct {
328         MsgID   uint16
329         Context uint32
330 }
331
332 func (c *vppClient) readerLoop() {
333         defer c.wg.Done()
334         for {
335                 select {
336                 case <-c.quit:
337                         Log.Debugf("reader quit")
338                         return
339                 default:
340                 }
341
342                 msg, err := c.read()
343                 if err != nil {
344                         if isClosedError(err) {
345                                 return
346                         }
347                         Log.Debugf("READ FAILED: %v", err)
348                         continue
349                 }
350                 h := new(msgHeader)
351                 if err := struc.Unpack(bytes.NewReader(msg), h); err != nil {
352                         Log.Debugf("unpacking header failed: %v", err)
353                         continue
354                 }
355
356                 Log.Debugf("recvMsg (%d) msgID=%d context=%v", len(msg), h.MsgID, h.Context)
357                 c.cb(h.MsgID, msg)
358         }
359 }
360
361 type msgheader struct {
362         Q                 int    `struc:"uint64"`
363         Data_len          uint32 `struc:"uint32"`
364         Gc_mark_timestamp uint32 `struc:"uint32"`
365         //data              [0]uint8
366 }
367
368 func (c *vppClient) read() ([]byte, error) {
369         Log.Debug("reading next msg..")
370
371         header := make([]byte, 16)
372
373         n, err := io.ReadAtLeast(c.reader, header, 16)
374         if err != nil {
375                 return nil, err
376         } else if n == 0 {
377                 Log.Debugln("zero bytes header")
378                 return nil, nil
379         }
380         if n != 16 {
381                 Log.Debug("invalid header data (%d): % 0X", n, header[:n])
382                 return nil, fmt.Errorf("invalid header (expected 16 bytes, got %d)", n)
383         }
384         Log.Debugf(" - read header %d bytes: % 0X", n, header)
385
386         h := &msgheader{}
387         if err := struc.Unpack(bytes.NewReader(header[:]), h); err != nil {
388                 return nil, err
389         }
390         Log.Debugf(" - decoded header: %+v", h)
391
392         msgLen := int(h.Data_len)
393         msg := make([]byte, msgLen)
394
395         n, err = c.reader.Read(msg)
396         if err != nil {
397                 return nil, err
398         }
399         Log.Debugf(" - read msg %d bytes (%d buffered)", n, c.reader.Buffered())
400
401         if msgLen > n {
402                 remain := msgLen - n
403                 Log.Debugf("continue read for another %d bytes", remain)
404                 view := msg[n:]
405
406                 for remain > 0 {
407
408                         nbytes, err := c.reader.Read(view)
409                         if err != nil {
410                                 return nil, err
411                         } else if nbytes == 0 {
412                                 return nil, fmt.Errorf("zero nbytes")
413                         }
414
415                         remain -= nbytes
416                         Log.Debugf("another data received: %d bytes (remain: %d)", nbytes, remain)
417
418                         view = view[nbytes:]
419                 }
420         }
421
422         return msg, nil
423 }
424
425 func isClosedError(err error) bool {
426         if err == io.EOF {
427                 return true
428         }
429         return strings.HasSuffix(err.Error(), "use of closed network connection")
430 }