1/*
2 * gomacro - A Go interpreter with Lisp-like macros
3 *
4 * Copyright (C) 2018-2019 Massimiliano Ghilardi
5 *
6 *     This Source Code Form is subject to the terms of the Mozilla Public
7 *     License, v. 2.0. If a copy of the MPL was not distributed with this
8 *     file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 *
10 *
11 * template_infer.go
12 *
13 *  Created on Jun 06, 2018
14 *      Author Massimiliano Ghilardi
15 */
16
17package fast
18
19import (
20	"fmt"
21	"go/ast"
22	"go/token"
23	r "reflect"
24
25	"github.com/cosmos72/gomacro/base/untyped"
26	xr "github.com/cosmos72/gomacro/xreflect"
27)
28
29type inferType struct {
30	Type    xr.Type
31	Untyped untyped.Kind // for untyped literals
32	Value   I            // in case we infer a constant, not a type
33	Exact   bool
34}
35
36func (inf *inferType) String() string {
37	if inf.Value != nil {
38		return fmt.Sprint(inf.Value)
39	}
40	var s string
41	if inf.Type != nil {
42		s = inf.Type.String()
43	} else {
44		s = inf.Untyped.String()
45	}
46	return "<" + s + ">"
47}
48
49// type inference on template functions
50type inferFuncType struct {
51	comp     *Comp
52	tfun     *TemplateFunc
53	funcname string
54	inferred map[string]inferType
55	patterns []ast.Expr
56	targs    []inferType
57	call     *ast.CallExpr // for error messages
58}
59
60func (inf *inferFuncType) String() string {
61	return inf.tfun.Signature(inf.funcname)
62}
63
64func (c *Comp) inferTemplateFunc(call *ast.CallExpr, fun *Expr, args []*Expr) *Expr {
65	tfun, ok := fun.Value.(*TemplateFunc)
66	if !ok {
67		c.Errorf("internal error: Comp.inferTemplateFunc() invoked on non-template function %v: %v", fun.Type, call.Fun)
68	}
69	var upc *Comp
70	var funcname string
71	{
72		ident, ok := call.Fun.(*ast.Ident)
73		if !ok {
74			c.Errorf("unimplemented type inference on non-name template function %v: %v", call.Fun, call)
75		}
76		if fun.Sym == nil {
77			c.Errorf("unimplemented type inference on non-symbol template function %v %#v: %v", call.Fun, fun, call)
78		}
79		// find the scope where fun is declared
80		funcname = ident.Name
81		fbind := &fun.Sym.Bind
82		for upc = c; upc != nil; upc = upc.Outer {
83			if bind, ok := upc.Binds[funcname]; ok && bind.Name == fbind.Name && bind.Desc == fbind.Desc && bind.Type.IdenticalTo(fbind.Type) {
84				break
85			}
86		}
87	}
88	if upc == nil {
89		c.Errorf("internal error: Comp.inferTemplateFunc() failed to determine the scope containing template function declaration: %v", call.Fun)
90	}
91
92	master := tfun.Master
93	typ := master.Decl.Type
94
95	var patterns []ast.Expr
96	ellipsis := call.Ellipsis != token.NoPos
97	variadic := false
98	// collect template function param types expressions
99	if fields := typ.Params; fields != nil {
100		if n := len(fields.List); n != 0 {
101			_, variadic = fields.List[n-1].Type.(*ast.Ellipsis)
102			for _, field := range fields.List {
103				for _ = range field.Names {
104					patterns = append(patterns, field.Type)
105				}
106			}
107		}
108	}
109	if variadic && !ellipsis {
110		c.Errorf("unimplemented type inference on variadic template function: %v", call)
111	} else if !variadic && ellipsis {
112		c.Errorf("invalid use of ... in call to non-variadic template function: %v", call)
113	}
114
115	// collect call arg types
116	nargs := len(args)
117	var targs []inferType
118	if nargs == 1 {
119		arg := args[0]
120		nargs = arg.NumOut()
121		targs = make([]inferType, nargs)
122		for i := 0; i < nargs; i++ {
123			targs[i] = inferType{Type: arg.Out(i)}
124		}
125	} else {
126		targs = make([]inferType, nargs)
127		for i, arg := range args {
128			if kind := arg.UntypedKind(); kind != untyped.None {
129				targs[i] = inferType{Untyped: kind}
130			} else {
131				targs[i] = inferType{Type: arg.Type}
132			}
133		}
134	}
135	if nargs != len(patterns) {
136		c.Errorf("template function %v has %d params, cannot call with %d values: %v", tfun, len(patterns), nargs, call)
137	}
138	inferred := make(map[string]inferType)
139	for _, name := range master.Params {
140		inferred[name] = inferType{}
141	}
142	inf := inferFuncType{comp: c, tfun: tfun, funcname: funcname, inferred: inferred, patterns: patterns, targs: targs, call: call}
143	vals, types := inf.args()
144	maker := &templateMaker{
145		comp: upc, sym: fun.Sym, ifun: fun.Sym.Value,
146		exprs: nil, vals: vals, types: types,
147		ikey: makeTemplateKey(vals, types),
148		pos:  inf.call.Pos(),
149	}
150	return c.templateFunc(maker, call)
151}
152
153// infer type of template function from arguments
154func (inf *inferFuncType) args() (vals []I, types []xr.Type) {
155	exact := false // allow implicit type conversions
156
157	// first pass: types and typed constants
158	for i, targ := range inf.targs {
159		node := inf.patterns[i]
160		if targ.Type != nil {
161			inf.arg(node, targ.Type, exact)
162		} else if targ.Untyped != untyped.None {
163			// skip untyped constant, handled below
164		} else if targ.Value != nil {
165			inf.constant(node, targ.Value, exact)
166		} else {
167			inf.fail(node, targ)
168		}
169	}
170
171	// second pass: untyped constants
172	for i, targ := range inf.targs {
173		if targ.Type == nil && targ.Untyped != untyped.None {
174			inf.untyped(inf.patterns[i], targ.Untyped, exact)
175		}
176	}
177
178	params := inf.tfun.Master.Params
179	n := len(params)
180	vals = make([]I, n)
181	types = make([]xr.Type, n)
182	for i, name := range params {
183		inferred, ok := inf.inferred[name]
184		if !ok || inferred.Type == nil {
185			inf.comp.Errorf("failed to infer %v in call to template function: %v", name, inf.call)
186		}
187		types[i] = inferred.Type
188		vals[i] = inferred.Value
189	}
190	return vals, types
191}
192
193// partially infer type of template function for a single parameter
194func (inf *inferFuncType) arg(pattern ast.Expr, targ xr.Type, exact bool) {
195	stars := 0
196	for {
197		if targ == nil {
198			inf.fail(pattern, targ)
199		}
200		if node, ok := pattern.(*ast.Ident); ok {
201			inf.ident(node, targ, exact)
202			break
203		}
204		switch node := pattern.(type) {
205		case *ast.ArrayType:
206			pattern, targ, exact = inf.arrayType(node, targ, exact)
207			continue
208		case *ast.ChanType:
209			pattern, targ, exact = inf.chanType(node, targ, exact)
210			continue
211		case *ast.FuncType:
212			pattern, targ, exact = inf.funcType(node, targ, exact)
213			if pattern != nil {
214				continue
215			}
216		case *ast.IndexExpr:
217			// function's parameter is itself a template
218			pattern, targ, exact = inf.templateType(node, targ, exact)
219			if pattern != nil {
220				continue
221			}
222		case *ast.InterfaceType:
223			pattern, targ, exact = inf.interfaceType(node, targ, exact)
224			if pattern != nil {
225				continue
226			}
227		case *ast.MapType:
228			pattern, targ, exact = inf.mapType(node, targ, exact)
229			continue
230		case *ast.ParenExpr:
231			pattern = node.X
232			continue
233		case *ast.SelectorExpr:
234			// packagename.typename
235			pattern, targ, exact = inf.selector(node, targ, exact)
236			if pattern != nil {
237				continue
238			}
239		case *ast.StarExpr:
240			inf.is(pattern, targ, r.Ptr)
241			pattern, targ = node.X, targ.Elem()
242			if stars != 0 {
243				exact = true
244			}
245			stars++
246			continue
247		case *ast.StructType:
248			pattern, targ, exact = inf.structType(node, targ, exact)
249			if pattern != nil {
250				continue
251			}
252		default:
253			inf.unimplemented(node, targ)
254		}
255		break
256	}
257}
258
259// partially infer type of template function from an array or slice parameter
260func (inf *inferFuncType) arrayType(node *ast.ArrayType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
261	if node.Len == nil {
262		inf.is(node, targ, r.Slice)
263	} else {
264		inf.is(node, targ, r.Array)
265		if _, ok := node.Len.(*ast.Ellipsis); !ok {
266			// [n]array
267			inf.constant(node.Len, targ.Len(), exact)
268		}
269	}
270	return node.Elt, targ.Elem(), true
271}
272
273// partially infer type of template function for a channel parameter
274func (inf *inferFuncType) chanType(node *ast.ChanType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
275	inf.is(node, targ, r.Chan)
276	tdir := targ.ChanDir()
277	dir := reflectChanDir(node.Dir)
278	if dir&tdir == 0 || (exact && dir != tdir) {
279		inf.fail(node, targ)
280	}
281	return node.Value, targ.Elem(), true
282}
283
284// partially infer type of template function for a constant parameter
285func (inf *inferFuncType) constant(node ast.Expr, val I, exact bool) {
286	// TODO
287	inf.comp.ErrorAt(node.Pos(), "unimplemented type inference: template function with parameter type %v and argument %v: %v",
288		node, val, inf.call)
289}
290
291// partially infer type of template function for a func parameter
292func (inf *inferFuncType) funcType(node *ast.FuncType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
293	// TODO
294	return inf.unimplemented(node, targ)
295}
296
297// partially infer type of template function for an identifier parameter
298func (inf *inferFuncType) ident(node *ast.Ident, targ xr.Type, exact bool) {
299	c := inf.comp
300	name := node.Name
301	inferred, ok := inf.inferred[name]
302	if !ok {
303		// name must be an existing type
304		t := c.TryResolveType(name)
305		if t != nil {
306			if !targ.AssignableTo(t) {
307				inf.comp.ErrorAt(node.Pos(),
308					"type inference: in %v, mismatched types for %v: %v cannot be assigned to %v: %v",
309					inf, name, targ, t, inf.call)
310			}
311		}
312		return
313	}
314
315	// inferring one of the function template parameters
316	inf.combine(node, &inferred, inferType{Type: targ, Exact: exact})
317	inf.inferred[name] = inferred
318
319}
320
321func (inf *inferFuncType) untyped(node ast.Expr, kind untyped.Kind, exact bool) {
322	ident, ok := node.(*ast.Ident)
323	if !ok {
324		inf.fail(node, kind)
325	}
326	inf.unimplemented(ident, kind)
327}
328
329func (inf *inferFuncType) combine(node ast.Expr, inferred *inferType, with inferType) {
330	targ := with.Type
331	exact := with.Exact
332	if inferred.Type == nil {
333		inferred.Type = targ
334	} else if !inferred.Type.IdenticalTo(targ) {
335		if exact && inferred.Exact {
336			inf.fail3(node, inferred, targ)
337		}
338		fwd := targ.AssignableTo(inferred.Type)
339		rev := inferred.Type.AssignableTo(targ)
340		if inferred.Exact {
341			if fwd {
342				inf.fail3(node, inferred, targ)
343			}
344		} else if exact {
345			if rev {
346				inferred.Type = targ
347			} else {
348				inf.fail3(node, inferred, targ)
349			}
350		} else {
351			if fwd && rev {
352				if !targ.Named() {
353					inferred.Type = targ
354				}
355			} else if fwd {
356			} else if rev {
357				inferred.Type = targ
358			} else {
359				inf.fail3(node, inferred, targ)
360			}
361		}
362	}
363	if exact {
364		inferred.Exact = true
365	}
366}
367
368// partially infer type of template function for an interface parameter
369func (inf *inferFuncType) interfaceType(node *ast.InterfaceType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
370	// TODO
371	return inf.unimplemented(node, targ)
372}
373
374// partially infer type of template function for a map parameter
375func (inf *inferFuncType) mapType(node *ast.MapType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
376	inf.is(node, targ, r.Map)
377	inf.arg(node.Key, targ.Key(), true)
378	return node.Value, targ.Elem(), true
379}
380
381// partially infer type of template function for an imported type
382func (inf *inferFuncType) selector(node *ast.SelectorExpr, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
383	// TODO
384	return inf.unimplemented(node, targ)
385}
386
387// partially infer type of template function for a struct parameter
388func (inf *inferFuncType) structType(node *ast.StructType, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
389	// TODO
390	return inf.unimplemented(node, targ)
391}
392
393// partially infer type of template function for a template parameter
394func (inf *inferFuncType) templateType(node *ast.IndexExpr, targ xr.Type, exact bool) (ast.Expr, xr.Type, bool) {
395	// TODO
396	return inf.unimplemented(node, targ)
397}
398
399func (inf *inferFuncType) is(node ast.Expr, targ xr.Type, kind r.Kind) {
400	if targ.Kind() != kind {
401		inf.fail(node, targ)
402	}
403}
404
405func (inf *inferFuncType) fail(node ast.Expr, targ I) {
406	inf.comp.ErrorAt(node.Pos(),
407		"type inference: in %v, parameter %v cannot match argument type %v: %v",
408		inf, node, targ, inf.call)
409}
410
411func (inf *inferFuncType) fail3(node ast.Expr, tinferred *inferType, targ xr.Type) {
412	inf.comp.ErrorAt(node.Pos(),
413		"type inference: in %v, parameter %v cannot match both %v and <%v>: %v",
414		inf, node, tinferred, targ, inf.call)
415}
416
417func (inf *inferFuncType) unimplemented(node ast.Expr, targ I) (ast.Expr, xr.Type, bool) {
418	inf.comp.ErrorAt(node.Pos(), "unimplemented type inference: in %v, parameter type %v with argument type %v: %v",
419		inf, node, targ, inf.call)
420	return nil, nil, false
421}
422
423var chandirs = map[ast.ChanDir]r.ChanDir{
424	ast.RECV:            r.RecvDir,
425	ast.SEND:            r.SendDir,
426	ast.RECV | ast.SEND: r.BothDir,
427}
428
429func reflectChanDir(dir ast.ChanDir) r.ChanDir {
430	return chandirs[dir]
431}
432