1/*
2 * gomacro - A Go interpreter with Lisp-like macros
3 *
4 * Copyright (C) 2017-2019 Massimiliano Ghilardi
5 *
6 *     This Source Code Form is subject to the terms of the Mozilla Public
7 *     License, v. 2.0. If a copy of the MPL was not distributed with this
8 *     file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 *
10 *
11 * binary_eql.go
12 *
13 *  Created on Apr 02, 2017
14 *      Author Massimiliano Ghilardi
15 */
16
17package fast
18
19import (
20	"go/ast"
21	"go/token"
22	r "reflect"
23
24	"github.com/cosmos72/gomacro/base/reflect"
25	. "github.com/cosmos72/gomacro/base"
26)
27
28:import (
29	"fmt"
30	"go/ast"
31	"go/token"
32	r "reflect"
33)
34
35
36:func upcasefirstbyte(str string) string {
37	if len(str) > 0 && str[0] >= 'a' && str[0] <= 'z' {
38		bytes := []byte(str)
39		bytes[0] -= 'a' - 'A'
40		return string(bytes)
41	}
42	return str
43}
44
45:func makekind(typ ast.Node) ast.Node {
46	t := EvalType(typ)
47
48	// go/ast.SelectorExpr requires the foo in r.foo to be an *ast.Ident, cannot unquote there
49	kind := ~"{r . foo}
50	kind.Sel = &ast.Ident{Name: upcasefirstbyte(t.Name())}
51	return kind
52}
53
54
55:func convertvalue1(typ, val ast.Node) ast.Node {
56	var t r.Type = EvalType(typ)
57	if t == nil {
58		// keep the result wrapped in a reflect.Value
59		return val
60	}
61	// unwrap the result
62	tname := t.Name()
63	// remove final digits from t.Name()
64	// needed to convert Uint64 -> Uint etc. to calls reflect.Value.{tname}
65	for len(tname) != 0 {
66		ch := tname[len(tname)-1]
67		if ch < '0' || ch > '9' {
68			break
69		}
70		tname = tname[0:len(tname)-1]
71	}
72	if tname == "uintptr" {
73		tname = "uint" // use reflect.Value.Uint()
74	}
75	sel := ~"{~,val . foo} // we modify it destructively
76	sel.Sel = &ast.Ident{Name: upcasefirstbyte(tname)}
77
78	switch t.Kind() {
79	case r.Bool, r.Int64, r.Uint64, r.Float64, r.Complex128, r.String:
80		// result of reflect.Value.{tname} is already the correct type
81		val = ~"{~,sel ()}
82	default:
83		// convert int64, uint64... to the correct type
84		val = ~"{~,typ ( ~,sel () )}
85	}
86	return val
87}
88
89:macro eqlneq(opnode, xconst, yconst, typ ast.Node) ast.Node {
90
91	// the return type of Eval() and EvalType() varies. better check early.
92	xc, yc := Eval(xconst).(bool), Eval(yconst).(bool)
93	optoken := Eval(opnode).(token.Token)
94
95	if xc == yc {
96		var expr *ast.BinaryExpr = ~"{x(env) && y(env)} // quasiquote, we modify it destructively
97		expr.Op = optoken
98
99		return ~"{
100			x := x.(func(*Env) ~,typ)
101			y := y.(func(*Env) ~,typ)
102			fun = func(env *Env) bool {
103				return ~,expr
104			}
105		}
106	} else if yc {
107		var expr *ast.BinaryExpr = ~"{x(env) && y} // quasiquote, we modify it destructively
108		expr.Op = optoken
109
110		yconv := convertvalue1(typ, ~'yv)
111		return ~"{
112			x := x.(func(*Env) ~,typ)
113			y := ~,yconv
114			fun = func(env *Env) bool {
115				return ~,expr
116			}
117		}
118	} else {
119		var expr *ast.BinaryExpr = ~"{x && y(env)} // quasiquote, we modify it destructively
120		expr.Op = optoken
121
122		xconv := convertvalue1(typ, ~'xv)
123		return ~"{
124			x := ~,xconv
125			y := y.(func(*Env) ~,typ)
126			fun = func(env *Env) bool {
127				return ~,expr
128			}
129		}
130	}
131}
132
133:macro eqlneqs(opnode, xconst, yconst, types ast.Node) ast.Node {
134	typelist := types.(*ast.BlockStmt).List
135	caselist := make([]ast.Stmt, 0, len(typelist))
136	for _, typ := range typelist {
137		t := EvalType(typ)
138		if t.Kind() == r.Int {
139			// shortcut for all int* types
140			for _, typ := range []ast.Expr{~'int, ~'int8, ~'int16, ~'int32, ~'int64} {
141				kind := makekind(typ)
142				caselist = append(caselist, ~"{case ~,kind: eqlneq; ~,opnode; ~,xconst; ~,yconst; ~,typ})
143			}
144		} else if t.Kind() == r.Uint {
145			// shortcut for all uint* types
146			for _, typ := range []ast.Expr{~'uint, ~'uint8, ~'uint16, ~'uint32, ~'uint64, ~'uintptr} {
147				kind := makekind(typ)
148				caselist = append(caselist, ~"{case ~,kind: eqlneq; ~,opnode; ~,xconst; ~,yconst; ~,typ})
149			}
150		} else {
151			kind := makekind(typ)
152			caselist = append(caselist, ~"{case ~,kind: eqlneq; ~,opnode; ~,xconst; ~,yconst; ~,typ})
153		}
154	}
155	return ~"{ switch k { ~,@caselist } }
156}
157
158func (c *Comp) Eql(node *ast.BinaryExpr, xe *Expr, ye *Expr) *Expr {
159	if xe.IsNil() {
160		if ye.IsNil() {
161			return c.invalidBinaryExpr(node, xe, ye)
162		} else {
163			// nil == expr
164			return c.eqlneqNil(node, xe, ye)
165		}
166	} else if ye.IsNil() {
167		// expr == nil
168		return c.eqlneqNil(node, xe, ye)
169	}
170	if !xe.Type.Comparable() || !xe.Type.Comparable() {
171		return c.invalidBinaryExpr(node, xe, ye)
172	}
173	xc, yc := xe.Const(), ye.Const()
174	if xe.Type.Kind() != r.Interface && ye.Type.Kind() != r.Interface {
175		// comparison between different types is allowed only if at least one is an interface
176		c.toSameFuncType(node, xe, ye)
177	}
178	k := xe.Type.Kind()
179	yk := ye.Type.Kind() // may differ from k
180
181	// if both x and y are constants, BinaryExpr will invoke EvalConst()
182	// on our return value. no need to optimize that.
183	var fun func(env *Env) bool
184	if k != yk {
185		// call c.eqlneqMisc() below
186	} else if xc == yc {
187		x, y := xe.Fun, ye.Fun
188		{eqlneqs; token.EQL; false; false; { bool; int; uint; float32; float64; complex64; complex128; string } }
189	} else if yc {
190		x := xe.Fun
191		yv := r.ValueOf(ye.Value)
192		if k == r.Bool && yv.Bool() {
193		    // xe == true is the same as xe
194			return xe
195		}
196		{eqlneqs; token.EQL; false; true; { bool; int; uint; float32; float64; complex64; complex128; string } }
197	} else {
198		xv := r.ValueOf(xe.Value)
199		y := ye.Fun
200		if k == r.Bool && xv.Bool() {
201		    // true == ye is the same as ye
202			return ye
203		}
204		{eqlneqs; token.EQL; true; false; { bool; int; uint; float32; float64; complex64; complex128; string } }
205	}
206	if fun != nil {
207		return c.exprBool(fun)
208	}
209	return c.eqlneqMisc(node, xe, ye)
210}
211
212func (c *Comp) Neq(node *ast.BinaryExpr, xe *Expr, ye *Expr) *Expr {
213	if xe.IsNil() {
214		if ye.IsNil() {
215			return c.invalidBinaryExpr(node, xe, ye)
216		} else {
217			// nil == expr
218			return c.eqlneqNil(node, xe, ye)
219		}
220	} else if ye.IsNil() {
221		// expr == nil
222		return c.eqlneqNil(node, xe, ye)
223	}
224	if !xe.Type.Comparable() || !xe.Type.Comparable() {
225		return c.invalidBinaryExpr(node, xe, ye)
226	}
227	xc, yc := xe.Const(), ye.Const()
228	if xe.Type.Kind() != r.Interface && ye.Type.Kind() != r.Interface {
229		// comparison between different types is allowed only if at least one is an interface
230		c.toSameFuncType(node, xe, ye)
231	}
232	k := xe.Type.Kind()
233	yk := ye.Type.Kind() // may differ from k
234
235	// if both x and y are constants, BinaryExpr will invoke EvalConst()
236	// on our return value. no need to optimize that.
237	var fun func(env *Env) bool
238	if k != yk {
239		// call c.eqlneqMisc() below
240	} else if xc == yc {
241		x, y := xe.Fun, ye.Fun
242		{eqlneqs; token.NEQ; false; false; { int; uint; float32; float64; complex64; complex128; string } }
243	} else if yc {
244		x := xe.Fun
245		yv := r.ValueOf(ye.Value)
246		if k == r.Bool && !yv.Bool() {
247		    // xe != false is the same as xe
248			return xe
249		}
250		{eqlneqs; token.NEQ; false; true; { int; uint; float32; float64; complex64; complex128; string } }
251	} else {
252		xv := r.ValueOf(xe.Value)
253		y := ye.Fun
254		if k == r.Bool && !xv.Bool() {
255		    // false != ye is the same as ye
256			return ye
257		}
258		{eqlneqs; token.NEQ; true; false; { int; uint; float32; float64; complex64; complex128; string } }
259	}
260	if fun != nil {
261		return c.exprBool(fun)
262	}
263	return c.eqlneqMisc(node, xe, ye)
264}
265
266// compare arrays, interfaces, pointers, structs
267func (c *Comp) eqlneqMisc(node *ast.BinaryExpr, xe *Expr, ye *Expr) *Expr {
268	var fun func(*Env) bool
269
270	x := xe.AsX1()
271	y := ye.AsX1()
272	t1 := xe.Type
273	t2 := ye.Type
274	extractor1 := c.extractor(t1)
275	extractor2 := c.extractor(t2)
276
277	if node.Op == token.EQL {
278		fun = func(env *Env) bool {
279			v1 := x(env)
280			v2 := y(env)
281			if v1 == Nil || v2 == Nil {
282				return v1 == v2
283			}
284			t1, t2 := t1, t2
285			if extractor1 != nil {
286				v1, t1 = extractor1(v1)
287			}
288			if extractor2 != nil {
289				v2, t2 = extractor2(v2)
290			}
291			if v1 == Nil || v2 == Nil {
292				return v1 == v2
293			}
294			return v1.Interface() == v2.Interface() &&
295				(t1 == nil || t2 == nil || t1.IdenticalTo(t2))
296		}
297	} else {
298		fun = func(env *Env) bool {
299			v1 := x(env)
300			v2 := y(env)
301			if v1 == Nil || v2 == Nil {
302				return v1 != v2
303			}
304			t1, t2 := t1, t2
305			if extractor1 != nil {
306				v1, t1 = extractor1(v1)
307			}
308			if extractor2 != nil {
309				v2, t2 = extractor2(v2)
310			}
311			if v1 == Nil || v2 == Nil {
312				return v1 != v2
313			}
314			return v1.Interface() != v2.Interface() ||
315				(t1 != nil && t2 != nil && !t1.IdenticalTo(t2))
316		}
317	}
318	return c.exprBool(fun)
319}
320
321func (c *Comp) eqlneqNil(node *ast.BinaryExpr, xe *Expr, ye *Expr) *Expr {
322	var e *Expr
323	if ye.IsNil() {
324		e = xe
325	} else {
326		e = ye
327	}
328	// e can be a constant... for example nil == nil
329	if !reflect.IsNillableKind(e.Type.Kind()) {
330		return c.invalidBinaryExpr(node, xe, ye)
331	}
332
333	var fun func(env *Env) bool
334	if f, ok := e.Fun.(func(env *Env) (r.Value, []r.Value)); ok {
335		e.CheckX1() // to warn or error as appropriate
336		if node.Op == token.EQL {
337			fun = func(env *Env) bool {
338				v, _ := f(env)
339				vnil := v == Nil || reflect.IsNillableKind(v.Kind()) && v.IsNil()
340				return vnil
341			}
342		} else {
343			fun = func(env *Env) bool {
344				v, _ := f(env)
345				vnil := v == Nil || reflect.IsNillableKind(v.Kind()) && v.IsNil()
346				return !vnil
347			}
348		}
349	} else {
350		f := e.AsX1()
351		if node.Op == token.EQL {
352			fun = func(env *Env) bool {
353				v := f(env)
354				vnil := v == Nil || reflect.IsNillableKind(v.Kind()) && v.IsNil()
355				return vnil
356			}
357		} else {
358			fun = func(env *Env) bool {
359				v := f(env)
360				vnil := v == Nil || reflect.IsNillableKind(v.Kind()) && v.IsNil()
361				return !vnil
362			}
363		}
364	}
365	return c.exprBool(fun)
366}
367