1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package types
16
17import (
18	"reflect"
19
20	"github.com/golang/protobuf/proto"
21	"github.com/golang/protobuf/ptypes"
22
23	"github.com/google/cel-go/common/types/pb"
24	"github.com/google/cel-go/common/types/ref"
25
26	descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
27	anypb "github.com/golang/protobuf/ptypes/any"
28	dpb "github.com/golang/protobuf/ptypes/duration"
29	structpb "github.com/golang/protobuf/ptypes/struct"
30	tpb "github.com/golang/protobuf/ptypes/timestamp"
31	wrapperspb "github.com/golang/protobuf/ptypes/wrappers"
32
33	exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
34)
35
36type protoTypeRegistry struct {
37	revTypeMap map[string]ref.Type
38	pbdb       *pb.Db
39}
40
41// NewRegistry accepts a list of proto message instances and returns a type
42// provider which can create new instances of the provided message or any
43// message that proto depends upon in its FileDescriptor.
44func NewRegistry(types ...proto.Message) ref.TypeRegistry {
45	p := &protoTypeRegistry{
46		revTypeMap: make(map[string]ref.Type),
47		pbdb:       pb.NewDb(),
48	}
49	p.RegisterType(
50		BoolType,
51		BytesType,
52		DoubleType,
53		DurationType,
54		IntType,
55		ListType,
56		MapType,
57		NullType,
58		StringType,
59		TimestampType,
60		TypeType,
61		UintType)
62
63	for _, msgType := range types {
64		err := p.RegisterMessage(msgType)
65		if err != nil {
66			panic(err)
67		}
68	}
69	return p
70}
71
72// NewEmptyRegistry returns a registry which is completely unconfigured.
73func NewEmptyRegistry() ref.TypeRegistry {
74	return &protoTypeRegistry{
75		revTypeMap: make(map[string]ref.Type),
76		pbdb:       pb.NewDb(),
77	}
78}
79
80func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val {
81	enumVal, err := p.pbdb.DescribeEnum(enumName)
82	if err != nil {
83		return NewErr("unknown enum name '%s'", enumName)
84	}
85	return Int(enumVal.Value())
86}
87
88func (p *protoTypeRegistry) FindFieldType(messageType string,
89	fieldName string) (*ref.FieldType, bool) {
90	msgType, err := p.pbdb.DescribeType(messageType)
91	if err != nil {
92		return nil, false
93	}
94	field, found := msgType.FieldByName(fieldName)
95	if !found {
96		return nil, false
97	}
98	return &ref.FieldType{
99			Type:             field.CheckedType(),
100			SupportsPresence: field.SupportsPresence(),
101			IsSet:            field.IsSet,
102			GetFrom:          field.GetFrom},
103		true
104}
105
106func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) {
107	if t, found := p.revTypeMap[identName]; found {
108		return t.(ref.Val), true
109	}
110	if enumVal, err := p.pbdb.DescribeEnum(identName); err == nil {
111		return Int(enumVal.Value()), true
112	}
113	return nil, false
114}
115
116func (p *protoTypeRegistry) FindType(typeName string) (*exprpb.Type, bool) {
117	if _, err := p.pbdb.DescribeType(typeName); err != nil {
118		return nil, false
119	}
120	if typeName != "" && typeName[0] == '.' {
121		typeName = typeName[1:]
122	}
123	return &exprpb.Type{
124		TypeKind: &exprpb.Type_Type{
125			Type: &exprpb.Type{
126				TypeKind: &exprpb.Type_MessageType{
127					MessageType: typeName}}}}, true
128}
129
130func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val) ref.Val {
131	td, err := p.pbdb.DescribeType(typeName)
132	if err != nil {
133		return NewErr("unknown type '%s'", typeName)
134	}
135	refType := td.ReflectType()
136	// create the new type instance.
137	value := reflect.New(refType.Elem())
138	pbValue := value.Elem()
139
140	// for all of the field names referenced, set the provided value.
141	for name, value := range fields {
142		fd, found := td.FieldByName(name)
143		if !found {
144			return NewErr("no such field '%s'", name)
145		}
146		refField := pbValue.Field(fd.Index())
147		if !refField.IsValid() {
148			return NewErr("no such field '%s'", name)
149		}
150
151		dstType := refField.Type()
152		// Oneof fields are defined with wrapper structs that have a single proto.Message
153		// field value. The oneof wrapper is not a proto.Message instance.
154		if fd.IsOneof() {
155			oneofVal := reflect.New(fd.OneofType().Elem())
156			refField.Set(oneofVal)
157			refField = oneofVal.Elem().Field(0)
158			dstType = refField.Type()
159		}
160		fieldValue, err := value.ConvertToNative(dstType)
161		if err != nil {
162			return &Err{err}
163		}
164		refField.Set(reflect.ValueOf(fieldValue))
165	}
166	return p.NativeToValue(value.Interface())
167}
168
169func (p *protoTypeRegistry) RegisterDescriptor(fileDesc *descpb.FileDescriptorProto) error {
170	fd, err := p.pbdb.RegisterDescriptor(fileDesc)
171	if err != nil {
172		return err
173	}
174	return p.registerAllTypes(fd)
175}
176
177func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error {
178	fd, err := p.pbdb.RegisterMessage(message)
179	if err != nil {
180		return err
181	}
182	return p.registerAllTypes(fd)
183}
184
185func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error {
186	for _, t := range types {
187		p.revTypeMap[t.TypeName()] = t
188	}
189	// TODO: generate an error when the type name is registered more than once.
190	return nil
191}
192
193func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error {
194	for _, typeName := range fd.GetTypeNames() {
195		err := p.RegisterType(NewObjectTypeValue(typeName))
196		if err != nil {
197			return err
198		}
199	}
200	return nil
201}
202
203// NativeToValue converts various "native" types to ref.Val with this specific implementation
204// providing support for custom proto-based types.
205//
206// This method should be the inverse of ref.Val.ConvertToNative.
207func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val {
208	switch v := value.(type) {
209	case ref.Val:
210		return v
211	// Adapt common types and aggregate specializations using the DefaultTypeAdapter.
212	case bool, *bool,
213		float32, *float32, float64, *float64,
214		int, *int, int32, *int32, int64, *int64,
215		string, *string,
216		uint, *uint, uint32, *uint32, uint64, *uint64,
217		[]byte,
218		[]string,
219		map[string]string:
220		return DefaultTypeAdapter.NativeToValue(value)
221	// Adapt well-known proto-types using the DefaultTypeAdapter.
222	case *dpb.Duration,
223		*tpb.Timestamp,
224		*structpb.ListValue,
225		structpb.NullValue,
226		*structpb.Struct,
227		*structpb.Value,
228		*wrapperspb.BoolValue,
229		*wrapperspb.BytesValue,
230		*wrapperspb.DoubleValue,
231		*wrapperspb.FloatValue,
232		*wrapperspb.Int32Value,
233		*wrapperspb.Int64Value,
234		*wrapperspb.StringValue,
235		*wrapperspb.UInt32Value,
236		*wrapperspb.UInt64Value:
237		return DefaultTypeAdapter.NativeToValue(value)
238	// Override the Any type by ensuring that custom proto-types are considered on recursive calls.
239	case *anypb.Any:
240		if v == nil {
241			return NewErr("unsupported type conversion: '%T'", value)
242		}
243		unpackedAny := ptypes.DynamicAny{}
244		if ptypes.UnmarshalAny(v, &unpackedAny) != nil {
245			return NewErr("unknown type: '%s'", v.GetTypeUrl())
246		}
247		return p.NativeToValue(unpackedAny.Message)
248	// Convert custom proto types to CEL values based on type's presence within the pb.Db.
249	case proto.Message:
250		typeName := proto.MessageName(v)
251		td, err := p.pbdb.DescribeType(typeName)
252		if err != nil {
253			return NewErr("unknown type: '%s'", typeName)
254		}
255		typeVal, found := p.FindIdent(typeName)
256		if !found {
257			return NewErr("unknown type: '%s'", typeName)
258		}
259		return NewObject(p, td, typeVal.(*TypeValue), v)
260	// Override default handling for list and maps to ensure that blends of Go + proto types
261	// are appropriately adapted on recursive calls or subsequent inspection of the aggregate
262	// value.
263	default:
264		refValue := reflect.ValueOf(value)
265		if refValue.Kind() == reflect.Ptr {
266			if refValue.IsNil() {
267				return NewErr("unsupported type conversion: '%T'", value)
268			}
269			refValue = refValue.Elem()
270		}
271		refKind := refValue.Kind()
272		switch refKind {
273		case reflect.Array, reflect.Slice:
274			return NewDynamicList(p, value)
275		case reflect.Map:
276			return NewDynamicMap(p, value)
277		}
278	}
279	// By default return the default type adapter's conversion to CEL.
280	return DefaultTypeAdapter.NativeToValue(value)
281}
282
283// defaultTypeAdapter converts go native types to CEL values.
284type defaultTypeAdapter struct{}
285
286var (
287	// DefaultTypeAdapter adapts canonical CEL types from their equivalent Go values.
288	DefaultTypeAdapter = &defaultTypeAdapter{}
289)
290
291// NativeToValue implements the ref.TypeAdapter interface.
292func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val {
293	switch value.(type) {
294	case nil:
295		return NullValue
296	case *Bool:
297		if ptr := value.(*Bool); ptr != nil {
298			return ptr
299		}
300	case *Bytes:
301		if ptr := value.(*Bytes); ptr != nil {
302			return ptr
303		}
304	case *Double:
305		if ptr := value.(*Double); ptr != nil {
306			return ptr
307		}
308	case *Int:
309		if ptr := value.(*Int); ptr != nil {
310			return ptr
311		}
312	case *String:
313		if ptr := value.(*String); ptr != nil {
314			return ptr
315		}
316	case *Uint:
317		if ptr := value.(*Uint); ptr != nil {
318			return ptr
319		}
320	case ref.Val:
321		return value.(ref.Val)
322	case bool:
323		return Bool(value.(bool))
324	case int:
325		return Int(value.(int))
326	case int32:
327		return Int(value.(int32))
328	case int64:
329		return Int(value.(int64))
330	case uint:
331		return Uint(value.(uint))
332	case uint32:
333		return Uint(value.(uint32))
334	case uint64:
335		return Uint(value.(uint64))
336	case float32:
337		return Double(value.(float32))
338	case float64:
339		return Double(value.(float64))
340	case string:
341		return String(value.(string))
342	case *bool:
343		if ptr := value.(*bool); ptr != nil {
344			return Bool(*ptr)
345		}
346	case *float32:
347		if ptr := value.(*float32); ptr != nil {
348			return Double(*ptr)
349		}
350	case *float64:
351		if ptr := value.(*float64); ptr != nil {
352			return Double(*ptr)
353		}
354	case *int:
355		if ptr := value.(*int); ptr != nil {
356			return Int(*ptr)
357		}
358	case *int32:
359		if ptr := value.(*int32); ptr != nil {
360			return Int(*ptr)
361		}
362	case *int64:
363		if ptr := value.(*int64); ptr != nil {
364			return Int(*ptr)
365		}
366	case *string:
367		if ptr := value.(*string); ptr != nil {
368			return String(*ptr)
369		}
370	case *uint:
371		if ptr := value.(*uint); ptr != nil {
372			return Uint(*ptr)
373		}
374	case *uint32:
375		if ptr := value.(*uint32); ptr != nil {
376			return Uint(*ptr)
377		}
378	case *uint64:
379		if ptr := value.(*uint64); ptr != nil {
380			return Uint(*ptr)
381		}
382	case []byte:
383		return Bytes(value.([]byte))
384	case []string:
385		return NewStringList(a, value.([]string))
386	case map[string]string:
387		return NewStringStringMap(a, value.(map[string]string))
388	case *dpb.Duration:
389		if ptr := value.(*dpb.Duration); ptr != nil {
390			return Duration{ptr}
391		}
392	case *structpb.ListValue:
393		if ptr := value.(*structpb.ListValue); ptr != nil {
394			return NewJSONList(a, ptr)
395		}
396	case structpb.NullValue, *structpb.NullValue:
397		return NullValue
398	case *structpb.Struct:
399		if ptr := value.(*structpb.Struct); ptr != nil {
400			return NewJSONStruct(a, ptr)
401		}
402	case *structpb.Value:
403		v := value.(*structpb.Value)
404		if v == nil {
405			return NullValue
406		}
407		switch v.Kind.(type) {
408		case *structpb.Value_BoolValue:
409			return a.NativeToValue(v.GetBoolValue())
410		case *structpb.Value_ListValue:
411			return a.NativeToValue(v.GetListValue())
412		case *structpb.Value_NullValue:
413			return NullValue
414		case *structpb.Value_NumberValue:
415			return a.NativeToValue(v.GetNumberValue())
416		case *structpb.Value_StringValue:
417			return a.NativeToValue(v.GetStringValue())
418		case *structpb.Value_StructValue:
419			return a.NativeToValue(v.GetStructValue())
420		}
421	case *tpb.Timestamp:
422		if ptr := value.(*tpb.Timestamp); ptr != nil {
423			return Timestamp{ptr}
424		}
425	case *anypb.Any:
426		val := value.(*anypb.Any)
427		if val == nil {
428			return NewErr("unsupported type conversion")
429		}
430		unpackedAny := ptypes.DynamicAny{}
431		if ptypes.UnmarshalAny(val, &unpackedAny) != nil {
432			return NewErr("unknown type: %s", val.GetTypeUrl())
433		}
434		return a.NativeToValue(unpackedAny.Message)
435	case *wrapperspb.BoolValue:
436		val := value.(*wrapperspb.BoolValue)
437		if val == nil {
438			return NewErr("unsupported type conversion")
439		}
440		return Bool(val.GetValue())
441	case *wrapperspb.BytesValue:
442		val := value.(*wrapperspb.BytesValue)
443		if val == nil {
444			return NewErr("unsupported type conversion")
445		}
446		return Bytes(val.GetValue())
447	case *wrapperspb.DoubleValue:
448		val := value.(*wrapperspb.DoubleValue)
449		if val == nil {
450			return NewErr("unsupported type conversion")
451		}
452		return Double(val.GetValue())
453	case *wrapperspb.FloatValue:
454		val := value.(*wrapperspb.FloatValue)
455		if val == nil {
456			return NewErr("unsupported type conversion")
457		}
458		return Double(val.GetValue())
459	case *wrapperspb.Int32Value:
460		val := value.(*wrapperspb.Int32Value)
461		if val == nil {
462			return NewErr("unsupported type conversion")
463		}
464		return Int(val.GetValue())
465	case *wrapperspb.Int64Value:
466		val := value.(*wrapperspb.Int64Value)
467		if val == nil {
468			return NewErr("unsupported type conversion")
469		}
470		return Int(val.GetValue())
471	case *wrapperspb.StringValue:
472		val := value.(*wrapperspb.StringValue)
473		if val == nil {
474			return NewErr("unsupported type conversion")
475		}
476		return String(val.GetValue())
477	case *wrapperspb.UInt32Value:
478		val := value.(*wrapperspb.UInt32Value)
479		if val == nil {
480			return NewErr("unsupported type conversion")
481		}
482		return Uint(val.GetValue())
483	case *wrapperspb.UInt64Value:
484		val := value.(*wrapperspb.UInt64Value)
485		if val == nil {
486			return NewErr("unsupported type conversion")
487		}
488		return Uint(val.GetValue())
489	default:
490		refValue := reflect.ValueOf(value)
491		if refValue.Kind() == reflect.Ptr {
492			if refValue.IsNil() {
493				return NewErr("unsupported type conversion: '%T'", value)
494			}
495			refValue = refValue.Elem()
496		}
497		refKind := refValue.Kind()
498		switch refKind {
499		case reflect.Array, reflect.Slice:
500			return NewDynamicList(a, value)
501		case reflect.Map:
502			return NewDynamicMap(a, value)
503		// type aliases of primitive types cannot be asserted as that type, but rather need
504		// to be downcast to int32 before being converted to a CEL representation.
505		case reflect.Int32:
506			intType := reflect.TypeOf(int32(0))
507			return Int(refValue.Convert(intType).Interface().(int32))
508		case reflect.Int64:
509			intType := reflect.TypeOf(int64(0))
510			return Int(refValue.Convert(intType).Interface().(int64))
511		case reflect.Uint32:
512			uintType := reflect.TypeOf(uint32(0))
513			return Uint(refValue.Convert(uintType).Interface().(uint32))
514		case reflect.Uint64:
515			uintType := reflect.TypeOf(uint64(0))
516			return Uint(refValue.Convert(uintType).Interface().(uint64))
517		case reflect.Float32:
518			doubleType := reflect.TypeOf(float32(0))
519			return Double(refValue.Convert(doubleType).Interface().(float32))
520		case reflect.Float64:
521			doubleType := reflect.TypeOf(float64(0))
522			return Double(refValue.Convert(doubleType).Interface().(float64))
523		}
524	}
525	return NewErr("unsupported type conversion: '%T'", value)
526}
527