1/*
2 * gomacro - A Go interpreter with Lisp-like macros
3 *
4 * Copyright (C) 2017-2019 Massimiliano Ghilardi
5 *
6 *     This Source Code Form is subject to the terms of the Mozilla Public
7 *     License, v. 2.0. If a copy of the MPL was not distributed with this
8 *     file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 *
10 *
11 * call.go
12 *
13 *  Created on Apr 15, 2017
14 *      Author Massimiliano Ghilardi
15 */
16
17package fast
18
19import (
20	"bytes"
21	"fmt"
22	"go/ast"
23	"go/token"
24	r "reflect"
25
26	xr "github.com/cosmos72/gomacro/xreflect"
27)
28
29type Call struct {
30	Fun      *Expr
31	Args     []*Expr
32	OutTypes []xr.Type
33	Builtin  bool // if true, call is a builtin function
34	Const    bool // if true, call has no side effects and always returns the same result => it can be invoked at compile time
35	Ellipsis bool // if true, must use reflect.Value.CallSlice or equivalent to invoke the function
36}
37
38func newCall1(fun *Expr, arg *Expr, isconst bool, outtypes ...xr.Type) *Call {
39	return &Call{
40		Fun:      fun,
41		Args:     []*Expr{arg},
42		OutTypes: outtypes,
43		Const:    isconst,
44	}
45}
46
47func (call *Call) MakeArgfunsX1() []func(*Env) r.Value {
48	args := call.Args
49	argfuns := make([]func(*Env) r.Value, len(args))
50	for i, arg := range args {
51		argfuns[i] = arg.AsX1()
52	}
53	return argfuns
54}
55
56// CallExpr compiles a function call or a type conversion
57func (c *Comp) CallExpr(node *ast.CallExpr) *Expr {
58	var fun *Expr
59	if len(node.Args) == 1 {
60		var t xr.Type
61		fun, t = c.Expr1OrType(node.Fun)
62		if t != nil {
63			return c.Convert(node.Args[0], t)
64		}
65	}
66	call := c.prepareCall(node, fun)
67	return c.call_any(call)
68}
69
70// callExpr compiles the common part between CallExpr and Go statement
71func (c *Comp) prepareCall(node *ast.CallExpr, fun *Expr) *Call {
72	if fun == nil {
73		fun = c.expr1(node.Fun, nil)
74	}
75	t := fun.Type
76	var builtin bool
77	var lastarg *Expr
78	if t.IdenticalTo(c.TypeOfBuiltin()) {
79		return c.callBuiltin(node, fun)
80	} else if t.IdenticalTo(c.TypeOfFunction()) {
81		fun, lastarg = c.callFunction(node, fun)
82		t = fun.Type
83		builtin = true
84	}
85	// compile args early, and use them to infer template function instantiation
86	var args []*Expr
87	if len(node.Args) == 1 {
88		// support foo(bar()) where bar() returns multiple values
89		arg := c.Expr(node.Args[0], nil)
90		if arg.NumOut() == 0 {
91			c.Errorf("function argument returns zero values: %v ", node.Args[0])
92		}
93		args = []*Expr{arg}
94	} else {
95		args = c.Exprs(node.Args)
96	}
97	if lastarg != nil {
98		args = append(args, lastarg)
99	}
100	switch t.Kind() {
101	case r.Func:
102	case r.Ptr:
103		if GENERICS_V1 && t.ReflectType() == rtypeOfPtrTemplateFunc {
104			fun = c.inferTemplateFunc(node, fun, args)
105			t = fun.Type
106			break
107		}
108		fallthrough
109	default:
110		c.Errorf("call of non-function: %v <%v>", node.Fun, t)
111		return nil
112	}
113	ellipsis := node.Ellipsis != token.NoPos
114	c.checkCallArgs(node, t, args, ellipsis)
115
116	outn := t.NumOut()
117	outtypes := make([]xr.Type, outn)
118	for i := 0; i < outn; i++ {
119		outtypes[i] = t.Out(i)
120	}
121	return &Call{Fun: fun, Args: args, OutTypes: outtypes, Builtin: builtin, Ellipsis: ellipsis}
122}
123
124// call_any emits a compiled function call
125func (c *Comp) call_any(call *Call) *Expr {
126	expr := &Expr{}
127	tout := call.OutTypes
128	nout := len(tout)
129	expr.SetTypes(tout)
130
131	maxdepth := c.Depth
132	// functions imported from other packages are constant too...
133	// but call_builtin does not know about them
134	if call.Fun.Const() {
135		if call.Builtin {
136			fun := c.call_builtin(call)
137			if _, untyped := fun.(UntypedLit); untyped {
138				// complex(), real(), imag() of untyped constants produce an untyped constant, not a function
139				expr.Value = fun
140				return expr
141			} else {
142				expr.Fun = fun
143			}
144		} else {
145			// normal calls do not expect function to be a constant.
146			call.Fun.WithFun()
147		}
148	}
149
150	if expr.Fun != nil {
151		// done already
152	} else if len(call.Args) == 1 && call.Args[0].NumOut() > 1 {
153		// support foo(bar()) where bar() returns multiple values.
154		//
155		// do NOT use this case for calls like fmt.Printf("foo") where the function
156		// formally expects two args but is variadic => accepts one arg too:
157		// fixes gophernotes issue 118
158		expr.Fun = call_multivalue(call, maxdepth)
159	} else if nout == 0 {
160		expr.Fun = c.call_ret0(call, maxdepth)
161	} else if nout == 1 {
162		expr.Fun = c.call_ret1(call, maxdepth)
163	} else {
164		expr.Fun = c.call_ret2plus(call, maxdepth)
165	}
166	// constant propagation - only if function returns a single value
167	if call.Const && len(call.OutTypes) == 1 {
168		expr.EvalConst(COptDefaults)
169		// c.Debugf("pre-computed result of constant call %v: %v <%v>", call, expr.Value, TypeOf(expr.Value))
170	}
171	return expr
172}
173
174func (c *Comp) checkCallArgs(node *ast.CallExpr, t xr.Type, args []*Expr, ellipsis bool) {
175	variadic := t.IsVariadic()
176	if ellipsis {
177		if variadic {
178			// a variadic function invoked as fun(x, y...)
179			// behaves exactly as a non-variadic function call:
180			// number and type of arguments must match
181			variadic = false
182		} else {
183			c.Errorf("invalid use of ... in call to non-variadic function <%v>: %v", t, node)
184			return
185		}
186	}
187	n := t.NumIn()
188	narg := len(args)
189	if narg == 1 {
190		// support foo(bar()) where bar() returns multiple values
191		narg = args[0].NumOut()
192	}
193	if narg < n-1 || (!variadic && narg != n) {
194		c.badCallArgNum(node.Fun, t, args)
195		return
196	}
197	var ti, tlast xr.Type
198	if variadic {
199		tlast = t.In(n - 1).Elem()
200	}
201	var convs []func(r.Value) r.Value
202	needconvs := false
203	multivalue := len(args) != narg
204	if multivalue {
205		convs = make([]func(r.Value) r.Value, narg)
206	}
207	for i := 0; i < narg; i++ {
208		if variadic && i >= n-1 {
209			ti = tlast
210		} else {
211			ti = t.In(i)
212		}
213		if multivalue {
214			// support foo(bar()) where bar() returns multiple values
215			targ := args[0].Out(i)
216			if targ == nil || !targ.AssignableTo(ti) {
217				c.Errorf("cannot use <%v> as <%v> in argument to %v", targ, ti, node.Fun)
218			} else if conv := c.Converter(targ, ti); conv != nil {
219				convs[i] = conv
220				args[0].Types[i] = ti
221				needconvs = true
222			}
223			continue
224		}
225		// one argument per parameter: foo(arg1, arg2 /*...*/)
226		arg := args[i]
227		if arg.Const() {
228			arg.ConstTo(ti)
229		} else if arg.Type == nil || !arg.Type.AssignableTo(ti) {
230			c.Errorf("cannot use <%v> as <%v> in argument to %v", arg.Type, ti, node.Fun)
231		} else {
232			arg.To(c, ti)
233		}
234	}
235	if !multivalue || !needconvs {
236		return
237	}
238	f := args[0].AsXV(COptDefaults)
239	args[0].Fun = func(env *Env) (r.Value, []r.Value) {
240		_, vs := f(env)
241		for i, conv := range convs {
242			if conv != nil {
243				vs[i] = conv(vs[i])
244			}
245		}
246		return vs[0], vs
247	}
248}
249
250func (call *Call) canOptimize() bool {
251	rtype := call.Fun.Type.ReflectType()
252	if rtype.Name() != "" {
253		// no optimization for named func type
254		return false
255	}
256	for i, n := 0, rtype.NumIn(); i < n; i++ {
257		ti := rtype.In(i)
258		if ti.Kind() == r.UnsafePointer || ti != xr.ReflectBasicTypes[ti.Kind()] {
259			// no optimization for func argument whose type is not a basic type
260			return false
261		}
262	}
263	for i, n := 0, rtype.NumOut(); i < n; i++ {
264		ti := rtype.Out(i)
265		if ti.Kind() == r.UnsafePointer || ti != xr.ReflectBasicTypes[ti.Kind()] {
266			// no optimization for func return value whose type is not a basic type
267			return false
268		}
269	}
270	return true
271}
272
273// mandatory optimization: fast_interpreter ASSUMES that expressions
274// returning bool, int, uint, float, complex, string do NOT wrap them in reflect.Value
275func (c *Comp) call_ret0(call *Call, maxdepth int) func(env *Env) {
276	if call.Ellipsis {
277		return call_ellipsis_ret0(call, maxdepth)
278	} else if call.Fun.Type.IsVariadic() {
279		return call_variadic_ret0(call, maxdepth)
280	}
281	// optimize fun(t1, t2)
282	var ret func(*Env)
283	if call.canOptimize() {
284		switch len(call.Args) {
285		case 0:
286			ret = c.call0ret0(call, maxdepth)
287		case 1:
288			ret = c.call1ret0(call, maxdepth)
289		case 2:
290			ret = c.call2ret0(call, maxdepth)
291		}
292	}
293	if ret == nil {
294		ret = c.callnret0(call, maxdepth)
295	}
296	return ret
297}
298
299// mandatory optimization: fast_interpreter ASSUMES that expressions
300// returning no values are compiled as func(*Env)
301func (c *Comp) callnret0(call *Call, maxdepth int) func(env *Env) {
302	exprfun := call.Fun.AsX1()
303	argfunsX1 := call.MakeArgfunsX1()
304	var ret func(*Env)
305	switch len(argfunsX1) {
306	case 0:
307		ret = func(env *Env) {
308			funv := exprfun(env)
309			callxr(funv, nil)
310		}
311	case 1:
312		argfun := argfunsX1[0]
313		ret = func(env *Env) {
314			funv := exprfun(env)
315			argv := []r.Value{
316				argfun(env),
317			}
318			callxr(funv, argv)
319		}
320	case 2:
321		ret = func(env *Env) {
322			funv := exprfun(env)
323			argv := []r.Value{
324				argfunsX1[0](env),
325				argfunsX1[1](env),
326			}
327			callxr(funv, argv)
328		}
329	default:
330		ret = func(env *Env) {
331			funv := exprfun(env)
332			argv := make([]r.Value, len(argfunsX1))
333			for i, argfun := range argfunsX1 {
334				argv[i] = argfun(env)
335			}
336			callxr(funv, argv)
337		}
338	}
339	return ret
340}
341
342// mandatory optimization: fast_interpreter ASSUMES that expressions
343// returning bool, int, uint, float, complex, string do NOT wrap them in reflect.Value
344func (c *Comp) call_ret1(call *Call, maxdepth int) I {
345	if call.Ellipsis {
346		return call_ellipsis_ret1(call, maxdepth)
347	} else if call.Fun.Type.IsVariadic() {
348		return call_variadic_ret1(call, maxdepth)
349	}
350	var ret I
351	if call.canOptimize() {
352		switch len(call.Args) {
353		case 0:
354			ret = c.call0ret1(call, maxdepth)
355		case 1:
356			ret = c.call1ret1(call, maxdepth)
357		case 2:
358			ret = c.call2ret1(call, maxdepth)
359		}
360	}
361	if ret == nil {
362		ret = c.callnret1(call, maxdepth)
363	}
364	return ret
365}
366
367// cannot optimize much here... fast_interpreter ASSUMES that expressions
368// returning multiple values actually return (reflect.Value, []reflect.Value)
369func (c *Comp) call_ret2plus(call *Call, maxdepth int) func(env *Env) (r.Value, []r.Value) {
370	if call.Ellipsis {
371		return call_ellipsis_ret2plus(call, maxdepth)
372	}
373	// no need to special case variadic functions here
374	expr := call.Fun
375	exprfun := expr.AsX1()
376	argfunsX1 := call.MakeArgfunsX1()
377	var ret func(*Env) (r.Value, []r.Value)
378	switch len(call.Args) {
379	case 0:
380		ret = func(env *Env) (r.Value, []r.Value) {
381			funv := exprfun(env)
382			retv := callxr(funv, nil)
383			return retv[0], retv
384		}
385	case 1:
386		argfun := argfunsX1[0]
387		ret = func(env *Env) (r.Value, []r.Value) {
388			funv := exprfun(env)
389			argv := []r.Value{
390				argfun(env),
391			}
392			retv := callxr(funv, argv)
393			return retv[0], retv
394		}
395	case 2:
396		argfuns := [2]func(*Env) r.Value{
397			argfunsX1[0],
398			argfunsX1[1],
399		}
400		ret = func(env *Env) (r.Value, []r.Value) {
401			funv := exprfun(env)
402			argv := []r.Value{
403				argfuns[0](env),
404				argfuns[1](env),
405			}
406			retv := callxr(funv, argv)
407			return retv[0], retv
408		}
409	case 3:
410		argfuns := [3]func(*Env) r.Value{
411			argfunsX1[0],
412			argfunsX1[1],
413			argfunsX1[2],
414		}
415		ret = func(env *Env) (r.Value, []r.Value) {
416			funv := exprfun(env)
417			argv := []r.Value{
418				argfuns[0](env),
419				argfuns[1](env),
420				argfuns[2](env),
421			}
422			retv := callxr(funv, argv)
423			return retv[0], retv
424		}
425	default:
426		// general case
427		ret = func(env *Env) (r.Value, []r.Value) {
428			funv := exprfun(env)
429			argv := make([]r.Value, len(argfunsX1))
430			for i, argfun := range argfunsX1 {
431				argv[i] = argfun(env)
432			}
433			retv := callxr(funv, argv)
434			return retv[0], retv
435		}
436	}
437	return ret
438}
439
440// replacement for reflect.Value.Call() that correctly handles
441// functions wrapped in xr.Forward
442func callxr(fun r.Value, args []r.Value) []r.Value {
443	if fun.Kind() == r.Interface {
444		fun = fun.Elem()
445	}
446	return fun.Call(args)
447}
448
449func callslicexr(fun r.Value, args []r.Value) []r.Value {
450	if fun.Kind() == r.Interface {
451		fun = fun.Elem()
452	}
453	return fun.CallSlice(args)
454}
455
456func (c *Comp) badCallArgNum(fun ast.Expr, t xr.Type, args []*Expr) *Call {
457	prefix := "not enough"
458	n := t.NumIn()
459	nargs := len(args)
460	if nargs > n {
461		prefix = "too many"
462	}
463	have := bytes.Buffer{}
464	for i, arg := range args {
465		if i == 0 {
466			fmt.Fprintf(&have, "%v", arg.Type)
467		} else {
468			fmt.Fprintf(&have, ", %v", arg.Type)
469		}
470	}
471	want := bytes.Buffer{}
472	for i := 0; i < n; i++ {
473		if i == 0 {
474			fmt.Fprintf(&want, "%v", t.In(i))
475		} else {
476			fmt.Fprintf(&want, ", %v", t.In(i))
477		}
478	}
479	c.Errorf("%s arguments in call to %v:\n\thave (%s)\n\twant (%s)", prefix, fun, have.Bytes(), want.Bytes())
480	return nil
481}
482