1package descriptor
2
3import (
4	"fmt"
5	"strings"
6
7	"github.com/golang/protobuf/protoc-gen-go/descriptor"
8	gogen "github.com/golang/protobuf/protoc-gen-go/generator"
9	"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
10)
11
12// IsWellKnownType returns true if the provided fully qualified type name is considered 'well-known'.
13func IsWellKnownType(typeName string) bool {
14	_, ok := wellKnownTypeConv[typeName]
15	return ok
16}
17
18// GoPackage represents a golang package
19type GoPackage struct {
20	// Path is the package path to the package.
21	Path string
22	// Name is the package name of the package
23	Name string
24	// Alias is an alias of the package unique within the current invokation of grpc-gateway generator.
25	Alias string
26}
27
28// Standard returns whether the import is a golang standard package.
29func (p GoPackage) Standard() bool {
30	return !strings.Contains(p.Path, ".")
31}
32
33// String returns a string representation of this package in the form of import line in golang.
34func (p GoPackage) String() string {
35	if p.Alias == "" {
36		return fmt.Sprintf("%q", p.Path)
37	}
38	return fmt.Sprintf("%s %q", p.Alias, p.Path)
39}
40
41// File wraps descriptor.FileDescriptorProto for richer features.
42type File struct {
43	*descriptor.FileDescriptorProto
44	// GoPkg is the go package of the go file generated from this file..
45	GoPkg GoPackage
46	// Messages is the list of messages defined in this file.
47	Messages []*Message
48	// Enums is the list of enums defined in this file.
49	Enums []*Enum
50	// Services is the list of services defined in this file.
51	Services []*Service
52}
53
54// proto2 determines if the syntax of the file is proto2.
55func (f *File) proto2() bool {
56	return f.Syntax == nil || f.GetSyntax() == "proto2"
57}
58
59// Message describes a protocol buffer message types
60type Message struct {
61	// File is the file where the message is defined
62	File *File
63	// Outers is a list of outer messages if this message is a nested type.
64	Outers []string
65	*descriptor.DescriptorProto
66	Fields []*Field
67
68	// Index is proto path index of this message in File.
69	Index int
70}
71
72// FQMN returns a fully qualified message name of this message.
73func (m *Message) FQMN() string {
74	components := []string{""}
75	if m.File.Package != nil {
76		components = append(components, m.File.GetPackage())
77	}
78	components = append(components, m.Outers...)
79	components = append(components, m.GetName())
80	return strings.Join(components, ".")
81}
82
83// GoType returns a go type name for the message type.
84// It prefixes the type name with the package alias if
85// its belonging package is not "currentPackage".
86func (m *Message) GoType(currentPackage string) string {
87	var components []string
88	components = append(components, m.Outers...)
89	components = append(components, m.GetName())
90
91	name := strings.Join(components, "_")
92	if m.File.GoPkg.Path == currentPackage {
93		return name
94	}
95	pkg := m.File.GoPkg.Name
96	if alias := m.File.GoPkg.Alias; alias != "" {
97		pkg = alias
98	}
99	return fmt.Sprintf("%s.%s", pkg, name)
100}
101
102// Enum describes a protocol buffer enum types
103type Enum struct {
104	// File is the file where the enum is defined
105	File *File
106	// Outers is a list of outer messages if this enum is a nested type.
107	Outers []string
108	*descriptor.EnumDescriptorProto
109
110	Index int
111}
112
113// FQEN returns a fully qualified enum name of this enum.
114func (e *Enum) FQEN() string {
115	components := []string{""}
116	if e.File.Package != nil {
117		components = append(components, e.File.GetPackage())
118	}
119	components = append(components, e.Outers...)
120	components = append(components, e.GetName())
121	return strings.Join(components, ".")
122}
123
124// GoType returns a go type name for the enum type.
125// It prefixes the type name with the package alias if
126// its belonging package is not "currentPackage".
127func (e *Enum) GoType(currentPackage string) string {
128	var components []string
129	components = append(components, e.Outers...)
130	components = append(components, e.GetName())
131
132	name := strings.Join(components, "_")
133	if e.File.GoPkg.Path == currentPackage {
134		return name
135	}
136	pkg := e.File.GoPkg.Name
137	if alias := e.File.GoPkg.Alias; alias != "" {
138		pkg = alias
139	}
140	return fmt.Sprintf("%s.%s", pkg, name)
141}
142
143// Service wraps descriptor.ServiceDescriptorProto for richer features.
144type Service struct {
145	// File is the file where this service is defined.
146	File *File
147	*descriptor.ServiceDescriptorProto
148	// Methods is the list of methods defined in this service.
149	Methods []*Method
150}
151
152// FQSN returns the fully qualified service name of this service.
153func (s *Service) FQSN() string {
154	components := []string{""}
155	if s.File.Package != nil {
156		components = append(components, s.File.GetPackage())
157	}
158	components = append(components, s.GetName())
159	return strings.Join(components, ".")
160}
161
162// Method wraps descriptor.MethodDescriptorProto for richer features.
163type Method struct {
164	// Service is the service which this method belongs to.
165	Service *Service
166	*descriptor.MethodDescriptorProto
167
168	// RequestType is the message type of requests to this method.
169	RequestType *Message
170	// ResponseType is the message type of responses from this method.
171	ResponseType *Message
172	Bindings     []*Binding
173}
174
175// FQMN returns a fully qualified rpc method name of this method.
176func (m *Method) FQMN() string {
177	components := []string{}
178	components = append(components, m.Service.FQSN())
179	components = append(components, m.GetName())
180	return strings.Join(components, ".")
181}
182
183// Binding describes how an HTTP endpoint is bound to a gRPC method.
184type Binding struct {
185	// Method is the method which the endpoint is bound to.
186	Method *Method
187	// Index is a zero-origin index of the binding in the target method
188	Index int
189	// PathTmpl is path template where this method is mapped to.
190	PathTmpl httprule.Template
191	// HTTPMethod is the HTTP method which this method is mapped to.
192	HTTPMethod string
193	// PathParams is the list of parameters provided in HTTP request paths.
194	PathParams []Parameter
195	// Body describes parameters provided in HTTP request body.
196	Body *Body
197	// ResponseBody describes field in response struct to marshal in HTTP response body.
198	ResponseBody *Body
199}
200
201// ExplicitParams returns a list of explicitly bound parameters of "b",
202// i.e. a union of field path for body and field paths for path parameters.
203func (b *Binding) ExplicitParams() []string {
204	var result []string
205	if b.Body != nil {
206		result = append(result, b.Body.FieldPath.String())
207	}
208	for _, p := range b.PathParams {
209		result = append(result, p.FieldPath.String())
210	}
211	return result
212}
213
214// Field wraps descriptor.FieldDescriptorProto for richer features.
215type Field struct {
216	// Message is the message type which this field belongs to.
217	Message *Message
218	// FieldMessage is the message type of the field.
219	FieldMessage *Message
220	*descriptor.FieldDescriptorProto
221}
222
223// Parameter is a parameter provided in http requests
224type Parameter struct {
225	// FieldPath is a path to a proto field which this parameter is mapped to.
226	FieldPath
227	// Target is the proto field which this parameter is mapped to.
228	Target *Field
229	// Method is the method which this parameter is used for.
230	Method *Method
231}
232
233// ConvertFuncExpr returns a go expression of a converter function.
234// The converter function converts a string into a value for the parameter.
235func (p Parameter) ConvertFuncExpr() (string, error) {
236	tbl := proto3ConvertFuncs
237	if !p.IsProto2() && p.IsRepeated() {
238		tbl = proto3RepeatedConvertFuncs
239	} else if p.IsProto2() && !p.IsRepeated() {
240		tbl = proto2ConvertFuncs
241	} else if p.IsProto2() && p.IsRepeated() {
242		tbl = proto2RepeatedConvertFuncs
243	}
244	typ := p.Target.GetType()
245	conv, ok := tbl[typ]
246	if !ok {
247		conv, ok = wellKnownTypeConv[p.Target.GetTypeName()]
248	}
249	if !ok {
250		return "", fmt.Errorf("unsupported field type %s of parameter %s in %s.%s", typ, p.FieldPath, p.Method.Service.GetName(), p.Method.GetName())
251	}
252	return conv, nil
253}
254
255// IsEnum returns true if the field is an enum type, otherwise false is returned.
256func (p Parameter) IsEnum() bool {
257	return p.Target.GetType() == descriptor.FieldDescriptorProto_TYPE_ENUM
258}
259
260// IsRepeated returns true if the field is repeated, otherwise false is returned.
261func (p Parameter) IsRepeated() bool {
262	return p.Target.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED
263}
264
265// IsProto2 returns true if the field is proto2, otherwise false is returned.
266func (p Parameter) IsProto2() bool {
267	return p.Target.Message.File.proto2()
268}
269
270// Body describes a http (request|response) body to be sent to the (method|client).
271// This is used in body and response_body options in google.api.HttpRule
272type Body struct {
273	// FieldPath is a path to a proto field which the (request|response) body is mapped to.
274	// The (request|response) body is mapped to the (request|response) type itself if FieldPath is empty.
275	FieldPath FieldPath
276}
277
278// AssignableExpr returns an assignable expression in Go to be used to initialize method request object.
279// It starts with "msgExpr", which is the go expression of the method request object.
280func (b Body) AssignableExpr(msgExpr string) string {
281	return b.FieldPath.AssignableExpr(msgExpr)
282}
283
284// FieldPath is a path to a field from a request message.
285type FieldPath []FieldPathComponent
286
287// String returns a string representation of the field path.
288func (p FieldPath) String() string {
289	var components []string
290	for _, c := range p {
291		components = append(components, c.Name)
292	}
293	return strings.Join(components, ".")
294}
295
296// IsNestedProto3 indicates whether the FieldPath is a nested Proto3 path.
297func (p FieldPath) IsNestedProto3() bool {
298	if len(p) > 1 && !p[0].Target.Message.File.proto2() {
299		return true
300	}
301	return false
302}
303
304// AssignableExpr is an assignable expression in Go to be used to assign a value to the target field.
305// It starts with "msgExpr", which is the go expression of the method request object.
306func (p FieldPath) AssignableExpr(msgExpr string) string {
307	l := len(p)
308	if l == 0 {
309		return msgExpr
310	}
311
312	var preparations []string
313	components := msgExpr
314	for i, c := range p {
315		// Check if it is a oneOf field.
316		if c.Target.OneofIndex != nil {
317			index := c.Target.OneofIndex
318			msg := c.Target.Message
319			oneOfName := gogen.CamelCase(msg.GetOneofDecl()[*index].GetName())
320			oneofFieldName := msg.GetName() + "_" + c.AssignableExpr()
321
322			components = components + "." + oneOfName
323			s := `if %s == nil {
324				%s =&%s{}
325			} else if _, ok := %s.(*%s); !ok {
326				return nil, metadata, grpc.Errorf(codes.InvalidArgument, "expect type: *%s, but: %%t\n",%s)
327			}`
328
329			preparations = append(preparations, fmt.Sprintf(s, components, components, oneofFieldName, components, oneofFieldName, oneofFieldName, components))
330			components = components + ".(*" + oneofFieldName + ")"
331		}
332
333		if i == l-1 {
334			components = components + "." + c.AssignableExpr()
335			continue
336		}
337		components = components + "." + c.ValueExpr()
338	}
339
340	preparations = append(preparations, components)
341	return strings.Join(preparations, "\n")
342}
343
344// FieldPathComponent is a path component in FieldPath
345type FieldPathComponent struct {
346	// Name is a name of the proto field which this component corresponds to.
347	// TODO(yugui) is this necessary?
348	Name string
349	// Target is the proto field which this component corresponds to.
350	Target *Field
351}
352
353// AssignableExpr returns an assignable expression in go for this field.
354func (c FieldPathComponent) AssignableExpr() string {
355	return gogen.CamelCase(c.Name)
356}
357
358// ValueExpr returns an expression in go for this field.
359func (c FieldPathComponent) ValueExpr() string {
360	if c.Target.Message.File.proto2() {
361		return fmt.Sprintf("Get%s()", gogen.CamelCase(c.Name))
362	}
363	return gogen.CamelCase(c.Name)
364}
365
366var (
367	proto3ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
368		descriptor.FieldDescriptorProto_TYPE_DOUBLE:  "runtime.Float64",
369		descriptor.FieldDescriptorProto_TYPE_FLOAT:   "runtime.Float32",
370		descriptor.FieldDescriptorProto_TYPE_INT64:   "runtime.Int64",
371		descriptor.FieldDescriptorProto_TYPE_UINT64:  "runtime.Uint64",
372		descriptor.FieldDescriptorProto_TYPE_INT32:   "runtime.Int32",
373		descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64",
374		descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32",
375		descriptor.FieldDescriptorProto_TYPE_BOOL:    "runtime.Bool",
376		descriptor.FieldDescriptorProto_TYPE_STRING:  "runtime.String",
377		// FieldDescriptorProto_TYPE_GROUP
378		// FieldDescriptorProto_TYPE_MESSAGE
379		descriptor.FieldDescriptorProto_TYPE_BYTES:    "runtime.Bytes",
380		descriptor.FieldDescriptorProto_TYPE_UINT32:   "runtime.Uint32",
381		descriptor.FieldDescriptorProto_TYPE_ENUM:     "runtime.Enum",
382		descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32",
383		descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64",
384		descriptor.FieldDescriptorProto_TYPE_SINT32:   "runtime.Int32",
385		descriptor.FieldDescriptorProto_TYPE_SINT64:   "runtime.Int64",
386	}
387
388	proto3RepeatedConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
389		descriptor.FieldDescriptorProto_TYPE_DOUBLE:  "runtime.Float64Slice",
390		descriptor.FieldDescriptorProto_TYPE_FLOAT:   "runtime.Float32Slice",
391		descriptor.FieldDescriptorProto_TYPE_INT64:   "runtime.Int64Slice",
392		descriptor.FieldDescriptorProto_TYPE_UINT64:  "runtime.Uint64Slice",
393		descriptor.FieldDescriptorProto_TYPE_INT32:   "runtime.Int32Slice",
394		descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64Slice",
395		descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32Slice",
396		descriptor.FieldDescriptorProto_TYPE_BOOL:    "runtime.BoolSlice",
397		descriptor.FieldDescriptorProto_TYPE_STRING:  "runtime.StringSlice",
398		// FieldDescriptorProto_TYPE_GROUP
399		// FieldDescriptorProto_TYPE_MESSAGE
400		descriptor.FieldDescriptorProto_TYPE_BYTES:    "runtime.BytesSlice",
401		descriptor.FieldDescriptorProto_TYPE_UINT32:   "runtime.Uint32Slice",
402		descriptor.FieldDescriptorProto_TYPE_ENUM:     "runtime.EnumSlice",
403		descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32Slice",
404		descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64Slice",
405		descriptor.FieldDescriptorProto_TYPE_SINT32:   "runtime.Int32Slice",
406		descriptor.FieldDescriptorProto_TYPE_SINT64:   "runtime.Int64Slice",
407	}
408
409	proto2ConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
410		descriptor.FieldDescriptorProto_TYPE_DOUBLE:  "runtime.Float64P",
411		descriptor.FieldDescriptorProto_TYPE_FLOAT:   "runtime.Float32P",
412		descriptor.FieldDescriptorProto_TYPE_INT64:   "runtime.Int64P",
413		descriptor.FieldDescriptorProto_TYPE_UINT64:  "runtime.Uint64P",
414		descriptor.FieldDescriptorProto_TYPE_INT32:   "runtime.Int32P",
415		descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64P",
416		descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32P",
417		descriptor.FieldDescriptorProto_TYPE_BOOL:    "runtime.BoolP",
418		descriptor.FieldDescriptorProto_TYPE_STRING:  "runtime.StringP",
419		// FieldDescriptorProto_TYPE_GROUP
420		// FieldDescriptorProto_TYPE_MESSAGE
421		// FieldDescriptorProto_TYPE_BYTES
422		// TODO(yugui) Handle bytes
423		descriptor.FieldDescriptorProto_TYPE_UINT32:   "runtime.Uint32P",
424		descriptor.FieldDescriptorProto_TYPE_ENUM:     "runtime.EnumP",
425		descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32P",
426		descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64P",
427		descriptor.FieldDescriptorProto_TYPE_SINT32:   "runtime.Int32P",
428		descriptor.FieldDescriptorProto_TYPE_SINT64:   "runtime.Int64P",
429	}
430
431	proto2RepeatedConvertFuncs = map[descriptor.FieldDescriptorProto_Type]string{
432		descriptor.FieldDescriptorProto_TYPE_DOUBLE:  "runtime.Float64Slice",
433		descriptor.FieldDescriptorProto_TYPE_FLOAT:   "runtime.Float32Slice",
434		descriptor.FieldDescriptorProto_TYPE_INT64:   "runtime.Int64Slice",
435		descriptor.FieldDescriptorProto_TYPE_UINT64:  "runtime.Uint64Slice",
436		descriptor.FieldDescriptorProto_TYPE_INT32:   "runtime.Int32Slice",
437		descriptor.FieldDescriptorProto_TYPE_FIXED64: "runtime.Uint64Slice",
438		descriptor.FieldDescriptorProto_TYPE_FIXED32: "runtime.Uint32Slice",
439		descriptor.FieldDescriptorProto_TYPE_BOOL:    "runtime.BoolSlice",
440		descriptor.FieldDescriptorProto_TYPE_STRING:  "runtime.StringSlice",
441		// FieldDescriptorProto_TYPE_GROUP
442		// FieldDescriptorProto_TYPE_MESSAGE
443		// FieldDescriptorProto_TYPE_BYTES
444		// TODO(maros7) Handle bytes
445		descriptor.FieldDescriptorProto_TYPE_UINT32:   "runtime.Uint32Slice",
446		descriptor.FieldDescriptorProto_TYPE_ENUM:     "runtime.EnumSlice",
447		descriptor.FieldDescriptorProto_TYPE_SFIXED32: "runtime.Int32Slice",
448		descriptor.FieldDescriptorProto_TYPE_SFIXED64: "runtime.Int64Slice",
449		descriptor.FieldDescriptorProto_TYPE_SINT32:   "runtime.Int32Slice",
450		descriptor.FieldDescriptorProto_TYPE_SINT64:   "runtime.Int64Slice",
451	}
452
453	wellKnownTypeConv = map[string]string{
454		".google.protobuf.Timestamp":   "runtime.Timestamp",
455		".google.protobuf.Duration":    "runtime.Duration",
456		".google.protobuf.StringValue": "runtime.StringValue",
457		".google.protobuf.FloatValue":  "runtime.FloatValue",
458		".google.protobuf.DoubleValue": "runtime.DoubleValue",
459		".google.protobuf.BoolValue":   "runtime.BoolValue",
460		".google.protobuf.BytesValue":  "runtime.BytesValue",
461		".google.protobuf.Int32Value":  "runtime.Int32Value",
462		".google.protobuf.UInt32Value": "runtime.UInt32Value",
463		".google.protobuf.Int64Value":  "runtime.Int64Value",
464		".google.protobuf.UInt64Value": "runtime.UInt64Value",
465	}
466)
467