Recover possible panic in EncodeMsg and improve debug logs 88/14788/1
authorOndrej Fabry <ofabry@cisco.com>
Wed, 12 Sep 2018 15:38:36 +0000 (17:38 +0200)
committerOndrej Fabry <ofabry@cisco.com>
Wed, 12 Sep 2018 15:38:36 +0000 (17:38 +0200)
Change-Id: I771c171ae30a957f4436e7f4ba834d8a38d02f80
Signed-off-by: Ondrej Fabry <ofabry@cisco.com>
codec/msg_codec.go
codec/msg_codec_test.go [new file with mode: 0644]
core/request_handler.go

index 9d3f614..67628a4 100644 (file)
@@ -53,24 +53,32 @@ type VppOtherHeader struct {
 }
 
 // EncodeMsg encodes provided `Message` structure into its binary-encoded data representation.
-func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) {
+func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) (data []byte, err error) {
        if msg == nil {
                return nil, errors.New("nil message passed in")
        }
 
+       // try to recover panic which might possibly occur in struc.Pack call
+       defer func() {
+               if r := recover(); r != nil {
+                       var ok bool
+                       if err, ok = r.(error); !ok {
+                               err = fmt.Errorf("%v", r)
+                       }
+                       err = fmt.Errorf("panic occurred: %v", err)
+               }
+       }()
+
        var header interface{}
 
        // encode message header
        switch msg.GetMessageType() {
        case api.RequestMessage:
                header = &VppRequestHeader{VlMsgID: msgID}
-
        case api.ReplyMessage:
                header = &VppReplyHeader{VlMsgID: msgID}
-
        case api.EventMessage:
                header = &VppEventHeader{VlMsgID: msgID}
-
        default:
                header = &VppOtherHeader{VlMsgID: msgID}
        }
@@ -79,13 +87,13 @@ func (*MsgCodec) EncodeMsg(msg api.Message, msgID uint16) ([]byte, error) {
 
        // encode message header
        if err := struc.Pack(buf, header); err != nil {
-               return nil, fmt.Errorf("unable to encode message header: %v, error %v", header, err)
+               return nil, fmt.Errorf("failed to encode message header: %+v, error: %v", header, err)
        }
 
        // encode message content
        if reflect.TypeOf(msg).Elem().NumField() > 0 {
                if err := struc.Pack(buf, msg); err != nil {
-                       return nil, fmt.Errorf("unable to encode message data: %v, error %v", header, err)
+                       return nil, fmt.Errorf("failed to encode message data: %+v, error: %v", data, err)
                }
        }
 
@@ -104,13 +112,10 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error {
        switch msg.GetMessageType() {
        case api.RequestMessage:
                header = new(VppRequestHeader)
-
        case api.ReplyMessage:
                header = new(VppReplyHeader)
-
        case api.EventMessage:
                header = new(VppEventHeader)
-
        default:
                header = new(VppOtherHeader)
        }
@@ -119,12 +124,12 @@ func (*MsgCodec) DecodeMsg(data []byte, msg api.Message) error {
 
        // decode message header
        if err := struc.Unpack(buf, header); err != nil {
-               return fmt.Errorf("unable to decode message header: %+v, error %v", data, err)
+               return fmt.Errorf("failed to decode message header: %+v, error: %v", header, err)
        }
 
        // decode message content
        if err := struc.Unpack(buf, msg); err != nil {
-               return fmt.Errorf("unable to decode message data: %+v, error %v", data, err)
+               return fmt.Errorf("failed to decode message data: %+v, error: %v", data, err)
        }
 
        return nil
diff --git a/codec/msg_codec_test.go b/codec/msg_codec_test.go
new file mode 100644 (file)
index 0000000..cd1240e
--- /dev/null
@@ -0,0 +1,63 @@
+package codec
+
+import (
+       "bytes"
+       "testing"
+
+       "git.fd.io/govpp.git/api"
+)
+
+type MyMsg struct {
+       Index uint16
+       Label []byte `struc:"[16]byte"`
+       Port  uint16
+}
+
+func (*MyMsg) GetMessageName() string {
+       return "my_msg"
+}
+func (*MyMsg) GetCrcString() string {
+       return "xxxxx"
+}
+func (*MyMsg) GetMessageType() api.MessageType {
+       return api.OtherMessage
+}
+
+func TestEncode(t *testing.T) {
+       tests := []struct {
+               name    string
+               msg     api.Message
+               msgID   uint16
+               expData []byte
+       }{
+               {name: "basic",
+                       msg:     &MyMsg{Index: 1, Label: []byte("Abcdef"), Port: 1000},
+                       msgID:   100,
+                       expData: []byte{0x00, 0x64, 0x00, 0x01, 0x41, 0x62, 0x63, 0x64, 0x65, 0x66, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xE8},
+               },
+       }
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       c := &MsgCodec{}
+
+                       data, err := c.EncodeMsg(test.msg, test.msgID)
+                       if err != nil {
+                               t.Fatalf("expected nil error, got: %v", err)
+                       }
+                       if !bytes.Equal(data, test.expData) {
+                               t.Fatalf("expected data: % 0X, got: % 0X", test.expData, data)
+                       }
+               })
+       }
+}
+
+func TestEncodePanic(t *testing.T) {
+       c := &MsgCodec{}
+
+       msg := &MyMsg{Index: 1, Label: []byte("thisIsLongerThan16Bytes"), Port: 1000}
+
+       _, err := c.EncodeMsg(msg, 100)
+       if err == nil {
+               t.Fatalf("expected non-nil error, got: %v", err)
+       }
+}
index c042948..e52e262 100644 (file)
@@ -39,7 +39,9 @@ func (c *Connection) watchRequests(ch *Channel) {
                                c.releaseAPIChannel(ch)
                                return
                        }
-                       c.processRequest(ch, req)
+                       if err := c.processRequest(ch, req); err != nil {
+                               sendReplyError(ch, req, err)
+                       }
                }
        }
 }
