1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package types
16
17import (
18	"errors"
19	"fmt"
20	"reflect"
21	"strings"
22
23	"github.com/golang/protobuf/proto"
24	"github.com/golang/protobuf/ptypes"
25
26	"github.com/google/cel-go/common/types/ref"
27	"github.com/google/cel-go/common/types/traits"
28
29	structpb "github.com/golang/protobuf/ptypes/struct"
30)
31
32// baseMap is a reflection based map implementation designed to handle a variety of map-like types.
33type baseMap struct {
34	ref.TypeAdapter
35	value    interface{}
36	refValue reflect.Value
37}
38
39// stringMap is a specialization to improve the performance of simple key, value pair lookups by
40// string as this is the most common usage of maps.
41type stringMap struct {
42	*baseMap
43	mapStrStr map[string]string
44}
45
46// NewDynamicMap returns a traits.Mapper value with dynamic key, value pairs.
47func NewDynamicMap(adapter ref.TypeAdapter, value interface{}) traits.Mapper {
48	return &baseMap{
49		TypeAdapter: adapter,
50		value:       value,
51		refValue:    reflect.ValueOf(value)}
52}
53
54// NewStringStringMap returns a specialized traits.Mapper with string keys and values.
55func NewStringStringMap(adapter ref.TypeAdapter, value map[string]string) traits.Mapper {
56	return &stringMap{
57		baseMap:   &baseMap{TypeAdapter: adapter, value: value},
58		mapStrStr: value,
59	}
60}
61
62var (
63	// MapType singleton.
64	MapType = NewTypeValue("map",
65		traits.ContainerType,
66		traits.IndexerType,
67		traits.IterableType,
68		traits.SizerType)
69)
70
71// Contains implements the traits.Container interface method.
72func (m *baseMap) Contains(index ref.Val) ref.Val {
73	val, found := m.Find(index)
74	// When the index is not found and val is non-nil, this is an error or unknown value.
75	if !found && val != nil && IsUnknownOrError(val) {
76		return val
77	}
78	return Bool(found)
79}
80
81// Contains implements the traits.Container interface method.
82func (m *stringMap) Contains(index ref.Val) ref.Val {
83	val, found := m.Find(index)
84	// When the index is not found and val is non-nil, this is an error or unknown value.
85	if !found && val != nil && IsUnknownOrError(val) {
86		return val
87	}
88	return Bool(found)
89}
90
91// ConvertToNative implements the ref.Val interface method.
92func (m *baseMap) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
93	switch typeDesc {
94	case anyValueType:
95		json, err := m.ConvertToNative(jsonStructType)
96		if err != nil {
97			return nil, err
98		}
99		return ptypes.MarshalAny(json.(proto.Message))
100	case jsonValueType, jsonStructType:
101		jsonEntries, err :=
102			m.ConvertToNative(reflect.TypeOf(map[string]*structpb.Value{}))
103		if err != nil {
104			return nil, err
105		}
106		jsonMap := &structpb.Struct{
107			Fields: jsonEntries.(map[string]*structpb.Value)}
108		if typeDesc == jsonStructType {
109			return jsonMap, nil
110		}
111		return &structpb.Value{
112			Kind: &structpb.Value_StructValue{
113				StructValue: jsonMap}}, nil
114	}
115
116	// Unwrap pointers, but track their use.
117	isPtr := false
118	if typeDesc.Kind() == reflect.Ptr {
119		tk := typeDesc
120		typeDesc = typeDesc.Elem()
121		if typeDesc.Kind() == reflect.Ptr {
122			return nil, fmt.Errorf("unsupported type conversion to '%v'", tk)
123		}
124		isPtr = true
125	}
126
127	// If the map is already assignable to the desired type return it, e.g. interfaces and
128	// maps with the same key value types.
129	if reflect.TypeOf(m).AssignableTo(typeDesc) {
130		return m, nil
131	}
132
133	// Establish some basic facts about the map key and value types.
134	thisType := m.refValue.Type()
135	thisKey := thisType.Key()
136	thisKeyKind := thisKey.Kind()
137	thisElem := thisType.Elem()
138	thisElemKind := thisElem.Kind()
139
140	switch typeDesc.Kind() {
141	// Map conversion.
142	case reflect.Map:
143		otherKey := typeDesc.Key()
144		otherKeyKind := otherKey.Kind()
145		otherElem := typeDesc.Elem()
146		otherElemKind := otherElem.Kind()
147		if otherKeyKind == thisKeyKind && otherElemKind == thisElemKind {
148			return m.value, nil
149		}
150		elemCount := m.Size().(Int)
151		nativeMap := reflect.MakeMapWithSize(typeDesc, int(elemCount))
152		it := m.Iterator()
153		for it.HasNext() == True {
154			key := it.Next()
155			refKeyValue, err := key.ConvertToNative(otherKey)
156			if err != nil {
157				return nil, err
158			}
159			refElemValue, err := m.Get(key).ConvertToNative(otherElem)
160			if err != nil {
161				return nil, err
162			}
163			nativeMap.SetMapIndex(
164				reflect.ValueOf(refKeyValue),
165				reflect.ValueOf(refElemValue))
166		}
167		return nativeMap.Interface(), nil
168	case reflect.Struct:
169		if thisKeyKind != reflect.String && thisKeyKind != reflect.Interface {
170			break
171		}
172		nativeStructPtr := reflect.New(typeDesc)
173		nativeStruct := nativeStructPtr.Elem()
174		it := m.Iterator()
175		for it.HasNext() == True {
176			key := it.Next()
177			// Ensure the field name being referenced is exported.
178			// Only exported (public) field names can be set by reflection, where the name
179			// must be at least one character in length and start with an upper-case letter.
180			fieldName := string(key.ConvertToType(StringType).(String))
181			switch len(fieldName) {
182			case 0:
183				return nil, errors.New("type conversion error, unsupported empty field")
184			case 1:
185				fieldName = strings.ToUpper(fieldName)
186			default:
187				fieldName = strings.ToUpper(fieldName[0:1]) + fieldName[1:]
188			}
189			fieldRef := nativeStruct.FieldByName(fieldName)
190			if !fieldRef.IsValid() {
191				return nil, fmt.Errorf(
192					"type conversion error, no such field '%s' in type '%v'",
193					fieldName, typeDesc)
194			}
195			fieldValue, err := m.Get(key).ConvertToNative(fieldRef.Type())
196			if err != nil {
197				return nil, err
198			}
199			fieldRef.Set(reflect.ValueOf(fieldValue))
200		}
201		if isPtr {
202			return nativeStructPtr.Interface(), nil
203		}
204		return nativeStruct.Interface(), nil
205	}
206	return nil, fmt.Errorf("type conversion error from map to '%v'", typeDesc)
207}
208
209// ConvertToNative implements the ref.Val interface method.
210func (m *stringMap) ConvertToNative(refType reflect.Type) (interface{}, error) {
211	if !m.baseMap.refValue.IsValid() {
212		m.baseMap.refValue = reflect.ValueOf(m.value)
213	}
214	return m.baseMap.ConvertToNative(refType)
215}
216
217// ConvertToType implements the ref.Val interface method.
218func (m *baseMap) ConvertToType(typeVal ref.Type) ref.Val {
219	switch typeVal {
220	case MapType:
221		return m
222	case TypeType:
223		return MapType
224	}
225	return NewErr("type conversion error from '%s' to '%s'", MapType, typeVal)
226}
227
228// ConvertToType implements the ref.Val interface method.
229func (m *stringMap) ConvertToType(typeVal ref.Type) ref.Val {
230	switch typeVal {
231	case MapType:
232		return m
233	default:
234		return m.baseMap.ConvertToType(typeVal)
235	}
236}
237
238// Equal implements the ref.Val interface method.
239func (m *baseMap) Equal(other ref.Val) ref.Val {
240	if MapType != other.Type() {
241		return ValOrErr(other, "no such overload")
242	}
243	otherMap := other.(traits.Mapper)
244	if m.Size() != otherMap.Size() {
245		return False
246	}
247	it := m.Iterator()
248	for it.HasNext() == True {
249		key := it.Next()
250		thisVal, _ := m.Find(key)
251		otherVal, found := otherMap.Find(key)
252		if !found {
253			if otherVal == nil {
254				return False
255			}
256			return ValOrErr(otherVal, "no such overload")
257		}
258		valEq := thisVal.Equal(otherVal)
259		if valEq != True {
260			return valEq
261		}
262	}
263	return True
264}
265
266// Equal implements the ref.Val interface method.
267func (m *stringMap) Equal(other ref.Val) ref.Val {
268	if !m.baseMap.refValue.IsValid() {
269		m.baseMap.refValue = reflect.ValueOf(m.value)
270	}
271	return m.baseMap.Equal(other)
272}
273
274// Find implements the traits.Mapper interface method.
275func (m *baseMap) Find(key ref.Val) (ref.Val, bool) {
276	// TODO: There are multiple reasons why a Get could fail. Typically, this is because the key
277	// does not exist in the map; however, it's possible that the value cannot be converted to
278	// the desired type. Refine this strategy to disambiguate these cases.
279	if IsUnknownOrError(key) {
280		return key, false
281	}
282	thisKeyType := m.refValue.Type().Key()
283	nativeKey, err := key.ConvertToNative(thisKeyType)
284	if err != nil {
285		return &Err{err}, false
286	}
287	nativeKeyVal := reflect.ValueOf(nativeKey)
288	value := m.refValue.MapIndex(nativeKeyVal)
289	if !value.IsValid() {
290		return nil, false
291	}
292	return m.NativeToValue(value.Interface()), true
293}
294
295// Find implements the traits.Mapper interface method.
296func (m *stringMap) Find(key ref.Val) (ref.Val, bool) {
297	strKey, ok := key.(String)
298	if !ok {
299		return ValOrErr(key, "no such overload"), false
300	}
301	val, found := m.mapStrStr[string(strKey)]
302	if !found {
303		return nil, false
304	}
305	return String(val), true
306}
307
308// Get implements the traits.Indexer interface method.
309func (m *baseMap) Get(key ref.Val) ref.Val {
310	v, found := m.Find(key)
311	if !found {
312		return ValOrErr(v, "no such key: %v", key)
313	}
314	return v
315}
316
317// Get implements the traits.Indexer interface method.
318func (m *stringMap) Get(key ref.Val) ref.Val {
319	v, found := m.Find(key)
320	if !found {
321		return ValOrErr(v, "no such key: %v", key)
322	}
323	return v
324}
325
326// Iterator implements the traits.Iterable interface method.
327func (m *baseMap) Iterator() traits.Iterator {
328	mapKeys := m.refValue.MapKeys()
329	return &mapIterator{
330		baseIterator: &baseIterator{},
331		TypeAdapter:  m.TypeAdapter,
332		mapValue:     m,
333		mapKeys:      mapKeys,
334		cursor:       0,
335		len:          int(m.Size().(Int))}
336}
337
338// Iterator implements the traits.Iterable interface method.
339func (m *stringMap) Iterator() traits.Iterator {
340	if !m.baseMap.refValue.IsValid() {
341		m.baseMap.refValue = reflect.ValueOf(m.value)
342	}
343	return m.baseMap.Iterator()
344}
345
346// Size implements the traits.Sizer interface method.
347func (m *baseMap) Size() ref.Val {
348	return Int(m.refValue.Len())
349}
350
351// Size implements the traits.Sizer interface method.
352func (m *stringMap) Size() ref.Val {
353	return Int(len(m.mapStrStr))
354}
355
356// Type implements the ref.Val interface method.
357func (m *baseMap) Type() ref.Type {
358	return MapType
359}
360
361// Value implements the ref.Val interface method.
362func (m *baseMap) Value() interface{} {
363	return m.value
364}
365
366type mapIterator struct {
367	*baseIterator
368	ref.TypeAdapter
369	mapValue traits.Mapper
370	mapKeys  []reflect.Value
371	cursor   int
372	len      int
373}
374
375// HasNext implements the traits.Iterator interface method.
376func (it *mapIterator) HasNext() ref.Val {
377	return Bool(it.cursor < it.len)
378}
379
380// Next implements the traits.Iterator interface method.
381func (it *mapIterator) Next() ref.Val {
382	if it.HasNext() == True {
383		index := it.cursor
384		it.cursor++
385		refKey := it.mapKeys[index]
386		return it.NativeToValue(refKey.Interface())
387	}
388	return nil
389}
390