1// Resolve function calls and variable types
2
3package parser
4
5import (
6	"fmt"
7	"reflect"
8	"sort"
9
10	. "github.com/benhoyt/goawk/internal/ast"
11	. "github.com/benhoyt/goawk/lexer"
12)
13
14type varType int
15
16const (
17	typeUnknown varType = iota
18	typeScalar
19	typeArray
20)
21
22func (t varType) String() string {
23	switch t {
24	case typeScalar:
25		return "Scalar"
26	case typeArray:
27		return "Array"
28	default:
29		return "Unknown"
30	}
31}
32
33// typeInfo records type information for a single variable
34type typeInfo struct {
35	typ      varType
36	ref      *VarExpr
37	scope    VarScope
38	index    int
39	callName string
40	argIndex int
41}
42
43// Used by printVarTypes when debugTypes is turned on
44func (t typeInfo) String() string {
45	var scope string
46	switch t.scope {
47	case ScopeGlobal:
48		scope = "Global"
49	case ScopeLocal:
50		scope = "Local"
51	default:
52		scope = "Special"
53	}
54	return fmt.Sprintf("typ=%s ref=%p scope=%s index=%d callName=%q argIndex=%d",
55		t.typ, t.ref, scope, t.index, t.callName, t.argIndex)
56}
57
58// A single variable reference (normally scalar)
59type varRef struct {
60	funcName string
61	ref      *VarExpr
62	isArg    bool
63	pos      Position
64}
65
66// A single array reference
67type arrayRef struct {
68	funcName string
69	ref      *ArrayExpr
70	pos      Position
71}
72
73// Initialize the resolver
74func (p *parser) initResolve() {
75	p.varTypes = make(map[string]map[string]typeInfo)
76	p.varTypes[""] = make(map[string]typeInfo) // globals
77	p.functions = make(map[string]int)
78	p.arrayRef("ARGV", Position{1, 1}) // interpreter relies on ARGV being present
79	p.multiExprs = make(map[*MultiExpr]Position, 3)
80}
81
82// Signal the start of a function
83func (p *parser) startFunction(name string, params []string) {
84	p.funcName = name
85	p.varTypes[name] = make(map[string]typeInfo)
86}
87
88// Signal the end of a function
89func (p *parser) stopFunction() {
90	p.funcName = ""
91}
92
93// Add function by name with given index
94func (p *parser) addFunction(name string, index int) {
95	p.functions[name] = index
96}
97
98// Records a call to a user function (for resolving indexes later)
99type userCall struct {
100	call   *UserCallExpr
101	pos    Position
102	inFunc string
103}
104
105// Record a user call site
106func (p *parser) recordUserCall(call *UserCallExpr, pos Position) {
107	p.userCalls = append(p.userCalls, userCall{call, pos, p.funcName})
108}
109
110// After parsing, resolve all user calls to their indexes. Also
111// ensures functions called have actually been defined, and that
112// they're not being called with too many arguments.
113func (p *parser) resolveUserCalls(prog *Program) {
114	// Number the native funcs (order by name to get consistent order)
115	nativeNames := make([]string, 0, len(p.nativeFuncs))
116	for name := range p.nativeFuncs {
117		nativeNames = append(nativeNames, name)
118	}
119	sort.Strings(nativeNames)
120	nativeIndexes := make(map[string]int, len(nativeNames))
121	for i, name := range nativeNames {
122		nativeIndexes[name] = i
123	}
124
125	for _, c := range p.userCalls {
126		// AWK-defined functions take precedence over native Go funcs
127		index, ok := p.functions[c.call.Name]
128		if !ok {
129			f, haveNative := p.nativeFuncs[c.call.Name]
130			if !haveNative {
131				panic(&ParseError{c.pos, fmt.Sprintf("undefined function %q", c.call.Name)})
132			}
133			typ := reflect.TypeOf(f)
134			if !typ.IsVariadic() && len(c.call.Args) > typ.NumIn() {
135				panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)})
136			}
137			c.call.Native = true
138			c.call.Index = nativeIndexes[c.call.Name]
139			continue
140		}
141		function := prog.Functions[index]
142		if len(c.call.Args) > len(function.Params) {
143			panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)})
144		}
145		c.call.Index = index
146	}
147}
148
149// For arguments that are variable references, we don't know the
150// type based on context, so mark the types for these as unknown.
151func (p *parser) processUserCallArg(funcName string, arg Expr, index int) {
152	if varExpr, ok := arg.(*VarExpr); ok {
153		scope, varFuncName := p.getScope(varExpr.Name)
154		ref := p.varTypes[varFuncName][varExpr.Name].ref
155		if ref == varExpr {
156			// Only applies if this is the first reference to this
157			// variable (otherwise we know the type already)
158			p.varTypes[varFuncName][varExpr.Name] = typeInfo{typeUnknown, ref, scope, 0, funcName, index}
159		}
160		// Mark the last related varRef (the most recent one) as a
161		// call argument for later error handling
162		p.varRefs[len(p.varRefs)-1].isArg = true
163	}
164}
165
166// Determine scope of given variable reference (and funcName if it's
167// a local, otherwise empty string)
168func (p *parser) getScope(name string) (VarScope, string) {
169	switch {
170	case p.locals[name]:
171		return ScopeLocal, p.funcName
172	case SpecialVarIndex(name) > 0:
173		return ScopeSpecial, ""
174	default:
175		return ScopeGlobal, ""
176	}
177}
178
179// Record a variable (scalar) reference and return the *VarExpr (but
180// VarExpr.Index won't be set till later)
181func (p *parser) varRef(name string, pos Position) *VarExpr {
182	scope, funcName := p.getScope(name)
183	expr := &VarExpr{scope, 0, name}
184	p.varRefs = append(p.varRefs, varRef{funcName, expr, false, pos})
185	info := p.varTypes[funcName][name]
186	if info.typ == typeUnknown {
187		p.varTypes[funcName][name] = typeInfo{typeScalar, expr, scope, 0, info.callName, 0}
188	}
189	return expr
190}
191
192// Record an array reference and return the *ArrayExpr (but
193// ArrayExpr.Index won't be set till later)
194func (p *parser) arrayRef(name string, pos Position) *ArrayExpr {
195	scope, funcName := p.getScope(name)
196	if scope == ScopeSpecial {
197		panic(p.error("can't use scalar %q as array", name))
198	}
199	expr := &ArrayExpr{scope, 0, name}
200	p.arrayRefs = append(p.arrayRefs, arrayRef{funcName, expr, pos})
201	info := p.varTypes[funcName][name]
202	if info.typ == typeUnknown {
203		p.varTypes[funcName][name] = typeInfo{typeArray, nil, scope, 0, info.callName, 0}
204	}
205	return expr
206}
207
208// Print variable type information (for debugging) on p.debugWriter
209func (p *parser) printVarTypes(prog *Program) {
210	fmt.Fprintf(p.debugWriter, "scalars: %v\n", prog.Scalars)
211	fmt.Fprintf(p.debugWriter, "arrays: %v\n", prog.Arrays)
212	funcNames := []string{}
213	for funcName := range p.varTypes {
214		funcNames = append(funcNames, funcName)
215	}
216	sort.Strings(funcNames)
217	for _, funcName := range funcNames {
218		if funcName != "" {
219			fmt.Fprintf(p.debugWriter, "function %s\n", funcName)
220		} else {
221			fmt.Fprintf(p.debugWriter, "globals\n")
222		}
223		varNames := []string{}
224		for name := range p.varTypes[funcName] {
225			varNames = append(varNames, name)
226		}
227		sort.Strings(varNames)
228		for _, name := range varNames {
229			info := p.varTypes[funcName][name]
230			fmt.Fprintf(p.debugWriter, "  %s: %s\n", name, info)
231		}
232	}
233}
234
235// If we can't finish resolving after this many iterations, give up
236const maxResolveIterations = 10000
237
238// Resolve unknown variables types and generate variable indexes and
239// name-to-index mappings for interpreter
240func (p *parser) resolveVars(prog *Program) {
241	// First go through all unknown types and try to determine the
242	// type from the parameter type in that function definition. May
243	// need multiple passes depending on the order of functions. This
244	// is not particularly efficient, but on realistic programs it's
245	// not an issue.
246	for i := 0; ; i++ {
247		progressed := false
248		for funcName, infos := range p.varTypes {
249			for name, info := range infos {
250				if info.scope == ScopeSpecial || info.typ != typeUnknown {
251					// It's a special var or type is already known
252					continue
253				}
254				funcIndex, ok := p.functions[info.callName]
255				if !ok {
256					// Function being called is a native function
257					continue
258				}
259				// Determine var type based on type of this parameter
260				// in the called function (if we know that)
261				paramName := prog.Functions[funcIndex].Params[info.argIndex]
262				typ := p.varTypes[info.callName][paramName].typ
263				if typ != typeUnknown {
264					if p.debugTypes {
265						fmt.Fprintf(p.debugWriter, "resolving %s:%s to %s\n",
266							funcName, name, typ)
267					}
268					info.typ = typ
269					p.varTypes[funcName][name] = info
270					progressed = true
271				}
272			}
273		}
274		if !progressed {
275			// If we didn't progress we're done (or trying again is
276			// not going to help)
277			break
278		}
279		if i >= maxResolveIterations {
280			panic(p.error("too many iterations trying to resolve variable types"))
281		}
282	}
283
284	// Resolve global variables (iteration order is undefined, so
285	// assign indexes basically randomly)
286	prog.Scalars = make(map[string]int)
287	prog.Arrays = make(map[string]int)
288	for name, info := range p.varTypes[""] {
289		_, isFunc := p.functions[name]
290		if isFunc {
291			// Global var can't also be the name of a function
292			panic(p.error("global var %q can't also be a function", name))
293		}
294		var index int
295		if info.scope == ScopeSpecial {
296			index = SpecialVarIndex(name)
297		} else if info.typ == typeArray {
298			index = len(prog.Arrays)
299			prog.Arrays[name] = index
300		} else {
301			index = len(prog.Scalars)
302			prog.Scalars[name] = index
303		}
304		info.index = index
305		p.varTypes[""][name] = info
306	}
307
308	// Fill in unknown parameter types that are being called with arrays,
309	// for example, as in the following code:
310	//
311	// BEGIN { arr[0]; f(arr) }
312	// function f(a) { }
313	for _, c := range p.userCalls {
314		if c.call.Native {
315			continue
316		}
317		function := prog.Functions[c.call.Index]
318		for i, arg := range c.call.Args {
319			varExpr, ok := arg.(*VarExpr)
320			if !ok {
321				continue
322			}
323			funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
324			argType := p.varTypes[funcName][varExpr.Name]
325			paramType := p.varTypes[function.Name][function.Params[i]]
326			if argType.typ == typeArray && paramType.typ == typeUnknown {
327				paramType.typ = argType.typ
328				p.varTypes[function.Name][function.Params[i]] = paramType
329			}
330		}
331	}
332
333	// Resolve local variables (assign indexes in order of params).
334	// Also patch up Function.Arrays (tells interpreter which args
335	// are arrays).
336	for funcName, infos := range p.varTypes {
337		if funcName == "" {
338			continue
339		}
340		scalarIndex := 0
341		arrayIndex := 0
342		functionIndex := p.functions[funcName]
343		function := prog.Functions[functionIndex]
344		arrays := make([]bool, len(function.Params))
345		for i, name := range function.Params {
346			info := infos[name]
347			var index int
348			if info.typ == typeArray {
349				index = arrayIndex
350				arrayIndex++
351				arrays[i] = true
352			} else {
353				// typeScalar or typeUnknown: variables may still be
354				// of unknown type if they've never been referenced --
355				// default to scalar in that case
356				index = scalarIndex
357				scalarIndex++
358			}
359			info.index = index
360			p.varTypes[funcName][name] = info
361		}
362		prog.Functions[functionIndex].Arrays = arrays
363	}
364
365	// Check that variables passed to functions are the correct type
366	for _, c := range p.userCalls {
367		// Check native function calls
368		if c.call.Native {
369			for _, arg := range c.call.Args {
370				varExpr, ok := arg.(*VarExpr)
371				if !ok {
372					// Non-variable expression, must be scalar
373					continue
374				}
375				funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
376				info := p.varTypes[funcName][varExpr.Name]
377				if info.typ == typeArray {
378					message := fmt.Sprintf("can't pass array %q to native function", varExpr.Name)
379					panic(&ParseError{c.pos, message})
380				}
381			}
382			continue
383		}
384
385		// Check AWK function calls
386		function := prog.Functions[c.call.Index]
387		for i, arg := range c.call.Args {
388			varExpr, ok := arg.(*VarExpr)
389			if !ok {
390				if function.Arrays[i] {
391					message := fmt.Sprintf("can't pass scalar %s as array param", arg)
392					panic(&ParseError{c.pos, message})
393				}
394				continue
395			}
396			funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
397			info := p.varTypes[funcName][varExpr.Name]
398			if info.typ == typeArray && !function.Arrays[i] {
399				message := fmt.Sprintf("can't pass array %q as scalar param", varExpr.Name)
400				panic(&ParseError{c.pos, message})
401			}
402			if info.typ != typeArray && function.Arrays[i] {
403				message := fmt.Sprintf("can't pass scalar %q as array param", varExpr.Name)
404				panic(&ParseError{c.pos, message})
405			}
406		}
407	}
408
409	if p.debugTypes {
410		p.printVarTypes(prog)
411	}
412
413	// Patch up variable indexes (interpreter uses an index instead
414	// of name for more efficient lookups)
415	for _, varRef := range p.varRefs {
416		info := p.varTypes[varRef.funcName][varRef.ref.Name]
417		if info.typ == typeArray && !varRef.isArg {
418			message := fmt.Sprintf("can't use array %q as scalar", varRef.ref.Name)
419			panic(&ParseError{varRef.pos, message})
420		}
421		varRef.ref.Index = info.index
422	}
423	for _, arrayRef := range p.arrayRefs {
424		info := p.varTypes[arrayRef.funcName][arrayRef.ref.Name]
425		if info.typ == typeScalar {
426			message := fmt.Sprintf("can't use scalar %q as array", arrayRef.ref.Name)
427			panic(&ParseError{arrayRef.pos, message})
428		}
429		arrayRef.ref.Index = info.index
430	}
431}
432
433// If name refers to a local (in function inFunc), return that
434// function's name, otherwise return "" (meaning global).
435func (p *parser) getVarFuncName(prog *Program, name, inFunc string) string {
436	if inFunc == "" {
437		return ""
438	}
439	for _, param := range prog.Functions[p.functions[inFunc]].Params {
440		if name == param {
441			return inFunc
442		}
443	}
444	return ""
445}
446
447// Record a "multi expression" (comma-separated pseudo-expression
448// used to allow commas around print/printf arguments).
449func (p *parser) multiExpr(exprs []Expr, pos Position) Expr {
450	expr := &MultiExpr{exprs}
451	p.multiExprs[expr] = pos
452	return expr
453}
454
455// Mark the multi expression as used (by a print/printf statement).
456func (p *parser) useMultiExpr(expr *MultiExpr) {
457	delete(p.multiExprs, expr)
458}
459
460// Check that there are no unused multi expressions (syntax error).
461func (p *parser) checkMultiExprs() {
462	if len(p.multiExprs) == 0 {
463		return
464	}
465	// Show error on first comma-separated expression
466	min := Position{1000000000, 1000000000}
467	for _, pos := range p.multiExprs {
468		if pos.Line < min.Line || (pos.Line == min.Line && pos.Column < min.Column) {
469			min = pos
470		}
471	}
472	message := fmt.Sprintf("unexpected comma-separated expression")
473	panic(&ParseError{min, message})
474}
475