1// Copyright 2018 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package types 16 17import ( 18 "reflect" 19 20 "github.com/golang/protobuf/proto" 21 "github.com/golang/protobuf/ptypes" 22 23 "github.com/google/cel-go/common/types/pb" 24 "github.com/google/cel-go/common/types/ref" 25 26 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 27 anypb "github.com/golang/protobuf/ptypes/any" 28 dpb "github.com/golang/protobuf/ptypes/duration" 29 structpb "github.com/golang/protobuf/ptypes/struct" 30 tpb "github.com/golang/protobuf/ptypes/timestamp" 31 wrapperspb "github.com/golang/protobuf/ptypes/wrappers" 32 33 exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" 34) 35 36type protoTypeRegistry struct { 37 revTypeMap map[string]ref.Type 38 pbdb *pb.Db 39} 40 41// NewRegistry accepts a list of proto message instances and returns a type 42// provider which can create new instances of the provided message or any 43// message that proto depends upon in its FileDescriptor. 44func NewRegistry(types ...proto.Message) ref.TypeRegistry { 45 p := &protoTypeRegistry{ 46 revTypeMap: make(map[string]ref.Type), 47 pbdb: pb.NewDb(), 48 } 49 p.RegisterType( 50 BoolType, 51 BytesType, 52 DoubleType, 53 DurationType, 54 IntType, 55 ListType, 56 MapType, 57 NullType, 58 StringType, 59 TimestampType, 60 TypeType, 61 UintType) 62 63 for _, msgType := range types { 64 err := p.RegisterMessage(msgType) 65 if err != nil { 66 panic(err) 67 } 68 } 69 return p 70} 71 72// NewEmptyRegistry returns a registry which is completely unconfigured. 73func NewEmptyRegistry() ref.TypeRegistry { 74 return &protoTypeRegistry{ 75 revTypeMap: make(map[string]ref.Type), 76 pbdb: pb.NewDb(), 77 } 78} 79 80func (p *protoTypeRegistry) EnumValue(enumName string) ref.Val { 81 enumVal, err := p.pbdb.DescribeEnum(enumName) 82 if err != nil { 83 return NewErr("unknown enum name '%s'", enumName) 84 } 85 return Int(enumVal.Value()) 86} 87 88func (p *protoTypeRegistry) FindFieldType(messageType string, 89 fieldName string) (*ref.FieldType, bool) { 90 msgType, err := p.pbdb.DescribeType(messageType) 91 if err != nil { 92 return nil, false 93 } 94 field, found := msgType.FieldByName(fieldName) 95 if !found { 96 return nil, false 97 } 98 return &ref.FieldType{ 99 Type: field.CheckedType(), 100 SupportsPresence: field.SupportsPresence(), 101 IsSet: field.IsSet, 102 GetFrom: field.GetFrom}, 103 true 104} 105 106func (p *protoTypeRegistry) FindIdent(identName string) (ref.Val, bool) { 107 if t, found := p.revTypeMap[identName]; found { 108 return t.(ref.Val), true 109 } 110 if enumVal, err := p.pbdb.DescribeEnum(identName); err == nil { 111 return Int(enumVal.Value()), true 112 } 113 return nil, false 114} 115 116func (p *protoTypeRegistry) FindType(typeName string) (*exprpb.Type, bool) { 117 if _, err := p.pbdb.DescribeType(typeName); err != nil { 118 return nil, false 119 } 120 if typeName != "" && typeName[0] == '.' { 121 typeName = typeName[1:] 122 } 123 return &exprpb.Type{ 124 TypeKind: &exprpb.Type_Type{ 125 Type: &exprpb.Type{ 126 TypeKind: &exprpb.Type_MessageType{ 127 MessageType: typeName}}}}, true 128} 129 130func (p *protoTypeRegistry) NewValue(typeName string, fields map[string]ref.Val) ref.Val { 131 td, err := p.pbdb.DescribeType(typeName) 132 if err != nil { 133 return NewErr("unknown type '%s'", typeName) 134 } 135 refType := td.ReflectType() 136 // create the new type instance. 137 value := reflect.New(refType.Elem()) 138 pbValue := value.Elem() 139 140 // for all of the field names referenced, set the provided value. 141 for name, value := range fields { 142 fd, found := td.FieldByName(name) 143 if !found { 144 return NewErr("no such field '%s'", name) 145 } 146 refField := pbValue.Field(fd.Index()) 147 if !refField.IsValid() { 148 return NewErr("no such field '%s'", name) 149 } 150 151 dstType := refField.Type() 152 // Oneof fields are defined with wrapper structs that have a single proto.Message 153 // field value. The oneof wrapper is not a proto.Message instance. 154 if fd.IsOneof() { 155 oneofVal := reflect.New(fd.OneofType().Elem()) 156 refField.Set(oneofVal) 157 refField = oneofVal.Elem().Field(0) 158 dstType = refField.Type() 159 } 160 fieldValue, err := value.ConvertToNative(dstType) 161 if err != nil { 162 return &Err{err} 163 } 164 refField.Set(reflect.ValueOf(fieldValue)) 165 } 166 return p.NativeToValue(value.Interface()) 167} 168 169func (p *protoTypeRegistry) RegisterDescriptor(fileDesc *descpb.FileDescriptorProto) error { 170 fd, err := p.pbdb.RegisterDescriptor(fileDesc) 171 if err != nil { 172 return err 173 } 174 return p.registerAllTypes(fd) 175} 176 177func (p *protoTypeRegistry) RegisterMessage(message proto.Message) error { 178 fd, err := p.pbdb.RegisterMessage(message) 179 if err != nil { 180 return err 181 } 182 return p.registerAllTypes(fd) 183} 184 185func (p *protoTypeRegistry) RegisterType(types ...ref.Type) error { 186 for _, t := range types { 187 p.revTypeMap[t.TypeName()] = t 188 } 189 // TODO: generate an error when the type name is registered more than once. 190 return nil 191} 192 193func (p *protoTypeRegistry) registerAllTypes(fd *pb.FileDescription) error { 194 for _, typeName := range fd.GetTypeNames() { 195 err := p.RegisterType(NewObjectTypeValue(typeName)) 196 if err != nil { 197 return err 198 } 199 } 200 return nil 201} 202 203// NativeToValue converts various "native" types to ref.Val with this specific implementation 204// providing support for custom proto-based types. 205// 206// This method should be the inverse of ref.Val.ConvertToNative. 207func (p *protoTypeRegistry) NativeToValue(value interface{}) ref.Val { 208 switch v := value.(type) { 209 case ref.Val: 210 return v 211 // Adapt common types and aggregate specializations using the DefaultTypeAdapter. 212 case bool, *bool, 213 float32, *float32, float64, *float64, 214 int, *int, int32, *int32, int64, *int64, 215 string, *string, 216 uint, *uint, uint32, *uint32, uint64, *uint64, 217 []byte, 218 []string, 219 map[string]string: 220 return DefaultTypeAdapter.NativeToValue(value) 221 // Adapt well-known proto-types using the DefaultTypeAdapter. 222 case *dpb.Duration, 223 *tpb.Timestamp, 224 *structpb.ListValue, 225 structpb.NullValue, 226 *structpb.Struct, 227 *structpb.Value, 228 *wrapperspb.BoolValue, 229 *wrapperspb.BytesValue, 230 *wrapperspb.DoubleValue, 231 *wrapperspb.FloatValue, 232 *wrapperspb.Int32Value, 233 *wrapperspb.Int64Value, 234 *wrapperspb.StringValue, 235 *wrapperspb.UInt32Value, 236 *wrapperspb.UInt64Value: 237 return DefaultTypeAdapter.NativeToValue(value) 238 // Override the Any type by ensuring that custom proto-types are considered on recursive calls. 239 case *anypb.Any: 240 if v == nil { 241 return NewErr("unsupported type conversion: '%T'", value) 242 } 243 unpackedAny := ptypes.DynamicAny{} 244 if ptypes.UnmarshalAny(v, &unpackedAny) != nil { 245 return NewErr("unknown type: '%s'", v.GetTypeUrl()) 246 } 247 return p.NativeToValue(unpackedAny.Message) 248 // Convert custom proto types to CEL values based on type's presence within the pb.Db. 249 case proto.Message: 250 typeName := proto.MessageName(v) 251 td, err := p.pbdb.DescribeType(typeName) 252 if err != nil { 253 return NewErr("unknown type: '%s'", typeName) 254 } 255 typeVal, found := p.FindIdent(typeName) 256 if !found { 257 return NewErr("unknown type: '%s'", typeName) 258 } 259 return NewObject(p, td, typeVal.(*TypeValue), v) 260 // Override default handling for list and maps to ensure that blends of Go + proto types 261 // are appropriately adapted on recursive calls or subsequent inspection of the aggregate 262 // value. 263 default: 264 refValue := reflect.ValueOf(value) 265 if refValue.Kind() == reflect.Ptr { 266 if refValue.IsNil() { 267 return NewErr("unsupported type conversion: '%T'", value) 268 } 269 refValue = refValue.Elem() 270 } 271 refKind := refValue.Kind() 272 switch refKind { 273 case reflect.Array, reflect.Slice: 274 return NewDynamicList(p, value) 275 case reflect.Map: 276 return NewDynamicMap(p, value) 277 } 278 } 279 // By default return the default type adapter's conversion to CEL. 280 return DefaultTypeAdapter.NativeToValue(value) 281} 282 283// defaultTypeAdapter converts go native types to CEL values. 284type defaultTypeAdapter struct{} 285 286var ( 287 // DefaultTypeAdapter adapts canonical CEL types from their equivalent Go values. 288 DefaultTypeAdapter = &defaultTypeAdapter{} 289) 290 291// NativeToValue implements the ref.TypeAdapter interface. 292func (a *defaultTypeAdapter) NativeToValue(value interface{}) ref.Val { 293 switch value.(type) { 294 case nil: 295 return NullValue 296 case *Bool: 297 if ptr := value.(*Bool); ptr != nil { 298 return ptr 299 } 300 case *Bytes: 301 if ptr := value.(*Bytes); ptr != nil { 302 return ptr 303 } 304 case *Double: 305 if ptr := value.(*Double); ptr != nil { 306 return ptr 307 } 308 case *Int: 309 if ptr := value.(*Int); ptr != nil { 310 return ptr 311 } 312 case *String: 313 if ptr := value.(*String); ptr != nil { 314 return ptr 315 } 316 case *Uint: 317 if ptr := value.(*Uint); ptr != nil { 318 return ptr 319 } 320 case ref.Val: 321 return value.(ref.Val) 322 case bool: 323 return Bool(value.(bool)) 324 case int: 325 return Int(value.(int)) 326 case int32: 327 return Int(value.(int32)) 328 case int64: 329 return Int(value.(int64)) 330 case uint: 331 return Uint(value.(uint)) 332 case uint32: 333 return Uint(value.(uint32)) 334 case uint64: 335 return Uint(value.(uint64)) 336 case float32: 337 return Double(value.(float32)) 338 case float64: 339 return Double(value.(float64)) 340 case string: 341 return String(value.(string)) 342 case *bool: 343 if ptr := value.(*bool); ptr != nil { 344 return Bool(*ptr) 345 } 346 case *float32: 347 if ptr := value.(*float32); ptr != nil { 348 return Double(*ptr) 349 } 350 case *float64: 351 if ptr := value.(*float64); ptr != nil { 352 return Double(*ptr) 353 } 354 case *int: 355 if ptr := value.(*int); ptr != nil { 356 return Int(*ptr) 357 } 358 case *int32: 359 if ptr := value.(*int32); ptr != nil { 360 return Int(*ptr) 361 } 362 case *int64: 363 if ptr := value.(*int64); ptr != nil { 364 return Int(*ptr) 365 } 366 case *string: 367 if ptr := value.(*string); ptr != nil { 368 return String(*ptr) 369 } 370 case *uint: 371 if ptr := value.(*uint); ptr != nil { 372 return Uint(*ptr) 373 } 374 case *uint32: 375 if ptr := value.(*uint32); ptr != nil { 376 return Uint(*ptr) 377 } 378 case *uint64: 379 if ptr := value.(*uint64); ptr != nil { 380 return Uint(*ptr) 381 } 382 case []byte: 383 return Bytes(value.([]byte)) 384 case []string: 385 return NewStringList(a, value.([]string)) 386 case map[string]string: 387 return NewStringStringMap(a, value.(map[string]string)) 388 case *dpb.Duration: 389 if ptr := value.(*dpb.Duration); ptr != nil { 390 return Duration{ptr} 391 } 392 case *structpb.ListValue: 393 if ptr := value.(*structpb.ListValue); ptr != nil { 394 return NewJSONList(a, ptr) 395 } 396 case structpb.NullValue, *structpb.NullValue: 397 return NullValue 398 case *structpb.Struct: 399 if ptr := value.(*structpb.Struct); ptr != nil { 400 return NewJSONStruct(a, ptr) 401 } 402 case *structpb.Value: 403 v := value.(*structpb.Value) 404 if v == nil { 405 return NullValue 406 } 407 switch v.Kind.(type) { 408 case *structpb.Value_BoolValue: 409 return a.NativeToValue(v.GetBoolValue()) 410 case *structpb.Value_ListValue: 411 return a.NativeToValue(v.GetListValue()) 412 case *structpb.Value_NullValue: 413 return NullValue 414 case *structpb.Value_NumberValue: 415 return a.NativeToValue(v.GetNumberValue()) 416 case *structpb.Value_StringValue: 417 return a.NativeToValue(v.GetStringValue()) 418 case *structpb.Value_StructValue: 419 return a.NativeToValue(v.GetStructValue()) 420 } 421 case *tpb.Timestamp: 422 if ptr := value.(*tpb.Timestamp); ptr != nil { 423 return Timestamp{ptr} 424 } 425 case *anypb.Any: 426 val := value.(*anypb.Any) 427 if val == nil { 428 return NewErr("unsupported type conversion") 429 } 430 unpackedAny := ptypes.DynamicAny{} 431 if ptypes.UnmarshalAny(val, &unpackedAny) != nil { 432 return NewErr("unknown type: %s", val.GetTypeUrl()) 433 } 434 return a.NativeToValue(unpackedAny.Message) 435 case *wrapperspb.BoolValue: 436 val := value.(*wrapperspb.BoolValue) 437 if val == nil { 438 return NewErr("unsupported type conversion") 439 } 440 return Bool(val.GetValue()) 441 case *wrapperspb.BytesValue: 442 val := value.(*wrapperspb.BytesValue) 443 if val == nil { 444 return NewErr("unsupported type conversion") 445 } 446 return Bytes(val.GetValue()) 447 case *wrapperspb.DoubleValue: 448 val := value.(*wrapperspb.DoubleValue) 449 if val == nil { 450 return NewErr("unsupported type conversion") 451 } 452 return Double(val.GetValue()) 453 case *wrapperspb.FloatValue: 454 val := value.(*wrapperspb.FloatValue) 455 if val == nil { 456 return NewErr("unsupported type conversion") 457 } 458 return Double(val.GetValue()) 459 case *wrapperspb.Int32Value: 460 val := value.(*wrapperspb.Int32Value) 461 if val == nil { 462 return NewErr("unsupported type conversion") 463 } 464 return Int(val.GetValue()) 465 case *wrapperspb.Int64Value: 466 val := value.(*wrapperspb.Int64Value) 467 if val == nil { 468 return NewErr("unsupported type conversion") 469 } 470 return Int(val.GetValue()) 471 case *wrapperspb.StringValue: 472 val := value.(*wrapperspb.StringValue) 473 if val == nil { 474 return NewErr("unsupported type conversion") 475 } 476 return String(val.GetValue()) 477 case *wrapperspb.UInt32Value: 478 val := value.(*wrapperspb.UInt32Value) 479 if val == nil { 480 return NewErr("unsupported type conversion") 481 } 482 return Uint(val.GetValue()) 483 case *wrapperspb.UInt64Value: 484 val := value.(*wrapperspb.UInt64Value) 485 if val == nil { 486 return NewErr("unsupported type conversion") 487 } 488 return Uint(val.GetValue()) 489 default: 490 refValue := reflect.ValueOf(value) 491 if refValue.Kind() == reflect.Ptr { 492 if refValue.IsNil() { 493 return NewErr("unsupported type conversion: '%T'", value) 494 } 495 refValue = refValue.Elem() 496 } 497 refKind := refValue.Kind() 498 switch refKind { 499 case reflect.Array, reflect.Slice: 500 return NewDynamicList(a, value) 501 case reflect.Map: 502 return NewDynamicMap(a, value) 503 // type aliases of primitive types cannot be asserted as that type, but rather need 504 // to be downcast to int32 before being converted to a CEL representation. 505 case reflect.Int32: 506 intType := reflect.TypeOf(int32(0)) 507 return Int(refValue.Convert(intType).Interface().(int32)) 508 case reflect.Int64: 509 intType := reflect.TypeOf(int64(0)) 510 return Int(refValue.Convert(intType).Interface().(int64)) 511 case reflect.Uint32: 512 uintType := reflect.TypeOf(uint32(0)) 513 return Uint(refValue.Convert(uintType).Interface().(uint32)) 514 case reflect.Uint64: 515 uintType := reflect.TypeOf(uint64(0)) 516 return Uint(refValue.Convert(uintType).Interface().(uint64)) 517 case reflect.Float32: 518 doubleType := reflect.TypeOf(float32(0)) 519 return Double(refValue.Convert(doubleType).Interface().(float32)) 520 case reflect.Float64: 521 doubleType := reflect.TypeOf(float64(0)) 522 return Double(refValue.Convert(doubleType).Interface().(float64)) 523 } 524 } 525 return NewErr("unsupported type conversion: '%T'", value) 526} 527