1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package main
5
6import (
7	"flag"
8	"fmt"
9	"runtime/debug"
10	"strconv"
11	"strings"
12
13	"google.golang.org/protobuf/compiler/protogen"
14)
15
16type config struct {
17	protolib string
18	json     bool
19}
20
21func main() {
22	var flags flag.FlagSet
23	var conf config
24	flags.StringVar(&conf.protolib, "protolib", "google.golang.org/protobuf", "which protobuf library to use for encoding")
25	flags.BoolVar(&conf.json, "json", true, "generate encoders with json support")
26
27	protogen.Options{
28		ParamFunc: flags.Set,
29	}.Run(func(plugin *protogen.Plugin) error {
30		for _, f := range plugin.Files {
31			if !f.Generate || len(f.Services) == 0 {
32				continue
33			}
34			generateFile(plugin, f, conf)
35		}
36		return nil
37	})
38}
39
40func generateFile(plugin *protogen.Plugin, file *protogen.File, conf config) {
41	gf := plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_drpc.pb.go", file.GoImportPath)
42	d := &drpc{gf, file}
43
44	d.P("// Code generated by protoc-gen-go-drpc. DO NOT EDIT.")
45	if bi, ok := debug.ReadBuildInfo(); ok {
46		d.P("// protoc-gen-go-drpc version: ", bi.Main.Version)
47	}
48	d.P("// source: ", file.Desc.Path())
49	d.P()
50	d.P("package ", file.GoPackageName)
51	d.P()
52
53	d.generateEncoding(conf)
54	for _, service := range file.Services {
55		d.generateService(service)
56	}
57}
58
59type drpc struct {
60	*protogen.GeneratedFile
61	file *protogen.File
62}
63
64//
65// name helpers
66//
67
68func (d *drpc) Ident(path, ident string) string {
69	return d.QualifiedGoIdent(protogen.GoImportPath(path).Ident(ident))
70}
71
72func (d *drpc) EncodingName() string {
73	return "drpcEncoding_" + d.file.GoDescriptorIdent.GoName
74}
75
76func (d *drpc) RPCGoString(method *protogen.Method) string {
77	return strconv.Quote(fmt.Sprintf("/%s/%s", method.Parent.Desc.FullName(), method.Desc.Name()))
78}
79
80func (d *drpc) InputType(method *protogen.Method) string {
81	return d.QualifiedGoIdent(method.Input.GoIdent)
82}
83
84func (d *drpc) OutputType(method *protogen.Method) string {
85	return d.QualifiedGoIdent(method.Output.GoIdent)
86}
87
88func (d *drpc) ClientIface(service *protogen.Service) string {
89	return "DRPC" + service.GoName + "Client"
90}
91
92func (d *drpc) ClientImpl(service *protogen.Service) string {
93	return "drpc" + service.GoName + "Client"
94}
95
96func (d *drpc) ServerIface(service *protogen.Service) string {
97	return "DRPC" + service.GoName + "Server"
98}
99
100func (d *drpc) ServerImpl(service *protogen.Service) string {
101	return "drpc" + service.GoName + "Server"
102}
103
104func (d *drpc) ServerUnimpl(service *protogen.Service) string {
105	return "DRPC" + service.GoName + "UnimplementedServer"
106}
107
108func (d *drpc) ServerDesc(service *protogen.Service) string {
109	return "DRPC" + service.GoName + "Description"
110}
111
112func (d *drpc) ClientStreamIface(method *protogen.Method) string {
113	return "DRPC" +
114		strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" +
115		strings.ReplaceAll(method.GoName, "_", "__") +
116		"Client"
117}
118
119func (d *drpc) ClientStreamImpl(method *protogen.Method) string {
120	return "drpc" +
121		strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" +
122		strings.ReplaceAll(method.GoName, "_", "__") +
123		"Client"
124}
125
126func (d *drpc) ServerStreamIface(method *protogen.Method) string {
127	return "DRPC" +
128		strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" +
129		strings.ReplaceAll(method.GoName, "_", "__") +
130		"Stream"
131}
132
133func (d *drpc) ServerStreamImpl(method *protogen.Method) string {
134	return "drpc" +
135		strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" +
136		strings.ReplaceAll(method.GoName, "_", "__") +
137		"Stream"
138}
139
140//
141// encoding generation
142//
143
144func (d *drpc) generateEncoding(conf config) {
145	d.P("type ", d.EncodingName(), " struct{}")
146	d.P()
147
148	switch conf.protolib {
149	case "google.golang.org/protobuf":
150		d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
151		d.P("return ", d.Ident("google.golang.org/protobuf/proto", "Marshal"), "(msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))")
152		d.P("}")
153		d.P()
154
155		d.P("func (", d.EncodingName(), ") MarshalAppend(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
156		d.P("return ", d.Ident("google.golang.org/protobuf/proto", "MarshalOptions"), "{}.MarshalAppend(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))")
157		d.P("}")
158		d.P()
159
160		d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
161		d.P("return ", d.Ident("google.golang.org/protobuf/proto", "Unmarshal"), "(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))")
162		d.P("}")
163		d.P()
164
165		if conf.json {
166			d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
167			d.P("return ", d.Ident("google.golang.org/protobuf/encoding/protojson", "Marshal"), "(msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))")
168			d.P("}")
169			d.P()
170
171			d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
172			d.P("return ", d.Ident("google.golang.org/protobuf/encoding/protojson", "Unmarshal"), "(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))")
173			d.P("}")
174			d.P()
175		}
176
177	case "github.com/gogo/protobuf":
178		d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
179		d.P("return ", d.Ident("github.com/gogo/protobuf/proto", "Marshal"), "(msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))")
180		d.P("}")
181		d.P()
182
183		d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
184		d.P("return ", d.Ident("github.com/gogo/protobuf/proto", "Unmarshal"), "(buf, msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))")
185		d.P("}")
186		d.P()
187
188		if conf.json {
189			d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
190			d.P("var buf ", d.Ident("bytes", "Buffer"))
191			d.P("err := new(", d.Ident("github.com/gogo/protobuf/jsonpb", "Marshaler"), ").Marshal(&buf, msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))")
192			d.P("if err != nil {")
193			d.P("return nil, err")
194			d.P("}")
195			d.P("return buf.Bytes(), nil")
196			d.P("}")
197			d.P()
198
199			d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
200			d.P("return ", d.Ident("github.com/gogo/protobuf/jsonpb", "Unmarshal"), "(", d.Ident("bytes", "NewReader"), "(buf), msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))")
201			d.P("}")
202			d.P()
203		}
204
205	default:
206		d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
207		d.P("return ", d.Ident(conf.protolib, "Marshal"), "(msg)")
208		d.P("}")
209		d.P()
210
211		d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
212		d.P("return ", d.Ident(conf.protolib, "Unmarshal"), "(buf, msg)")
213		d.P("}")
214		d.P()
215
216		if conf.json {
217			d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {")
218			d.P("return ", d.Ident(conf.protolib, "JSONMarshal"), "(msg)")
219			d.P("}")
220			d.P()
221
222			d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {")
223			d.P("return ", d.Ident(conf.protolib, "JSONUnmarshal"), "(buf, msg)")
224			d.P("}")
225			d.P()
226		}
227
228	}
229}
230
231//
232// service generation
233//
234
235func (d *drpc) generateService(service *protogen.Service) {
236	// Client interface
237	d.P("type ", d.ClientIface(service), " interface {")
238	d.P("DRPCConn() ", d.Ident("storj.io/drpc", "Conn"))
239	d.P()
240	for _, method := range service.Methods {
241		d.P(d.generateClientSignature(method))
242	}
243	d.P("}")
244	d.P()
245
246	// Client implementation
247	d.P("type ", d.ClientImpl(service), " struct {")
248	d.P("cc ", d.Ident("storj.io/drpc", "Conn"))
249	d.P("}")
250	d.P()
251
252	// Client constructor
253	d.P("func New", d.ClientIface(service), "(cc ", d.Ident("storj.io/drpc", "Conn"), ") ", d.ClientIface(service), " {")
254	d.P("return &", d.ClientImpl(service), "{cc}")
255	d.P("}")
256	d.P()
257
258	// Client method implementations
259	d.P("func (c *", d.ClientImpl(service), ") DRPCConn() ", d.Ident("storj.io/drpc", "Conn"), "{ return c.cc }")
260	d.P()
261	for _, method := range service.Methods {
262		d.generateClientMethod(method)
263	}
264
265	// Server interface
266	d.P("type ", d.ServerIface(service), " interface {")
267	for _, method := range service.Methods {
268		d.P(d.generateServerSignature(method))
269	}
270	d.P("}")
271	d.P()
272
273	// Server Unimplemented struct
274	d.P("type ", d.ServerUnimpl(service), " struct {}")
275	d.P()
276	for _, method := range service.Methods {
277		d.generateUnimplementedServerMethod(method)
278	}
279	d.P()
280
281	// Server description.
282	d.P("type ", d.ServerDesc(service), " struct{}")
283	d.P()
284	d.P("func (", d.ServerDesc(service), ") NumMethods() int { return ", len(service.Methods), " }")
285	d.P()
286	d.P("func (", d.ServerDesc(service), ") Method(n int) (string, ", d.Ident("storj.io/drpc", "Encoding"), ", ", d.Ident("storj.io/drpc", "Receiver"), ", interface{}, bool) {")
287	d.P("switch n {")
288	for i, method := range service.Methods {
289		d.P("case ", i, ":")
290		d.P("return ", d.RPCGoString(method), ", ", d.EncodingName(), "{}, ")
291		d.generateServerReceiver(method)
292		d.P("}, ", d.ServerIface(service), ".", method.GoName, ", true")
293	}
294	d.P("default:")
295	d.P(`return "", nil, nil, nil, false`)
296	d.P("}")
297	d.P("}")
298	d.P()
299
300	// Registration helper
301	d.P("func DRPCRegister", service.GoName, "(mux ", d.Ident("storj.io/drpc", "Mux"), ", impl ", d.ServerIface(service), ") error {")
302	d.P("return mux.Register(impl, ", d.ServerDesc(service), "{})")
303	d.P("}")
304
305	// Server methods
306	for _, method := range service.Methods {
307		d.generateServerMethod(method)
308	}
309}
310
311//
312// client methods
313//
314
315func (d *drpc) generateClientSignature(method *protogen.Method) string {
316	reqArg := ", in *" + d.InputType(method)
317	if method.Desc.IsStreamingClient() {
318		reqArg = ""
319	}
320	respName := "*" + d.OutputType(method)
321	if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
322		respName = d.ClientStreamIface(method)
323	}
324	return fmt.Sprintf("%s(ctx %s%s) (%s, error)", method.GoName, d.Ident("context", "Context"), reqArg, respName)
325}
326
327func (d *drpc) generateClientMethod(method *protogen.Method) {
328	recvType := d.ClientImpl(method.Parent)
329	outType := d.OutputType(method)
330	inType := d.InputType(method)
331
332	d.P("func (c *", recvType, ") ", d.generateClientSignature(method), "{")
333	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
334		d.P("out := new(", outType, ")")
335		d.P("err := c.cc.Invoke(ctx, ", d.RPCGoString(method), ", ", d.EncodingName(), "{}, in, out)")
336		d.P("if err != nil { return nil, err }")
337		d.P("return out, nil")
338		d.P("}")
339		d.P()
340		return
341	}
342
343	d.P("stream, err := c.cc.NewStream(ctx, ", d.RPCGoString(method), ", ", d.EncodingName(), "{})")
344	d.P("if err != nil { return nil, err }")
345	d.P("x := &", d.ClientStreamImpl(method), "{stream}")
346	if !method.Desc.IsStreamingClient() {
347		d.P("if err := x.MsgSend(in, ", d.EncodingName(), "{}); err != nil { return nil, err }")
348		d.P("if err := x.CloseSend(); err != nil { return nil, err }")
349	}
350	d.P("return x, nil")
351	d.P("}")
352	d.P()
353
354	genSend := method.Desc.IsStreamingClient()
355	genRecv := method.Desc.IsStreamingServer()
356	genCloseAndRecv := !method.Desc.IsStreamingServer()
357
358	// Stream auxiliary types and methods.
359	d.P("type ", d.ClientStreamIface(method), " interface {")
360	d.P(d.Ident("storj.io/drpc", "Stream"))
361	if genSend {
362		d.P("Send(*", inType, ") error")
363	}
364	if genRecv {
365		d.P("Recv() (*", outType, ", error)")
366	}
367	if genCloseAndRecv {
368		d.P("CloseAndRecv() (*", outType, ", error)")
369	}
370	d.P("}")
371	d.P()
372
373	d.P("type ", d.ClientStreamImpl(method), " struct {")
374	d.P(d.Ident("storj.io/drpc", "Stream"))
375	d.P("}")
376	d.P()
377
378	if genSend {
379		d.P("func (x *", d.ClientStreamImpl(method), ") Send(m *", inType, ") error {")
380		d.P("return x.MsgSend(m, ", d.EncodingName(), "{})")
381		d.P("}")
382		d.P()
383	}
384	if genRecv {
385		d.P("func (x *", d.ClientStreamImpl(method), ") Recv() (*", outType, ", error) {")
386		d.P("m := new(", outType, ")")
387		d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }")
388		d.P("return m, nil")
389		d.P("}")
390		d.P()
391
392		d.P("func (x *", d.ClientStreamImpl(method), ") RecvMsg(m *", outType, ") error {")
393		d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})")
394		d.P("}")
395		d.P()
396	}
397	if genCloseAndRecv {
398		d.P("func (x *", d.ClientStreamImpl(method), ") CloseAndRecv() (*", outType, ", error) {")
399		d.P("if err := x.CloseSend(); err != nil { return nil, err }")
400		d.P("m := new(", outType, ")")
401		d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }")
402		d.P("return m, nil")
403		d.P("}")
404		d.P()
405
406		d.P("func (x *", d.ClientStreamImpl(method), ") CloseAndRecvMsg(m *", outType, ") error {")
407		d.P("if err := x.CloseSend(); err != nil { return err }")
408		d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})")
409		d.P("}")
410		d.P()
411	}
412}
413
414//
415// server methods
416//
417
418func (d *drpc) generateServerSignature(method *protogen.Method) string {
419	var reqArgs []string
420	ret := "error"
421	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
422		reqArgs = append(reqArgs, d.Ident("context", "Context"))
423		ret = "(*" + d.OutputType(method) + ", error)"
424	}
425	if !method.Desc.IsStreamingClient() {
426		reqArgs = append(reqArgs, "*"+d.InputType(method))
427	}
428	if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
429		reqArgs = append(reqArgs, d.ServerStreamIface(method))
430	}
431	return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
432}
433
434func (d *drpc) generateUnimplementedServerMethod(method *protogen.Method) {
435	d.P("func (s *", d.ServerUnimpl(method.Parent), ") ", d.generateServerSignature(method), " {")
436	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
437		d.P("return nil, ", d.Ident("storj.io/drpc/drpcerr", "WithCode"), "(", d.Ident("errors", "New"), "(\"Unimplemented\"), ", d.Ident("storj.io/drpc/drpcerr", "Unimplemented"), ")")
438	} else {
439		d.P("return ", d.Ident("storj.io/drpc/drpcerr", "WithCode"), "(", d.Ident("errors", "New"), "(\"Unimplemented\"), ", d.Ident("storj.io/drpc/drpcerr", "Unimplemented"), ")")
440	}
441	d.P("}")
442	d.P()
443}
444
445func (d *drpc) generateServerReceiver(method *protogen.Method) {
446	d.P("func (srv interface{}, ctx context.Context, in1, in2 interface{}) (" + d.Ident("storj.io/drpc", "Message") + ", error) {")
447	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
448		d.P("return srv.(", d.ServerIface(method.Parent), ").")
449	} else {
450		d.P("return nil, srv.(", d.ServerIface(method.Parent), ").")
451	}
452	d.P(method.GoName, "(")
453
454	n := 1
455	if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() {
456		d.P("ctx,")
457	}
458	if !method.Desc.IsStreamingClient() {
459		d.P("in", n, ".(*", d.InputType(method), "),")
460		n++
461	}
462	if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() {
463		d.P("&", d.ServerStreamImpl(method), "{in", n, ".(", d.Ident("storj.io/drpc", "Stream"), ")},")
464	}
465	d.P(")")
466}
467
468func (d *drpc) generateServerMethod(method *protogen.Method) {
469	genSend := method.Desc.IsStreamingServer()
470	genSendAndClose := !method.Desc.IsStreamingServer()
471	genRecv := method.Desc.IsStreamingClient()
472
473	// Stream auxiliary types and methods.
474	d.P("type ", d.ServerStreamIface(method), " interface {")
475	d.P(d.Ident("storj.io/drpc", "Stream"))
476	if genSend {
477		d.P("Send(*", d.OutputType(method), ") error")
478	}
479	if genSendAndClose {
480		d.P("SendAndClose(*", d.OutputType(method), ") error")
481	}
482	if genRecv {
483		d.P("Recv() (*", d.InputType(method), ", error)")
484	}
485	d.P("}")
486	d.P()
487
488	d.P("type ", d.ServerStreamImpl(method), " struct {")
489	d.P(d.Ident("storj.io/drpc", "Stream"))
490	d.P("}")
491	d.P()
492
493	if genSend {
494		d.P("func (x *", d.ServerStreamImpl(method), ") Send(m *", d.OutputType(method), ") error {")
495		d.P("return x.MsgSend(m, ", d.EncodingName(), "{})")
496		d.P("}")
497		d.P()
498	}
499
500	if genSendAndClose {
501		d.P("func (x *", d.ServerStreamImpl(method), ") SendAndClose(m *", d.OutputType(method), ") error {")
502		d.P("if err := x.MsgSend(m, ", d.EncodingName(), "{}); err != nil { return err }")
503		d.P("return x.CloseSend()")
504		d.P("}")
505		d.P()
506	}
507
508	if genRecv {
509		d.P("func (x *", d.ServerStreamImpl(method), ") Recv() (*", d.InputType(method), ", error) {")
510		d.P("m := new(", d.InputType(method), ")")
511		d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }")
512		d.P("return m, nil")
513		d.P("}")
514		d.P()
515
516		d.P("func (x *", d.ServerStreamImpl(method), ") RecvMsg(m *", d.InputType(method), ") error {")
517		d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})")
518		d.P("}")
519		d.P()
520	}
521}
522