Recognize stat_dir_type_empty
[govpp.git] / binapigen / gen_rpc.go
1 //  Copyright (c) 2020 Cisco and/or its affiliates.
2 //
3 //  Licensed under the Apache License, Version 2.0 (the "License");
4 //  you may not use this file except in compliance with the License.
5 //  You may obtain a copy of the License at:
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 //  Unless required by applicable law or agreed to in writing, software
10 //  distributed under the License is distributed on an "AS IS" BASIS,
11 //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 //  See the License for the specific language governing permissions and
13 //  limitations under the License.
14
15 package binapigen
16
17 import (
18         "fmt"
19         "path"
20
21         "github.com/sirupsen/logrus"
22 )
23
24 func init() {
25         RegisterPlugin("rpc", GenerateRPC)
26 }
27
28 // library dependencies
29 const (
30         contextPkg = GoImportPath("context")
31         ioPkg      = GoImportPath("io")
32 )
33
34 // generated names
35 const (
36         serviceApiName    = "RPCService"    // name for the RPC service interface
37         serviceImplName   = "serviceClient" // name for the RPC service implementation
38         serviceClientName = "ServiceClient" // name for the RPC service client
39
40         // TODO: register service descriptor
41         //serviceDescType = "ServiceDesc"             // name for service descriptor type
42         //serviceDescName = "_ServiceRPC_serviceDesc" // name for service descriptor var
43 )
44
45 func GenerateRPC(gen *Generator, file *File) *GenFile {
46         if file.Service == nil {
47                 return nil
48         }
49
50         logf("----------------------------")
51         logf(" Generate RPC - %s", file.Desc.Name)
52         logf("----------------------------")
53
54         filename := path.Join(file.FilenamePrefix, file.Desc.Name+"_rpc.ba.go")
55         g := gen.NewGenFile(filename, file.GoImportPath)
56         g.file = file
57
58         // generate file header
59         g.P("// Code generated by GoVPP's binapi-generator. DO NOT EDIT.")
60         g.P()
61         g.P("package ", file.PackageName)
62         g.P()
63
64         // generate RPC service
65         if len(file.Service.RPCs) > 0 {
66                 genService(g, file.Service)
67         }
68
69         return g
70 }
71
72 func genService(g *GenFile, svc *Service) {
73         // generate comment
74         g.P("// ", serviceApiName, " defines RPC service ", g.file.Desc.Name, ".")
75
76         // generate service interface
77         g.P("type ", serviceApiName, " interface {")
78         for _, rpc := range svc.RPCs {
79                 g.P(rpcMethodSignature(g, rpc))
80         }
81         g.P("}")
82         g.P()
83
84         // generate client implementation
85         g.P("type ", serviceImplName, " struct {")
86         g.P("conn ", govppApiPkg.Ident("Connection"))
87         g.P("}")
88         g.P()
89
90         // generate client constructor
91         g.P("func New", serviceClientName, "(conn ", govppApiPkg.Ident("Connection"), ") ", serviceApiName, " {")
92         g.P("return &", serviceImplName, "{conn}")
93         g.P("}")
94         g.P()
95
96         msgControlPingReply, ok := g.gen.messagesByName["control_ping_reply"]
97         if !ok {
98                 logrus.Fatalf("no message for %v", "control_ping_reply")
99         }
100         msgControlPing, ok := g.gen.messagesByName["control_ping"]
101         if !ok {
102                 logrus.Fatalf("no message for %v", "control_ping")
103         }
104
105         for _, rpc := range svc.RPCs {
106                 logf(" gen RPC: %v (%s)", rpc.GoName, rpc.VPP.Request)
107
108                 g.P("func (c *", serviceImplName, ") ", rpcMethodSignature(g, rpc), " {")
109                 if rpc.VPP.Stream {
110                         streamImpl := fmt.Sprintf("%s_%sClient", serviceImplName, rpc.GoName)
111                         streamApi := fmt.Sprintf("%s_%sClient", serviceApiName, rpc.GoName)
112
113                         msgDetails := rpc.MsgReply
114                         var msgReply *Message
115                         if rpc.MsgStream != nil {
116                                 msgDetails = rpc.MsgStream
117                                 msgReply = rpc.MsgReply
118                         } else {
119                                 msgDetails = rpc.MsgReply
120                                 msgReply = msgControlPingReply
121                         }
122
123                         g.P("stream, err := c.conn.NewStream(ctx)")
124                         g.P("if err != nil { return nil, err }")
125                         g.P("x := &", streamImpl, "{stream}")
126                         g.P("if err := x.Stream.SendMsg(in); err != nil {")
127                         g.P("   return nil, err")
128                         g.P("}")
129                         if rpc.MsgStream == nil {
130                                 g.P("if err = x.Stream.SendMsg(&", msgControlPing.GoIdent, "{}); err != nil {")
131                                 g.P("   return nil, err")
132                                 g.P("}")
133                         }
134                         g.P("return x, nil")
135                         g.P("}")
136                         g.P()
137                         g.P("type ", streamApi, " interface {")
138                         g.P("   Recv() (*", msgDetails.GoIdent, ", error)")
139                         g.P("   ", govppApiPkg.Ident("Stream"))
140                         g.P("}")
141                         g.P()
142
143                         g.P("type ", streamImpl, " struct {")
144                         g.P("   ", govppApiPkg.Ident("Stream"))
145                         g.P("}")
146                         g.P()
147
148                         g.P("func (c *", streamImpl, ") Recv() (*", msgDetails.GoIdent, ", error) {")
149                         g.P("   msg, err := c.Stream.RecvMsg()")
150                         g.P("   if err != nil { return nil, err }")
151                         g.P("   switch m := msg.(type) {")
152                         g.P("   case *", msgDetails.GoIdent, ":")
153                         g.P("           return m, nil")
154                         g.P("   case *", msgReply.GoIdent, ":")
155                         g.P("           return nil, ", ioPkg.Ident("EOF"))
156                         g.P("   default:")
157                         g.P("           return nil, ", fmtPkg.Ident("Errorf"), "(\"unexpected message: %T %v\", m, m)")
158                         g.P("}")
159                 } else if rpc.MsgReply != nil {
160                         g.P("out := new(", rpc.MsgReply.GoIdent, ")")
161                         g.P("err := c.conn.Invoke(ctx, in, out)")
162                         g.P("if err != nil { return nil, err }")
163                         if retvalField := getRetvalField(rpc.MsgReply); retvalField != nil {
164                                 if fieldType := getFieldType(g, retvalField); fieldType == "int32" {
165                                         g.P("return out, ", govppApiPkg.Ident("RetvalToVPPApiError"), "(out.", retvalField.GoName, ")")
166                                 } else {
167                                         g.P("return out, ", govppApiPkg.Ident("RetvalToVPPApiError"), "(int32(out.", retvalField.GoName, "))")
168                                 }
169                         } else {
170                                 g.P("return out, nil")
171                         }
172                 } else {
173                         g.P("stream, err := c.conn.NewStream(ctx)")
174                         g.P("if err != nil { return err }")
175                         g.P("err = stream.SendMsg(in)")
176                         g.P("if err != nil { return err }")
177                         g.P("return nil")
178                 }
179                 g.P("}")
180                 g.P()
181         }
182
183         // TODO: generate service descriptor
184         /*fmt.Fprintf(w, "var %s = api.%s{\n", serviceDescName, serviceDescType)
185           fmt.Fprintf(w, "\tServiceName: \"%s\",\n", ctx.moduleName)
186           fmt.Fprintf(w, "\tHandlerType: (*%s)(nil),\n", serviceApiName)
187           fmt.Fprintf(w, "\tMethods: []api.MethodDesc{\n")
188           for _, method := range rpcs {
189                 fmt.Fprintf(w, "\t  {\n")
190                 fmt.Fprintf(w, "\t    MethodName: \"%s\",\n", method.Name)
191                 fmt.Fprintf(w, "\t  },\n")
192           }
193           fmt.Fprintf(w, "\t},\n")
194           //fmt.Fprintf(w, "\tCompatibility: %s,\n", messageCrcName)
195           //fmt.Fprintf(w, "\tMetadata: reflect.TypeOf((*%s)(nil)).Elem().PkgPath(),\n", serviceApiName)
196           fmt.Fprintf(w, "\tMetadata: \"%s\",\n", ctx.inputFile)
197           fmt.Fprintln(w, "}")*/
198
199         g.P()
200 }
201
202 func rpcMethodSignature(g *GenFile, rpc *RPC) string {
203         s := rpc.GoName + "(ctx " + g.GoIdent(contextPkg.Ident("Context"))
204         s += ", in *" + g.GoIdent(rpc.MsgRequest.GoIdent) + ") ("
205         if rpc.VPP.Stream {
206                 s += serviceApiName + "_" + rpc.GoName + "Client, "
207         } else if rpc.MsgReply != nil {
208                 s += "*" + g.GoIdent(rpc.MsgReply.GoIdent) + ", "
209         }
210         s += "error)"
211         return s
212 }