1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package gengogrpc contains the gRPC code generator.
6package gengogrpc
7
8import (
9	"fmt"
10	"strconv"
11	"strings"
12
13	"google.golang.org/protobuf/compiler/protogen"
14
15	"google.golang.org/protobuf/types/descriptorpb"
16)
17
18const (
19	contextPackage = protogen.GoImportPath("context")
20	grpcPackage    = protogen.GoImportPath("google.golang.org/grpc")
21	codesPackage   = protogen.GoImportPath("google.golang.org/grpc/codes")
22	statusPackage  = protogen.GoImportPath("google.golang.org/grpc/status")
23)
24
25// GenerateFile generates a _grpc.pb.go file containing gRPC service definitions.
26func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
27	if len(file.Services) == 0 {
28		return nil
29	}
30	filename := file.GeneratedFilenamePrefix + "_grpc.pb.go"
31	g := gen.NewGeneratedFile(filename, file.GoImportPath)
32	g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.")
33	g.P()
34	g.P("package ", file.GoPackageName)
35	g.P()
36	GenerateFileContent(gen, file, g)
37	return g
38}
39
40// GenerateFileContent generates the gRPC service definitions, excluding the package statement.
41func GenerateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
42	if len(file.Services) == 0 {
43		return
44	}
45
46	// TODO: Remove this. We don't need to include these references any more.
47	g.P("// Reference imports to suppress errors if they are not otherwise used.")
48	g.P("var _ ", contextPackage.Ident("Context"))
49	g.P("var _ ", grpcPackage.Ident("ClientConnInterface"))
50	g.P()
51
52	g.P("// This is a compile-time assertion to ensure that this generated file")
53	g.P("// is compatible with the grpc package it is being compiled against.")
54	g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6"))
55	g.P()
56	for _, service := range file.Services {
57		genService(gen, file, g, service)
58	}
59}
60
61func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
62	clientName := service.GoName + "Client"
63
64	g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
65	g.P("//")
66	g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.")
67
68	// Client interface.
69	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
70		g.P("//")
71		g.P(deprecationComment)
72	}
73	g.Annotate(clientName, service.Location)
74	g.P("type ", clientName, " interface {")
75	for _, method := range service.Methods {
76		g.Annotate(clientName+"."+method.GoName, method.Location)
77		if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
78			g.P(deprecationComment)
79		}
80		g.P(method.Comments.Leading,
81			clientSignature(g, method))
82	}
83	g.P("}")
84	g.P()
85
86	// Client structure.
87	g.P("type ", unexport(clientName), " struct {")
88	g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
89	g.P("}")
90	g.P()
91
92	// NewClient factory.
93	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
94		g.P(deprecationComment)
95	}
96	g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
97	g.P("return &", unexport(clientName), "{cc}")
98	g.P("}")
99	g.P()
100
101	var methodIndex, streamIndex int
102	// Client method implementations.
103	for _, method := range service.Methods {
104		if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
105			// Unary RPC method
106			genClientMethod(gen, file, g, method, methodIndex)
107			methodIndex++
108		} else {
109			// Streaming RPC method
110			genClientMethod(gen, file, g, method, streamIndex)
111			streamIndex++
112		}
113	}
114
115	// Server interface.
116	serverType := service.GoName + "Server"
117	g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
118	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
119		g.P("//")
120		g.P(deprecationComment)
121	}
122	g.Annotate(serverType, service.Location)
123	g.P("type ", serverType, " interface {")
124	for _, method := range service.Methods {
125		g.Annotate(serverType+"."+method.GoName, method.Location)
126		if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
127			g.P(deprecationComment)
128		}
129		g.P(method.Comments.Leading,
130			serverSignature(g, method))
131	}
132	g.P("}")
133	g.P()
134
135	// Server Unimplemented struct for forward compatibility.
136	g.P("// Unimplemented", serverType, " can be embedded to have forward compatible implementations.")
137	g.P("type Unimplemented", serverType, " struct {")
138	g.P("}")
139	g.P()
140	for _, method := range service.Methods {
141		nilArg := ""
142		if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
143			nilArg = "nil,"
144		}
145		g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{")
146		g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
147		g.P("}")
148	}
149	g.P()
150
151	// Server registration.
152	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
153		g.P(deprecationComment)
154	}
155	serviceDescVar := "_" + service.GoName + "_serviceDesc"
156	g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {")
157	g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
158	g.P("}")
159	g.P()
160
161	// Server handler implementations.
162	var handlerNames []string
163	for _, method := range service.Methods {
164		hname := genServerMethod(gen, file, g, method)
165		handlerNames = append(handlerNames, hname)
166	}
167
168	// Service descriptor.
169	g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
170	g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
171	g.P("HandlerType: (*", serverType, ")(nil),")
172	g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
173	for i, method := range service.Methods {
174		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
175			continue
176		}
177		g.P("{")
178		g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
179		g.P("Handler: ", handlerNames[i], ",")
180		g.P("},")
181	}
182	g.P("},")
183	g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
184	for i, method := range service.Methods {
185		if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
186			continue
187		}
188		g.P("{")
189		g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
190		g.P("Handler: ", handlerNames[i], ",")
191		if method.Desc.IsStreamingServer() {
192			g.P("ServerStreams: true,")
193		}
194		if method.Desc.IsStreamingClient() {
195			g.P("ClientStreams: true,")
196		}
197		g.P("},")
198	}
199	g.P("},")
200	g.P("Metadata: \"", file.Desc.Path(), "\",")
201	g.P("}")
202	g.P()
203}
204
205func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
206	s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
207	if !method.Desc.IsStreamingClient() {
208		s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
209	}
210	s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
211	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
212		s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
213	} else {
214		s += method.Parent.GoName + "_" + method.GoName + "Client"
215	}
216	s += ", error)"
217	return s
218}
219
220func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
221	service := method.Parent
222	sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
223
224	if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
225		g.P(deprecationComment)
226	}
227	g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
228	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
229		g.P("out := new(", method.Output.GoIdent, ")")
230		g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
231		g.P("if err != nil { return nil, err }")
232		g.P("return out, nil")
233		g.P("}")
234		g.P()
235		return
236	}
237	streamType := unexport(service.GoName) + method.GoName + "Client"
238	serviceDescVar := "_" + service.GoName + "_serviceDesc"
239	g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
240	g.P("if err != nil { return nil, err }")
241	g.P("x := &", streamType, "{stream}")
242	if !method.Desc.IsStreamingClient() {
243		g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
244		g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
245	}
246	g.P("return x, nil")
247	g.P("}")
248	g.P()
249
250	genSend := method.Desc.IsStreamingClient()
251	genRecv := method.Desc.IsStreamingServer()
252	genCloseAndRecv := !method.Desc.IsStreamingServer()
253
254	// Stream auxiliary types and methods.
255	g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
256	if genSend {
257		g.P("Send(*", method.Input.GoIdent, ") error")
258	}
259	if genRecv {
260		g.P("Recv() (*", method.Output.GoIdent, ", error)")
261	}
262	if genCloseAndRecv {
263		g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
264	}
265	g.P(grpcPackage.Ident("ClientStream"))
266	g.P("}")
267	g.P()
268
269	g.P("type ", streamType, " struct {")
270	g.P(grpcPackage.Ident("ClientStream"))
271	g.P("}")
272	g.P()
273
274	if genSend {
275		g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
276		g.P("return x.ClientStream.SendMsg(m)")
277		g.P("}")
278		g.P()
279	}
280	if genRecv {
281		g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
282		g.P("m := new(", method.Output.GoIdent, ")")
283		g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
284		g.P("return m, nil")
285		g.P("}")
286		g.P()
287	}
288	if genCloseAndRecv {
289		g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
290		g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
291		g.P("m := new(", method.Output.GoIdent, ")")
292		g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
293		g.P("return m, nil")
294		g.P("}")
295		g.P()
296	}
297}
298
299func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
300	var reqArgs []string
301	ret := "error"
302	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
303		reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
304		ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
305	}
306	if !method.Desc.IsStreamingClient() {
307		reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
308	}
309	if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
310		reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
311	}
312	return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
313}
314
315func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
316	service := method.Parent
317	hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
318
319	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
320		g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
321		g.P("in := new(", method.Input.GoIdent, ")")
322		g.P("if err := dec(in); err != nil { return nil, err }")
323		g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
324		g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
325		g.P("Server: srv,")
326		g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",")
327		g.P("}")
328		g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
329		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
330		g.P("}")
331		g.P("return interceptor(ctx, in, info, handler)")
332		g.P("}")
333		g.P()
334		return hname
335	}
336	streamType := unexport(service.GoName) + method.GoName + "Server"
337	g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
338	if !method.Desc.IsStreamingClient() {
339		g.P("m := new(", method.Input.GoIdent, ")")
340		g.P("if err := stream.RecvMsg(m); err != nil { return err }")
341		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
342	} else {
343		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
344	}
345	g.P("}")
346	g.P()
347
348	genSend := method.Desc.IsStreamingServer()
349	genSendAndClose := !method.Desc.IsStreamingServer()
350	genRecv := method.Desc.IsStreamingClient()
351
352	// Stream auxiliary types and methods.
353	g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
354	if genSend {
355		g.P("Send(*", method.Output.GoIdent, ") error")
356	}
357	if genSendAndClose {
358		g.P("SendAndClose(*", method.Output.GoIdent, ") error")
359	}
360	if genRecv {
361		g.P("Recv() (*", method.Input.GoIdent, ", error)")
362	}
363	g.P(grpcPackage.Ident("ServerStream"))
364	g.P("}")
365	g.P()
366
367	g.P("type ", streamType, " struct {")
368	g.P(grpcPackage.Ident("ServerStream"))
369	g.P("}")
370	g.P()
371
372	if genSend {
373		g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
374		g.P("return x.ServerStream.SendMsg(m)")
375		g.P("}")
376		g.P()
377	}
378	if genSendAndClose {
379		g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
380		g.P("return x.ServerStream.SendMsg(m)")
381		g.P("}")
382		g.P()
383	}
384	if genRecv {
385		g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
386		g.P("m := new(", method.Input.GoIdent, ")")
387		g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
388		g.P("return m, nil")
389		g.P("}")
390		g.P()
391	}
392
393	return hname
394}
395
396const deprecationComment = "// Deprecated: Do not use."
397
398func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
399