2 *------------------------------------------------------------------
3 * Copyright (c) 2020 Cisco and/or its affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *------------------------------------------------------------------
30 const maxEpollEvents = 1
31 const maxControlLen = 256
33 const errorFdNotFound = "fd not found"
35 // controlMsg represents a message used in communication between memif peers
36 type controlMsg struct {
41 // listener represents a listener functionality of UNIX domain socket
42 type listener struct {
44 event syscall.EpollEvent
47 // controlChannel represents a communication channel between memif peers
48 // backed by UNIX domain socket
49 type controlChannel struct {
53 event syscall.EpollEvent
55 control [maxControlLen]byte
61 // Socket represents a UNIX domain socket used for communication
62 // between memif peers
67 interfaceList *list.List
71 wakeEvent syscall.EpollEvent
72 stopPollChan chan struct{}
76 type interrupt struct {
78 event syscall.EpollEvent
81 type memifInterrupt struct {
86 // StopPolling stops polling events on the socket
87 func (socket *Socket) StopPolling() error {
88 if socket.stopPollChan != nil {
90 close(socket.stopPollChan)
92 buf := make([]byte, 8)
93 binary.PutUvarint(buf, 1)
94 n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
99 return fmt.Errorf("Faild to write to eventfd")
101 // wait until polling is stopped
108 // StartPolling starts polling and handling events on the socket,
109 // enabling communication between memif peers
110 func (socket *Socket) StartPolling(errChan chan<- error) {
111 socket.stopPollChan = make(chan struct{})
114 var events [maxEpollEvents]syscall.EpollEvent
115 defer socket.wg.Done()
119 case <-socket.stopPollChan:
122 num, err := syscall.EpollWait(socket.epfd, events[:], -1)
124 errChan <- fmt.Errorf("EpollWait: ", err)
128 for ev := 0; ev < num; ev++ {
129 if events[0].Fd == socket.wakeEvent.Fd {
132 err = socket.handleEvent(&events[0])
134 errChan <- fmt.Errorf("handleEvent: ", err)
142 // addEvent adds event to epoll instance associated with the socket
143 func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
144 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
146 return fmt.Errorf("EpollCtl: %s", err)
151 // addEvent deletes event to epoll instance associated with the socket
152 func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
153 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
155 return fmt.Errorf("EpollCtl: %s", err)
160 // Delete deletes the socket
161 func (socket *Socket) Delete() (err error) {
162 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
163 cc, ok := elt.Value.(*controlChannel)
165 err = cc.close(true, "Socket deleted")
171 for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
172 i, ok := elt.Value.(*Interface)
181 if socket.listener != nil {
182 err = socket.listener.close()
186 err = os.Remove(socket.filename)
192 err = socket.delEvent(&socket.wakeEvent)
194 return fmt.Errorf("Failed to delete event: ", err)
197 syscall.Close(socket.epfd)
202 // NewSocket returns a new Socket
203 func NewSocket(appName string, filename string) (socket *Socket, err error) {
207 interfaceList: list.New(),
210 if socket.filename == "" {
211 socket.filename = DefaultSocketFilename
214 socket.epfd, _ = syscall.EpollCreate1(0)
216 efd, err := eventFd()
217 socket.wakeEvent = syscall.EpollEvent{
218 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
221 err = socket.addEvent(&socket.wakeEvent)
223 return nil, fmt.Errorf("Failed to add event: ", err)
229 // handleEvent handles epoll event
230 func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
231 if socket.listener != nil && socket.listener.event.Fd == event.Fd {
232 return socket.listener.handleEvent(event)
234 intf := socket.interfaceList.Back().Value.(*Interface)
235 if intf.args.InterruptFunc != nil {
236 if int(event.Fd) == int(intf.args.InterruptFd) {
238 syscall.Read(int(event.Fd), b)
239 intf.onInterrupt(intf)
244 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
245 cc, ok := elt.Value.(*controlChannel)
247 if cc.event.Fd == event.Fd {
248 return cc.handleEvent(event)
253 return fmt.Errorf(errorFdNotFound)
256 func (socket *Socket) addInterrupt(fd int) (err error) {
258 // we will need this to look up master interface by id
262 l.event = syscall.EpollEvent{
263 Events: syscall.EPOLLIN,
266 err = socket.addEvent(&l.event)
268 return fmt.Errorf("Failed to add event: ", err)
275 // handleEvent handles epoll event for listener
276 func (l *listener) handleEvent(event *syscall.EpollEvent) error {
278 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
281 return fmt.Errorf("Failed to close listener after hang up event: ", err)
283 return fmt.Errorf("Hang up: ", l.socket.filename)
287 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
290 return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
292 return fmt.Errorf("Received error event on listener ", l.socket.filename)
296 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
297 newFd, _, err := syscall.Accept(int(l.event.Fd))
299 return fmt.Errorf("Accept: %s", err)
302 cc, err := l.socket.addControlChannel(newFd, nil)
304 return fmt.Errorf("Failed to add control channel: %s", err)
307 err = cc.msgEnqHello()
309 return fmt.Errorf("msgEnqHello: %s", err)
320 return fmt.Errorf("Unexpected event: ", event.Events)
323 // handleEvent handles epoll event for control channel
324 func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
329 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
330 // close cc, don't send msg
331 err := cc.close(false, "")
333 return fmt.Errorf("Failed to close control channel after hang up event: ", err)
335 return fmt.Errorf("Hang up: ", cc.i.GetName())
338 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
339 // close cc, don't send msg
340 err := cc.close(false, "")
342 return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
344 return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
347 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
348 size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
350 return fmt.Errorf("recvmsg: %s", err)
353 return fmt.Errorf("invalid message size %d", size)
369 return fmt.Errorf("Unexpected event: ", event.Events)
372 // close closes the listener
373 func (l *listener) close() error {
374 err := l.socket.delEvent(&l.event)
376 return fmt.Errorf("Failed to del event: ", err)
378 err = syscall.Close(int(l.event.Fd))
380 return fmt.Errorf("Failed to close socket: ", err)
385 // AddListener adds a lisntener to the socket. The fd must describe a
386 // UNIX domain socket already bound to a UNIX domain filename and
387 // marked as listener
388 func (socket *Socket) AddListener(fd int) (err error) {
390 // we will need this to look up master interface by id
394 l.event = syscall.EpollEvent{
395 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
398 err = socket.addEvent(&l.event)
400 return fmt.Errorf("Failed to add event: ", err)
408 // addListener creates new UNIX domain socket, binds it to the address
409 // and marks it as listener
410 func (socket *Socket) addListener() (err error) {
412 fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
414 return fmt.Errorf("Failed to create UNIX domain socket")
416 usa := &syscall.SockaddrUnix{Name: socket.filename}
417 // Bind to address and start listening
418 err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
420 return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
422 err = syscall.Bind(fd, usa)
424 return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
426 err = syscall.Listen(fd, syscall.SOMAXCONN)
428 return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
431 return socket.AddListener(fd)
434 // close closes a control channel, if the control channel is assigned an
435 // interface, the interface is disconnected
436 func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
438 // first clear message queue so that the disconnect
439 // message is the only message in queue
440 cc.msgQueue = []controlMsg{}
441 cc.msgEnqDisconnect(str)
449 err = cc.socket.delEvent(&cc.event)
451 return fmt.Errorf("Failed to del event: ", err)
454 // remove referance form socket
455 cc.socket.ccList.Remove(cc.listRef)
458 err = cc.i.disconnect()
460 return fmt.Errorf("Interface Disconnect: ", err)
467 //addControlChannel returns a new controlChannel and adds it to the socket
468 func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
469 cc := &controlChannel{
477 cc.event = syscall.EpollEvent{
478 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
481 err = socket.addEvent(&cc.event)
483 return nil, fmt.Errorf("Failed to add event: ", err)
486 cc.listRef = socket.ccList.PushBack(cc)
491 func (cc *controlChannel) msgEnqAck() (err error) {
492 buf := new(bytes.Buffer)
493 err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
500 cc.msgQueue = append(cc.msgQueue, msg)
505 func (cc *controlChannel) msgEnqHello() (err error) {
515 copy(hello.Name[:], []byte(cc.socket.appName))
517 buf := new(bytes.Buffer)
518 err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
519 err = binary.Write(buf, binary.LittleEndian, hello)
526 cc.msgQueue = append(cc.msgQueue, msg)
531 func (cc *controlChannel) parseHello() (err error) {
534 buf := bytes.NewReader(cc.data[msgTypeSize:])
535 err = binary.Read(buf, binary.LittleEndian, &hello)
540 if hello.VersionMin > Version || hello.VersionMax < Version {
541 return fmt.Errorf("Incompatible memif version")
544 cc.i.run = cc.i.args.MemoryConfig
546 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
547 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
548 cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
550 cc.i.remoteName = string(hello.Name[:])
555 func (cc *controlChannel) msgEnqInit() (err error) {
559 Mode: cc.i.args.Mode,
562 copy(init.Name[:], []byte(cc.socket.appName))
564 buf := new(bytes.Buffer)
565 err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
566 err = binary.Write(buf, binary.LittleEndian, init)
573 cc.msgQueue = append(cc.msgQueue, msg)
578 func (cc *controlChannel) parseInit() (err error) {
581 buf := bytes.NewReader(cc.data[msgTypeSize:])
582 err = binary.Read(buf, binary.LittleEndian, &init)
587 if init.Version != Version {
588 return fmt.Errorf("Incompatible memif driver version")
591 // find peer interface
592 for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
593 i, ok := elt.Value.(*Interface)
595 if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
597 if i.args.Secret != init.Secret {
598 return fmt.Errorf("Invalid secret")
600 // interface is assigned to control channel
603 cc.i.run = cc.i.args.MemoryConfig
604 cc.i.remoteName = string(init.Name[:])
611 return fmt.Errorf("Invalid interface id")
614 func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
615 if len(cc.i.regions) <= int(regionIndex) {
616 return fmt.Errorf("Invalid region index")
619 addRegion := MsgAddRegion{
621 Size: cc.i.regions[regionIndex].size,
624 buf := new(bytes.Buffer)
625 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
626 err = binary.Write(buf, binary.LittleEndian, addRegion)
630 Fd: cc.i.regions[regionIndex].fd,
633 cc.msgQueue = append(cc.msgQueue, msg)
638 func (cc *controlChannel) parseAddRegion() (err error) {
639 var addRegion MsgAddRegion
641 buf := bytes.NewReader(cc.data[msgTypeSize:])
642 err = binary.Read(buf, binary.LittleEndian, &addRegion)
647 fd, err := cc.parseControlMsg()
649 return fmt.Errorf("parseControlMsg: %s", err)
652 if addRegion.Index > 255 {
653 return fmt.Errorf("Invalid memory region index")
656 region := memoryRegion{
657 size: addRegion.Size,
661 cc.i.regions = append(cc.i.regions, region)
666 func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
670 if ringType == ringTypeS2M {
671 q = cc.i.txQueues[ringIndex]
672 flags = msgAddRingFlagS2M
674 q = cc.i.rxQueues[ringIndex]
677 addRing := MsgAddRing{
679 Offset: uint32(q.ring.offset),
680 Region: uint16(q.ring.region),
681 RingSizeLog2: uint8(q.ring.log2Size),
686 buf := new(bytes.Buffer)
687 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
688 err = binary.Write(buf, binary.LittleEndian, addRing)
695 cc.msgQueue = append(cc.msgQueue, msg)
700 func (cc *controlChannel) parseAddRing() (err error) {
701 var addRing MsgAddRing
703 buf := bytes.NewReader(cc.data[msgTypeSize:])
704 err = binary.Read(buf, binary.LittleEndian, &addRing)
709 fd, err := cc.parseControlMsg()
714 if addRing.Index >= cc.i.run.NumQueuePairs {
715 return fmt.Errorf("invalid ring index")
723 if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
724 q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
725 cc.i.rxQueues = append(cc.i.rxQueues, q)
727 q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
728 cc.i.txQueues = append(cc.i.txQueues, q)
734 func (cc *controlChannel) msgEnqConnect() (err error) {
735 var connect MsgConnect
736 copy(connect.Name[:], []byte(cc.i.args.Name))
738 buf := new(bytes.Buffer)
739 err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
740 err = binary.Write(buf, binary.LittleEndian, connect)
747 cc.msgQueue = append(cc.msgQueue, msg)
752 func (cc *controlChannel) parseConnect() (err error) {
753 var connect MsgConnect
755 buf := bytes.NewReader(cc.data[msgTypeSize:])
756 err = binary.Read(buf, binary.LittleEndian, &connect)
761 cc.i.peerName = string(connect.Name[:])
767 cc.isConnected = true
772 func (cc *controlChannel) msgEnqConnected() (err error) {
773 var connected MsgConnected
774 copy(connected.Name[:], []byte(cc.i.args.Name))
776 buf := new(bytes.Buffer)
777 err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
778 err = binary.Write(buf, binary.LittleEndian, connected)
785 cc.msgQueue = append(cc.msgQueue, msg)
790 func (cc *controlChannel) parseConnected() (err error) {
791 var conn MsgConnected
793 buf := bytes.NewReader(cc.data[msgTypeSize:])
794 err = binary.Read(buf, binary.LittleEndian, &conn)
799 cc.i.peerName = string(conn.Name[:])
805 cc.isConnected = true
810 func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
815 copy(dc.String[:], str)
817 buf := new(bytes.Buffer)
818 err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
819 err = binary.Write(buf, binary.LittleEndian, dc)
826 cc.msgQueue = append(cc.msgQueue, msg)
831 func (cc *controlChannel) parseDisconnect() (err error) {
834 buf := bytes.NewReader(cc.data[msgTypeSize:])
835 err = binary.Read(buf, binary.LittleEndian, &dc)
840 err = cc.close(false, string(dc.String[:]))
842 return fmt.Errorf("Failed to disconnect control channel: ", err)
848 func (cc *controlChannel) parseMsg() error {
852 buf := bytes.NewReader(cc.data[:])
853 err = binary.Read(buf, binary.LittleEndian, &msgType)
855 if msgType == msgTypeAck {
857 } else if msgType == msgTypeHello {
859 err = cc.parseHello()
863 // Initialize slave memif
864 err = cc.i.initializeRegions()
868 err = cc.i.initializeQueues()
873 err = cc.msgEnqInit()
877 for i := 0; i < len(cc.i.regions); i++ {
878 err = cc.msgEnqAddRegion(uint16(i))
883 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
884 err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
889 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
890 err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
895 err = cc.msgEnqConnect()
899 } else if msgType == msgTypeInit {
909 } else if msgType == msgTypeAddRegion {
910 err = cc.parseAddRegion()
919 } else if msgType == msgTypeAddRing {
920 err = cc.parseAddRing()
929 } else if msgType == msgTypeConnect {
930 err = cc.parseConnect()
935 err = cc.msgEnqConnected()
939 } else if msgType == msgTypeConnected {
940 err = cc.parseConnected()
944 } else if msgType == msgTypeDisconnect {
945 err = cc.parseDisconnect()
950 err = fmt.Errorf("unknown message %d", msgType)
957 err1 := cc.close(true, err.Error())
959 return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
965 // parseControlMsg parses control message and returns file descriptor
967 func (cc *controlChannel) parseControlMsg() (fd int, err error) {
968 // Assert only called when we require FD
971 controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
973 return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
976 if len(controlMsgs) == 0 {
977 return -1, fmt.Errorf("Missing control message")
980 for _, cmsg := range controlMsgs {
981 if cmsg.Header.Level == syscall.SOL_SOCKET {
982 if cmsg.Header.Type == syscall.SCM_RIGHTS {
983 FDs, err := syscall.ParseUnixRights(&cmsg)
985 return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
990 // Only expect single FD
997 return -1, fmt.Errorf("Missing file descriptor")