1// Copyright 2018 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// Package gengogrpc contains the gRPC code generator. 6package gengogrpc 7 8import ( 9 "fmt" 10 "strconv" 11 "strings" 12 13 "google.golang.org/protobuf/compiler/protogen" 14 15 "google.golang.org/protobuf/types/descriptorpb" 16) 17 18const ( 19 contextPackage = protogen.GoImportPath("context") 20 grpcPackage = protogen.GoImportPath("google.golang.org/grpc") 21 codesPackage = protogen.GoImportPath("google.golang.org/grpc/codes") 22 statusPackage = protogen.GoImportPath("google.golang.org/grpc/status") 23) 24 25// GenerateFile generates a _grpc.pb.go file containing gRPC service definitions. 26func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile { 27 if len(file.Services) == 0 { 28 return nil 29 } 30 filename := file.GeneratedFilenamePrefix + "_grpc.pb.go" 31 g := gen.NewGeneratedFile(filename, file.GoImportPath) 32 g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.") 33 g.P() 34 g.P("package ", file.GoPackageName) 35 g.P() 36 GenerateFileContent(gen, file, g) 37 return g 38} 39 40// GenerateFileContent generates the gRPC service definitions, excluding the package statement. 41func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) { 42 if len(file.Services) == 0 { 43 return 44 } 45 46 // TODO: Remove this. We don't need to include these references any more. 47 g.P("// Reference imports to suppress errors if they are not otherwise used.") 48 g.P("var _ ", contextPackage.Ident("Context")) 49 g.P("var _ ", grpcPackage.Ident("ClientConnInterface")) 50 g.P() 51 52 g.P("// This is a compile-time assertion to ensure that this generated file") 53 g.P("// is compatible with the grpc package it is being compiled against.") 54 g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6")) 55 g.P() 56 for _, service := range file.Services { 57 genService(gen, file, g, service) 58 } 59} 60 61func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) { 62 clientName := service.GoName + "Client" 63 64 g.P("// ", clientName, " is the client API for ", service.GoName, " service.") 65 g.P("//") 66 g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.") 67 68 // Client interface. 69 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 70 g.P("//") 71 g.P(deprecationComment) 72 } 73 g.Annotate(clientName, service.Location) 74 g.P("type ", clientName, " interface {") 75 for _, method := range service.Methods { 76 g.Annotate(clientName+"."+method.GoName, method.Location) 77 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 78 g.P(deprecationComment) 79 } 80 g.P(method.Comments.Leading, 81 clientSignature(g, method)) 82 } 83 g.P("}") 84 g.P() 85 86 // Client structure. 87 g.P("type ", unexport(clientName), " struct {") 88 g.P("cc ", grpcPackage.Ident("ClientConnInterface")) 89 g.P("}") 90 g.P() 91 92 // NewClient factory. 93 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 94 g.P(deprecationComment) 95 } 96 g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {") 97 g.P("return &", unexport(clientName), "{cc}") 98 g.P("}") 99 g.P() 100 101 var methodIndex, streamIndex int 102 // Client method implementations. 103 for _, method := range service.Methods { 104 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 105 // Unary RPC method 106 genClientMethod(gen, file, g, method, methodIndex) 107 methodIndex++ 108 } else { 109 // Streaming RPC method 110 genClientMethod(gen, file, g, method, streamIndex) 111 streamIndex++ 112 } 113 } 114 115 // Server interface. 116 serverType := service.GoName + "Server" 117 g.P("// ", serverType, " is the server API for ", service.GoName, " service.") 118 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 119 g.P("//") 120 g.P(deprecationComment) 121 } 122 g.Annotate(serverType, service.Location) 123 g.P("type ", serverType, " interface {") 124 for _, method := range service.Methods { 125 g.Annotate(serverType+"."+method.GoName, method.Location) 126 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 127 g.P(deprecationComment) 128 } 129 g.P(method.Comments.Leading, 130 serverSignature(g, method)) 131 } 132 g.P("}") 133 g.P() 134 135 // Server Unimplemented struct for forward compatibility. 136 g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.") 137 g.P("type Unimplemented", serverType, " struct {") 138 g.P("}") 139 g.P() 140 for _, method := range service.Methods { 141 nilArg := "" 142 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 143 nilArg = "nil," 144 } 145 g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{") 146 g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`) 147 g.P("}") 148 } 149 g.P() 150 151 // Server registration. 152 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 153 g.P(deprecationComment) 154 } 155 serviceDescVar := "_" + service.GoName + "_serviceDesc" 156 g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {") 157 g.P("s.RegisterService(&", serviceDescVar, `, srv)`) 158 g.P("}") 159 g.P() 160 161 // Server handler implementations. 162 var handlerNames []string 163 for _, method := range service.Methods { 164 hname := genServerMethod(gen, file, g, method) 165 handlerNames = append(handlerNames, hname) 166 } 167 168 // Service descriptor. 169 g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {") 170 g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",") 171 g.P("HandlerType: (*", serverType, ")(nil),") 172 g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{") 173 for i, method := range service.Methods { 174 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { 175 continue 176 } 177 g.P("{") 178 g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",") 179 g.P("Handler: ", handlerNames[i], ",") 180 g.P("},") 181 } 182 g.P("},") 183 g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{") 184 for i, method := range service.Methods { 185 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 186 continue 187 } 188 g.P("{") 189 g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",") 190 g.P("Handler: ", handlerNames[i], ",") 191 if method.Desc.IsStreamingServer() { 192 g.P("ServerStreams: true,") 193 } 194 if method.Desc.IsStreamingClient() { 195 g.P("ClientStreams: true,") 196 } 197 g.P("},") 198 } 199 g.P("},") 200 g.P("Metadata: \"", file.Desc.Path(), "\",") 201 g.P("}") 202 g.P() 203} 204 205func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string { 206 s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context")) 207 if !method.Desc.IsStreamingClient() { 208 s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent) 209 } 210 s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") (" 211 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 212 s += "*" + g.QualifiedGoIdent(method.Output.GoIdent) 213 } else { 214 s += method.Parent.GoName + "_" + method.GoName + "Client" 215 } 216 s += ", error)" 217 return s 218} 219 220func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) { 221 service := method.Parent 222 sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name()) 223 224 if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() { 225 g.P(deprecationComment) 226 } 227 g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{") 228 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 229 g.P("out := new(", method.Output.GoIdent, ")") 230 g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`) 231 g.P("if err != nil { return nil, err }") 232 g.P("return out, nil") 233 g.P("}") 234 g.P() 235 return 236 } 237 streamType := unexport(service.GoName) + method.GoName + "Client" 238 serviceDescVar := "_" + service.GoName + "_serviceDesc" 239 g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`) 240 g.P("if err != nil { return nil, err }") 241 g.P("x := &", streamType, "{stream}") 242 if !method.Desc.IsStreamingClient() { 243 g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }") 244 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") 245 } 246 g.P("return x, nil") 247 g.P("}") 248 g.P() 249 250 genSend := method.Desc.IsStreamingClient() 251 genRecv := method.Desc.IsStreamingServer() 252 genCloseAndRecv := !method.Desc.IsStreamingServer() 253 254 // Stream auxiliary types and methods. 255 g.P("type ", service.GoName, "_", method.GoName, "Client interface {") 256 if genSend { 257 g.P("Send(*", method.Input.GoIdent, ") error") 258 } 259 if genRecv { 260 g.P("Recv() (*", method.Output.GoIdent, ", error)") 261 } 262 if genCloseAndRecv { 263 g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)") 264 } 265 g.P(grpcPackage.Ident("ClientStream")) 266 g.P("}") 267 g.P() 268 269 g.P("type ", streamType, " struct {") 270 g.P(grpcPackage.Ident("ClientStream")) 271 g.P("}") 272 g.P() 273 274 if genSend { 275 g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {") 276 g.P("return x.ClientStream.SendMsg(m)") 277 g.P("}") 278 g.P() 279 } 280 if genRecv { 281 g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {") 282 g.P("m := new(", method.Output.GoIdent, ")") 283 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") 284 g.P("return m, nil") 285 g.P("}") 286 g.P() 287 } 288 if genCloseAndRecv { 289 g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {") 290 g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }") 291 g.P("m := new(", method.Output.GoIdent, ")") 292 g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }") 293 g.P("return m, nil") 294 g.P("}") 295 g.P() 296 } 297} 298 299func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string { 300 var reqArgs []string 301 ret := "error" 302 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 303 reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context"))) 304 ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)" 305 } 306 if !method.Desc.IsStreamingClient() { 307 reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent)) 308 } 309 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { 310 reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server") 311 } 312 return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret 313} 314 315func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string { 316 service := method.Parent 317 hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName) 318 319 if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() { 320 g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {") 321 g.P("in := new(", method.Input.GoIdent, ")") 322 g.P("if err := dec(in); err != nil { return nil, err }") 323 g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }") 324 g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{") 325 g.P("Server: srv,") 326 g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",") 327 g.P("}") 328 g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {") 329 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))") 330 g.P("}") 331 g.P("return interceptor(ctx, in, info, handler)") 332 g.P("}") 333 g.P() 334 return hname 335 } 336 streamType := unexport(service.GoName) + method.GoName + "Server" 337 g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {") 338 if !method.Desc.IsStreamingClient() { 339 g.P("m := new(", method.Input.GoIdent, ")") 340 g.P("if err := stream.RecvMsg(m); err != nil { return err }") 341 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})") 342 } else { 343 g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})") 344 } 345 g.P("}") 346 g.P() 347 348 genSend := method.Desc.IsStreamingServer() 349 genSendAndClose := !method.Desc.IsStreamingServer() 350 genRecv := method.Desc.IsStreamingClient() 351 352 // Stream auxiliary types and methods. 353 g.P("type ", service.GoName, "_", method.GoName, "Server interface {") 354 if genSend { 355 g.P("Send(*", method.Output.GoIdent, ") error") 356 } 357 if genSendAndClose { 358 g.P("SendAndClose(*", method.Output.GoIdent, ") error") 359 } 360 if genRecv { 361 g.P("Recv() (*", method.Input.GoIdent, ", error)") 362 } 363 g.P(grpcPackage.Ident("ServerStream")) 364 g.P("}") 365 g.P() 366 367 g.P("type ", streamType, " struct {") 368 g.P(grpcPackage.Ident("ServerStream")) 369 g.P("}") 370 g.P() 371 372 if genSend { 373 g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {") 374 g.P("return x.ServerStream.SendMsg(m)") 375 g.P("}") 376 g.P() 377 } 378 if genSendAndClose { 379 g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {") 380 g.P("return x.ServerStream.SendMsg(m)") 381 g.P("}") 382 g.P() 383 } 384 if genRecv { 385 g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {") 386 g.P("m := new(", method.Input.GoIdent, ")") 387 g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }") 388 g.P("return m, nil") 389 g.P("}") 390 g.P() 391 } 392 393 return hname 394} 395 396const deprecationComment = "// Deprecated: Do not use." 397 398func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] } 399