1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package types
16
17import (
18	"fmt"
19	"reflect"
20
21	"github.com/golang/protobuf/jsonpb"
22	"github.com/golang/protobuf/proto"
23	"github.com/golang/protobuf/ptypes"
24
25	"github.com/google/cel-go/common/types/pb"
26	"github.com/google/cel-go/common/types/ref"
27
28	structpb "github.com/golang/protobuf/ptypes/struct"
29)
30
31type protoObj struct {
32	ref.TypeAdapter
33	value     proto.Message
34	refValue  reflect.Value
35	typeDesc  *pb.TypeDescription
36	typeValue *TypeValue
37	isAny     bool
38}
39
40// NewObject returns an object based on a proto.Message value which handles
41// conversion between protobuf type values and expression type values.
42// Objects support indexing and iteration.
43//
44// Note: the type value is pulled from the list of registered types within the
45// type provider. If the proto type is not registered within the type provider,
46// then this will result in an error within the type adapter / provider.
47func NewObject(adapter ref.TypeAdapter,
48	typeDesc *pb.TypeDescription,
49	typeValue *TypeValue,
50	value proto.Message) ref.Val {
51	return &protoObj{
52		TypeAdapter: adapter,
53		value:       value,
54		refValue:    reflect.ValueOf(value),
55		typeDesc:    typeDesc,
56		typeValue:   typeValue}
57}
58
59func (o *protoObj) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
60	pb := o.Value().(proto.Message)
61	switch typeDesc {
62	case anyValueType:
63		if o.isAny {
64			return pb, nil
65		}
66		return ptypes.MarshalAny(pb)
67	case jsonValueType:
68		// Marshal the proto to JSON first, and then rehydrate as protobuf.Value as there is no
69		// support for direct conversion from proto.Message to protobuf.Value.
70		jsonTxt, err := (&jsonpb.Marshaler{}).MarshalToString(pb)
71		if err != nil {
72			return nil, err
73		}
74		json := &structpb.Value{}
75		err = jsonpb.UnmarshalString(jsonTxt, json)
76		if err != nil {
77			return nil, err
78		}
79		return json, nil
80	}
81	if o.refValue.Type().AssignableTo(typeDesc) {
82		return pb, nil
83	}
84	return nil, fmt.Errorf("type conversion error from '%v' to '%v'",
85		o.refValue.Type(), typeDesc)
86}
87
88func (o *protoObj) ConvertToType(typeVal ref.Type) ref.Val {
89	switch typeVal {
90	default:
91		if o.Type().TypeName() == typeVal.TypeName() {
92			return o
93		}
94	case TypeType:
95		return o.typeValue
96	}
97	return NewErr("type conversion error from '%s' to '%s'",
98		o.typeDesc.Name(), typeVal)
99}
100
101func (o *protoObj) Equal(other ref.Val) ref.Val {
102	if o.typeDesc.Name() != other.Type().TypeName() {
103		return ValOrErr(other, "no such overload")
104	}
105	return Bool(proto.Equal(o.value, other.Value().(proto.Message)))
106}
107
108// IsSet tests whether a field which is defined is set to a non-default value.
109func (o *protoObj) IsSet(field ref.Val) ref.Val {
110	protoFieldName, ok := field.(String)
111	if !ok {
112		return ValOrErr(field, "no such overload")
113	}
114	protoFieldStr := string(protoFieldName)
115	fd, found := o.typeDesc.FieldByName(protoFieldStr)
116	if !found {
117		return NewErr("no such field '%s'", field)
118	}
119	if !fd.SupportsPresence() {
120		return NewErr("field does not support presence testing.")
121	}
122	if fd.IsSet(o.refValue) {
123		return True
124	}
125	return False
126}
127
128func (o *protoObj) Get(index ref.Val) ref.Val {
129	protoFieldName, ok := index.(String)
130	if !ok {
131		return ValOrErr(index, "no such overload")
132	}
133	protoFieldStr := string(protoFieldName)
134	fd, found := o.typeDesc.FieldByName(protoFieldStr)
135	if !found {
136		return NewErr("no such field '%s'", index)
137	}
138	fv, err := fd.GetFrom(o.refValue)
139	if err != nil {
140		return NewErr(err.Error())
141	}
142	return o.NativeToValue(fv)
143}
144
145func (o *protoObj) Type() ref.Type {
146	return o.typeValue
147}
148
149func (o *protoObj) Value() interface{} {
150	return o.value
151}
152