4788fcb5ea5655ea3f54c394a25698a387b01b57
[vpp.git] / extras / gomemif / memif / control_channel.go
1 /*
2  *------------------------------------------------------------------
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *------------------------------------------------------------------
16  */
17
18 package memif
19
20 import (
21         "bytes"
22         "container/list"
23         "encoding/binary"
24         "fmt"
25         "os"
26         "sync"
27         "syscall"
28 )
29
30 const maxEpollEvents = 1
31 const maxControlLen = 256
32
33 const errorFdNotFound = "fd not found"
34
35 // controlMsg represents a message used in communication between memif peers
36 type controlMsg struct {
37         Buffer *bytes.Buffer
38         Fd     int
39 }
40
41 // listener represents a listener functionality of UNIX domain socket
42 type listener struct {
43         socket *Socket
44         event  syscall.EpollEvent
45 }
46
47 // controlChannel represents a communication channel between memif peers
48 // backed by UNIX domain socket
49 type controlChannel struct {
50         listRef     *list.Element
51         socket      *Socket
52         i           *Interface
53         event       syscall.EpollEvent
54         data        [msgSize]byte
55         control     [maxControlLen]byte
56         controlLen  int
57         msgQueue    []controlMsg
58         isConnected bool
59 }
60
61 // Socket represents a UNIX domain socket used for communication
62 // between memif peers
63 type Socket struct {
64         appName       string
65         filename      string
66         listener      *listener
67         interfaceList *list.List
68         ccList        *list.List
69         epfd          int
70         interruptfd   int
71         wakeEvent     syscall.EpollEvent
72         stopPollChan  chan struct{}
73         wg            sync.WaitGroup
74 }
75
76 type interrupt struct {
77         socket *Socket
78         event  syscall.EpollEvent
79 }
80
81 type memifInterrupt struct {
82         connection *Socket
83         qid        uint16
84 }
85
86 // StopPolling stops polling events on the socket
87 func (socket *Socket) StopPolling() error {
88         if socket.stopPollChan != nil {
89                 // stop polling msg
90                 close(socket.stopPollChan)
91                 // wake epoll
92                 buf := make([]byte, 8)
93                 binary.PutUvarint(buf, 1)
94                 n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
95                 if err != nil {
96                         return err
97                 }
98                 if n != 8 {
99                         return fmt.Errorf("Faild to write to eventfd")
100                 }
101                 // wait until polling is stopped
102                 socket.wg.Wait()
103         }
104
105         return nil
106 }
107
108 // StartPolling starts polling and handling events on the socket,
109 // enabling communication between memif peers
110 func (socket *Socket) StartPolling(errChan chan<- error) {
111         socket.stopPollChan = make(chan struct{})
112         socket.wg.Add(1)
113         go func() {
114                 var events [maxEpollEvents]syscall.EpollEvent
115                 defer socket.wg.Done()
116
117                 for {
118                         select {
119                         case <-socket.stopPollChan:
120                                 return
121                         default:
122                                 num, err := syscall.EpollWait(socket.epfd, events[:], -1)
123                                 if err != nil {
124                                         errChan <- fmt.Errorf("EpollWait: ", err)
125                                         return
126                                 }
127
128                                 for ev := 0; ev < num; ev++ {
129                                         if events[0].Fd == socket.wakeEvent.Fd {
130                                                 continue
131                                         }
132                                         err = socket.handleEvent(&events[0])
133                                         if err != nil {
134                                                 errChan <- fmt.Errorf("handleEvent: ", err)
135                                         }
136                                 }
137                         }
138                 }
139         }()
140 }
141
142 // addEvent adds event to epoll instance associated with the socket
143 func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
144         err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
145         if err != nil {
146                 return fmt.Errorf("EpollCtl: %s", err)
147         }
148         return nil
149 }
150
151 // addEvent deletes event to epoll instance associated with the socket
152 func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
153         err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
154         if err != nil {
155                 return fmt.Errorf("EpollCtl: %s", err)
156         }
157         return nil
158 }
159
160 // Delete deletes the socket
161 func (socket *Socket) Delete() (err error) {
162         for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
163                 cc, ok := elt.Value.(*controlChannel)
164                 if ok {
165                         err = cc.close(true, "Socket deleted")
166                         if err != nil {
167                                 return nil
168                         }
169                 }
170         }
171         for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
172                 i, ok := elt.Value.(*Interface)
173                 if ok {
174                         err = i.Delete()
175                         if err != nil {
176                                 return err
177                         }
178                 }
179         }
180
181         if socket.listener != nil {
182                 err = socket.listener.close()
183                 if err != nil {
184                         return err
185                 }
186                 err = os.Remove(socket.filename)
187                 if err != nil {
188                         return nil
189                 }
190         }
191
192         err = socket.delEvent(&socket.wakeEvent)
193         if err != nil {
194                 return fmt.Errorf("Failed to delete event: ", err)
195         }
196
197         syscall.Close(socket.epfd)
198
199         return nil
200 }
201
202 // NewSocket returns a new Socket
203 func NewSocket(appName string, filename string) (socket *Socket, err error) {
204         socket = &Socket{
205                 appName:       appName,
206                 filename:      filename,
207                 interfaceList: list.New(),
208                 ccList:        list.New(),
209         }
210         if socket.filename == "" {
211                 socket.filename = DefaultSocketFilename
212         }
213
214         socket.epfd, _ = syscall.EpollCreate1(0)
215
216         efd, err := eventFd()
217         socket.wakeEvent = syscall.EpollEvent{
218                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
219                 Fd:     int32(efd),
220         }
221         err = socket.addEvent(&socket.wakeEvent)
222         if err != nil {
223                 return nil, fmt.Errorf("Failed to add event: ", err)
224         }
225
226         return socket, nil
227 }
228
229 // handleEvent handles epoll event
230 func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
231         if socket.listener != nil && socket.listener.event.Fd == event.Fd {
232                 return socket.listener.handleEvent(event)
233         }
234         intf := socket.interfaceList.Back().Value.(*Interface)
235         if intf.args.InterruptFunc != nil {
236                 if int(event.Fd) == int(intf.args.InterruptFd) {
237                         b := make([]byte, 8)
238                         syscall.Read(int(event.Fd), b)
239                         intf.onInterrupt(intf)
240                         return nil
241                 }
242         }
243
244         for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
245                 cc, ok := elt.Value.(*controlChannel)
246                 if ok {
247                         if cc.event.Fd == event.Fd {
248                                 return cc.handleEvent(event)
249                         }
250                 }
251         }
252
253         return fmt.Errorf(errorFdNotFound)
254 }
255
256 func (socket *Socket) addInterrupt(fd int) (err error) {
257         l := &interrupt{
258                 // we will need this to look up master interface by id
259                 socket: socket,
260         }
261
262         l.event = syscall.EpollEvent{
263                 Events: syscall.EPOLLIN,
264                 Fd:     int32(fd),
265         }
266         err = socket.addEvent(&l.event)
267         if err != nil {
268                 return fmt.Errorf("Failed to add event: ", err)
269         }
270
271         return nil
272
273 }
274
275 // handleEvent handles epoll event for listener
276 func (l *listener) handleEvent(event *syscall.EpollEvent) error {
277         // hang up
278         if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
279                 err := l.close()
280                 if err != nil {
281                         return fmt.Errorf("Failed to close listener after hang up event: ", err)
282                 }
283                 return fmt.Errorf("Hang up: ", l.socket.filename)
284         }
285
286         // error
287         if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
288                 err := l.close()
289                 if err != nil {
290                         return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
291                 }
292                 return fmt.Errorf("Received error event on listener ", l.socket.filename)
293         }
294
295         // read message
296         if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
297                 newFd, _, err := syscall.Accept(int(l.event.Fd))
298                 if err != nil {
299                         return fmt.Errorf("Accept: %s", err)
300                 }
301
302                 cc, err := l.socket.addControlChannel(newFd, nil)
303                 if err != nil {
304                         return fmt.Errorf("Failed to add control channel: %s", err)
305                 }
306
307                 err = cc.msgEnqHello()
308                 if err != nil {
309                         return fmt.Errorf("msgEnqHello: %s", err)
310                 }
311
312                 err = cc.sendMsg()
313                 if err != nil {
314                         return err
315                 }
316
317                 return nil
318         }
319
320         return fmt.Errorf("Unexpected event: ", event.Events)
321 }
322
323 // handleEvent handles epoll event for control channel
324 func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
325         var size int
326         var err error
327
328         // hang up
329         if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
330                 // close cc, don't send msg
331                 err := cc.close(false, "")
332                 if err != nil {
333                         return fmt.Errorf("Failed to close control channel after hang up event: ", err)
334                 }
335                 return fmt.Errorf("Hang up: ", cc.i.GetName())
336         }
337
338         if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
339                 // close cc, don't send msg
340                 err := cc.close(false, "")
341                 if err != nil {
342                         return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
343                 }
344                 return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
345         }
346
347         if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
348                 size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
349                 if err != nil {
350                         return fmt.Errorf("recvmsg: %s", err)
351                 }
352                 if size != msgSize {
353                         return fmt.Errorf("invalid message size %d", size)
354                 }
355
356                 err = cc.parseMsg()
357                 if err != nil {
358                         return err
359                 }
360
361                 err = cc.sendMsg()
362                 if err != nil {
363                         return err
364                 }
365
366                 return nil
367         }
368
369         return fmt.Errorf("Unexpected event: ", event.Events)
370 }
371
372 // close closes the listener
373 func (l *listener) close() error {
374         err := l.socket.delEvent(&l.event)
375         if err != nil {
376                 return fmt.Errorf("Failed to del event: ", err)
377         }
378         err = syscall.Close(int(l.event.Fd))
379         if err != nil {
380                 return fmt.Errorf("Failed to close socket: ", err)
381         }
382         return nil
383 }
384
385 // AddListener adds a lisntener to the socket. The fd must describe a
386 // UNIX domain socket already bound to a UNIX domain filename and
387 // marked as listener
388 func (socket *Socket) AddListener(fd int) (err error) {
389         l := &listener{
390                 // we will need this to look up master interface by id
391                 socket: socket,
392         }
393
394         l.event = syscall.EpollEvent{
395                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
396                 Fd:     int32(fd),
397         }
398         err = socket.addEvent(&l.event)
399         if err != nil {
400                 return fmt.Errorf("Failed to add event: ", err)
401         }
402
403         socket.listener = l
404
405         return nil
406 }
407
408 // addListener creates new UNIX domain socket, binds it to the address
409 // and marks it as listener
410 func (socket *Socket) addListener() (err error) {
411         // create socket
412         fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
413         if err != nil {
414                 return fmt.Errorf("Failed to create UNIX domain socket")
415         }
416         usa := &syscall.SockaddrUnix{Name: socket.filename}
417         // Bind to address and start listening
418         err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
419         if err != nil {
420                 return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
421         }
422         err = syscall.Bind(fd, usa)
423         if err != nil {
424                 return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
425         }
426         err = syscall.Listen(fd, syscall.SOMAXCONN)
427         if err != nil {
428                 return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
429         }
430
431         return socket.AddListener(fd)
432 }
433
434 // close closes a control channel, if the control channel is assigned an
435 // interface, the interface is disconnected
436 func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
437         if sendMsg == true {
438                 // first clear message queue so that the disconnect
439                 // message is the only message in queue
440                 cc.msgQueue = []controlMsg{}
441                 cc.msgEnqDisconnect(str)
442
443                 err = cc.sendMsg()
444                 if err != nil {
445                         return err
446                 }
447         }
448
449         err = cc.socket.delEvent(&cc.event)
450         if err != nil {
451                 return fmt.Errorf("Failed to del event: ", err)
452         }
453
454         // remove referance form socket
455         cc.socket.ccList.Remove(cc.listRef)
456
457         if cc.i != nil {
458                 err = cc.i.disconnect()
459                 if err != nil {
460                         return fmt.Errorf("Interface Disconnect: ", err)
461                 }
462         }
463
464         return nil
465 }
466
467 //addControlChannel returns a new controlChannel and adds it to the socket
468 func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
469         cc := &controlChannel{
470                 socket:      socket,
471                 i:           i,
472                 isConnected: false,
473         }
474
475         var err error
476
477         cc.event = syscall.EpollEvent{
478                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
479                 Fd:     int32(fd),
480         }
481         err = socket.addEvent(&cc.event)
482         if err != nil {
483                 return nil, fmt.Errorf("Failed to add event: ", err)
484         }
485
486         cc.listRef = socket.ccList.PushBack(cc)
487
488         return cc, nil
489 }
490
491 func (cc *controlChannel) msgEnqAck() (err error) {
492         buf := new(bytes.Buffer)
493         err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
494
495         msg := controlMsg{
496                 Buffer: buf,
497                 Fd:     -1,
498         }
499
500         cc.msgQueue = append(cc.msgQueue, msg)
501
502         return nil
503 }
504
505 func (cc *controlChannel) msgEnqHello() (err error) {
506         hello := MsgHello{
507                 VersionMin:      Version,
508                 VersionMax:      Version,
509                 MaxRegion:       255,
510                 MaxRingM2S:      255,
511                 MaxRingS2M:      255,
512                 MaxLog2RingSize: 14,
513         }
514
515         copy(hello.Name[:], []byte(cc.socket.appName))
516
517         buf := new(bytes.Buffer)
518         err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
519         err = binary.Write(buf, binary.LittleEndian, hello)
520
521         msg := controlMsg{
522                 Buffer: buf,
523                 Fd:     -1,
524         }
525
526         cc.msgQueue = append(cc.msgQueue, msg)
527
528         return nil
529 }
530
531 func (cc *controlChannel) parseHello() (err error) {
532         var hello MsgHello
533
534         buf := bytes.NewReader(cc.data[msgTypeSize:])
535         err = binary.Read(buf, binary.LittleEndian, &hello)
536         if err != nil {
537                 return
538         }
539
540         if hello.VersionMin > Version || hello.VersionMax < Version {
541                 return fmt.Errorf("Incompatible memif version")
542         }
543
544         cc.i.run = cc.i.args.MemoryConfig
545
546         cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
547         cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
548         cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
549
550         cc.i.remoteName = string(hello.Name[:])
551
552         return nil
553 }
554
555 func (cc *controlChannel) msgEnqInit() (err error) {
556         init := MsgInit{
557                 Version: Version,
558                 Id:      cc.i.args.Id,
559                 Mode:    cc.i.args.Mode,
560         }
561
562         copy(init.Name[:], []byte(cc.socket.appName))
563
564         buf := new(bytes.Buffer)
565         err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
566         err = binary.Write(buf, binary.LittleEndian, init)
567
568         msg := controlMsg{
569                 Buffer: buf,
570                 Fd:     -1,
571         }
572
573         cc.msgQueue = append(cc.msgQueue, msg)
574
575         return nil
576 }
577
578 func (cc *controlChannel) parseInit() (err error) {
579         var init MsgInit
580
581         buf := bytes.NewReader(cc.data[msgTypeSize:])
582         err = binary.Read(buf, binary.LittleEndian, &init)
583         if err != nil {
584                 return
585         }
586
587         if init.Version != Version {
588                 return fmt.Errorf("Incompatible memif driver version")
589         }
590
591         // find peer interface
592         for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
593                 i, ok := elt.Value.(*Interface)
594                 if ok {
595                         if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
596                                 // verify secret
597                                 if i.args.Secret != init.Secret {
598                                         return fmt.Errorf("Invalid secret")
599                                 }
600                                 // interface is assigned to control channel
601                                 i.cc = cc
602                                 cc.i = i
603                                 cc.i.run = cc.i.args.MemoryConfig
604                                 cc.i.remoteName = string(init.Name[:])
605
606                                 return nil
607                         }
608                 }
609         }
610
611         return fmt.Errorf("Invalid interface id")
612 }
613
614 func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
615         if len(cc.i.regions) <= int(regionIndex) {
616                 return fmt.Errorf("Invalid region index")
617         }
618
619         addRegion := MsgAddRegion{
620                 Index: regionIndex,
621                 Size:  cc.i.regions[regionIndex].size,
622         }
623
624         buf := new(bytes.Buffer)
625         err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
626         err = binary.Write(buf, binary.LittleEndian, addRegion)
627
628         msg := controlMsg{
629                 Buffer: buf,
630                 Fd:     cc.i.regions[regionIndex].fd,
631         }
632
633         cc.msgQueue = append(cc.msgQueue, msg)
634
635         return nil
636 }
637
638 func (cc *controlChannel) parseAddRegion() (err error) {
639         var addRegion MsgAddRegion
640
641         buf := bytes.NewReader(cc.data[msgTypeSize:])
642         err = binary.Read(buf, binary.LittleEndian, &addRegion)
643         if err != nil {
644                 return
645         }
646
647         fd, err := cc.parseControlMsg()
648         if err != nil {
649                 return fmt.Errorf("parseControlMsg: %s", err)
650         }
651
652         if addRegion.Index > 255 {
653                 return fmt.Errorf("Invalid memory region index")
654         }
655
656         region := memoryRegion{
657                 size: addRegion.Size,
658                 fd:   fd,
659         }
660
661         cc.i.regions = append(cc.i.regions, region)
662
663         return nil
664 }
665
666 func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
667         var q Queue
668         var flags uint16 = 0
669
670         if ringType == ringTypeS2M {
671                 q = cc.i.txQueues[ringIndex]
672                 flags = msgAddRingFlagS2M
673         } else {
674                 q = cc.i.rxQueues[ringIndex]
675         }
676
677         addRing := MsgAddRing{
678                 Index:          ringIndex,
679                 Offset:         uint32(q.ring.offset),
680                 Region:         uint16(q.ring.region),
681                 RingSizeLog2:   uint8(q.ring.log2Size),
682                 Flags:          flags,
683                 PrivateHdrSize: 0,
684         }
685
686         buf := new(bytes.Buffer)
687         err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
688         err = binary.Write(buf, binary.LittleEndian, addRing)
689
690         msg := controlMsg{
691                 Buffer: buf,
692                 Fd:     q.interruptFd,
693         }
694
695         cc.msgQueue = append(cc.msgQueue, msg)
696
697         return nil
698 }
699
700 func (cc *controlChannel) parseAddRing() (err error) {
701         var addRing MsgAddRing
702
703         buf := bytes.NewReader(cc.data[msgTypeSize:])
704         err = binary.Read(buf, binary.LittleEndian, &addRing)
705         if err != nil {
706                 return
707         }
708
709         fd, err := cc.parseControlMsg()
710         if err != nil {
711                 return err
712         }
713
714         if addRing.Index >= cc.i.run.NumQueuePairs {
715                 return fmt.Errorf("invalid ring index")
716         }
717
718         q := Queue{
719                 i:           cc.i,
720                 interruptFd: fd,
721         }
722
723         if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
724                 q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
725                 cc.i.rxQueues = append(cc.i.rxQueues, q)
726         } else {
727                 q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
728                 cc.i.txQueues = append(cc.i.txQueues, q)
729         }
730
731         return nil
732 }
733
734 func (cc *controlChannel) msgEnqConnect() (err error) {
735         var connect MsgConnect
736         copy(connect.Name[:], []byte(cc.i.args.Name))
737
738         buf := new(bytes.Buffer)
739         err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
740         err = binary.Write(buf, binary.LittleEndian, connect)
741
742         msg := controlMsg{
743                 Buffer: buf,
744                 Fd:     -1,
745         }
746
747         cc.msgQueue = append(cc.msgQueue, msg)
748
749         return nil
750 }
751
752 func (cc *controlChannel) parseConnect() (err error) {
753         var connect MsgConnect
754
755         buf := bytes.NewReader(cc.data[msgTypeSize:])
756         err = binary.Read(buf, binary.LittleEndian, &connect)
757         if err != nil {
758                 return
759         }
760
761         cc.i.peerName = string(connect.Name[:])
762
763         err = cc.i.connect()
764         if err != nil {
765                 return err
766         }
767         cc.isConnected = true
768
769         return nil
770 }
771
772 func (cc *controlChannel) msgEnqConnected() (err error) {
773         var connected MsgConnected
774         copy(connected.Name[:], []byte(cc.i.args.Name))
775
776         buf := new(bytes.Buffer)
777         err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
778         err = binary.Write(buf, binary.LittleEndian, connected)
779
780         msg := controlMsg{
781                 Buffer: buf,
782                 Fd:     -1,
783         }
784
785         cc.msgQueue = append(cc.msgQueue, msg)
786
787         return nil
788 }
789
790 func (cc *controlChannel) parseConnected() (err error) {
791         var conn MsgConnected
792
793         buf := bytes.NewReader(cc.data[msgTypeSize:])
794         err = binary.Read(buf, binary.LittleEndian, &conn)
795         if err != nil {
796                 return
797         }
798
799         cc.i.peerName = string(conn.Name[:])
800
801         err = cc.i.connect()
802         if err != nil {
803                 return err
804         }
805         cc.isConnected = true
806
807         return nil
808 }
809
810 func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
811         dc := MsgDisconnect{
812                 // not implemented
813                 Code: 0,
814         }
815         copy(dc.String[:], str)
816
817         buf := new(bytes.Buffer)
818         err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
819         err = binary.Write(buf, binary.LittleEndian, dc)
820
821         msg := controlMsg{
822                 Buffer: buf,
823                 Fd:     -1,
824         }
825
826         cc.msgQueue = append(cc.msgQueue, msg)
827
828         return nil
829 }
830
831 func (cc *controlChannel) parseDisconnect() (err error) {
832         var dc MsgDisconnect
833
834         buf := bytes.NewReader(cc.data[msgTypeSize:])
835         err = binary.Read(buf, binary.LittleEndian, &dc)
836         if err != nil {
837                 return
838         }
839
840         err = cc.close(false, string(dc.String[:]))
841         if err != nil {
842                 return fmt.Errorf("Failed to disconnect control channel: ", err)
843         }
844
845         return nil
846 }
847
848 func (cc *controlChannel) parseMsg() error {
849         var msgType msgType
850         var err error
851
852         buf := bytes.NewReader(cc.data[:])
853         err = binary.Read(buf, binary.LittleEndian, &msgType)
854
855         if msgType == msgTypeAck {
856                 return nil
857         } else if msgType == msgTypeHello {
858                 // Configure
859                 err = cc.parseHello()
860                 if err != nil {
861                         goto error
862                 }
863                 // Initialize slave memif
864                 err = cc.i.initializeRegions()
865                 if err != nil {
866                         goto error
867                 }
868                 err = cc.i.initializeQueues()
869                 if err != nil {
870                         goto error
871                 }
872                 // Enqueue messages
873                 err = cc.msgEnqInit()
874                 if err != nil {
875                         goto error
876                 }
877                 for i := 0; i < len(cc.i.regions); i++ {
878                         err = cc.msgEnqAddRegion(uint16(i))
879                         if err != nil {
880                                 goto error
881                         }
882                 }
883                 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
884                         err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
885                         if err != nil {
886                                 goto error
887                         }
888                 }
889                 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
890                         err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
891                         if err != nil {
892                                 goto error
893                         }
894                 }
895                 err = cc.msgEnqConnect()
896                 if err != nil {
897                         goto error
898                 }
899         } else if msgType == msgTypeInit {
900                 err = cc.parseInit()
901                 if err != nil {
902                         goto error
903                 }
904
905                 err = cc.msgEnqAck()
906                 if err != nil {
907                         goto error
908                 }
909         } else if msgType == msgTypeAddRegion {
910                 err = cc.parseAddRegion()
911                 if err != nil {
912                         goto error
913                 }
914
915                 err = cc.msgEnqAck()
916                 if err != nil {
917                         goto error
918                 }
919         } else if msgType == msgTypeAddRing {
920                 err = cc.parseAddRing()
921                 if err != nil {
922                         goto error
923                 }
924
925                 err = cc.msgEnqAck()
926                 if err != nil {
927                         goto error
928                 }
929         } else if msgType == msgTypeConnect {
930                 err = cc.parseConnect()
931                 if err != nil {
932                         goto error
933                 }
934
935                 err = cc.msgEnqConnected()
936                 if err != nil {
937                         goto error
938                 }
939         } else if msgType == msgTypeConnected {
940                 err = cc.parseConnected()
941                 if err != nil {
942                         goto error
943                 }
944         } else if msgType == msgTypeDisconnect {
945                 err = cc.parseDisconnect()
946                 if err != nil {
947                         goto error
948                 }
949         } else {
950                 err = fmt.Errorf("unknown message %d", msgType)
951                 goto error
952         }
953
954         return nil
955
956 error:
957         err1 := cc.close(true, err.Error())
958         if err1 != nil {
959                 return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
960         }
961
962         return err
963 }
964
965 // parseControlMsg parses control message and returns file descriptor
966 // if any
967 func (cc *controlChannel) parseControlMsg() (fd int, err error) {
968         // Assert only called when we require FD
969         fd = -1
970
971         controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
972         if err != nil {
973                 return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
974         }
975
976         if len(controlMsgs) == 0 {
977                 return -1, fmt.Errorf("Missing control message")
978         }
979
980         for _, cmsg := range controlMsgs {
981                 if cmsg.Header.Level == syscall.SOL_SOCKET {
982                         if cmsg.Header.Type == syscall.SCM_RIGHTS {
983                                 FDs, err := syscall.ParseUnixRights(&cmsg)
984                                 if err != nil {
985                                         return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
986                                 }
987                                 if len(FDs) == 0 {
988                                         continue
989                                 }
990                                 // Only expect single FD
991                                 fd = FDs[0]
992                         }
993                 }
994         }
995
996         if fd == -1 {
997                 return -1, fmt.Errorf("Missing file descriptor")
998         }
999
1000         return fd, nil
1001 }