2 *------------------------------------------------------------------
3 * Copyright (c) 2020 Cisco and/or its affiliates.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at:
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *------------------------------------------------------------------
30 const maxEpollEvents = 1
31 const maxControlLen = 256
33 const errorFdNotFound = "fd not found"
35 // controlMsg represents a message used in communication between memif peers
36 type controlMsg struct {
41 // listener represents a listener functionality of UNIX domain socket
42 type listener struct {
44 event syscall.EpollEvent
47 // controlChannel represents a communication channel between memif peers
48 // backed by UNIX domain socket
49 type controlChannel struct {
53 event syscall.EpollEvent
55 control [maxControlLen]byte
61 // Socket represents a UNIX domain socket used for communication
62 // between memif peers
67 interfaceList *list.List
70 wakeEvent syscall.EpollEvent
71 stopPollChan chan struct{}
75 // StopPolling stops polling events on the socket
76 func (socket *Socket) StopPolling() error {
77 if socket.stopPollChan != nil {
79 close(socket.stopPollChan)
81 buf := make([]byte, 8)
82 binary.PutUvarint(buf, 1)
83 n, err := syscall.Write(int(socket.wakeEvent.Fd), buf[:])
88 return fmt.Errorf("Faild to write to eventfd")
90 // wait until polling is stopped
97 // StartPolling starts polling and handling events on the socket,
98 // enabling communication between memif peers
99 func (socket *Socket) StartPolling(errChan chan<- error) {
100 socket.stopPollChan = make(chan struct{})
103 var events [maxEpollEvents]syscall.EpollEvent
104 defer socket.wg.Done()
108 case <-socket.stopPollChan:
111 num, err := syscall.EpollWait(socket.epfd, events[:], -1)
113 errChan <- fmt.Errorf("EpollWait: ", err)
117 for ev := 0; ev < num; ev++ {
118 if events[0].Fd == socket.wakeEvent.Fd {
121 err = socket.handleEvent(&events[0])
123 errChan <- fmt.Errorf("handleEvent: ", err)
131 // addEvent adds event to epoll instance associated with the socket
132 func (socket *Socket) addEvent(event *syscall.EpollEvent) error {
133 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_ADD, int(event.Fd), event)
135 return fmt.Errorf("EpollCtl: %s", err)
140 // addEvent deletes event to epoll instance associated with the socket
141 func (socket *Socket) delEvent(event *syscall.EpollEvent) error {
142 err := syscall.EpollCtl(socket.epfd, syscall.EPOLL_CTL_DEL, int(event.Fd), event)
144 return fmt.Errorf("EpollCtl: %s", err)
149 // Delete deletes the socket
150 func (socket *Socket) Delete() (err error) {
151 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
152 cc, ok := elt.Value.(*controlChannel)
154 err = cc.close(true, "Socket deleted")
160 for elt := socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
161 i, ok := elt.Value.(*Interface)
170 if socket.listener != nil {
171 err = socket.listener.close()
175 err = os.Remove(socket.filename)
181 err = socket.delEvent(&socket.wakeEvent)
183 return fmt.Errorf("Failed to delete event: ", err)
186 syscall.Close(socket.epfd)
191 // NewSocket returns a new Socket
192 func NewSocket(appName string, filename string) (socket *Socket, err error) {
196 interfaceList: list.New(),
199 if socket.filename == "" {
200 socket.filename = DefaultSocketFilename
203 socket.epfd, _ = syscall.EpollCreate1(0)
205 efd, err := eventFd()
206 socket.wakeEvent = syscall.EpollEvent{
207 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
210 err = socket.addEvent(&socket.wakeEvent)
212 return nil, fmt.Errorf("Failed to add event: ", err)
218 // handleEvent handles epoll event
219 func (socket *Socket) handleEvent(event *syscall.EpollEvent) error {
220 if socket.listener != nil && socket.listener.event.Fd == event.Fd {
221 return socket.listener.handleEvent(event)
224 for elt := socket.ccList.Front(); elt != nil; elt = elt.Next() {
225 cc, ok := elt.Value.(*controlChannel)
227 if cc.event.Fd == event.Fd {
228 return cc.handleEvent(event)
233 return fmt.Errorf(errorFdNotFound)
236 // handleEvent handles epoll event for listener
237 func (l *listener) handleEvent(event *syscall.EpollEvent) error {
239 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
242 return fmt.Errorf("Failed to close listener after hang up event: ", err)
244 return fmt.Errorf("Hang up: ", l.socket.filename)
248 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
251 return fmt.Errorf("Failed to close listener after receiving an error event: ", err)
253 return fmt.Errorf("Received error event on listener ", l.socket.filename)
257 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
258 newFd, _, err := syscall.Accept(int(l.event.Fd))
260 return fmt.Errorf("Accept: %s", err)
263 cc, err := l.socket.addControlChannel(newFd, nil)
265 return fmt.Errorf("Failed to add control channel: %s", err)
268 err = cc.msgEnqHello()
270 return fmt.Errorf("msgEnqHello: %s", err)
281 return fmt.Errorf("Unexpected event: ", event.Events)
284 // handleEvent handles epoll event for control channel
285 func (cc *controlChannel) handleEvent(event *syscall.EpollEvent) error {
290 if (event.Events & syscall.EPOLLHUP) == syscall.EPOLLHUP {
291 // close cc, don't send msg
292 err := cc.close(false, "")
294 return fmt.Errorf("Failed to close control channel after hang up event: ", err)
296 return fmt.Errorf("Hang up: ", cc.i.GetName())
299 if (event.Events & syscall.EPOLLERR) == syscall.EPOLLERR {
300 // close cc, don't send msg
301 err := cc.close(false, "")
303 return fmt.Errorf("Failed to close control channel after receiving an error event: ", err)
305 return fmt.Errorf("Received error event on control channel ", cc.i.GetName())
308 if (event.Events & syscall.EPOLLIN) == syscall.EPOLLIN {
309 size, cc.controlLen, _, _, err = syscall.Recvmsg(int(cc.event.Fd), cc.data[:], cc.control[:], 0)
311 return fmt.Errorf("recvmsg: %s", err)
314 return fmt.Errorf("invalid message size %d", size)
330 return fmt.Errorf("Unexpected event: ", event.Events)
333 // close closes the listener
334 func (l *listener) close() error {
335 err := l.socket.delEvent(&l.event)
337 return fmt.Errorf("Failed to del event: ", err)
339 err = syscall.Close(int(l.event.Fd))
341 return fmt.Errorf("Failed to close socket: ", err)
346 // AddListener adds a lisntener to the socket. The fd must describe a
347 // UNIX domain socket already bound to a UNIX domain filename and
348 // marked as listener
349 func (socket *Socket) AddListener(fd int) (err error) {
351 // we will need this to look up master interface by id
355 l.event = syscall.EpollEvent{
356 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
359 err = socket.addEvent(&l.event)
361 return fmt.Errorf("Failed to add event: ", err)
369 // addListener creates new UNIX domain socket, binds it to the address
370 // and marks it as listener
371 func (socket *Socket) addListener() (err error) {
373 fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
375 return fmt.Errorf("Failed to create UNIX domain socket")
377 usa := &syscall.SockaddrUnix{Name: socket.filename}
379 // Bind to address and start listening
380 err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
382 return fmt.Errorf("Failed to set socket option %s : %v", socket.filename, err)
384 err = syscall.Bind(fd, usa)
386 return fmt.Errorf("Failed to bind socket %s : %v", socket.filename, err)
388 err = syscall.Listen(fd, syscall.SOMAXCONN)
390 return fmt.Errorf("Failed to listen on socket %s : %v", socket.filename, err)
393 return socket.AddListener(fd)
396 // close closes a control channel, if the control channel is assigned an
397 // interface, the interface is disconnected
398 func (cc *controlChannel) close(sendMsg bool, str string) (err error) {
400 // first clear message queue so that the disconnect
401 // message is the only message in queue
402 cc.msgQueue = []controlMsg{}
403 cc.msgEnqDisconnect(str)
411 err = cc.socket.delEvent(&cc.event)
413 return fmt.Errorf("Failed to del event: ", err)
416 // remove referance form socket
417 cc.socket.ccList.Remove(cc.listRef)
420 err = cc.i.disconnect()
422 return fmt.Errorf("Interface Disconnect: ", err)
429 //addControlChannel returns a new controlChannel and adds it to the socket
430 func (socket *Socket) addControlChannel(fd int, i *Interface) (*controlChannel, error) {
431 cc := &controlChannel{
439 cc.event = syscall.EpollEvent{
440 Events: syscall.EPOLLIN | syscall.EPOLLERR | syscall.EPOLLHUP,
443 err = socket.addEvent(&cc.event)
445 return nil, fmt.Errorf("Failed to add event: ", err)
448 cc.listRef = socket.ccList.PushBack(cc)
453 func (cc *controlChannel) msgEnqAck() (err error) {
454 buf := new(bytes.Buffer)
455 err = binary.Write(buf, binary.LittleEndian, msgTypeAck)
462 cc.msgQueue = append(cc.msgQueue, msg)
467 func (cc *controlChannel) msgEnqHello() (err error) {
477 copy(hello.Name[:], []byte(cc.socket.appName))
479 buf := new(bytes.Buffer)
480 err = binary.Write(buf, binary.LittleEndian, msgTypeHello)
481 err = binary.Write(buf, binary.LittleEndian, hello)
488 cc.msgQueue = append(cc.msgQueue, msg)
493 func (cc *controlChannel) parseHello() (err error) {
496 buf := bytes.NewReader(cc.data[msgTypeSize:])
497 err = binary.Read(buf, binary.LittleEndian, &hello)
502 if hello.VersionMin > Version || hello.VersionMax < Version {
503 return fmt.Errorf("Incompatible memif version")
506 cc.i.run = cc.i.args.MemoryConfig
508 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingS2M)
509 cc.i.run.NumQueuePairs = min16(cc.i.args.MemoryConfig.NumQueuePairs, hello.MaxRingM2S)
510 cc.i.run.Log2RingSize = min8(cc.i.args.MemoryConfig.Log2RingSize, hello.MaxLog2RingSize)
512 cc.i.remoteName = string(hello.Name[:])
517 func (cc *controlChannel) msgEnqInit() (err error) {
521 Mode: interfaceModeEthernet,
524 copy(init.Name[:], []byte(cc.socket.appName))
526 buf := new(bytes.Buffer)
527 err = binary.Write(buf, binary.LittleEndian, msgTypeInit)
528 err = binary.Write(buf, binary.LittleEndian, init)
535 cc.msgQueue = append(cc.msgQueue, msg)
540 func (cc *controlChannel) parseInit() (err error) {
543 buf := bytes.NewReader(cc.data[msgTypeSize:])
544 err = binary.Read(buf, binary.LittleEndian, &init)
549 if init.Version != Version {
550 return fmt.Errorf("Incompatible memif driver version")
553 // find peer interface
554 for elt := cc.socket.interfaceList.Front(); elt != nil; elt = elt.Next() {
555 i, ok := elt.Value.(*Interface)
557 if i.args.Id == init.Id && i.args.IsMaster && i.cc == nil {
559 if i.args.Secret != init.Secret {
560 return fmt.Errorf("Invalid secret")
562 // interface is assigned to control channel
565 cc.i.run = cc.i.args.MemoryConfig
566 cc.i.remoteName = string(init.Name[:])
573 return fmt.Errorf("Invalid interface id")
576 func (cc *controlChannel) msgEnqAddRegion(regionIndex uint16) (err error) {
577 if len(cc.i.regions) <= int(regionIndex) {
578 return fmt.Errorf("Invalid region index")
581 addRegion := MsgAddRegion{
583 Size: cc.i.regions[regionIndex].size,
586 buf := new(bytes.Buffer)
587 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRegion)
588 err = binary.Write(buf, binary.LittleEndian, addRegion)
592 Fd: cc.i.regions[regionIndex].fd,
595 cc.msgQueue = append(cc.msgQueue, msg)
600 func (cc *controlChannel) parseAddRegion() (err error) {
601 var addRegion MsgAddRegion
603 buf := bytes.NewReader(cc.data[msgTypeSize:])
604 err = binary.Read(buf, binary.LittleEndian, &addRegion)
609 fd, err := cc.parseControlMsg()
611 return fmt.Errorf("parseControlMsg: %s", err)
614 if addRegion.Index > 255 {
615 return fmt.Errorf("Invalid memory region index")
618 region := memoryRegion{
619 size: addRegion.Size,
623 cc.i.regions = append(cc.i.regions, region)
628 func (cc *controlChannel) msgEnqAddRing(ringType ringType, ringIndex uint16) (err error) {
632 if ringType == ringTypeS2M {
633 q = cc.i.txQueues[ringIndex]
634 flags = msgAddRingFlagS2M
636 q = cc.i.rxQueues[ringIndex]
639 addRing := MsgAddRing{
641 Offset: uint32(q.ring.offset),
642 Region: uint16(q.ring.region),
643 RingSizeLog2: uint8(q.ring.log2Size),
648 buf := new(bytes.Buffer)
649 err = binary.Write(buf, binary.LittleEndian, msgTypeAddRing)
650 err = binary.Write(buf, binary.LittleEndian, addRing)
657 cc.msgQueue = append(cc.msgQueue, msg)
662 func (cc *controlChannel) parseAddRing() (err error) {
663 var addRing MsgAddRing
665 buf := bytes.NewReader(cc.data[msgTypeSize:])
666 err = binary.Read(buf, binary.LittleEndian, &addRing)
671 fd, err := cc.parseControlMsg()
676 if addRing.Index >= cc.i.run.NumQueuePairs {
677 return fmt.Errorf("invalid ring index")
685 if (addRing.Flags & msgAddRingFlagS2M) == msgAddRingFlagS2M {
686 q.ring = newRing(int(addRing.Region), ringTypeS2M, int(addRing.Offset), int(addRing.RingSizeLog2))
687 cc.i.rxQueues = append(cc.i.rxQueues, q)
689 q.ring = newRing(int(addRing.Region), ringTypeM2S, int(addRing.Offset), int(addRing.RingSizeLog2))
690 cc.i.txQueues = append(cc.i.txQueues, q)
696 func (cc *controlChannel) msgEnqConnect() (err error) {
697 var connect MsgConnect
698 copy(connect.Name[:], []byte(cc.i.args.Name))
700 buf := new(bytes.Buffer)
701 err = binary.Write(buf, binary.LittleEndian, msgTypeConnect)
702 err = binary.Write(buf, binary.LittleEndian, connect)
709 cc.msgQueue = append(cc.msgQueue, msg)
714 func (cc *controlChannel) parseConnect() (err error) {
715 var connect MsgConnect
717 buf := bytes.NewReader(cc.data[msgTypeSize:])
718 err = binary.Read(buf, binary.LittleEndian, &connect)
723 cc.i.peerName = string(connect.Name[:])
730 cc.isConnected = true
735 func (cc *controlChannel) msgEnqConnected() (err error) {
736 var connected MsgConnected
737 copy(connected.Name[:], []byte(cc.i.args.Name))
739 buf := new(bytes.Buffer)
740 err = binary.Write(buf, binary.LittleEndian, msgTypeConnected)
741 err = binary.Write(buf, binary.LittleEndian, connected)
748 cc.msgQueue = append(cc.msgQueue, msg)
753 func (cc *controlChannel) parseConnected() (err error) {
754 var conn MsgConnected
756 buf := bytes.NewReader(cc.data[msgTypeSize:])
757 err = binary.Read(buf, binary.LittleEndian, &conn)
762 cc.i.peerName = string(conn.Name[:])
769 cc.isConnected = true
774 func (cc *controlChannel) msgEnqDisconnect(str string) (err error) {
779 copy(dc.String[:], str)
781 buf := new(bytes.Buffer)
782 err = binary.Write(buf, binary.LittleEndian, msgTypeDisconnect)
783 err = binary.Write(buf, binary.LittleEndian, dc)
790 cc.msgQueue = append(cc.msgQueue, msg)
795 func (cc *controlChannel) parseDisconnect() (err error) {
798 buf := bytes.NewReader(cc.data[msgTypeSize:])
799 err = binary.Read(buf, binary.LittleEndian, &dc)
804 err = cc.close(false, string(dc.String[:]))
806 return fmt.Errorf("Failed to disconnect control channel: ", err)
812 func (cc *controlChannel) parseMsg() error {
816 buf := bytes.NewReader(cc.data[:])
817 err = binary.Read(buf, binary.LittleEndian, &msgType)
819 if msgType == msgTypeAck {
821 } else if msgType == msgTypeHello {
823 err = cc.parseHello()
827 // Initialize slave memif
828 err = cc.i.initializeRegions()
832 err = cc.i.initializeQueues()
837 err = cc.msgEnqInit()
841 for i := 0; i < len(cc.i.regions); i++ {
842 err = cc.msgEnqAddRegion(uint16(i))
847 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
848 err = cc.msgEnqAddRing(ringTypeS2M, uint16(i))
853 for i := 0; uint16(i) < cc.i.run.NumQueuePairs; i++ {
854 err = cc.msgEnqAddRing(ringTypeM2S, uint16(i))
859 err = cc.msgEnqConnect()
863 } else if msgType == msgTypeInit {
873 } else if msgType == msgTypeAddRegion {
874 err = cc.parseAddRegion()
883 } else if msgType == msgTypeAddRing {
884 err = cc.parseAddRing()
893 } else if msgType == msgTypeConnect {
894 err = cc.parseConnect()
899 err = cc.msgEnqConnected()
903 } else if msgType == msgTypeConnected {
904 err = cc.parseConnected()
908 } else if msgType == msgTypeDisconnect {
909 err = cc.parseDisconnect()
914 err = fmt.Errorf("unknown message %d", msgType)
921 err1 := cc.close(true, err.Error())
923 return fmt.Errorf(err.Error(), ": Failed to close control channel: ", err1)
929 // parseControlMsg parses control message and returns file descriptor
931 func (cc *controlChannel) parseControlMsg() (fd int, err error) {
932 // Assert only called when we require FD
935 controlMsgs, err := syscall.ParseSocketControlMessage(cc.control[:cc.controlLen])
937 return -1, fmt.Errorf("syscall.ParseSocketControlMessage: %s", err)
940 if len(controlMsgs) == 0 {
941 return -1, fmt.Errorf("Missing control message")
944 for _, cmsg := range controlMsgs {
945 if cmsg.Header.Level == syscall.SOL_SOCKET {
946 if cmsg.Header.Type == syscall.SCM_RIGHTS {
947 FDs, err := syscall.ParseUnixRights(&cmsg)
949 return -1, fmt.Errorf("syscall.ParseUnixRights: %s", err)
954 // Only expect single FD
961 return -1, fmt.Errorf("Missing file descriptor")