gomemif: Add mode support
[vpp.git] / extras / gomemif / memif / control_channel.go
1 /*
2  *------------------------------------------------------------------
3  * Copyright (c) 2020 Cisco and/or its affiliates.
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *------------------------------------------------------------------
16  */
17
18 package memif
19
20 import (
21         "bytes"
22         "container/list"
23         "encoding/binary"
24         "fmt"
25         "os"
26         "sync"
27         "syscall"
28 )
29
30 const maxEpollEvents = 1
31 const maxControlLen = 256
32
33 const errorFdNotFound = "fd not found"
34
35 // controlMsg represents a message used in communication between memif peers
36 type controlMsg struct {
37         Buffer *bytes.Buffer
38         Fd     int
39 }
40
41 // listener represents a listener functionality of UNIX domain socket
42 type listener struct {
43         socket *Socket
44         event  syscall.EpollEvent
45 }
46
47 // controlChannel represents a communication channel between memif peers
48 // backed by UNIX domain socket
49 type controlChannel struct {
50         listRef     *list.Element
51         socket      *Socket
52         i           *Interface
53         event       syscall.EpollEvent
54         data        [msgSize]byte
55         control     [maxControlLen]byte
56         controlLen  int
57         msgQueue    []controlMsg
58         isConnected bool
59 }
60
61 // Socket represents a UNIX domain socket used for communication
62 // between memif peers
63 type Socket struct {
64         appName       string
65         filename      string
66         listener      *listener
67         interfaceList *list.List
68         ccList        *list.List
69         epfd          int
70         wakeEvent     syscall.EpollEvent
71         stopPollChan  chan struct{}
72         wg            sync.WaitGroup
73 }
74
75 // StopPolling stops polling events on the socket
76 func (socket *Socket) StopPolling() error {
77         if socket.stopPollChan != nil {
78                 // stop polling msg
79                 close(socket.stopPollChan)
80                 // wake epoll
81                 buf := make([]byte, 8)
82                 binary.PutUvarint(buf, 1)
83                 n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
84                 if err != nil {
85                         return err
86                 }
87                 if n != 8 {
88                         return fmt.Errorf("Faild to write to eventfd")
89                 }
90                 // wait until polling is stopped
91                 socket.wg.Wait()
92         }
93
94         return nil
95 }
96
97 // StartPolling starts polling and handling events on the socket,
98 // enabling communication between memif peers
99 func (socket *Socket) StartPolling(errChan chan<- error) {
100         socket.stopPollChan = make(chan struct{})
101         socket.wg.Add(1)
102         go func() {
103                 var events [maxEpollEvents]syscall.EpollEvent
104                 defer socket.wg.Done()
105
106                 for {
107                         select {
108                         case <-socket.stopPollChan:
109                                 return
110                         default:
111                                 num, err := syscall.EpollWait(socket.epfd, events[:], -1)
112                                 if err != nil {
113                                         errChan <- fmt.Errorf("EpollWait: ", err)
114                                         return
115                                 }
116
117                                 for ev := 0; ev < num; ev++ {
118                                         if events[0].Fd == socket.wakeEvent.Fd {
119                                                 continue
120                                         }
121                                         err = socket.handleEvent(&events[0])
122                                         if err != nil {
123                                                 errChan <- fmt.Errorf("handleEvent: ", err)
124                                         }
125                                 }
126                         }
127                 }
128         }()
129 }
130
131 // addEvent adds event to epoll instance associated with the socket
132 func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
133         err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
134         if err != nil {
135                 return fmt.Errorf("EpollCtl: %s", err)
136         }
137         return nil
138 }
139
140 // addEvent deletes event to epoll instance associated with the socket
141 func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
142         err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
143         if err != nil {
144                 return fmt.Errorf("EpollCtl: %s", err)
145         }
146         return nil
147 }
148
149 // Delete deletes the socket
150 func (socket *Socket) Delete() (err error) {
151         for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
152                 cc, ok := elt.Value.(*controlChannel)
153                 if ok {
154                         err = cc.close(true, "Socket deleted")
155                         if err != nil {
156                                 return nil
157                         }
158                 }
159         }
160         for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
161                 i, ok := elt.Value.(*Interface)
162                 if ok {
163                         err = i.Delete()
164                         if err != nil {
165                                 return err
166                         }
167                 }
168         }
169
170         if socket.listener != nil {
171                 err = socket.listener.close()
172                 if err != nil {
173                         return err
174                 }
175                 err = os.Remove(socket.filename)
176                 if err != nil {
177                         return nil
178                 }
179         }
180
181         err = socket.delEvent(&socket.wakeEvent)
182         if err != nil {
183                 return fmt.Errorf("Failed to delete event: ", err)
184         }
185
186         syscall.Close(socket.epfd)
187
188         return nil
189 }
190
191 // NewSocket returns a new Socket
192 func NewSocket(appName string, filename string) (socket *Socket, err error) {
193         socket = &Socket{
194                 appName:       appName,
195                 filename:      filename,
196                 interfaceList: list.New(),
197                 ccList:        list.New(),
198         }
199         if socket.filename == "" {
200                 socket.filename = DefaultSocketFilename
201         }
202
203         socket.epfd, _ = syscall.EpollCreate1(0)
204
205         efd, err := eventFd()
206         socket.wakeEvent = syscall.EpollEvent{
207                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
208                 Fd:     int32(efd),
209         }
210         err = socket.addEvent(&socket.wakeEvent)
211         if err != nil {
212                 return nil, fmt.Errorf("Failed to add event: ", err)
213         }
214
215         return socket, nil
216 }
217
218 // handleEvent handles epoll event
219 func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
220         if socket.listener != nil && socket.listener.event.Fd == event.Fd {
221                 return socket.listener.handleEvent(event)
222         }
223
224         for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
225                 cc, ok := elt.Value.(*controlChannel)
226                 if ok {
227                         if cc.event.Fd == event.Fd {
228                                 return cc.handleEvent(event)
229                         }
230                 }
231         }
232
233         return fmt.Errorf(errorFdNotFound)
234 }
235
236 // handleEvent handles epoll event for listener
237 func (l *listener) handleEvent(event *syscall.EpollEvent) error {
238         // hang up
239         if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
240                 err := l.close()
241                 if err != nil {
242                         return fmt.Errorf("Failed to close listener after hang up event: ", err)
243                 }
244                 return fmt.Errorf("Hang up: ", l.socket.filename)
245         }
246
247         // error
248         if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
249                 err := l.close()
250                 if err != nil {
251                         return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
252                 }
253                 return fmt.Errorf("Received error event on listener ", l.socket.filename)
254         }
255
256         // read message
257         if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
258                 newFd, _, err := syscall.Accept(int(l.event.Fd))
259                 if err != nil {
260                         return fmt.Errorf("Accept: %s", err)
261                 }
262
263                 cc, err := l.socket.addControlChannel(newFd, nil)
264                 if err != nil {
265                         return fmt.Errorf("Failed to add control channel: %s", err)
266                 }
267
268                 err = cc.msgEnqHello()
269                 if err != nil {
270                         return fmt.Errorf("msgEnqHello: %s", err)
271                 }
272
273                 err = cc.sendMsg()
274                 if err != nil {
275                         return err
276                 }
277
278                 return nil
279         }
280
281         return fmt.Errorf("Unexpected event: ", event.Events)
282 }
283
284 // handleEvent handles epoll event for control channel
285 func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
286         var size int
287         var err error
288
289         // hang up
290         if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
291                 // close cc, don't send msg
292                 err := cc.close(false, "")
293                 if err != nil {
294                         return fmt.Errorf("Failed to close control channel after hang up event: ", err)
295                 }
296                 return fmt.Errorf("Hang up: ", cc.i.GetName())
297         }
298
299         if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
300                 // close cc, don't send msg
301                 err := cc.close(false, "")
302                 if err != nil {
303                         return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
304                 }
305                 return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
306         }
307
308         if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
309                 size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
310                 if err != nil {
311                         return fmt.Errorf("recvmsg: %s", err)
312                 }
313                 if size != msgSize {
314                         return fmt.Errorf("invalid message size %d", size)
315                 }
316
317                 err = cc.parseMsg()
318                 if err != nil {
319                         return err
320                 }
321
322                 err = cc.sendMsg()
323                 if err != nil {
324                         return err
325                 }
326
327                 return nil
328         }
329
330         return fmt.Errorf("Unexpected event: ", event.Events)
331 }
332
333 // close closes the listener
334 func (l *listener) close() error {
335         err := l.socket.delEvent(&l.event)
336         if err != nil {
337                 return fmt.Errorf("Failed to del event: ", err)
338         }
339         err = syscall.Close(int(l.event.Fd))
340         if err != nil {
341                 return fmt.Errorf("Failed to close socket: ", err)
342         }
343         return nil
344 }
345
346 // AddListener adds a lisntener to the socket. The fd must describe a
347 // UNIX domain socket already bound to a UNIX domain filename and
348 // marked as listener
349 func (socket *Socket) AddListener(fd int) (err error) {
350         l := &listener{
351                 // we will need this to look up master interface by id
352                 socket: socket,
353         }
354
355         l.event = syscall.EpollEvent{
356                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
357                 Fd:     int32(fd),
358         }
359         err = socket.addEvent(&l.event)
360         if err != nil {
361                 return fmt.Errorf("Failed to add event: ", err)
362         }
363
364         socket.listener = l
365
366         return nil
367 }
368
369 // addListener creates new UNIX domain socket, binds it to the address
370 // and marks it as listener
371 func (socket *Socket) addListener() (err error) {
372         // create socket
373         fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
374         if err != nil {
375                 return fmt.Errorf("Failed to create UNIX domain socket")
376         }
377         usa := &syscall.SockaddrUnix{Name: socket.filename}
378         // Bind to address and start listening
379         err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
380         if err != nil {
381                 return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
382         }
383         err = syscall.Bind(fd, usa)
384         if err != nil {
385                 return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
386         }
387         err = syscall.Listen(fd, syscall.SOMAXCONN)
388         if err != nil {
389                 return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
390         }
391
392         return socket.AddListener(fd)
393 }
394
395 // close closes a control channel, if the control channel is assigned an
396 // interface, the interface is disconnected
397 func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
398         if sendMsg == true {
399                 // first clear message queue so that the disconnect
400                 // message is the only message in queue
401                 cc.msgQueue = []controlMsg{}
402                 cc.msgEnqDisconnect(str)
403
404                 err = cc.sendMsg()
405                 if err != nil {
406                         return err
407                 }
408         }
409
410         err = cc.socket.delEvent(&cc.event)
411         if err != nil {
412                 return fmt.Errorf("Failed to del event: ", err)
413         }
414
415         // remove referance form socket
416         cc.socket.ccList.Remove(cc.listRef)
417
418         if cc.i != nil {
419                 err = cc.i.disconnect()
420                 if err != nil {
421                         return fmt.Errorf("Interface Disconnect: ", err)
422                 }
423         }
424
425         return nil
426 }
427
428 //addControlChannel returns a new controlChannel and adds it to the socket
429 func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
430         cc := &controlChannel{
431                 socket:      socket,
432                 i:           i,
433                 isConnected: false,
434         }
435
436         var err error
437
438         cc.event = syscall.EpollEvent{
439                 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
440                 Fd:     int32(fd),
441         }
442         err = socket.addEvent(&cc.event)
443         if err != nil {
444                 return nil, fmt.Errorf("Failed to add event: ", err)
445         }
446
447         cc.listRef = socket.ccList.PushBack(cc)
448
449         return cc, nil
450 }
451
452 func (cc *controlChannel) msgEnqAck() (err error) {
453         buf := new(bytes.Buffer)
454         err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
455
456         msg := controlMsg{
457                 Buffer: buf,
458                 Fd:     -1,
459         }
460
461         cc.msgQueue = append(cc.msgQueue, msg)
462
463         return nil
464 }
465
466 func (cc *controlChannel) msgEnqHello() (err error) {
467         hello := MsgHello{
468                 VersionMin:      Version,
469                 VersionMax:      Version,
470                 MaxRegion:       255,
471                 MaxRingM2S:      255,
472                 MaxRingS2M:      255,
473                 MaxLog2RingSize: 14,
474         }
475
476         copy(hello.Name[:], []byte(cc.socket.appName))
477
478         buf := new(bytes.Buffer)
479         err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
480         err = binary.Write(buf, binary.LittleEndian, hello)
481
482         msg := controlMsg{
483                 Buffer: buf,
484                 Fd:     -1,
485         }
486
487         cc.msgQueue = append(cc.msgQueue, msg)
488
489         return nil
490 }
491
492 func (cc *controlChannel) parseHello() (err error) {
493         var hello MsgHello
494
495         buf := bytes.NewReader(cc.data[msgTypeSize:])
496         err = binary.Read(buf, binary.LittleEndian, &hello)
497         if err != nil {
498                 return
499         }
500
501         if hello.VersionMin > Version || hello.VersionMax < Version {
502                 return fmt.Errorf("Incompatible memif version")
503         }
504
505         cc.i.run = cc.i.args.MemoryConfig
506
507         cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
508         cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
509         cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
510
511         cc.i.remoteName = string(hello.Name[:])
512
513         return nil
514 }
515
516 func (cc *controlChannel) msgEnqInit() (err error) {
517         init := MsgInit{
518                 Version: Version,
519                 Id:      cc.i.args.Id,
520                 Mode:    cc.i.args.Mode,
521         }
522
523         copy(init.Name[:], []byte(cc.socket.appName))
524
525         buf := new(bytes.Buffer)
526         err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
527         err = binary.Write(buf, binary.LittleEndian, init)
528
529         msg := controlMsg{
530                 Buffer: buf,
531                 Fd:     -1,
532         }
533
534         cc.msgQueue = append(cc.msgQueue, msg)
535
536         return nil
537 }
538
539 func (cc *controlChannel) parseInit() (err error) {
540         var init MsgInit
541
542         buf := bytes.NewReader(cc.data[msgTypeSize:])
543         err = binary.Read(buf, binary.LittleEndian, &init)
544         if err != nil {
545                 return
546         }
547
548         if init.Version != Version {
549                 return fmt.Errorf("Incompatible memif driver version")
550         }
551
552         // find peer interface
553         for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
554                 i, ok := elt.Value.(*Interface)
555                 if ok {
556                         if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
557                                 // verify secret
558                                 if i.args.Secret != init.Secret {
559                                         return fmt.Errorf("Invalid secret")
560                                 }
561                                 // interface is assigned to control channel
562                                 i.cc = cc
563                                 cc.i = i
564                                 cc.i.run = cc.i.args.MemoryConfig
565                                 cc.i.remoteName = string(init.Name[:])
566
567                                 return nil
568                         }
569                 }
570         }
571
572         return fmt.Errorf("Invalid interface id")
573 }
574
575 func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
576         if len(cc.i.regions) <= int(regionIndex) {
577                 return fmt.Errorf("Invalid region index")
578         }
579
580         addRegion := MsgAddRegion{
581                 Index: regionIndex,
582                 Size:  cc.i.regions[regionIndex].size,
583         }
584
585         buf := new(bytes.Buffer)
586         err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
587         err = binary.Write(buf, binary.LittleEndian, addRegion)
588
589         msg := controlMsg{
590                 Buffer: buf,
591                 Fd:     cc.i.regions[regionIndex].fd,
592         }
593
594         cc.msgQueue = append(cc.msgQueue, msg)
595
596         return nil
597 }
598
599 func (cc *controlChannel) parseAddRegion() (err error) {
600         var addRegion MsgAddRegion
601
602         buf := bytes.NewReader(cc.data[msgTypeSize:])
603         err = binary.Read(buf, binary.LittleEndian, &addRegion)
604         if err != nil {
605                 return
606         }
607
608         fd, err := cc.parseControlMsg()
609         if err != nil {
610                 return fmt.Errorf("parseControlMsg: %s", err)
611         }
612
613         if addRegion.Index > 255 {
614                 return fmt.Errorf("Invalid memory region index")
615         }
616
617         region := memoryRegion{
618                 size: addRegion.Size,
619                 fd:   fd,
620         }
621
622         cc.i.regions = append(cc.i.regions, region)
623
624         return nil
625 }
626
627 func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
628         var q Queue
629         var flags uint16 = 0
630
631         if ringType == ringTypeS2M {
632                 q = cc.i.txQueues[ringIndex]
633                 flags = msgAddRingFlagS2M
634         } else {
635                 q = cc.i.rxQueues[ringIndex]
636         }
637
638         addRing := MsgAddRing{
639                 Index:          ringIndex,
640                 Offset:         uint32(q.ring.offset),
641                 Region:         uint16(q.ring.region),
642                 RingSizeLog2:   uint8(q.ring.log2Size),
643                 Flags:          flags,
644                 PrivateHdrSize: 0,
645         }
646
647         buf := new(bytes.Buffer)
648         err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
649         err = binary.Write(buf, binary.LittleEndian, addRing)
650
651         msg := controlMsg{
652                 Buffer: buf,
653                 Fd:     q.interruptFd,
654         }
655
656         cc.msgQueue = append(cc.msgQueue, msg)
657
658         return nil
659 }
660
661 func (cc *controlChannel) parseAddRing() (err error) {
662         var addRing MsgAddRing
663
664         buf := bytes.NewReader(cc.data[msgTypeSize:])
665         err = binary.Read(buf, binary.LittleEndian, &addRing)
666         if err != nil {
667                 return
668         }
669
670         fd, err := cc.parseControlMsg()
671         if err != nil {
672                 return err
673         }
674
675         if addRing.Index >= cc.i.run.NumQueuePairs {
676                 return fmt.Errorf("invalid ring index")
677         }
678
679         q := Queue{
680                 i:           cc.i,
681                 interruptFd: fd,
682         }
683
684         if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
685                 q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
686                 cc.i.rxQueues = append(cc.i.rxQueues, q)
687         } else {
688                 q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
689                 cc.i.txQueues = append(cc.i.txQueues, q)
690         }
691
692         return nil
693 }
694
695 func (cc *controlChannel) msgEnqConnect() (err error) {
696         var connect MsgConnect
697         copy(connect.Name[:], []byte(cc.i.args.Name))
698
699         buf := new(bytes.Buffer)
700         err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
701         err = binary.Write(buf, binary.LittleEndian, connect)
702
703         msg := controlMsg{
704                 Buffer: buf,
705                 Fd:     -1,
706         }
707
708         cc.msgQueue = append(cc.msgQueue, msg)
709
710         return nil
711 }
712
713 func (cc *controlChannel) parseConnect() (err error) {
714         var connect MsgConnect
715
716         buf := bytes.NewReader(cc.data[msgTypeSize:])
717         err = binary.Read(buf, binary.LittleEndian, &connect)
718         if err != nil {
719                 return
720         }
721
722         cc.i.peerName = string(connect.Name[:])
723
724         err = cc.i.connect()
725         if err != nil {
726                 return err
727         }
728
729         cc.isConnected = true
730
731         return nil
732 }
733
734 func (cc *controlChannel) msgEnqConnected() (err error) {
735         var connected MsgConnected
736         copy(connected.Name[:], []byte(cc.i.args.Name))
737
738         buf := new(bytes.Buffer)
739         err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
740         err = binary.Write(buf, binary.LittleEndian, connected)
741
742         msg := controlMsg{
743                 Buffer: buf,
744                 Fd:     -1,
745         }
746
747         cc.msgQueue = append(cc.msgQueue, msg)
748
749         return nil
750 }
751
752 func (cc *controlChannel) parseConnected() (err error) {
753         var conn MsgConnected
754
755         buf := bytes.NewReader(cc.data[msgTypeSize:])
756         err = binary.Read(buf, binary.LittleEndian, &conn)
757         if err != nil {
758                 return
759         }
760
761         cc.i.peerName = string(conn.Name[:])
762
763         err = cc.i.connect()
764         if err != nil {
765                 return err
766         }
767
768         cc.isConnected = true
769
770         return nil
771 }
772
773 func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
774         dc := MsgDisconnect{
775                 // not implemented
776                 Code: 0,
777         }
778         copy(dc.String[:], str)
779
780         buf := new(bytes.Buffer)
781         err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
782         err = binary.Write(buf, binary.LittleEndian, dc)
783
784         msg := controlMsg{
785                 Buffer: buf,
786                 Fd:     -1,
787         }
788
789         cc.msgQueue = append(cc.msgQueue, msg)
790
791         return nil
792 }
793
794 func (cc *controlChannel) parseDisconnect() (err error) {
795         var dc MsgDisconnect
796
797         buf := bytes.NewReader(cc.data[msgTypeSize:])
798         err = binary.Read(buf, binary.LittleEndian, &dc)
799         if err != nil {
800                 return
801         }
802
803         err = cc.close(false, string(dc.String[:]))
804         if err != nil {
805                 return fmt.Errorf("Failed to disconnect control channel: ", err)
806         }
807
808         return nil
809 }
810
811 func (cc *controlChannel) parseMsg() error {
812         var msgType msgType
813         var err error
814
815         buf := bytes.NewReader(cc.data[:])
816         err = binary.Read(buf, binary.LittleEndian, &msgType)
817
818         if msgType == msgTypeAck {
819                 return nil
820         } else if msgType == msgTypeHello {
821                 // Configure
822                 err = cc.parseHello()
823                 if err != nil {
824                         goto error
825                 }
826                 // Initialize slave memif
827                 err = cc.i.initializeRegions()
828                 if err != nil {
829                         goto error
830                 }
831                 err = cc.i.initializeQueues()
832                 if err != nil {
833                         goto error
834                 }
835                 // Enqueue messages
836                 err = cc.msgEnqInit()
837                 if err != nil {
838                         goto error
839                 }
840                 for i := 0; i < len(cc.i.regions); i++ {
841                         err = cc.msgEnqAddRegion(uint16(i))
842                         if err != nil {
843                                 goto error
844                         }
845                 }
846                 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
847                         err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
848                         if err != nil {
849                                 goto error
850                         }
851                 }
852                 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
853                         err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
854                         if err != nil {
855                                 goto error
856                         }
857                 }
858                 err = cc.msgEnqConnect()
859                 if err != nil {
860                         goto error
861                 }
862         } else if msgType == msgTypeInit {
863                 err = cc.parseInit()
864                 if err != nil {
865                         goto error
866                 }
867
868                 err = cc.msgEnqAck()
869                 if err != nil {
870                         goto error
871                 }
872         } else if msgType == msgTypeAddRegion {
873                 err = cc.parseAddRegion()
874                 if err != nil {
875                         goto error
876                 }
877
878                 err = cc.msgEnqAck()
879                 if err != nil {
880                         goto error
881                 }
882         } else if msgType == msgTypeAddRing {
883                 err = cc.parseAddRing()
884                 if err != nil {
885                         goto error
886                 }
887
888                 err = cc.msgEnqAck()
889                 if err != nil {
890                         goto error
891                 }
892         } else if msgType == msgTypeConnect {
893                 err = cc.parseConnect()
894                 if err != nil {
895                         goto error
896                 }
897
898                 err = cc.msgEnqConnected()
899                 if err != nil {
900                         goto error
901                 }
902         } else if msgType == msgTypeConnected {
903                 err = cc.parseConnected()
904                 if err != nil {
905                         goto error
906                 }
907         } else if msgType == msgTypeDisconnect {
908                 err = cc.parseDisconnect()
909                 if err != nil {
910                         goto error
911                 }
912         } else {
913                 err = fmt.Errorf("unknown message %d", msgType)
914                 goto error
915         }
916
917         return nil
918
919 error:
920         err1 := cc.close(true, err.Error())
921         if err1 != nil {
922                 return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
923         }
924
925         return err
926 }
927
928 // parseControlMsg parses control message and returns file descriptor
929 // if any
930 func (cc *controlChannel) parseControlMsg() (fd int, err error) {
931         // Assert only called when we require FD
932         fd = -1
933
934         controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
935         if err != nil {
936                 return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
937         }
938
939         if len(controlMsgs) == 0 {
940                 return -1, fmt.Errorf("Missing control message")
941         }
942
943         for _, cmsg := range controlMsgs {
944                 if cmsg.Header.Level == syscall.SOL_SOCKET {
945                         if cmsg.Header.Type == syscall.SCM_RIGHTS {
946                                 FDs, err := syscall.ParseUnixRights(&cmsg)
947                                 if err != nil {
948                                         return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
949                                 }
950                                 if len(FDs) == 0 {
951                                         continue
952                                 }
953                                 // Only expect single FD
954                                 fd = FDs[0]
955                         }
956                 }
957         }
958
959         if fd == -1 {
960                 return -1, fmt.Errorf("Missing file descriptor")
961         }
962
963         return fd, nil
964 }