1// Copyright 2012 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5/* 6This file contains the code to check range loop variables bound inside function 7literals that are deferred or launched in new goroutines. We only check 8instances where the defer or go statement is the last statement in the loop 9body, as otherwise we would need whole program analysis. 10 11For example: 12 13 for i, v := range s { 14 go func() { 15 println(i, v) // not what you might expect 16 }() 17 } 18 19See: https://golang.org/doc/go_faq.html#closures_and_goroutines 20*/ 21 22package main 23 24import "go/ast" 25 26func init() { 27 register("rangeloops", 28 "check that loop variables are used correctly", 29 checkLoop, 30 rangeStmt, forStmt) 31} 32 33// checkLoop walks the body of the provided loop statement, checking whether 34// its index or value variables are used unsafely inside goroutines or deferred 35// function literals. 36func checkLoop(f *File, node ast.Node) { 37 // Find the variables updated by the loop statement. 38 var vars []*ast.Ident 39 addVar := func(expr ast.Expr) { 40 if id, ok := expr.(*ast.Ident); ok { 41 vars = append(vars, id) 42 } 43 } 44 var body *ast.BlockStmt 45 switch n := node.(type) { 46 case *ast.RangeStmt: 47 body = n.Body 48 addVar(n.Key) 49 addVar(n.Value) 50 case *ast.ForStmt: 51 body = n.Body 52 switch post := n.Post.(type) { 53 case *ast.AssignStmt: 54 // e.g. for p = head; p != nil; p = p.next 55 for _, lhs := range post.Lhs { 56 addVar(lhs) 57 } 58 case *ast.IncDecStmt: 59 // e.g. for i := 0; i < n; i++ 60 addVar(post.X) 61 } 62 } 63 if vars == nil { 64 return 65 } 66 67 // Inspect a go or defer statement 68 // if it's the last one in the loop body. 69 // (We give up if there are following statements, 70 // because it's hard to prove go isn't followed by wait, 71 // or defer by return.) 72 if len(body.List) == 0 { 73 return 74 } 75 var last *ast.CallExpr 76 switch s := body.List[len(body.List)-1].(type) { 77 case *ast.GoStmt: 78 last = s.Call 79 case *ast.DeferStmt: 80 last = s.Call 81 default: 82 return 83 } 84 lit, ok := last.Fun.(*ast.FuncLit) 85 if !ok { 86 return 87 } 88 ast.Inspect(lit.Body, func(n ast.Node) bool { 89 id, ok := n.(*ast.Ident) 90 if !ok || id.Obj == nil { 91 return true 92 } 93 if f.pkg.types[id].Type == nil { 94 // Not referring to a variable (e.g. struct field name) 95 return true 96 } 97 for _, v := range vars { 98 if v.Obj == id.Obj { 99 f.Badf(id.Pos(), "loop variable %s captured by func literal", 100 id.Name) 101 } 102 } 103 return true 104 }) 105} 106