1/*
2 *
3 * Copyright 2020 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package main
20
21import (
22	"fmt"
23	"strconv"
24	"strings"
25
26	"google.golang.org/protobuf/compiler/protogen"
27
28	"google.golang.org/protobuf/types/descriptorpb"
29)
30
31const (
32	contextPackage = protogen.GoImportPath("context")
33	grpcPackage    = protogen.GoImportPath("google.golang.org/grpc")
34	codesPackage   = protogen.GoImportPath("google.golang.org/grpc/codes")
35	statusPackage  = protogen.GoImportPath("google.golang.org/grpc/status")
36)
37
38// generateFile generates a _grpc.pb.go file containing gRPC service definitions.
39func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
40	if len(file.Services) == 0 {
41		return nil
42	}
43	filename := file.GeneratedFilenamePrefix + "_grpc.pb.go"
44	g := gen.NewGeneratedFile(filename, file.GoImportPath)
45	g.P("// Code generated by protoc-gen-go-grpc. DO NOT EDIT.")
46	g.P()
47	g.P("package ", file.GoPackageName)
48	g.P()
49	generateFileContent(gen, file, g)
50	return g
51}
52
53// generateFileContent generates the gRPC service definitions, excluding the package statement.
54func generateFileContent(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile) {
55	if len(file.Services) == 0 {
56		return
57	}
58
59	g.P("// This is a compile-time assertion to ensure that this generated file")
60	g.P("// is compatible with the grpc package it is being compiled against.")
61	g.P("const _ = ", grpcPackage.Ident("SupportPackageIsVersion6"))
62	g.P()
63	for _, service := range file.Services {
64		genService(gen, file, g, service)
65	}
66}
67
68func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
69	clientName := service.GoName + "Client"
70
71	g.P("// ", clientName, " is the client API for ", service.GoName, " service.")
72	g.P("//")
73	g.P("// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.")
74
75	// Client interface.
76	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
77		g.P("//")
78		g.P(deprecationComment)
79	}
80	g.Annotate(clientName, service.Location)
81	g.P("type ", clientName, " interface {")
82	for _, method := range service.Methods {
83		g.Annotate(clientName+"."+method.GoName, method.Location)
84		if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
85			g.P(deprecationComment)
86		}
87		g.P(method.Comments.Leading,
88			clientSignature(g, method))
89	}
90	g.P("}")
91	g.P()
92
93	// Client structure.
94	g.P("type ", unexport(clientName), " struct {")
95	g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
96	g.P("}")
97	g.P()
98
99	// NewClient factory.
100	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
101		g.P(deprecationComment)
102	}
103	g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
104	g.P("return &", unexport(clientName), "{cc}")
105	g.P("}")
106	g.P()
107
108	var methodIndex, streamIndex int
109	// Client method implementations.
110	for _, method := range service.Methods {
111		if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
112			// Unary RPC method
113			genClientMethod(gen, file, g, method, methodIndex)
114			methodIndex++
115		} else {
116			// Streaming RPC method
117			genClientMethod(gen, file, g, method, streamIndex)
118			streamIndex++
119		}
120	}
121
122	mustOrShould := "must"
123	if !*requireUnimplemented {
124		mustOrShould = "should"
125	}
126
127	// Server interface.
128	serverType := service.GoName + "Server"
129	g.P("// ", serverType, " is the server API for ", service.GoName, " service.")
130	g.P("// All implementations ", mustOrShould, " embed Unimplemented", serverType)
131	g.P("// for forward compatibility")
132	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
133		g.P("//")
134		g.P(deprecationComment)
135	}
136	g.Annotate(serverType, service.Location)
137	g.P("type ", serverType, " interface {")
138	for _, method := range service.Methods {
139		g.Annotate(serverType+"."+method.GoName, method.Location)
140		if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
141			g.P(deprecationComment)
142		}
143		g.P(method.Comments.Leading,
144			serverSignature(g, method))
145	}
146	if *requireUnimplemented {
147		g.P("mustEmbedUnimplemented", serverType, "()")
148	}
149	g.P("}")
150	g.P()
151
152	// Server Unimplemented struct for forward compatibility.
153	g.P("// Unimplemented", serverType, " ", mustOrShould, " be embedded to have forward compatible implementations.")
154	g.P("type Unimplemented", serverType, " struct {")
155	g.P("}")
156	g.P()
157	for _, method := range service.Methods {
158		nilArg := ""
159		if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
160			nilArg = "nil,"
161		}
162		g.P("func (*Unimplemented", serverType, ") ", serverSignature(g, method), "{")
163		g.P("return ", nilArg, statusPackage.Ident("Errorf"), "(", codesPackage.Ident("Unimplemented"), `, "method `, method.GoName, ` not implemented")`)
164		g.P("}")
165	}
166	if *requireUnimplemented {
167		g.P("func (*Unimplemented", serverType, ") mustEmbedUnimplemented", serverType, "() {}")
168	}
169	g.P()
170
171	// Server registration.
172	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
173		g.P(deprecationComment)
174	}
175	serviceDescVar := "_" + service.GoName + "_serviceDesc"
176	g.P("func Register", service.GoName, "Server(s *", grpcPackage.Ident("Server"), ", srv ", serverType, ") {")
177	g.P("s.RegisterService(&", serviceDescVar, `, srv)`)
178	g.P("}")
179	g.P()
180
181	// Server handler implementations.
182	var handlerNames []string
183	for _, method := range service.Methods {
184		hname := genServerMethod(gen, file, g, method)
185		handlerNames = append(handlerNames, hname)
186	}
187
188	// Service descriptor.
189	g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
190	g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
191	g.P("HandlerType: (*", serverType, ")(nil),")
192	g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
193	for i, method := range service.Methods {
194		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
195			continue
196		}
197		g.P("{")
198		g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
199		g.P("Handler: ", handlerNames[i], ",")
200		g.P("},")
201	}
202	g.P("},")
203	g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
204	for i, method := range service.Methods {
205		if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
206			continue
207		}
208		g.P("{")
209		g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
210		g.P("Handler: ", handlerNames[i], ",")
211		if method.Desc.IsStreamingServer() {
212			g.P("ServerStreams: true,")
213		}
214		if method.Desc.IsStreamingClient() {
215			g.P("ClientStreams: true,")
216		}
217		g.P("},")
218	}
219	g.P("},")
220	g.P("Metadata: \"", file.Desc.Path(), "\",")
221	g.P("}")
222	g.P()
223}
224
225func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
226	s := method.GoName + "(ctx " + g.QualifiedGoIdent(contextPackage.Ident("Context"))
227	if !method.Desc.IsStreamingClient() {
228		s += ", in *" + g.QualifiedGoIdent(method.Input.GoIdent)
229	}
230	s += ", opts ..." + g.QualifiedGoIdent(grpcPackage.Ident("CallOption")) + ") ("
231	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
232		s += "*" + g.QualifiedGoIdent(method.Output.GoIdent)
233	} else {
234		s += method.Parent.GoName + "_" + method.GoName + "Client"
235	}
236	s += ", error)"
237	return s
238}
239
240func genClientMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, index int) {
241	service := method.Parent
242	sname := fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
243
244	if method.Desc.Options().(*descriptorpb.MethodOptions).GetDeprecated() {
245		g.P(deprecationComment)
246	}
247	g.P("func (c *", unexport(service.GoName), "Client) ", clientSignature(g, method), "{")
248	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
249		g.P("out := new(", method.Output.GoIdent, ")")
250		g.P(`err := c.cc.Invoke(ctx, "`, sname, `", in, out, opts...)`)
251		g.P("if err != nil { return nil, err }")
252		g.P("return out, nil")
253		g.P("}")
254		g.P()
255		return
256	}
257	streamType := unexport(service.GoName) + method.GoName + "Client"
258	serviceDescVar := "_" + service.GoName + "_serviceDesc"
259	g.P("stream, err := c.cc.NewStream(ctx, &", serviceDescVar, ".Streams[", index, `], "`, sname, `", opts...)`)
260	g.P("if err != nil { return nil, err }")
261	g.P("x := &", streamType, "{stream}")
262	if !method.Desc.IsStreamingClient() {
263		g.P("if err := x.ClientStream.SendMsg(in); err != nil { return nil, err }")
264		g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
265	}
266	g.P("return x, nil")
267	g.P("}")
268	g.P()
269
270	genSend := method.Desc.IsStreamingClient()
271	genRecv := method.Desc.IsStreamingServer()
272	genCloseAndRecv := !method.Desc.IsStreamingServer()
273
274	// Stream auxiliary types and methods.
275	g.P("type ", service.GoName, "_", method.GoName, "Client interface {")
276	if genSend {
277		g.P("Send(*", method.Input.GoIdent, ") error")
278	}
279	if genRecv {
280		g.P("Recv() (*", method.Output.GoIdent, ", error)")
281	}
282	if genCloseAndRecv {
283		g.P("CloseAndRecv() (*", method.Output.GoIdent, ", error)")
284	}
285	g.P(grpcPackage.Ident("ClientStream"))
286	g.P("}")
287	g.P()
288
289	g.P("type ", streamType, " struct {")
290	g.P(grpcPackage.Ident("ClientStream"))
291	g.P("}")
292	g.P()
293
294	if genSend {
295		g.P("func (x *", streamType, ") Send(m *", method.Input.GoIdent, ") error {")
296		g.P("return x.ClientStream.SendMsg(m)")
297		g.P("}")
298		g.P()
299	}
300	if genRecv {
301		g.P("func (x *", streamType, ") Recv() (*", method.Output.GoIdent, ", error) {")
302		g.P("m := new(", method.Output.GoIdent, ")")
303		g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
304		g.P("return m, nil")
305		g.P("}")
306		g.P()
307	}
308	if genCloseAndRecv {
309		g.P("func (x *", streamType, ") CloseAndRecv() (*", method.Output.GoIdent, ", error) {")
310		g.P("if err := x.ClientStream.CloseSend(); err != nil { return nil, err }")
311		g.P("m := new(", method.Output.GoIdent, ")")
312		g.P("if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err }")
313		g.P("return m, nil")
314		g.P("}")
315		g.P()
316	}
317}
318
319func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
320	var reqArgs []string
321	ret := "error"
322	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
323		reqArgs = append(reqArgs, g.QualifiedGoIdent(contextPackage.Ident("Context")))
324		ret = "(*" + g.QualifiedGoIdent(method.Output.GoIdent) + ", error)"
325	}
326	if !method.Desc.IsStreamingClient() {
327		reqArgs = append(reqArgs, "*"+g.QualifiedGoIdent(method.Input.GoIdent))
328	}
329	if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
330		reqArgs = append(reqArgs, method.Parent.GoName+"_"+method.GoName+"Server")
331	}
332	return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
333}
334
335func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
336	service := method.Parent
337	hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
338
339	if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
340		g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
341		g.P("in := new(", method.Input.GoIdent, ")")
342		g.P("if err := dec(in); err != nil { return nil, err }")
343		g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
344		g.P("info := &", grpcPackage.Ident("UnaryServerInfo"), "{")
345		g.P("Server: srv,")
346		g.P("FullMethod: ", strconv.Quote(fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.GoName)), ",")
347		g.P("}")
348		g.P("handler := func(ctx ", contextPackage.Ident("Context"), ", req interface{}) (interface{}, error) {")
349		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(ctx, req.(*", method.Input.GoIdent, "))")
350		g.P("}")
351		g.P("return interceptor(ctx, in, info, handler)")
352		g.P("}")
353		g.P()
354		return hname
355	}
356	streamType := unexport(service.GoName) + method.GoName + "Server"
357	g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
358	if !method.Desc.IsStreamingClient() {
359		g.P("m := new(", method.Input.GoIdent, ")")
360		g.P("if err := stream.RecvMsg(m); err != nil { return err }")
361		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(m, &", streamType, "{stream})")
362	} else {
363		g.P("return srv.(", service.GoName, "Server).", method.GoName, "(&", streamType, "{stream})")
364	}
365	g.P("}")
366	g.P()
367
368	genSend := method.Desc.IsStreamingServer()
369	genSendAndClose := !method.Desc.IsStreamingServer()
370	genRecv := method.Desc.IsStreamingClient()
371
372	// Stream auxiliary types and methods.
373	g.P("type ", service.GoName, "_", method.GoName, "Server interface {")
374	if genSend {
375		g.P("Send(*", method.Output.GoIdent, ") error")
376	}
377	if genSendAndClose {
378		g.P("SendAndClose(*", method.Output.GoIdent, ") error")
379	}
380	if genRecv {
381		g.P("Recv() (*", method.Input.GoIdent, ", error)")
382	}
383	g.P(grpcPackage.Ident("ServerStream"))
384	g.P("}")
385	g.P()
386
387	g.P("type ", streamType, " struct {")
388	g.P(grpcPackage.Ident("ServerStream"))
389	g.P("}")
390	g.P()
391
392	if genSend {
393		g.P("func (x *", streamType, ") Send(m *", method.Output.GoIdent, ") error {")
394		g.P("return x.ServerStream.SendMsg(m)")
395		g.P("}")
396		g.P()
397	}
398	if genSendAndClose {
399		g.P("func (x *", streamType, ") SendAndClose(m *", method.Output.GoIdent, ") error {")
400		g.P("return x.ServerStream.SendMsg(m)")
401		g.P("}")
402		g.P()
403	}
404	if genRecv {
405		g.P("func (x *", streamType, ") Recv() (*", method.Input.GoIdent, ", error) {")
406		g.P("m := new(", method.Input.GoIdent, ")")
407		g.P("if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err }")
408		g.P("return m, nil")
409		g.P("}")
410		g.P()
411	}
412
413	return hname
414}
415
416const deprecationComment = "// Deprecated: Do not use."
417
418func unexport(s string) string { return strings.ToLower(s[:1]) + s[1:] }
419