1// Copyright (C) 2019 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package main 5 6import ( 7 "flag" 8 "fmt" 9 "runtime/debug" 10 "strconv" 11 "strings" 12 13 "google.golang.org/protobuf/compiler/protogen" 14) 15 16type config struct { 17 protolib string 18 json bool 19} 20 21func main() { 22 var flags flag.FlagSet 23 var conf config 24 flags.StringVar(&conf.protolib, "protolib", "google.golang.org/protobuf", "which protobuf library to use for encoding") 25 flags.BoolVar(&conf.json, "json", true, "generate encoders with json support") 26 27 protogen.Options{ 28 ParamFunc: flags.Set, 29 }.Run(func(plugin *protogen.Plugin) error { 30 for _, f := range plugin.Files { 31 if !f.Generate || len(f.Services) == 0 { 32 continue 33 } 34 generateFile(plugin, f, conf) 35 } 36 return nil 37 }) 38} 39 40func generateFile(plugin *protogen.Plugin, file *protogen.File, conf config) { 41 gf := plugin.NewGeneratedFile(file.GeneratedFilenamePrefix+"_drpc.pb.go", file.GoImportPath) 42 d := &drpc{gf, file} 43 44 d.P("// Code generated by protoc-gen-go-drpc. DO NOT EDIT.") 45 if bi, ok := debug.ReadBuildInfo(); ok { 46 d.P("// protoc-gen-go-drpc version: ", bi.Main.Version) 47 } 48 d.P("// source: ", file.Desc.Path()) 49 d.P() 50 d.P("package ", file.GoPackageName) 51 d.P() 52 53 d.generateEncoding(conf) 54 for _, service := range file.Services { 55 d.generateService(service) 56 } 57} 58 59type drpc struct { 60 *protogen.GeneratedFile 61 file *protogen.File 62} 63 64// 65// name helpers 66// 67 68func (d *drpc) Ident(path, ident string) string { 69 return d.QualifiedGoIdent(protogen.GoImportPath(path).Ident(ident)) 70} 71 72func (d *drpc) EncodingName() string { 73 return "drpcEncoding_" + d.file.GoDescriptorIdent.GoName 74} 75 76func (d *drpc) RPCGoString(method *protogen.Method) string { 77 return strconv.Quote(fmt.Sprintf("/%s/%s", method.Parent.Desc.FullName(), method.Desc.Name())) 78} 79 80func (d *drpc) InputType(method *protogen.Method) string { 81 return d.QualifiedGoIdent(method.Input.GoIdent) 82} 83 84func (d *drpc) OutputType(method *protogen.Method) string { 85 return d.QualifiedGoIdent(method.Output.GoIdent) 86} 87 88func (d *drpc) ClientIface(service *protogen.Service) string { 89 return "DRPC" + service.GoName + "Client" 90} 91 92func (d *drpc) ClientImpl(service *protogen.Service) string { 93 return "drpc" + service.GoName + "Client" 94} 95 96func (d *drpc) ServerIface(service *protogen.Service) string { 97 return "DRPC" + service.GoName + "Server" 98} 99 100func (d *drpc) ServerImpl(service *protogen.Service) string { 101 return "drpc" + service.GoName + "Server" 102} 103 104func (d *drpc) ServerUnimpl(service *protogen.Service) string { 105 return "DRPC" + service.GoName + "UnimplementedServer" 106} 107 108func (d *drpc) ServerDesc(service *protogen.Service) string { 109 return "DRPC" + service.GoName + "Description" 110} 111 112func (d *drpc) ClientStreamIface(method *protogen.Method) string { 113 return "DRPC" + 114 strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" + 115 strings.ReplaceAll(method.GoName, "_", "__") + 116 "Client" 117} 118 119func (d *drpc) ClientStreamImpl(method *protogen.Method) string { 120 return "drpc" + 121 strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" + 122 strings.ReplaceAll(method.GoName, "_", "__") + 123 "Client" 124} 125 126func (d *drpc) ServerStreamIface(method *protogen.Method) string { 127 return "DRPC" + 128 strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" + 129 strings.ReplaceAll(method.GoName, "_", "__") + 130 "Stream" 131} 132 133func (d *drpc) ServerStreamImpl(method *protogen.Method) string { 134 return "drpc" + 135 strings.ReplaceAll(method.Parent.GoName, "_", "__") + "_" + 136 strings.ReplaceAll(method.GoName, "_", "__") + 137 "Stream" 138} 139 140// 141// encoding generation 142// 143 144func (d *drpc) generateEncoding(conf config) { 145 d.P("type ", d.EncodingName(), " struct{}") 146 d.P() 147 148 switch conf.protolib { 149 case "google.golang.org/protobuf": 150 d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 151 d.P("return ", d.Ident("google.golang.org/protobuf/proto", "Marshal"), "(msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))") 152 d.P("}") 153 d.P() 154 155 d.P("func (", d.EncodingName(), ") MarshalAppend(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 156 d.P("return ", d.Ident("google.golang.org/protobuf/proto", "MarshalOptions"), "{}.MarshalAppend(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))") 157 d.P("}") 158 d.P() 159 160 d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 161 d.P("return ", d.Ident("google.golang.org/protobuf/proto", "Unmarshal"), "(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))") 162 d.P("}") 163 d.P() 164 165 if conf.json { 166 d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 167 d.P("return ", d.Ident("google.golang.org/protobuf/encoding/protojson", "Marshal"), "(msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))") 168 d.P("}") 169 d.P() 170 171 d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 172 d.P("return ", d.Ident("google.golang.org/protobuf/encoding/protojson", "Unmarshal"), "(buf, msg.(", d.Ident("google.golang.org/protobuf/proto", "Message"), "))") 173 d.P("}") 174 d.P() 175 } 176 177 case "github.com/gogo/protobuf": 178 d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 179 d.P("return ", d.Ident("github.com/gogo/protobuf/proto", "Marshal"), "(msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))") 180 d.P("}") 181 d.P() 182 183 d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 184 d.P("return ", d.Ident("github.com/gogo/protobuf/proto", "Unmarshal"), "(buf, msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))") 185 d.P("}") 186 d.P() 187 188 if conf.json { 189 d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 190 d.P("var buf ", d.Ident("bytes", "Buffer")) 191 d.P("err := new(", d.Ident("github.com/gogo/protobuf/jsonpb", "Marshaler"), ").Marshal(&buf, msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))") 192 d.P("if err != nil {") 193 d.P("return nil, err") 194 d.P("}") 195 d.P("return buf.Bytes(), nil") 196 d.P("}") 197 d.P() 198 199 d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 200 d.P("return ", d.Ident("github.com/gogo/protobuf/jsonpb", "Unmarshal"), "(", d.Ident("bytes", "NewReader"), "(buf), msg.(", d.Ident("github.com/gogo/protobuf/proto", "Message"), "))") 201 d.P("}") 202 d.P() 203 } 204 205 default: 206 d.P("func (", d.EncodingName(), ") Marshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 207 d.P("return ", d.Ident(conf.protolib, "Marshal"), "(msg)") 208 d.P("}") 209 d.P() 210 211 d.P("func (", d.EncodingName(), ") Unmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 212 d.P("return ", d.Ident(conf.protolib, "Unmarshal"), "(buf, msg)") 213 d.P("}") 214 d.P() 215 216 if conf.json { 217 d.P("func (", d.EncodingName(), ") JSONMarshal(msg ", d.Ident("storj.io/drpc", "Message"), ") ([]byte, error) {") 218 d.P("return ", d.Ident(conf.protolib, "JSONMarshal"), "(msg)") 219 d.P("}") 220 d.P() 221 222 d.P("func (", d.EncodingName(), ") JSONUnmarshal(buf []byte, msg ", d.Ident("storj.io/drpc", "Message"), ") error {") 223 d.P("return ", d.Ident(conf.protolib, "JSONUnmarshal"), "(buf, msg)") 224 d.P("}") 225 d.P() 226 } 227 228 } 229} 230 231// 232// service generation 233// 234 235func (d *drpc) generateService(service *protogen.Service) { 236 // Client interface 237 d.P("type ", d.ClientIface(service), " interface {") 238 d.P("DRPCConn() ", d.Ident("storj.io/drpc", "Conn")) 239 d.P() 240 for _, method := range service.Methods { 241 d.P(d.generateClientSignature(method)) 242 } 243 d.P("}") 244 d.P() 245 246 // Client implementation 247 d.P("type ", d.ClientImpl(service), " struct {") 248 d.P("cc ", d.Ident("storj.io/drpc", "Conn")) 249 d.P("}") 250 d.P() 251 252 // Client constructor 253 d.P("func New", d.ClientIface(service), "(cc ", d.Ident("storj.io/drpc", "Conn"), ") ", d.ClientIface(service), " {") 254 d.P("return &", d.ClientImpl(service), "{cc}") 255 d.P("}") 256 d.P() 257 258 // Client method implementations 259 d.P("func (c *", d.ClientImpl(service), ") DRPCConn() ", d.Ident("storj.io/drpc", "Conn"), "{ return c.cc }") 260 d.P() 261 for _, method := range service.Methods { 262 d.generateClientMethod(method) 263 } 264 265 // Server interface 266 d.P("type ", d.ServerIface(service), " interface {") 267 for _, method := range service.Methods { 268 d.P(d.generateServerSignature(method)) 269 } 270 d.P("}") 271 d.P() 272 273 // Server Unimplemented struct 274 d.P("type ", d.ServerUnimpl(service), " struct {}") 275 d.P() 276 for _, method := range service.Methods { 277 d.generateUnimplementedServerMethod(method) 278 } 279 d.P() 280 281 // Server description. 282 d.P("type ", d.ServerDesc(service), " struct{}") 283 d.P() 284 d.P("func (", d.ServerDesc(service), ") NumMethods() int { return ", len(service.Methods), " }") 285 d.P() 286 d.P("func (", d.ServerDesc(service), ") Method(n int) (string, ", d.Ident("storj.io/drpc", "Encoding"), ", ", d.Ident("storj.io/drpc", "Receiver"), ", interface{}, bool) {") 287 d.P("switch n {") 288 for i, method := range service.Methods { 289 d.P("case ", i, ":") 290 d.P("return ", d.RPCGoString(method), ", ", d.EncodingName(), "{}, ") 291 d.generateServerReceiver(method) 292 d.P("}, ", d.ServerIface(service), ".", method.GoName, ", true") 293 } 294 d.P("default:") 295 d.P(`return "", nil, nil, nil, false`) 296 d.P("}") 297 d.P("}") 298 d.P() 299 300 // Registration helper 301 d.P("func DRPCRegister", service.GoName, "(mux ", d.Ident("storj.io/drpc", "Mux"), ", impl ", d.ServerIface(service), ") error {") 302 d.P("return mux.Register(impl, ", d.ServerDesc(service), "{})") 303 d.P("}") 304 305 // Server methods 306 for _, method := range service.Methods { 307 d.generateServerMethod(method) 308 } 309} 310 311// 312// client methods 313// 314 315func (d *drpc) generateClientSignature(method *protogen.Method) string { 316 reqArg := ", in *" + d.InputType(method) 317 if method.Desc.IsStreamingClient() { 318 reqArg = "" 319 } 320 respName := "*" + d.OutputType(method) 321 if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() { 322 respName = d.ClientStreamIface(method) 323 } 324 return fmt.Sprintf("%s(ctx %s%s) (%s, error)", method.GoName, d.Ident("context", "Context"), reqArg, respName) 325} 326 327func (d *drpc) generateClientMethod(method *protogen.Method) { 328 recvType := d.ClientImpl(method.Parent) 329 outType := d.OutputType(method) 330 inType := d.InputType(method) 331 332 d.P("func (c *", recvType, ") ", d.generateClientSignature(method), "{") 333 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 334 d.P("out := new(", outType, ")") 335 d.P("err := c.cc.Invoke(ctx, ", d.RPCGoString(method), ", ", d.EncodingName(), "{}, in, out)") 336 d.P("if err != nil { return nil, err }") 337 d.P("return out, nil") 338 d.P("}") 339 d.P() 340 return 341 } 342 343 d.P("stream, err := c.cc.NewStream(ctx, ", d.RPCGoString(method), ", ", d.EncodingName(), "{})") 344 d.P("if err != nil { return nil, err }") 345 d.P("x := &", d.ClientStreamImpl(method), "{stream}") 346 if !method.Desc.IsStreamingClient() { 347 d.P("if err := x.MsgSend(in, ", d.EncodingName(), "{}); err != nil { return nil, err }") 348 d.P("if err := x.CloseSend(); err != nil { return nil, err }") 349 } 350 d.P("return x, nil") 351 d.P("}") 352 d.P() 353 354 genSend := method.Desc.IsStreamingClient() 355 genRecv := method.Desc.IsStreamingServer() 356 genCloseAndRecv := !method.Desc.IsStreamingServer() 357 358 // Stream auxiliary types and methods. 359 d.P("type ", d.ClientStreamIface(method), " interface {") 360 d.P(d.Ident("storj.io/drpc", "Stream")) 361 if genSend { 362 d.P("Send(*", inType, ") error") 363 } 364 if genRecv { 365 d.P("Recv() (*", outType, ", error)") 366 } 367 if genCloseAndRecv { 368 d.P("CloseAndRecv() (*", outType, ", error)") 369 } 370 d.P("}") 371 d.P() 372 373 d.P("type ", d.ClientStreamImpl(method), " struct {") 374 d.P(d.Ident("storj.io/drpc", "Stream")) 375 d.P("}") 376 d.P() 377 378 if genSend { 379 d.P("func (x *", d.ClientStreamImpl(method), ") Send(m *", inType, ") error {") 380 d.P("return x.MsgSend(m, ", d.EncodingName(), "{})") 381 d.P("}") 382 d.P() 383 } 384 if genRecv { 385 d.P("func (x *", d.ClientStreamImpl(method), ") Recv() (*", outType, ", error) {") 386 d.P("m := new(", outType, ")") 387 d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }") 388 d.P("return m, nil") 389 d.P("}") 390 d.P() 391 392 d.P("func (x *", d.ClientStreamImpl(method), ") RecvMsg(m *", outType, ") error {") 393 d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})") 394 d.P("}") 395 d.P() 396 } 397 if genCloseAndRecv { 398 d.P("func (x *", d.ClientStreamImpl(method), ") CloseAndRecv() (*", outType, ", error) {") 399 d.P("if err := x.CloseSend(); err != nil { return nil, err }") 400 d.P("m := new(", outType, ")") 401 d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }") 402 d.P("return m, nil") 403 d.P("}") 404 d.P() 405 406 d.P("func (x *", d.ClientStreamImpl(method), ") CloseAndRecvMsg(m *", outType, ") error {") 407 d.P("if err := x.CloseSend(); err != nil { return err }") 408 d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})") 409 d.P("}") 410 d.P() 411 } 412} 413 414// 415// server methods 416// 417 418func (d *drpc) generateServerSignature(method *protogen.Method) string { 419 var reqArgs []string 420 ret := "error" 421 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 422 reqArgs = append(reqArgs, d.Ident("context", "Context")) 423 ret = "(*" + d.OutputType(method) + ", error)" 424 } 425 if !method.Desc.IsStreamingClient() { 426 reqArgs = append(reqArgs, "*"+d.InputType(method)) 427 } 428 if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() { 429 reqArgs = append(reqArgs, d.ServerStreamIface(method)) 430 } 431 return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret 432} 433 434func (d *drpc) generateUnimplementedServerMethod(method *protogen.Method) { 435 d.P("func (s *", d.ServerUnimpl(method.Parent), ") ", d.generateServerSignature(method), " {") 436 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 437 d.P("return nil, ", d.Ident("storj.io/drpc/drpcerr", "WithCode"), "(", d.Ident("errors", "New"), "(\"Unimplemented\"), ", d.Ident("storj.io/drpc/drpcerr", "Unimplemented"), ")") 438 } else { 439 d.P("return ", d.Ident("storj.io/drpc/drpcerr", "WithCode"), "(", d.Ident("errors", "New"), "(\"Unimplemented\"), ", d.Ident("storj.io/drpc/drpcerr", "Unimplemented"), ")") 440 } 441 d.P("}") 442 d.P() 443} 444 445func (d *drpc) generateServerReceiver(method *protogen.Method) { 446 d.P("func (srv interface{}, ctx context.Context, in1, in2 interface{}) (" + d.Ident("storj.io/drpc", "Message") + ", error) {") 447 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 448 d.P("return srv.(", d.ServerIface(method.Parent), ").") 449 } else { 450 d.P("return nil, srv.(", d.ServerIface(method.Parent), ").") 451 } 452 d.P(method.GoName, "(") 453 454 n := 1 455 if !method.Desc.IsStreamingServer() && !method.Desc.IsStreamingClient() { 456 d.P("ctx,") 457 } 458 if !method.Desc.IsStreamingClient() { 459 d.P("in", n, ".(*", d.InputType(method), "),") 460 n++ 461 } 462 if method.Desc.IsStreamingServer() || method.Desc.IsStreamingClient() { 463 d.P("&", d.ServerStreamImpl(method), "{in", n, ".(", d.Ident("storj.io/drpc", "Stream"), ")},") 464 } 465 d.P(")") 466} 467 468func (d *drpc) generateServerMethod(method *protogen.Method) { 469 genSend := method.Desc.IsStreamingServer() 470 genSendAndClose := !method.Desc.IsStreamingServer() 471 genRecv := method.Desc.IsStreamingClient() 472 473 // Stream auxiliary types and methods. 474 d.P("type ", d.ServerStreamIface(method), " interface {") 475 d.P(d.Ident("storj.io/drpc", "Stream")) 476 if genSend { 477 d.P("Send(*", d.OutputType(method), ") error") 478 } 479 if genSendAndClose { 480 d.P("SendAndClose(*", d.OutputType(method), ") error") 481 } 482 if genRecv { 483 d.P("Recv() (*", d.InputType(method), ", error)") 484 } 485 d.P("}") 486 d.P() 487 488 d.P("type ", d.ServerStreamImpl(method), " struct {") 489 d.P(d.Ident("storj.io/drpc", "Stream")) 490 d.P("}") 491 d.P() 492 493 if genSend { 494 d.P("func (x *", d.ServerStreamImpl(method), ") Send(m *", d.OutputType(method), ") error {") 495 d.P("return x.MsgSend(m, ", d.EncodingName(), "{})") 496 d.P("}") 497 d.P() 498 } 499 500 if genSendAndClose { 501 d.P("func (x *", d.ServerStreamImpl(method), ") SendAndClose(m *", d.OutputType(method), ") error {") 502 d.P("if err := x.MsgSend(m, ", d.EncodingName(), "{}); err != nil { return err }") 503 d.P("return x.CloseSend()") 504 d.P("}") 505 d.P() 506 } 507 508 if genRecv { 509 d.P("func (x *", d.ServerStreamImpl(method), ") Recv() (*", d.InputType(method), ", error) {") 510 d.P("m := new(", d.InputType(method), ")") 511 d.P("if err := x.MsgRecv(m, ", d.EncodingName(), "{}); err != nil { return nil, err }") 512 d.P("return m, nil") 513 d.P("}") 514 d.P() 515 516 d.P("func (x *", d.ServerStreamImpl(method), ") RecvMsg(m *", d.InputType(method), ") error {") 517 d.P("return x.MsgRecv(m, ", d.EncodingName(), "{})") 518 d.P("}") 519 d.P() 520 } 521} 522