@@ -50,39 +52,36 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
        if atomic.LoadUint32(&c.connected) == 0 {
                err := ErrNotConnected
                log.Errorf("processing request failed: %v", err)
-               sendReplyError(ch, req, err)
                return err
        }
 
        // retrieve message ID
        msgID, err := c.GetMessageID(req.msg)
        if err != nil {
-               err = fmt.Errorf("unable to retrieve message ID: %v", err)
                log.WithFields(logger.Fields{
                        "msg_name": req.msg.GetMessageName(),
                        "msg_crc":  req.msg.GetCrcString(),
                        "seq_num":  req.seqNum,
-               }).Error(err)
-               sendReplyError(ch, req, err)
-               return err
+                       "error":    err,
+               }).Errorf("failed to retrieve message ID")
+               return fmt.Errorf("unable to retrieve message ID: %v", err)
        }
 
        // encode the message into binary
        data, err := c.codec.EncodeMsg(req.msg, msgID)
        if err != nil {
-               err = fmt.Errorf("unable to encode the messge: %v", err)
                log.WithFields(logger.Fields{
                        "channel":  ch.id,
                        "msg_id":   msgID,
                        "msg_name": req.msg.GetMessageName(),
                        "seq_num":  req.seqNum,
-               }).Error(err)
-               sendReplyError(ch, req, err)
-               return err
+                       "error":    err,
+               }).Errorf("failed to encode message: %#v", req.msg)
+               return fmt.Errorf("unable to encode the message: %v", err)
        }
 
-       // get context
        context := packRequestContext(ch.id, req.multi, req.seqNum)
+
        if log.Level == logger.DebugLevel { // for performance reasons - logrus does some processing even if debugs are disabled
                log.WithFields(logger.Fields{
                        "channel":  ch.id,
@@ -104,7 +103,6 @@ func (c *Connection) processRequest(ch *Channel, req *vppRequest) error {
                        "msg_id":  msgID,
                        "seq_num": req.seqNum,
                }).Error(err)
-               sendReplyError(ch, req, err)
                return err
        }