1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package interpreter
16
17import (
18	"github.com/google/cel-go/common/operators"
19	"github.com/google/cel-go/common/types"
20	"github.com/google/cel-go/common/types/ref"
21
22	structpb "github.com/golang/protobuf/ptypes/struct"
23
24	exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
25)
26
27type astPruner struct {
28	expr  *exprpb.Expr
29	state EvalState
30}
31
32// TODO Consider having a separate walk of the AST that finds common
33// subexpressions. This can be called before or after constant folding to find
34// common subexpressions.
35
36// PruneAst prunes the given AST based on the given EvalState and generates a new AST.
37// Given AST is copied on write and a new AST is returned.
38// Couple of typical use cases this interface would be:
39//
40// A)
41// 1) Evaluate expr with some unknowns,
42// 2) If result is unknown:
43//   a) PruneAst
44//   b) Goto 1
45// Functional call results which are known would be effectively cached across
46// iterations.
47//
48// B)
49// 1) Compile the expression (maybe via a service and maybe after checking a
50//    compiled expression does not exists in local cache)
51// 2) Prepare the environment and the interpreter. Activation might be empty.
52// 3) Eval the expression. This might return unknown or error or a concrete
53//    value.
54// 4) PruneAst
55// 4) Maybe cache the expression
56// This is effectively constant folding the expression. How the environment is
57// prepared in step 2 is flexible. For example, If the caller caches the
58// compiled and constant folded expressions, but is not willing to constant
59// fold(and thus cache results of) some external calls, then they can prepare
60// the overloads accordingly.
61func PruneAst(expr *exprpb.Expr, state EvalState) *exprpb.Expr {
62	pruner := &astPruner{
63		expr:  expr,
64		state: state}
65	newExpr, _ := pruner.prune(expr)
66	return newExpr
67}
68
69func (p *astPruner) createLiteral(node *exprpb.Expr, val *exprpb.Constant) *exprpb.Expr {
70	newExpr := *node
71	newExpr.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: val}
72	return &newExpr
73}
74
75func (p *astPruner) maybePruneAndOr(node *exprpb.Expr) (*exprpb.Expr, bool) {
76	if !p.existsWithUnknownValue(node.GetId()) {
77		return nil, false
78	}
79
80	call := node.GetCallExpr()
81
82	// We know result is unknown, so we have at least one unknown arg
83	// and if one side is a known value, we know we can ignore it.
84	if p.existsWithKnownValue(call.Args[0].GetId()) {
85		return call.Args[1], true
86	}
87
88	if p.existsWithKnownValue(call.Args[1].GetId()) {
89		return call.Args[0], true
90	}
91	return nil, false
92}
93
94func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) {
95	if !p.existsWithUnknownValue(node.GetId()) {
96		return nil, false
97	}
98
99	call := node.GetCallExpr()
100	condVal, condValueExists := p.value(call.Args[0].GetId())
101	if !condValueExists || types.IsUnknownOrError(condVal) {
102		return nil, false
103	}
104
105	if condVal.Value().(bool) {
106		return call.Args[1], true
107	}
108	return call.Args[2], true
109}
110
111func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) {
112	call := node.GetCallExpr()
113	if call.Function == operators.LogicalOr || call.Function == operators.LogicalAnd {
114		return p.maybePruneAndOr(node)
115	}
116	if call.Function == operators.Conditional {
117		return p.maybePruneConditional(node)
118	}
119
120	return nil, false
121}
122
123func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) {
124	if node == nil {
125		return node, false
126	}
127	if val, valueExists := p.value(node.GetId()); valueExists && !types.IsUnknownOrError(val) {
128
129		// TODO if we have a list or struct, create a list/struct
130		// expression. This is useful especially if these expressions
131		// are result of a function call.
132
133		switch val.Type() {
134		case types.BoolType:
135			return p.createLiteral(node,
136				&exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: val.Value().(bool)}}), true
137		case types.IntType:
138			return p.createLiteral(node,
139				&exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: val.Value().(int64)}}), true
140		case types.UintType:
141			return p.createLiteral(node,
142				&exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: val.Value().(uint64)}}), true
143		case types.StringType:
144			return p.createLiteral(node,
145				&exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: val.Value().(string)}}), true
146		case types.DoubleType:
147			return p.createLiteral(node,
148				&exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: val.Value().(float64)}}), true
149		case types.BytesType:
150			return p.createLiteral(node,
151				&exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: val.Value().([]byte)}}), true
152		case types.NullType:
153			return p.createLiteral(node,
154				&exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: val.Value().(structpb.NullValue)}}), true
155		}
156	}
157
158	// We have either an unknown/error value, or something we dont want to
159	// transform, or expression was not evaluated. If possible, drill down
160	// more.
161
162	switch node.ExprKind.(type) {
163	case *exprpb.Expr_SelectExpr:
164		if operand, pruned := p.prune(node.GetSelectExpr().Operand); pruned {
165			newExpr := *node
166			newSelect := *newExpr.GetSelectExpr()
167			newSelect.Operand = operand
168			newExpr.GetExprKind().(*exprpb.Expr_SelectExpr).SelectExpr = &newSelect
169			return &newExpr, true
170		}
171	case *exprpb.Expr_CallExpr:
172		if newExpr, pruned := p.maybePruneFunction(node); pruned {
173			newExpr, _ = p.prune(newExpr)
174			return newExpr, true
175		}
176		newCall := *node.GetCallExpr()
177		var prunedCall bool
178		var prunedArg bool
179		for i, arg := range node.GetCallExpr().Args {
180			if newCall.Args[i], prunedArg = p.prune(arg); prunedArg {
181				prunedCall = true
182			}
183		}
184		if newTarget, prunedTarget := p.prune(node.GetCallExpr().Target); prunedTarget {
185			prunedCall = true
186			newCall.Target = newTarget
187		}
188		if prunedCall {
189			newExpr := *node
190			newExpr.GetExprKind().(*exprpb.Expr_CallExpr).CallExpr = &newCall
191			return &newExpr, true
192		}
193	case *exprpb.Expr_ListExpr:
194		newList := *node.GetListExpr()
195		var prunedList bool
196		var prunedElem bool
197		for i, elem := range node.GetListExpr().Elements {
198			if newList.Elements[i], prunedElem = p.prune(elem); prunedElem {
199				prunedList = true
200			}
201		}
202		if prunedList {
203			newExpr := *node
204			newExpr.GetExprKind().(*exprpb.Expr_ListExpr).ListExpr = &newList
205			return &newExpr, true
206		}
207	case *exprpb.Expr_StructExpr:
208		newStruct := *node.GetStructExpr()
209		var prunedStruct bool
210		var prunedEntry bool
211		for i, entry := range node.GetStructExpr().Entries {
212			newEntry := *entry
213			if newKey, pruned := p.prune(entry.GetMapKey()); pruned {
214				prunedEntry = true
215				newEntry.GetKeyKind().(*exprpb.Expr_CreateStruct_Entry_MapKey).MapKey = newKey
216			}
217			if newValue, pruned := p.prune(entry.Value); pruned {
218				prunedEntry = true
219				newEntry.Value = newValue
220			}
221			if prunedEntry {
222				prunedStruct = true
223				newStruct.Entries[i] = &newEntry
224			}
225		}
226		if prunedStruct {
227			newExpr := *node
228			newExpr.GetExprKind().(*exprpb.Expr_StructExpr).StructExpr = &newStruct
229			return &newExpr, true
230		}
231	case *exprpb.Expr_ComprehensionExpr:
232		if newIterRange, pruned := p.prune(node.GetComprehensionExpr().IterRange); pruned {
233			newExpr := *node
234			newCompre := *newExpr.GetComprehensionExpr()
235			newCompre.IterRange = newIterRange
236			newExpr.GetExprKind().(*exprpb.Expr_ComprehensionExpr).ComprehensionExpr = &newCompre
237			return &newExpr, true
238		}
239	}
240	return node, false
241}
242
243func (p *astPruner) value(id int64) (ref.Val, bool) {
244	val, found := p.state.Value(id)
245	return val, (found && val != nil)
246}
247
248func (p *astPruner) existsWithUnknownValue(id int64) bool {
249	val, valueExists := p.value(id)
250	return valueExists && types.IsUnknown(val)
251}
252
253func (p *astPruner) existsWithKnownValue(id int64) bool {
254	val, valueExists := p.value(id)
255	return valueExists && !types.IsUnknown(val)
256}
257