Change module name to go.fd.io/govpp
[govpp.git] / proxy / server.go
index 472ad16..11c443f 100644 (file)
@@ -1,4 +1,4 @@
-//  Copyright (c) 2019 Cisco and/or its affiliates.
+//  Copyright (c) 2021 Cisco and/or its affiliates.
 //
 //  Licensed under the Apache License, Version 2.0 (the "License");
 //  you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
 package proxy
 
 import (
+       "context"
        "errors"
        "fmt"
        "reflect"
@@ -22,9 +23,9 @@ import (
        "sync/atomic"
        "time"
 
-       "git.fd.io/govpp.git/adapter"
-       "git.fd.io/govpp.git/api"
-       "git.fd.io/govpp.git/core"
+       "go.fd.io/govpp/adapter"
+       "go.fd.io/govpp/api"
+       "go.fd.io/govpp/core"
 )
 
 const (
@@ -55,6 +56,7 @@ type StatsResponse struct {
        IfaceStats *api.InterfaceStats
        ErrStats   *api.ErrorStats
        BufStats   *api.BufferStats
+       MemStats   *api.MemoryStats
 }
 
 // StatsRPC is a RPC server for proxying client request to api.StatsProvider.
@@ -75,7 +77,7 @@ type StatsRPC struct {
 // proxying request to given api.StatsProvider.
 func NewStatsRPC(stats adapter.StatsAPI) (*StatsRPC, error) {
        rpc := new(StatsRPC)
-       if err := rpc.Connect(stats); err != nil {
+       if err := rpc.connect(stats); err != nil {
                return nil, err
        }
        return rpc, nil
@@ -84,7 +86,7 @@ func NewStatsRPC(stats adapter.StatsAPI) (*StatsRPC, error) {
 func (s *StatsRPC) watchConnection() {
        heartbeatTicker := time.NewTicker(10 * time.Second).C
        atomic.StoreUint32(&s.available, 1)
-       log.Println("enabling statsRPC service")
+       log.Debugln("enabling statsRPC service")
 
        count := 0
        prev := new(api.SystemStats)
@@ -127,7 +129,7 @@ func (s *StatsRPC) watchConnection() {
                                                s.statsConn, err = core.ConnectStats(s.stats)
                                                if err == nil {
                                                        atomic.StoreUint32(&s.available, 1)
-                                                       log.Println("enabling statsRPC service")
+                                                       log.Debugln("enabling statsRPC service")
                                                        break
                                                }
                                                time.Sleep(5 * time.Second)
@@ -144,7 +146,7 @@ func (s *StatsRPC) watchConnection() {
        }
 }
 
-func (s *StatsRPC) Connect(stats adapter.StatsAPI) error {
+func (s *StatsRPC) connect(stats adapter.StatsAPI) error {
        if atomic.LoadUint32(&s.isConnected) == 1 {
                return errors.New("connection already exists")
        }
@@ -161,7 +163,7 @@ func (s *StatsRPC) Connect(stats adapter.StatsAPI) error {
        return nil
 }
 
-func (s *StatsRPC) Disconnect() {
+func (s *StatsRPC) disconnect() {
        if atomic.LoadUint32(&s.isConnected) == 1 {
                atomic.StoreUint32(&s.isConnected, 0)
                close(s.done)
@@ -176,7 +178,7 @@ func (s *StatsRPC) serviceAvailable() bool {
 
 func (s *StatsRPC) GetStats(req StatsRequest, resp *StatsResponse) error {
        if !s.serviceAvailable() {
-               log.Println(statsErrorMsg)
+               log.Print(statsErrorMsg)
                return errors.New("server does not support 'get stats' at this time, try again later")
        }
        log.Debugf("StatsRPC.GetStats - REQ: %+v", req)
@@ -200,6 +202,9 @@ func (s *StatsRPC) GetStats(req StatsRequest, resp *StatsResponse) error {
        case "buffer":
                resp.BufStats = new(api.BufferStats)
                return s.statsConn.GetBufferStats(resp.BufStats)
+       case "memory":
+               resp.MemStats = new(api.MemoryStats)
+               return s.statsConn.GetMemoryStats(resp.MemStats)
        default:
                return fmt.Errorf("unknown stats type: %s", req.StatsType)
        }
@@ -222,15 +227,21 @@ type BinapiCompatibilityRequest struct {
 }
 
 type BinapiCompatibilityResponse struct {
-       CompatibleMsgs   []string
-       IncompatibleMsgs []string
+       CompatibleMsgs   map[string][]string
+       IncompatibleMsgs map[string][]string
 }
 
-// BinapiRPC is a RPC server for proxying client request to api.Channel.
+// BinapiRPC is a RPC server for proxying client request to api.Channel
+// or api.Stream.
 type BinapiRPC struct {
        binapiConn *core.Connection
        binapi     adapter.VppAPI
 
+       streamsLock sync.Mutex
+       // local ID, different from api.Stream ID
+       maxStreamID uint32
+       streams     map[uint32]api.Stream
+
        events chan core.ConnectionEvent
        done   chan struct{}
        // non-zero if the RPC service is available
@@ -243,7 +254,7 @@ type BinapiRPC struct {
 // proxying request to given api.Channel.
 func NewBinapiRPC(binapi adapter.VppAPI) (*BinapiRPC, error) {
        rpc := new(BinapiRPC)
-       if err := rpc.Connect(binapi); err != nil {
+       if err := rpc.connect(binapi); err != nil {
                return nil, err
        }
        return rpc, nil
@@ -263,7 +274,7 @@ func (s *BinapiRPC) watchConnection() {
                        case core.Connected:
                                if !s.serviceAvailable() {
                                        atomic.StoreUint32(&s.available, 1)
-                                       log.Println("enabling binapiRPC service")
+                                       log.Debugln("enabling binapiRPC service")
                                }
                        case core.Disconnected:
                                if s.serviceAvailable() {
@@ -290,7 +301,7 @@ func (s *BinapiRPC) watchConnection() {
        }
 }
 
-func (s *BinapiRPC) Connect(binapi adapter.VppAPI) error {
+func (s *BinapiRPC) connect(binapi adapter.VppAPI) error {
        if atomic.LoadUint32(&s.isConnected) == 1 {
                return errors.New("connection already exists")
        }
@@ -307,7 +318,7 @@ func (s *BinapiRPC) Connect(binapi adapter.VppAPI) error {
        return nil
 }
 
-func (s *BinapiRPC) Disconnect() {
+func (s *BinapiRPC) disconnect() {
        if atomic.LoadUint32(&s.isConnected) == 1 {
                atomic.StoreUint32(&s.isConnected, 0)
                close(s.done)
@@ -320,9 +331,104 @@ func (s *BinapiRPC) serviceAvailable() bool {
        return atomic.LoadUint32(&s.available) == 1
 }
 
+type RPCStreamReqResp struct {
+       ID  uint32
+       Msg api.Message
+}
+
+func (s *BinapiRPC) NewAPIStream(req RPCStreamReqResp, resp *RPCStreamReqResp) error {
+       if !s.serviceAvailable() {
+               log.Print(binapiErrorMsg)
+               return errors.New("server does not support RPC calls at this time, try again later")
+       }
+       log.Debugf("BinapiRPC.NewAPIStream - REQ: %#v", req)
+
+       stream, err := s.binapiConn.NewStream(context.Background())
+       if err != nil {
+               return err
+       }
+
+       if s.streams == nil {
+               s.streams = make(map[uint32]api.Stream)
+       }
+
+       s.streamsLock.Lock()
+       s.maxStreamID++
+       s.streams[s.maxStreamID] = stream
+       resp.ID = s.maxStreamID
+       s.streamsLock.Unlock()
+
+       return nil
+}
+
+func (s *BinapiRPC) SendMessage(req RPCStreamReqResp, resp *RPCStreamReqResp) error {
+       if !s.serviceAvailable() {
+               log.Print(binapiErrorMsg)
+               return errors.New("server does not support RPC calls at this time, try again later")
+       }
+       log.Debugf("BinapiRPC.SendMessage - REQ: %#v", req)
+
+       stream, err := s.getStream(req.ID)
+       if err != nil {
+               return err
+       }
+
+       return stream.SendMsg(req.Msg)
+}
+
+func (s *BinapiRPC) ReceiveMessage(req RPCStreamReqResp, resp *RPCStreamReqResp) error {
+       if !s.serviceAvailable() {
+               log.Print(binapiErrorMsg)
+               return errors.New("server does not support RPC calls at this time, try again later")
+       }
+       log.Debugf("BinapiRPC.ReceiveMessage - REQ: %#v", req)
+
+       stream, err := s.getStream(req.ID)
+       if err != nil {
+               return err
+       }
+
+       resp.Msg, err = stream.RecvMsg()
+       return err
+}
+
+func (s *BinapiRPC) CloseStream(req RPCStreamReqResp, resp *RPCStreamReqResp) error {
+       if !s.serviceAvailable() {
+               log.Print(binapiErrorMsg)
+               return errors.New("server does not support RPC calls at this time, try again later")
+       }
+       log.Debugf("BinapiRPC.CloseStream - REQ: %#v", req)
+
+       stream, err := s.getStream(req.ID)
+       if err != nil {
+               return err
+       }
+
+       s.streamsLock.Lock()
+       delete(s.streams, req.ID)
+       s.streamsLock.Unlock()
+
+       return stream.Close()
+}
+
+func (s *BinapiRPC) getStream(id uint32) (api.Stream, error) {
+       s.streamsLock.Lock()
+       stream := s.streams[id]
+       s.streamsLock.Unlock()
+
+       if stream == nil || reflect.ValueOf(stream).IsNil() {
+               s.streamsLock.Lock()
+               // delete the stream in case it is still in the map
+               delete(s.streams, id)
+               s.streamsLock.Unlock()
+               return nil, errors.New("BinapiRPC stream closed")
+       }
+       return stream, nil
+}
+
 func (s *BinapiRPC) Invoke(req BinapiRequest, resp *BinapiResponse) error {
        if !s.serviceAvailable() {
-               log.Println(binapiErrorMsg)
+               log.Print(binapiErrorMsg)
                return errors.New("server does not support 'invoke' at this time, try again later")
        }
        log.Debugf("BinapiRPC.Invoke - REQ: %#v", req)
@@ -364,7 +470,7 @@ func (s *BinapiRPC) Invoke(req BinapiRequest, resp *BinapiResponse) error {
 
 func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCompatibilityResponse) error {
        if !s.serviceAvailable() {
-               log.Println(binapiErrorMsg)
+               log.Print(binapiErrorMsg)
                return errors.New("server does not support 'compatibility check' at this time, try again later")
        }
        log.Debugf("BinapiRPC.Compatiblity - REQ: %#v", req)
@@ -375,25 +481,37 @@ func (s *BinapiRPC) Compatibility(req BinapiCompatibilityRequest, resp *BinapiCo
        }
        defer ch.Close()
 
-       resp.CompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
-       resp.IncompatibleMsgs = make([]string, 0, len(req.MsgNameCrcs))
+       resp.CompatibleMsgs = make(map[string][]string)
+       resp.IncompatibleMsgs = make(map[string][]string)
 
-       for _, msg := range req.MsgNameCrcs {
-               val, ok := api.GetRegisteredMessages()[msg]
-               if !ok {
-                       resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
-                       continue
+       for path, messages := range api.GetRegisteredMessages() {
+               resp.IncompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
+               resp.CompatibleMsgs[path] = make([]string, 0, len(req.MsgNameCrcs))
+
+               for _, msg := range req.MsgNameCrcs {
+                       val, ok := messages[msg]
+                       if !ok {
+                               resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+                               continue
+                       }
+                       if err = ch.CheckCompatiblity(val); err != nil {
+                               resp.IncompatibleMsgs[path] = append(resp.IncompatibleMsgs[path], msg)
+                       } else {
+                               resp.CompatibleMsgs[path] = append(resp.CompatibleMsgs[path], msg)
+                       }
                }
+       }
 
-               if err = ch.CheckCompatiblity(val); err != nil {
-                       resp.IncompatibleMsgs = append(resp.IncompatibleMsgs, msg)
+       compatible := false
+       for path, incompatibleMsgs := range resp.IncompatibleMsgs {
+               if len(incompatibleMsgs) == 0 {
+                       compatible = true
                } else {
-                       resp.CompatibleMsgs = append(resp.CompatibleMsgs, msg)
+                       log.Debugf("messages are incompatible for path %s", path)
                }
        }
-
-       if len(resp.IncompatibleMsgs) > 0 {
-               return fmt.Errorf("compatibility check failed for messages: %v", resp.IncompatibleMsgs)
+       if !compatible {
+               return errors.New("compatibility check failed")
        }
 
        return nil