1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package parser
16
17import (
18	"fmt"
19	"strconv"
20	"strings"
21
22	"github.com/google/cel-go/common/operators"
23
24	exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
25)
26
27// Unparse takes an input expression and source position information and generates a human-readable
28// expression.
29//
30// Note, unparsing an AST will often generate the same expression as was originally parsed, but some
31// formatting may be lost in translation, notably:
32//
33// - All quoted literals are doubled quoted.
34// - Byte literals are represented as octal escapes (same as Google SQL).
35// - Floating point values are converted to the small number of digits needed to represent the value.
36// - Spacing around punctuation marks may be lost.
37// - Parentheses will only be applied when they affect operator precedence.
38func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo) (string, error) {
39	un := &unparser{info: info}
40	err := un.visit(expr)
41	if err != nil {
42		return "", err
43	}
44	return un.str.String(), nil
45}
46
47// unparser visits an expression to reconstruct a human-readable string from an AST.
48type unparser struct {
49	str    strings.Builder
50	offset int32
51	// TODO: use the source info to rescontruct macros into function calls.
52	info *exprpb.SourceInfo
53}
54
55func (un *unparser) visit(expr *exprpb.Expr) error {
56	switch expr.ExprKind.(type) {
57	case *exprpb.Expr_CallExpr:
58		return un.visitCall(expr)
59	// TODO: Comprehensions are currently not supported.
60	case *exprpb.Expr_ComprehensionExpr:
61		return un.visitComprehension(expr)
62	case *exprpb.Expr_ConstExpr:
63		return un.visitConst(expr)
64	case *exprpb.Expr_IdentExpr:
65		return un.visitIdent(expr)
66	case *exprpb.Expr_ListExpr:
67		return un.visitList(expr)
68	case *exprpb.Expr_SelectExpr:
69		return un.visitSelect(expr)
70	case *exprpb.Expr_StructExpr:
71		return un.visitStruct(expr)
72	}
73	return fmt.Errorf("unsupported expr: %v", expr)
74}
75
76func (un *unparser) visitCall(expr *exprpb.Expr) error {
77	c := expr.GetCallExpr()
78	fun := c.GetFunction()
79	switch fun {
80	// ternary operator
81	case operators.Conditional:
82		return un.visitCallConditional(expr)
83	// index operator
84	case operators.Index:
85		return un.visitCallIndex(expr)
86	// unary operators
87	case operators.LogicalNot, operators.Negate:
88		return un.visitCallUnary(expr)
89	// binary operators
90	case operators.Add,
91		operators.Divide,
92		operators.Equals,
93		operators.Greater,
94		operators.GreaterEquals,
95		operators.In,
96		operators.Less,
97		operators.LessEquals,
98		operators.LogicalAnd,
99		operators.LogicalOr,
100		operators.Modulo,
101		operators.Multiply,
102		operators.NotEquals,
103		operators.OldIn,
104		operators.Subtract:
105		return un.visitCallBinary(expr)
106	// standard function calls.
107	default:
108		return un.visitCallFunc(expr)
109	}
110}
111
112func (un *unparser) visitCallBinary(expr *exprpb.Expr) error {
113	c := expr.GetCallExpr()
114	fun := c.GetFunction()
115	args := c.GetArgs()
116	lhs := args[0]
117	// add parens if the current operator is lower precedence than the lhs expr operator.
118	lhsParen := isComplexOperatorWithRespectTo(fun, lhs)
119	rhs := args[1]
120	// add parens if the current operator is lower precedence than the rhs expr operator,
121	// or the same precedence and the operator is left recursive.
122	rhsParen := isComplexOperatorWithRespectTo(fun, rhs)
123	if !rhsParen && isLeftRecursive(fun) {
124		rhsParen = isSamePrecedence(fun, rhs)
125	}
126	err := un.visitMaybeNested(lhs, lhsParen)
127	if err != nil {
128		return err
129	}
130	unmangled, found := operators.FindReverseBinaryOperator(fun)
131	if !found {
132		return fmt.Errorf("cannot unmangle operator: %s", fun)
133	}
134	un.str.WriteString(" ")
135	un.str.WriteString(unmangled)
136	un.str.WriteString(" ")
137	return un.visitMaybeNested(rhs, rhsParen)
138}
139
140func (un *unparser) visitCallConditional(expr *exprpb.Expr) error {
141	c := expr.GetCallExpr()
142	args := c.GetArgs()
143	// add parens if operand is a conditional itself.
144	nested := isSamePrecedence(operators.Conditional, args[0]) ||
145		isComplexOperator(args[0])
146	err := un.visitMaybeNested(args[0], nested)
147	if err != nil {
148		return err
149	}
150	un.str.WriteString(" ? ")
151	// add parens if operand is a conditional itself.
152	nested = isSamePrecedence(operators.Conditional, args[1]) ||
153		isComplexOperator(args[1])
154	err = un.visitMaybeNested(args[1], nested)
155	if err != nil {
156		return err
157	}
158	un.str.WriteString(" : ")
159	// add parens if operand is a conditional itself.
160	nested = isSamePrecedence(operators.Conditional, args[2]) ||
161		isComplexOperator(args[2])
162
163	return un.visitMaybeNested(args[2], nested)
164}
165
166func (un *unparser) visitCallFunc(expr *exprpb.Expr) error {
167	c := expr.GetCallExpr()
168	fun := c.GetFunction()
169	args := c.GetArgs()
170	if c.GetTarget() != nil {
171		nested := isBinaryOrTernaryOperator(c.GetTarget())
172		err := un.visitMaybeNested(c.GetTarget(), nested)
173		if err != nil {
174			return err
175		}
176		un.str.WriteString(".")
177	}
178	un.str.WriteString(fun)
179	un.str.WriteString("(")
180	for i, arg := range args {
181		err := un.visit(arg)
182		if err != nil {
183			return err
184		}
185		if i < len(args)-1 {
186			un.str.WriteString(", ")
187		}
188	}
189	un.str.WriteString(")")
190	return nil
191}
192
193func (un *unparser) visitCallIndex(expr *exprpb.Expr) error {
194	c := expr.GetCallExpr()
195	args := c.GetArgs()
196	nested := isBinaryOrTernaryOperator(args[0])
197	err := un.visitMaybeNested(args[0], nested)
198	if err != nil {
199		return err
200	}
201	un.str.WriteString("[")
202	err = un.visit(args[1])
203	if err != nil {
204		return err
205	}
206	un.str.WriteString("]")
207	return nil
208}
209
210func (un *unparser) visitCallUnary(expr *exprpb.Expr) error {
211	c := expr.GetCallExpr()
212	fun := c.GetFunction()
213	args := c.GetArgs()
214	unmangled, found := operators.FindReverse(fun)
215	if !found {
216		return fmt.Errorf("cannot unmangle operator: %s", fun)
217	}
218	un.str.WriteString(unmangled)
219	nested := isComplexOperator(args[0])
220	return un.visitMaybeNested(args[0], nested)
221}
222
223func (un *unparser) visitComprehension(expr *exprpb.Expr) error {
224	// TODO: introduce a macro expansion map between the top-level comprehension id and the
225	// function call that the macro replaces.
226	return fmt.Errorf("unimplemented : %v", expr)
227}
228
229func (un *unparser) visitConst(expr *exprpb.Expr) error {
230	c := expr.GetConstExpr()
231	switch c.ConstantKind.(type) {
232	case *exprpb.Constant_BoolValue:
233		un.str.WriteString(strconv.FormatBool(c.GetBoolValue()))
234	case *exprpb.Constant_BytesValue:
235		// bytes constants are surrounded with b"<bytes>"
236		b := c.GetBytesValue()
237		un.str.WriteString(`b"`)
238		un.str.WriteString(bytesToOctets(b))
239		un.str.WriteString(`"`)
240	case *exprpb.Constant_DoubleValue:
241		// represent the float using the minimum required digits
242		d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64)
243		un.str.WriteString(d)
244	case *exprpb.Constant_Int64Value:
245		i := strconv.FormatInt(c.GetInt64Value(), 10)
246		un.str.WriteString(i)
247	case *exprpb.Constant_NullValue:
248		un.str.WriteString("null")
249	case *exprpb.Constant_StringValue:
250		// strings will be double quoted with quotes escaped.
251		un.str.WriteString(strconv.Quote(c.GetStringValue()))
252	case *exprpb.Constant_Uint64Value:
253		// uint literals have a 'u' suffix.
254		ui := strconv.FormatUint(c.GetUint64Value(), 10)
255		un.str.WriteString(ui)
256		un.str.WriteString("u")
257	default:
258		return fmt.Errorf("unimplemented : %v", expr)
259	}
260	return nil
261}
262
263func (un *unparser) visitIdent(expr *exprpb.Expr) error {
264	un.str.WriteString(expr.GetIdentExpr().GetName())
265	return nil
266}
267
268func (un *unparser) visitList(expr *exprpb.Expr) error {
269	l := expr.GetListExpr()
270	elems := l.GetElements()
271	un.str.WriteString("[")
272	for i, elem := range elems {
273		err := un.visit(elem)
274		if err != nil {
275			return err
276		}
277		if i < len(elems)-1 {
278			un.str.WriteString(", ")
279		}
280	}
281	un.str.WriteString("]")
282	return nil
283}
284
285func (un *unparser) visitSelect(expr *exprpb.Expr) error {
286	sel := expr.GetSelectExpr()
287	// handle the case when the select expression was generated by the has() macro.
288	if sel.GetTestOnly() {
289		un.str.WriteString("has(")
290	}
291	nested := !sel.GetTestOnly() && isBinaryOrTernaryOperator(sel.GetOperand())
292	err := un.visitMaybeNested(sel.GetOperand(), nested)
293	if err != nil {
294		return err
295	}
296	un.str.WriteString(".")
297	un.str.WriteString(sel.GetField())
298	if sel.GetTestOnly() {
299		un.str.WriteString(")")
300	}
301	return nil
302}
303
304func (un *unparser) visitStruct(expr *exprpb.Expr) error {
305	s := expr.GetStructExpr()
306	// If the message name is non-empty, then this should be treated as message construction.
307	if s.GetMessageName() != "" {
308		return un.visitStructMsg(expr)
309	}
310	// Otherwise, build a map.
311	return un.visitStructMap(expr)
312}
313
314func (un *unparser) visitStructMsg(expr *exprpb.Expr) error {
315	m := expr.GetStructExpr()
316	entries := m.GetEntries()
317	un.str.WriteString(m.GetMessageName())
318	un.str.WriteString("{")
319	for i, entry := range entries {
320		f := entry.GetFieldKey()
321		un.str.WriteString(f)
322		un.str.WriteString(": ")
323		v := entry.GetValue()
324		err := un.visit(v)
325		if err != nil {
326			return err
327		}
328		if i < len(entries)-1 {
329			un.str.WriteString(", ")
330		}
331	}
332	un.str.WriteString("}")
333	return nil
334}
335
336func (un *unparser) visitStructMap(expr *exprpb.Expr) error {
337	m := expr.GetStructExpr()
338	entries := m.GetEntries()
339	un.str.WriteString("{")
340	for i, entry := range entries {
341		k := entry.GetMapKey()
342		err := un.visit(k)
343		if err != nil {
344			return err
345		}
346		un.str.WriteString(": ")
347		v := entry.GetValue()
348		err = un.visit(v)
349		if err != nil {
350			return err
351		}
352		if i < len(entries)-1 {
353			un.str.WriteString(", ")
354		}
355	}
356	un.str.WriteString("}")
357	return nil
358}
359
360func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error {
361	if nested {
362		un.str.WriteString("(")
363	}
364	err := un.visit(expr)
365	if err != nil {
366		return err
367	}
368	if nested {
369		un.str.WriteString(")")
370	}
371	return nil
372}
373
374// isLeftRecursive indicates whether the parser resolves the call in a left-recursive manner as
375// this can have an effect of how parentheses affect the order of operations in the AST.
376func isLeftRecursive(op string) bool {
377	return op != operators.LogicalAnd && op != operators.LogicalOr
378}
379
380// isSamePrecedence indicates whether the precedence of the input operator is the same as the
381// precedence of the (possible) operation represented in the input Expr.
382//
383// If the expr is not a Call, the result is false.
384func isSamePrecedence(op string, expr *exprpb.Expr) bool {
385	if expr.GetCallExpr() == nil {
386		return false
387	}
388	c := expr.GetCallExpr()
389	other := c.GetFunction()
390	return operators.Precedence(op) == operators.Precedence(other)
391}
392
393// isLowerPrecedence indicates whether the precedence of the input operator is lower precedence
394// than the (possible) operation represented in the input Expr.
395//
396// If the expr is not a Call, the result is false.
397func isLowerPrecedence(op string, expr *exprpb.Expr) bool {
398	if expr.GetCallExpr() == nil {
399		return false
400	}
401	c := expr.GetCallExpr()
402	other := c.GetFunction()
403	return operators.Precedence(op) < operators.Precedence(other)
404}
405
406// Indicates whether the expr is a complex operator, i.e., a call expression
407// with 2 or more arguments.
408func isComplexOperator(expr *exprpb.Expr) bool {
409	if expr.GetCallExpr() != nil && len(expr.GetCallExpr().GetArgs()) >= 2 {
410		return true
411	}
412	return false
413}
414
415// Indicates whether it is a complex operation compared to another.
416// expr is *not* considered complex if it is not a call expression or has
417// less than two arguments, or if it has a higher precedence than op.
418func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool {
419	if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 {
420		return false
421	}
422	return isLowerPrecedence(op, expr)
423}
424
425// Indicate whether this is a binary or ternary operator.
426func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool {
427	if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 {
428		return false
429	}
430	_, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction())
431	return isBinaryOp || isSamePrecedence(operators.Conditional, expr)
432}
433
434// bytesToOctets converts byte sequences to a string using a three digit octal encoded value
435// per byte.
436func bytesToOctets(byteVal []byte) string {
437	var b strings.Builder
438	for _, c := range byteVal {
439		fmt.Fprintf(&b, "\\%03o", c)
440	}
441	return b.String()
442}
443