1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package pb
16
17import (
18	"fmt"
19	"reflect"
20	"strings"
21	"sync"
22
23	"github.com/golang/protobuf/proto"
24	descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
25	structpb "github.com/golang/protobuf/ptypes/struct"
26	exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
27)
28
29// NewTypeDescription produces a TypeDescription value for the fully-qualified proto type name
30// with a given descriptor.
31//
32// The type description creation method also expects the type to be marked clearly as a proto2 or
33// proto3 type, and accepts a typeResolver reference for resolving field TypeDescription during
34// lazily initialization of the type which is done atomically.
35func NewTypeDescription(typeName string, desc *descpb.DescriptorProto,
36	isProto3 bool, resolveType typeResolver) *TypeDescription {
37	return &TypeDescription{
38		typeName:    typeName,
39		isProto3:    isProto3,
40		desc:        desc,
41		resolveType: resolveType,
42	}
43}
44
45// TypeDescription is a collection of type metadata relevant to expression
46// checking and evaluation.
47type TypeDescription struct {
48	typeName string
49	isProto3 bool
50	desc     *descpb.DescriptorProto
51
52	// resolveType is used to lookup field types during type initialization.
53	// The resolver may point to shared state; however, this state is guaranteed to be computed at
54	// most one time.
55	resolveType typeResolver
56	init        sync.Once
57	metadata    *typeMetadata
58}
59
60// typeResolver accepts a type name and returns a TypeDescription.
61// The typeResolver is used to resolve field types during lazily initialization of the type
62// description metadata.
63type typeResolver func(typeName string) (*TypeDescription, error)
64
65type typeMetadata struct {
66	fields          map[string]*FieldDescription // fields by name (proto)
67	fieldIndices    map[int][]*FieldDescription  // fields by Go struct idx
68	fieldProperties *proto.StructProperties
69	reflectedType   *reflect.Type
70	reflectedVal    *reflect.Value
71	emptyVal        interface{}
72}
73
74// FieldCount returns the number of fields declared within the type.
75func (td *TypeDescription) FieldCount() int {
76	// The number of keys in the field indices map corresponds to the number
77	// of fields on the proto message.
78	return len(td.getMetadata().fieldIndices)
79}
80
81// FieldByName returns the FieldDescription associated with a field name.
82func (td *TypeDescription) FieldByName(name string) (*FieldDescription, bool) {
83	fd, found := td.getMetadata().fields[name]
84	return fd, found
85}
86
87// Name of the type.
88func (td *TypeDescription) Name() string {
89	return td.typeName
90}
91
92// ReflectType returns the reflected struct type of the generated proto struct.
93func (td *TypeDescription) ReflectType() reflect.Type {
94	if td.getMetadata().reflectedType == nil {
95		return nil
96	}
97	return *td.getMetadata().reflectedType
98}
99
100// DefaultValue returns an empty instance of the proto message associated with the type,
101// or nil for wrapper types.
102func (td *TypeDescription) DefaultValue() proto.Message {
103	val := td.getMetadata().emptyVal
104	if val == nil {
105		return nil
106	}
107	return val.(proto.Message)
108}
109
110// getMetadata computes the type field metadata used for determining field types and default
111// values. The call to makeMetadata within this method is guaranteed to be invoked exactly
112// once.
113func (td *TypeDescription) getMetadata() *typeMetadata {
114	td.init.Do(func() {
115		td.metadata = td.makeMetadata()
116	})
117	return td.metadata
118}
119
120func (td *TypeDescription) makeMetadata() *typeMetadata {
121	refType := proto.MessageType(td.typeName)
122	meta := &typeMetadata{
123		fields:       make(map[string]*FieldDescription),
124		fieldIndices: make(map[int][]*FieldDescription),
125	}
126	if refType != nil {
127		// Set the reflected type if non-nil.
128		meta.reflectedType = &refType
129
130		// Unwrap the pointer reference for the sake of later checks.
131		elemType := refType
132		if elemType.Kind() == reflect.Ptr {
133			elemType = elemType.Elem()
134		}
135		if elemType.Kind() == reflect.Struct {
136			meta.fieldProperties = proto.GetProperties(elemType)
137		}
138		refVal := reflect.New(elemType)
139		meta.reflectedVal = &refVal
140		if refVal.CanInterface() {
141			meta.emptyVal = refVal.Interface()
142		} else {
143			meta.emptyVal = reflect.Zero(elemType).Interface()
144		}
145	}
146
147	fieldIndexMap := make(map[string]int)
148	fieldDescMap := make(map[string]*descpb.FieldDescriptorProto)
149	for i, f := range td.desc.Field {
150		fieldDescMap[f.GetName()] = f
151		fieldIndexMap[f.GetName()] = i
152	}
153	if meta.fieldProperties != nil {
154		// This is a proper message type.
155		for i, prop := range meta.fieldProperties.Prop {
156			if strings.HasPrefix(prop.OrigName, "XXX_") {
157				// Book-keeping fields generated by protoc start with XXX_
158				continue
159			}
160			desc := fieldDescMap[prop.OrigName]
161			fd := td.newFieldDesc(*meta.reflectedType, desc, prop, i)
162			meta.fields[prop.OrigName] = fd
163			meta.fieldIndices[i] = append(meta.fieldIndices[i], fd)
164		}
165		for _, oneofProp := range meta.fieldProperties.OneofTypes {
166			desc := fieldDescMap[oneofProp.Prop.OrigName]
167			fd := td.newOneofFieldDesc(*meta.reflectedType, desc, oneofProp, oneofProp.Field)
168			meta.fields[oneofProp.Prop.OrigName] = fd
169			meta.fieldIndices[oneofProp.Field] = append(meta.fieldIndices[oneofProp.Field], fd)
170		}
171	} else {
172		for fieldName, desc := range fieldDescMap {
173			fd := td.newMapFieldDesc(desc)
174			meta.fields[fieldName] = fd
175			index := fieldIndexMap[fieldName]
176			meta.fieldIndices[index] = append(meta.fieldIndices[index], fd)
177		}
178	}
179	return meta
180}
181
182// Create a new field description for the proto field descriptor associated with the given type.
183// The field properties should never not be found when performing reflection on the type unless
184// there are fundamental changes to the backing proto library behavior.
185func (td *TypeDescription) newFieldDesc(
186	tdType reflect.Type,
187	desc *descpb.FieldDescriptorProto,
188	prop *proto.Properties,
189	index int) *FieldDescription {
190	getterName := fmt.Sprintf("Get%s", prop.Name)
191	getter, _ := tdType.MethodByName(getterName)
192	var field *reflect.StructField
193	if tdType.Kind() == reflect.Ptr {
194		tdType = tdType.Elem()
195	}
196	f, found := tdType.FieldByName(prop.Name)
197	if found {
198		field = &f
199	}
200	fieldDesc := &FieldDescription{
201		desc:      desc,
202		index:     index,
203		getter:    getter.Func,
204		field:     field,
205		prop:      prop,
206		isProto3:  td.isProto3,
207		isWrapper: isWrapperType(desc),
208	}
209	if desc.GetType() == descpb.FieldDescriptorProto_TYPE_MESSAGE {
210		typeName := sanitizeProtoName(desc.GetTypeName())
211		fieldType, _ := td.resolveType(typeName)
212		fieldDesc.td = fieldType
213		return fieldDesc
214	}
215	return fieldDesc
216}
217
218func (td *TypeDescription) newOneofFieldDesc(
219	tdType reflect.Type,
220	desc *descpb.FieldDescriptorProto,
221	oneofProp *proto.OneofProperties,
222	index int) *FieldDescription {
223	fieldDesc := td.newFieldDesc(tdType, desc, oneofProp.Prop, index)
224	fieldDesc.oneofProp = oneofProp
225	return fieldDesc
226}
227
228func (td *TypeDescription) newMapFieldDesc(desc *descpb.FieldDescriptorProto) *FieldDescription {
229	return &FieldDescription{
230		desc:     desc,
231		index:    int(desc.GetNumber()),
232		isProto3: td.isProto3,
233	}
234}
235
236func isWrapperType(desc *descpb.FieldDescriptorProto) bool {
237	if desc.GetType() != descpb.FieldDescriptorProto_TYPE_MESSAGE {
238		return false
239	}
240	switch sanitizeProtoName(desc.GetTypeName()) {
241	case "google.protobuf.BoolValue",
242		"google.protobuf.BytesValue",
243		"google.protobuf.DoubleValue",
244		"google.protobuf.FloatValue",
245		"google.protobuf.Int32Value",
246		"google.protobuf.Int64Value",
247		"google.protobuf.StringValue",
248		"google.protobuf.UInt32Value",
249		"google.protobuf.UInt64Value":
250		return true
251	}
252	return false
253}
254
255// FieldDescription holds metadata related to fields declared within a type.
256type FieldDescription struct {
257	// getter is the reflected accessor method that obtains the field value.
258	getter reflect.Value
259	// field is the field location in a refValue
260	// The field will be not found for oneofs, but this is accounted for
261	// by checking the 'desc' value which provides this information.
262	field *reflect.StructField
263	// isProto3 indicates whether the field is defined in a proto3 syntax.
264	isProto3 bool
265	// isWrapper indicates whether the field is a wrapper type.
266	isWrapper bool
267
268	// td is the type description for message typed fields.
269	td *TypeDescription
270
271	// proto descriptor data.
272	desc      *descpb.FieldDescriptorProto
273	index     int
274	prop      *proto.Properties
275	oneofProp *proto.OneofProperties
276}
277
278// CheckedType returns the type-definition used at type-check time.
279func (fd *FieldDescription) CheckedType() *exprpb.Type {
280	if fd.IsMap() {
281		// Get the FieldDescriptors for the type arranged by their index within the
282		// generated Go struct.
283		fieldIndices := fd.getFieldIndicies()
284		// Map keys and values are represented as repeated entries in a list.
285		key := fieldIndices[0][0]
286		val := fieldIndices[1][0]
287		return &exprpb.Type{
288			TypeKind: &exprpb.Type_MapType_{
289				MapType: &exprpb.Type_MapType{
290					KeyType:   key.typeDefToType(),
291					ValueType: val.typeDefToType()}}}
292	}
293	if fd.IsRepeated() {
294		return &exprpb.Type{
295			TypeKind: &exprpb.Type_ListType_{
296				ListType: &exprpb.Type_ListType{
297					ElemType: fd.typeDefToType()}}}
298	}
299	return fd.typeDefToType()
300}
301
302// IsSet returns whether the field is set on the target value, per the proto presence conventions
303// of proto2 or proto3 accordingly.
304//
305// The input target may either be a reflect.Value or Go struct type.
306func (fd *FieldDescription) IsSet(target interface{}) bool {
307	t, ok := target.(reflect.Value)
308	if !ok {
309		t = reflect.ValueOf(target)
310	}
311	// For the case where the field is not a oneof, test whether the field is set on the target
312	// value assuming it is a struct. A field that is not set will be one of the following values:
313	// - nil for message and primitive typed fields in proto2
314	// - nil for message typed fields in proto3
315	// - empty for primitive typed fields in proto3
316	if fd.field != nil && !fd.IsOneof() {
317		t = reflect.Indirect(t)
318		if t.Kind() != reflect.Struct {
319			return false
320		}
321		return isFieldSet(t.FieldByIndex(fd.field.Index))
322	}
323	// When the field is nil or when the field is a oneof, call the accessor
324	// associated with this field name to determine whether the field value is
325	// the default.
326	fieldVal := fd.getter.Call([]reflect.Value{t})[0]
327	return isFieldSet(fieldVal)
328}
329
330// GetFrom returns the accessor method associated with the field on the proto generated struct.
331//
332// If the field is not set, the proto default value is returned instead.
333//
334// The input target may either be a reflect.Value or Go struct type.
335func (fd *FieldDescription) GetFrom(target interface{}) (interface{}, error) {
336	t, ok := target.(reflect.Value)
337	if !ok {
338		t = reflect.ValueOf(target)
339	}
340	var fieldVal reflect.Value
341	if fd.isProto3 && fd.field != nil && !fd.IsOneof() {
342		// The target object should always be a struct.
343		t = reflect.Indirect(t)
344		if t.Kind() != reflect.Struct {
345			return nil, fmt.Errorf("unsupported field selection target: %T", target)
346		}
347		fieldVal = t.FieldByIndex(fd.field.Index)
348	} else {
349		// The accessor method must be used for proto2 in order to properly handle
350		// default values.
351		// Additionally, proto3 oneofs require the use of the accessor to get the proper value.
352		fieldVal = fd.getter.Call([]reflect.Value{t})[0]
353	}
354	if isFieldSet(fieldVal) {
355		// Return the field value assuming it is set. For proto3 the value may be a zero value.
356		if fieldVal.CanInterface() {
357			return fieldVal.Interface(), nil
358		}
359		return reflect.Zero(fieldVal.Type()).Interface(), nil
360	}
361	if fd.IsWrapper() {
362		return structpb.NullValue_NULL_VALUE, nil
363	}
364	if fd.IsMessage() {
365		return fd.Type().DefaultValue(), nil
366	}
367	return nil, fmt.Errorf("no default value for field: %s", fd.Name())
368}
369
370// Index returns the field index within a reflected value.
371func (fd *FieldDescription) Index() int {
372	return fd.index
373}
374
375// IsEnum returns true if the field type refers to an enum value.
376func (fd *FieldDescription) IsEnum() bool {
377	return fd.desc.GetType() == descpb.FieldDescriptorProto_TYPE_ENUM
378}
379
380// IsMap returns true if the field is of map type.
381func (fd *FieldDescription) IsMap() bool {
382	if !fd.IsRepeated() || !fd.IsMessage() {
383		return false
384	}
385	if fd.td == nil {
386		return false
387	}
388	return fd.td.desc.GetOptions().GetMapEntry()
389}
390
391// IsMessage returns true if the field is of message type.
392func (fd *FieldDescription) IsMessage() bool {
393	return fd.desc.GetType() == descpb.FieldDescriptorProto_TYPE_MESSAGE
394}
395
396// IsOneof returns true if the field is declared within a oneof block.
397func (fd *FieldDescription) IsOneof() bool {
398	if fd.desc != nil {
399		return fd.desc.OneofIndex != nil
400	}
401	return fd.oneofProp != nil
402}
403
404// IsRepeated returns true if the field is a repeated value.
405//
406// This method will also return true for map values, so check whether the
407// field is also a map.
408func (fd *FieldDescription) IsRepeated() bool {
409	return *fd.desc.Label == descpb.FieldDescriptorProto_LABEL_REPEATED
410}
411
412// IsWrapper returns true if the field type is a primitive wrapper type.
413func (fd *FieldDescription) IsWrapper() bool {
414	return fd.isWrapper
415}
416
417// OneofType returns the reflect.Type value of a oneof field.
418//
419// Oneof field values are wrapped in a struct which contains one field whose
420// value is a proto.Message.
421func (fd *FieldDescription) OneofType() reflect.Type {
422	return fd.oneofProp.Type
423}
424
425// OrigName returns the snake_case name of the field as it was declared within
426// the proto. This is the same name format that is expected within expressions.
427func (fd *FieldDescription) OrigName() string {
428	if fd.desc != nil && fd.desc.Name != nil {
429		return *fd.desc.Name
430	}
431	return fd.prop.OrigName
432}
433
434// Name returns the CamelCase name of the field within the proto-based struct.
435func (fd *FieldDescription) Name() string {
436	return fd.prop.Name
437}
438
439// SupportsPresence returns true if the field supports presence detection.
440func (fd *FieldDescription) SupportsPresence() bool {
441	return !fd.IsRepeated() && (fd.IsMessage() || !fd.isProto3)
442}
443
444// String returns a struct-like field definition string.
445func (fd *FieldDescription) String() string {
446	return fmt.Sprintf("%s %s `oneof=%t`",
447		fd.TypeName(), fd.OrigName(), fd.IsOneof())
448}
449
450// Type returns the TypeDescription for the field.
451func (fd *FieldDescription) Type() *TypeDescription {
452	return fd.td
453}
454
455// TypeName returns the type name of the field.
456func (fd *FieldDescription) TypeName() string {
457	return sanitizeProtoName(fd.desc.GetTypeName())
458}
459
460func (fd *FieldDescription) getFieldIndicies() map[int][]*FieldDescription {
461	return fd.td.getMetadata().fieldIndices
462}
463
464func (fd *FieldDescription) typeDefToType() *exprpb.Type {
465	if fd.IsMessage() {
466		if wk, found := CheckedWellKnowns[fd.TypeName()]; found {
467			return wk
468		}
469		return checkedMessageType(fd.TypeName())
470	}
471	if fd.IsEnum() {
472		return checkedInt
473	}
474	if p, found := CheckedPrimitives[fd.desc.GetType()]; found {
475		return p
476	}
477	return CheckedPrimitives[fd.desc.GetType()]
478}
479
480func checkedMessageType(name string) *exprpb.Type {
481	return &exprpb.Type{
482		TypeKind: &exprpb.Type_MessageType{MessageType: name}}
483}
484
485func checkedPrimitive(primitive exprpb.Type_PrimitiveType) *exprpb.Type {
486	return &exprpb.Type{
487		TypeKind: &exprpb.Type_Primitive{Primitive: primitive}}
488}
489
490func checkedWellKnown(wellKnown exprpb.Type_WellKnownType) *exprpb.Type {
491	return &exprpb.Type{
492		TypeKind: &exprpb.Type_WellKnown{WellKnown: wellKnown}}
493}
494
495func checkedWrap(t *exprpb.Type) *exprpb.Type {
496	return &exprpb.Type{
497		TypeKind: &exprpb.Type_Wrapper{Wrapper: t.GetPrimitive()}}
498}
499
500func isFieldSet(refVal reflect.Value) bool {
501	return refVal.Kind() != reflect.Ptr || !refVal.IsNil()
502}
503