1// Resolve function calls and variable types 2 3package parser 4 5import ( 6 "fmt" 7 "reflect" 8 "sort" 9 10 . "github.com/benhoyt/goawk/internal/ast" 11 . "github.com/benhoyt/goawk/lexer" 12) 13 14type varType int 15 16const ( 17 typeUnknown varType = iota 18 typeScalar 19 typeArray 20) 21 22func (t varType) String() string { 23 switch t { 24 case typeScalar: 25 return "Scalar" 26 case typeArray: 27 return "Array" 28 default: 29 return "Unknown" 30 } 31} 32 33// typeInfo records type information for a single variable 34type typeInfo struct { 35 typ varType 36 ref *VarExpr 37 scope VarScope 38 index int 39 callName string 40 argIndex int 41} 42 43// Used by printVarTypes when debugTypes is turned on 44func (t typeInfo) String() string { 45 var scope string 46 switch t.scope { 47 case ScopeGlobal: 48 scope = "Global" 49 case ScopeLocal: 50 scope = "Local" 51 default: 52 scope = "Special" 53 } 54 return fmt.Sprintf("typ=%s ref=%p scope=%s index=%d callName=%q argIndex=%d", 55 t.typ, t.ref, scope, t.index, t.callName, t.argIndex) 56} 57 58// A single variable reference (normally scalar) 59type varRef struct { 60 funcName string 61 ref *VarExpr 62 isArg bool 63 pos Position 64} 65 66// A single array reference 67type arrayRef struct { 68 funcName string 69 ref *ArrayExpr 70 pos Position 71} 72 73// Initialize the resolver 74func (p *parser) initResolve() { 75 p.varTypes = make(map[string]map[string]typeInfo) 76 p.varTypes[""] = make(map[string]typeInfo) // globals 77 p.functions = make(map[string]int) 78 p.arrayRef("ARGV", Position{1, 1}) // interpreter relies on ARGV being present 79 p.multiExprs = make(map[*MultiExpr]Position, 3) 80} 81 82// Signal the start of a function 83func (p *parser) startFunction(name string, params []string) { 84 p.funcName = name 85 p.varTypes[name] = make(map[string]typeInfo) 86} 87 88// Signal the end of a function 89func (p *parser) stopFunction() { 90 p.funcName = "" 91} 92 93// Add function by name with given index 94func (p *parser) addFunction(name string, index int) { 95 p.functions[name] = index 96} 97 98// Records a call to a user function (for resolving indexes later) 99type userCall struct { 100 call *UserCallExpr 101 pos Position 102 inFunc string 103} 104 105// Record a user call site 106func (p *parser) recordUserCall(call *UserCallExpr, pos Position) { 107 p.userCalls = append(p.userCalls, userCall{call, pos, p.funcName}) 108} 109 110// After parsing, resolve all user calls to their indexes. Also 111// ensures functions called have actually been defined, and that 112// they're not being called with too many arguments. 113func (p *parser) resolveUserCalls(prog *Program) { 114 // Number the native funcs (order by name to get consistent order) 115 nativeNames := make([]string, 0, len(p.nativeFuncs)) 116 for name := range p.nativeFuncs { 117 nativeNames = append(nativeNames, name) 118 } 119 sort.Strings(nativeNames) 120 nativeIndexes := make(map[string]int, len(nativeNames)) 121 for i, name := range nativeNames { 122 nativeIndexes[name] = i 123 } 124 125 for _, c := range p.userCalls { 126 // AWK-defined functions take precedence over native Go funcs 127 index, ok := p.functions[c.call.Name] 128 if !ok { 129 f, haveNative := p.nativeFuncs[c.call.Name] 130 if !haveNative { 131 panic(&ParseError{c.pos, fmt.Sprintf("undefined function %q", c.call.Name)}) 132 } 133 typ := reflect.TypeOf(f) 134 if !typ.IsVariadic() && len(c.call.Args) > typ.NumIn() { 135 panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)}) 136 } 137 c.call.Native = true 138 c.call.Index = nativeIndexes[c.call.Name] 139 continue 140 } 141 function := prog.Functions[index] 142 if len(c.call.Args) > len(function.Params) { 143 panic(&ParseError{c.pos, fmt.Sprintf("%q called with more arguments than declared", c.call.Name)}) 144 } 145 c.call.Index = index 146 } 147} 148 149// For arguments that are variable references, we don't know the 150// type based on context, so mark the types for these as unknown. 151func (p *parser) processUserCallArg(funcName string, arg Expr, index int) { 152 if varExpr, ok := arg.(*VarExpr); ok { 153 scope, varFuncName := p.getScope(varExpr.Name) 154 ref := p.varTypes[varFuncName][varExpr.Name].ref 155 if ref == varExpr { 156 // Only applies if this is the first reference to this 157 // variable (otherwise we know the type already) 158 p.varTypes[varFuncName][varExpr.Name] = typeInfo{typeUnknown, ref, scope, 0, funcName, index} 159 } 160 // Mark the last related varRef (the most recent one) as a 161 // call argument for later error handling 162 p.varRefs[len(p.varRefs)-1].isArg = true 163 } 164} 165 166// Determine scope of given variable reference (and funcName if it's 167// a local, otherwise empty string) 168func (p *parser) getScope(name string) (VarScope, string) { 169 switch { 170 case p.locals[name]: 171 return ScopeLocal, p.funcName 172 case SpecialVarIndex(name) > 0: 173 return ScopeSpecial, "" 174 default: 175 return ScopeGlobal, "" 176 } 177} 178 179// Record a variable (scalar) reference and return the *VarExpr (but 180// VarExpr.Index won't be set till later) 181func (p *parser) varRef(name string, pos Position) *VarExpr { 182 scope, funcName := p.getScope(name) 183 expr := &VarExpr{scope, 0, name} 184 p.varRefs = append(p.varRefs, varRef{funcName, expr, false, pos}) 185 info := p.varTypes[funcName][name] 186 if info.typ == typeUnknown { 187 p.varTypes[funcName][name] = typeInfo{typeScalar, expr, scope, 0, info.callName, 0} 188 } 189 return expr 190} 191 192// Record an array reference and return the *ArrayExpr (but 193// ArrayExpr.Index won't be set till later) 194func (p *parser) arrayRef(name string, pos Position) *ArrayExpr { 195 scope, funcName := p.getScope(name) 196 if scope == ScopeSpecial { 197 panic(p.error("can't use scalar %q as array", name)) 198 } 199 expr := &ArrayExpr{scope, 0, name} 200 p.arrayRefs = append(p.arrayRefs, arrayRef{funcName, expr, pos}) 201 info := p.varTypes[funcName][name] 202 if info.typ == typeUnknown { 203 p.varTypes[funcName][name] = typeInfo{typeArray, nil, scope, 0, info.callName, 0} 204 } 205 return expr 206} 207 208// Print variable type information (for debugging) on p.debugWriter 209func (p *parser) printVarTypes(prog *Program) { 210 fmt.Fprintf(p.debugWriter, "scalars: %v\n", prog.Scalars) 211 fmt.Fprintf(p.debugWriter, "arrays: %v\n", prog.Arrays) 212 funcNames := []string{} 213 for funcName := range p.varTypes { 214 funcNames = append(funcNames, funcName) 215 } 216 sort.Strings(funcNames) 217 for _, funcName := range funcNames { 218 if funcName != "" { 219 fmt.Fprintf(p.debugWriter, "function %s\n", funcName) 220 } else { 221 fmt.Fprintf(p.debugWriter, "globals\n") 222 } 223 varNames := []string{} 224 for name := range p.varTypes[funcName] { 225 varNames = append(varNames, name) 226 } 227 sort.Strings(varNames) 228 for _, name := range varNames { 229 info := p.varTypes[funcName][name] 230 fmt.Fprintf(p.debugWriter, " %s: %s\n", name, info) 231 } 232 } 233} 234 235// If we can't finish resolving after this many iterations, give up 236const maxResolveIterations = 10000 237 238// Resolve unknown variables types and generate variable indexes and 239// name-to-index mappings for interpreter 240func (p *parser) resolveVars(prog *Program) { 241 // First go through all unknown types and try to determine the 242 // type from the parameter type in that function definition. May 243 // need multiple passes depending on the order of functions. This 244 // is not particularly efficient, but on realistic programs it's 245 // not an issue. 246 for i := 0; ; i++ { 247 progressed := false 248 for funcName, infos := range p.varTypes { 249 for name, info := range infos { 250 if info.scope == ScopeSpecial || info.typ != typeUnknown { 251 // It's a special var or type is already known 252 continue 253 } 254 funcIndex, ok := p.functions[info.callName] 255 if !ok { 256 // Function being called is a native function 257 continue 258 } 259 // Determine var type based on type of this parameter 260 // in the called function (if we know that) 261 paramName := prog.Functions[funcIndex].Params[info.argIndex] 262 typ := p.varTypes[info.callName][paramName].typ 263 if typ != typeUnknown { 264 if p.debugTypes { 265 fmt.Fprintf(p.debugWriter, "resolving %s:%s to %s\n", 266 funcName, name, typ) 267 } 268 info.typ = typ 269 p.varTypes[funcName][name] = info 270 progressed = true 271 } 272 } 273 } 274 if !progressed { 275 // If we didn't progress we're done (or trying again is 276 // not going to help) 277 break 278 } 279 if i >= maxResolveIterations { 280 panic(p.error("too many iterations trying to resolve variable types")) 281 } 282 } 283 284 // Resolve global variables (iteration order is undefined, so 285 // assign indexes basically randomly) 286 prog.Scalars = make(map[string]int) 287 prog.Arrays = make(map[string]int) 288 for name, info := range p.varTypes[""] { 289 _, isFunc := p.functions[name] 290 if isFunc { 291 // Global var can't also be the name of a function 292 panic(p.error("global var %q can't also be a function", name)) 293 } 294 var index int 295 if info.scope == ScopeSpecial { 296 index = SpecialVarIndex(name) 297 } else if info.typ == typeArray { 298 index = len(prog.Arrays) 299 prog.Arrays[name] = index 300 } else { 301 index = len(prog.Scalars) 302 prog.Scalars[name] = index 303 } 304 info.index = index 305 p.varTypes[""][name] = info 306 } 307 308 // Fill in unknown parameter types that are being called with arrays, 309 // for example, as in the following code: 310 // 311 // BEGIN { arr[0]; f(arr) } 312 // function f(a) { } 313 for _, c := range p.userCalls { 314 if c.call.Native { 315 continue 316 } 317 function := prog.Functions[c.call.Index] 318 for i, arg := range c.call.Args { 319 varExpr, ok := arg.(*VarExpr) 320 if !ok { 321 continue 322 } 323 funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc) 324 argType := p.varTypes[funcName][varExpr.Name] 325 paramType := p.varTypes[function.Name][function.Params[i]] 326 if argType.typ == typeArray && paramType.typ == typeUnknown { 327 paramType.typ = argType.typ 328 p.varTypes[function.Name][function.Params[i]] = paramType 329 } 330 } 331 } 332 333 // Resolve local variables (assign indexes in order of params). 334 // Also patch up Function.Arrays (tells interpreter which args 335 // are arrays). 336 for funcName, infos := range p.varTypes { 337 if funcName == "" { 338 continue 339 } 340 scalarIndex := 0 341 arrayIndex := 0 342 functionIndex := p.functions[funcName] 343 function := prog.Functions[functionIndex] 344 arrays := make([]bool, len(function.Params)) 345 for i, name := range function.Params { 346 info := infos[name] 347 var index int 348 if info.typ == typeArray { 349 index = arrayIndex 350 arrayIndex++ 351 arrays[i] = true 352 } else { 353 // typeScalar or typeUnknown: variables may still be 354 // of unknown type if they've never been referenced -- 355 // default to scalar in that case 356 index = scalarIndex 357 scalarIndex++ 358 } 359 info.index = index 360 p.varTypes[funcName][name] = info 361 } 362 prog.Functions[functionIndex].Arrays = arrays 363 } 364 365 // Check that variables passed to functions are the correct type 366 for _, c := range p.userCalls { 367 // Check native function calls 368 if c.call.Native { 369 for _, arg := range c.call.Args { 370 varExpr, ok := arg.(*VarExpr) 371 if !ok { 372 // Non-variable expression, must be scalar 373 continue 374 } 375 funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc) 376 info := p.varTypes[funcName][varExpr.Name] 377 if info.typ == typeArray { 378 message := fmt.Sprintf("can't pass array %q to native function", varExpr.Name) 379 panic(&ParseError{c.pos, message}) 380 } 381 } 382 continue 383 } 384 385 // Check AWK function calls 386 function := prog.Functions[c.call.Index] 387 for i, arg := range c.call.Args { 388 varExpr, ok := arg.(*VarExpr) 389 if !ok { 390 if function.Arrays[i] { 391 message := fmt.Sprintf("can't pass scalar %s as array param", arg) 392 panic(&ParseError{c.pos, message}) 393 } 394 continue 395 } 396 funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc) 397 info := p.varTypes[funcName][varExpr.Name] 398 if info.typ == typeArray && !function.Arrays[i] { 399 message := fmt.Sprintf("can't pass array %q as scalar param", varExpr.Name) 400 panic(&ParseError{c.pos, message}) 401 } 402 if info.typ != typeArray && function.Arrays[i] { 403 message := fmt.Sprintf("can't pass scalar %q as array param", varExpr.Name) 404 panic(&ParseError{c.pos, message}) 405 } 406 } 407 } 408 409 if p.debugTypes { 410 p.printVarTypes(prog) 411 } 412 413 // Patch up variable indexes (interpreter uses an index instead 414 // of name for more efficient lookups) 415 for _, varRef := range p.varRefs { 416 info := p.varTypes[varRef.funcName][varRef.ref.Name] 417 if info.typ == typeArray && !varRef.isArg { 418 message := fmt.Sprintf("can't use array %q as scalar", varRef.ref.Name) 419 panic(&ParseError{varRef.pos, message}) 420 } 421 varRef.ref.Index = info.index 422 } 423 for _, arrayRef := range p.arrayRefs { 424 info := p.varTypes[arrayRef.funcName][arrayRef.ref.Name] 425 if info.typ == typeScalar { 426 message := fmt.Sprintf("can't use scalar %q as array", arrayRef.ref.Name) 427 panic(&ParseError{arrayRef.pos, message}) 428 } 429 arrayRef.ref.Index = info.index 430 } 431} 432 433// If name refers to a local (in function inFunc), return that 434// function's name, otherwise return "" (meaning global). 435func (p *parser) getVarFuncName(prog *Program, name, inFunc string) string { 436 if inFunc == "" { 437 return "" 438 } 439 for _, param := range prog.Functions[p.functions[inFunc]].Params { 440 if name == param { 441 return inFunc 442 } 443 } 444 return "" 445} 446 447// Record a "multi expression" (comma-separated pseudo-expression 448// used to allow commas around print/printf arguments). 449func (p *parser) multiExpr(exprs []Expr, pos Position) Expr { 450 expr := &MultiExpr{exprs} 451 p.multiExprs[expr] = pos 452 return expr 453} 454 455// Mark the multi expression as used (by a print/printf statement). 456func (p *parser) useMultiExpr(expr *MultiExpr) { 457 delete(p.multiExprs, expr) 458} 459 460// Check that there are no unused multi expressions (syntax error). 461func (p *parser) checkMultiExprs() { 462 if len(p.multiExprs) == 0 { 463 return 464 } 465 // Show error on first comma-separated expression 466 min := Position{1000000000, 1000000000} 467 for _, pos := range p.multiExprs { 468 if pos.Line < min.Line || (pos.Line == min.Line && pos.Column < min.Column) { 469 min = pos 470 } 471 } 472 message := fmt.Sprintf("unexpected comma-separated expression") 473 panic(&ParseError{min, message}) 474} 475