1/* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package main 20 21import ( 22 "fmt" 23 "strconv" 24 "strings" 25 26 "google.golang.org/protobuf/compiler/protogen" 27 28 "google.golang.org/protobuf/types/descriptorpb" 29) 30 31const ( 32 contextPackage = protogen.GoImportPath("context") 33 grpcPackage = protogen.GoImportPath("google.golang.org/grpc") 34 codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes") 35 statusPackage = protogen.GoImportPath("google.golang.org/grpc/status") 36) 37 38// generateFile generates a _grpc.pb.go file containing gRPC service definitions. 39func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile { 40 if len(file.Services) == 0 { 41 return nil 42 } 43 filename := file.GeneratedFilenamePrefix + "_grpc.pb.go" 44 g := gen.NewGeneratedFile(filename, file.GoImportPath) 45 g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") 46 g.P() 47 g.P("package ", file.GoPackageName) 48 g.P() 49 generateFileContent(gen, file, g) 50 return g 51} 52 53// generateFileContent generates the gRPC service definitions, excluding the package statement. 54func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { 55 if len(file.Services) == 0 { 56 return 57 } 58 59 g.P("// This is a compile-time assertion to ensure that this generated file") 60 g.P("// is compatible with the grpc package it is being compiled against.") 61 g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6")) 62 g.P() 63 for _, service := range file.Services { 64 genService(gen, file, g, service) 65 } 66} 67 68func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { 69 clientName := service.GoName + "Client" 70 71 g.P("// ", clientName, " is the client API for ", service.GoName, " service.") 72 g.P("//") 73 g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.") 74 75 // Client interface. 76 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 77 g.P("//") 78 g.P(deprecationComment) 79 } 80 g.Annotate(clientName, service.Location) 81 g.P("type ", clientName, " interface {") 82 for _, method := range service.Methods { 83 g.Annotate(clientName+"."+method.GoName, method.Location) 84 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 85 g.P(deprecationComment) 86 } 87 g.P(method.Comments.Leading, 88 clientSignature(g, method)) 89 } 90 g.P("}") 91 g.P() 92 93 // Client structure. 94 g.P("type ", unexport(clientName), " struct {") 95 g.P("cc ", grpcPackage.Ident("ClientConnInterface")) 96 g.P("}") 97 g.P() 98 99 // NewClient factory. 100 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 101 g.P(deprecationComment) 102 } 103 g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {") 104 g.P("return &", unexport(clientName), "{cc}") 105 g.P("}") 106 g.P() 107 108 var methodIndex, streamIndex int 109 // Client method implementations. 110 for _, method := range service.Methods { 111 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 112 // Unary RPC method 113 genClientMethod(gen, file, g, method, methodIndex) 114 methodIndex++ 115 } else { 116 // Streaming RPC method 117 genClientMethod(gen, file, g, method, streamIndex) 118 streamIndex++ 119 } 120 } 121 122 mustOrShould := "must" 123 if !*requireUnimplemented { 124 mustOrShould = "should" 125 } 126 127 // Server interface. 128 serverType := service.GoName + "Server" 129 g.P("// ", serverType, " is the server API for ", service.GoName, " service.") 130 g.P("// All implementations ", mustOrShould, " embed Unimplemented", serverType) 131 g.P("// for forward compatibility") 132 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 133 g.P("//") 134 g.P(deprecationComment) 135 } 136 g.Annotate(serverType, service.Location) 137 g.P("type ", serverType, " interface {") 138 for _, method := range service.Methods { 139 g.Annotate(serverType+"."+method.GoName, method.Location) 140 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 141 g.P(deprecationComment) 142 } 143 g.P(method.Comments.Leading, 144 serverSignature(g, method)) 145 } 146 if *requireUnimplemented { 147 g.P("mustEmbedUnimplemented", serverType, "()") 148 } 149 g.P("}") 150 g.P() 151 152 // Server Unimplemented struct for forward compatibility. 153 g.P("// Unimplemented", serverType, " ", mustOrShould, " be embedded to have forward compatible implementations.") 154 g.P("type Unimplemented", serverType, " struct {") 155 g.P("}") 156 g.P() 157 for _, method := range service.Methods { 158 nilArg := "" 159 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 160 nilArg = "nil," 161 } 162 g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{") 163 g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`) 164 g.P("}") 165 } 166 if *requireUnimplemented { 167 g.P("func (*Unimplemented", serverType, ") mustEmbedUnimplemented", serverType, "() {}") 168 } 169 g.P() 170 171 // Server registration. 172 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 173 g.P(deprecationComment) 174 } 175 serviceDescVar := "_" + service.GoName + "_serviceDesc" 176 g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {") 177 g.P("s.RegisterService(&", serviceDescVar, `, srv)`) 178 g.P("}") 179 g.P() 180 181 // Server handler implementations. 182 var handlerNames []string 183 for _, method := range service.Methods { 184 hname := genServerMethod(gen, file, g, method) 185 handlerNames = append(handlerNames, hname) 186 } 187 188 // Service descriptor. 189 g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {") 190 g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",") 191 g.P("HandlerType: (*", serverType, ")(nil),") 192 g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{") 193 for i, method := range service.Methods { 194 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { 195 continue 196 } 197 g.P("{") 198 g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",") 199 g.P("Handler: ", handlerNames[i], ",") 200 g.P("},") 201 } 202 g.P("},") 203 g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{") 204 for i, method := range service.Methods { 205 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 206 continue 207 } 208 g.P("{") 209 g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",") 210 g.P("Handler: ", handlerNames[i], ",") 211 if method.Desc.IsStreamingServer() { 212 g.P("ServerStreams: true,") 213 } 214 if method.Desc.IsStreamingClient() { 215 g.P("ClientStreams: true,") 216 } 217 g.P("},") 218 } 219 g.P("},") 220 g.P("Metadata: \"", file.Desc.Path(), "\",") 221 g.P("}") 222 g.P() 223} 224 225func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { 226 s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) 227 if !method.Desc.IsStreamingClient() { 228 s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent) 229 } 230 s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") (" 231 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 232 s += "*" + g.QualifiedGoIdent(method.Output.GoIdent) 233 } else { 234 s += method.Parent.GoName + "_" + method.GoName + "Client" 235 } 236 s += ", error)" 237 return s 238} 239 240func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) { 241 service := method.Parent 242 sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) 243 244 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 245 g.P(deprecationComment) 246 } 247 g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") 248 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 249 g.P("out := new(", method.Output.GoIdent, ")") 250 g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`) 251 g.P("if err != nil { return nil, err }") 252 g.P("return out, nil") 253 g.P("}") 254 g.P() 255 return 256 } 257 streamType := unexport(service.GoName) + method.GoName + "Client" 258 serviceDescVar := "_" + service.GoName + "_serviceDesc" 259 g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`) 260 g.P("if err != nil { return nil, err }") 261 g.P("x := &", streamType, "{stream}") 262 if !method.Desc.IsStreamingClient() { 263 g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }") 264 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") 265 } 266 g.P("return x, nil") 267 g.P("}") 268 g.P() 269 270 genSend := method.Desc.IsStreamingClient() 271 genRecv := method.Desc.IsStreamingServer() 272 genCloseAndRecv := !method.Desc.IsStreamingServer() 273 274 // Stream auxiliary types and methods. 275 g.P("type ", service.GoName, "_", method.GoName, "Client interface {") 276 if genSend { 277 g.P("Send(*", method.Input.GoIdent, ") error") 278 } 279 if genRecv { 280 g.P("Recv() (*", method.Output.GoIdent, ", error)") 281 } 282 if genCloseAndRecv { 283 g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)") 284 } 285 g.P(grpcPackage.Ident("ClientStream")) 286 g.P("}") 287 g.P() 288 289 g.P("type ", streamType, " struct {") 290 g.P(grpcPackage.Ident("ClientStream")) 291 g.P("}") 292 g.P() 293 294 if genSend { 295 g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {") 296 g.P("return x.ClientStream.SendMsg(m)") 297 g.P("}") 298 g.P() 299 } 300 if genRecv { 301 g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {") 302 g.P("m := new(", method.Output.GoIdent, ")") 303 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") 304 g.P("return m, nil") 305 g.P("}") 306 g.P() 307 } 308 if genCloseAndRecv { 309 g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {") 310 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") 311 g.P("m := new(", method.Output.GoIdent, ")") 312 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") 313 g.P("return m, nil") 314 g.P("}") 315 g.P() 316 } 317} 318 319func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { 320 var reqArgs []string 321 ret := "error" 322 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 323 reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context"))) 324 ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" 325 } 326 if !method.Desc.IsStreamingClient() { 327 reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent)) 328 } 329 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { 330 reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server") 331 } 332 return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret 333} 334 335func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string { 336 service := method.Parent 337 hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName) 338 339 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 340 g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {") 341 g.P("in := new(", method.Input.GoIdent, ")") 342 g.P("if err := dec(in); err != nil { return nil, err }") 343 g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }") 344 g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{") 345 g.P("Server: srv,") 346 g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",") 347 g.P("}") 348 g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {") 349 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))") 350 g.P("}") 351 g.P("return interceptor(ctx, in, info, handler)") 352 g.P("}") 353 g.P() 354 return hname 355 } 356 streamType := unexport(service.GoName) + method.GoName + "Server" 357 g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {") 358 if !method.Desc.IsStreamingClient() { 359 g.P("m := new(", method.Input.GoIdent, ")") 360 g.P("if err := stream.RecvMsg(m); err != nil { return err }") 361 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})") 362 } else { 363 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})") 364 } 365 g.P("}") 366 g.P() 367 368 genSend := method.Desc.IsStreamingServer() 369 genSendAndClose := !method.Desc.IsStreamingServer() 370 genRecv := method.Desc.IsStreamingClient() 371 372 // Stream auxiliary types and methods. 373 g.P("type ", service.GoName, "_", method.GoName, "Server interface {") 374 if genSend { 375 g.P("Send(*", method.Output.GoIdent, ") error") 376 } 377 if genSendAndClose { 378 g.P("SendAndClose(*", method.Output.GoIdent, ") error") 379 } 380 if genRecv { 381 g.P("Recv() (*", method.Input.GoIdent, ", error)") 382 } 383 g.P(grpcPackage.Ident("ServerStream")) 384 g.P("}") 385 g.P() 386 387 g.P("type ", streamType, " struct {") 388 g.P(grpcPackage.Ident("ServerStream")) 389 g.P("}") 390 g.P() 391 392 if genSend { 393 g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {") 394 g.P("return x.ServerStream.SendMsg(m)") 395 g.P("}") 396 g.P() 397 } 398 if genSendAndClose { 399 g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {") 400 g.P("return x.ServerStream.SendMsg(m)") 401 g.P("}") 402 g.P() 403 } 404 if genRecv { 405 g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {") 406 g.P("m := new(", method.Input.GoIdent, ")") 407 g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }") 408 g.P("return m, nil") 409 g.P("}") 410 g.P() 411 } 412 413 return hname 414} 415 416const deprecationComment = "// Deprecated: Do not use." 417 418func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } 419