1package source // import "gotest.tools/v3/internal/source" 2 3import ( 4 "bytes" 5 "fmt" 6 "go/ast" 7 "go/format" 8 "go/parser" 9 "go/token" 10 "os" 11 "runtime" 12 "strconv" 13 "strings" 14 15 "github.com/pkg/errors" 16) 17 18const baseStackIndex = 1 19 20// FormattedCallExprArg returns the argument from an ast.CallExpr at the 21// index in the call stack. The argument is formatted using FormatNode. 22func FormattedCallExprArg(stackIndex int, argPos int) (string, error) { 23 args, err := CallExprArgs(stackIndex + 1) 24 if err != nil { 25 return "", err 26 } 27 if argPos >= len(args) { 28 return "", errors.New("failed to find expression") 29 } 30 return FormatNode(args[argPos]) 31} 32 33// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at 34// the index in the call stack. 35func CallExprArgs(stackIndex int) ([]ast.Expr, error) { 36 _, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex) 37 if !ok { 38 return nil, errors.New("failed to get call stack") 39 } 40 debug("call stack position: %s:%d", filename, lineNum) 41 42 node, err := getNodeAtLine(filename, lineNum) 43 if err != nil { 44 return nil, err 45 } 46 debug("found node: %s", debugFormatNode{node}) 47 48 return getCallExprArgs(node) 49} 50 51func getNodeAtLine(filename string, lineNum int) (ast.Node, error) { 52 fileset := token.NewFileSet() 53 astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors) 54 if err != nil { 55 return nil, errors.Wrapf(err, "failed to parse source file: %s", filename) 56 } 57 58 if node := scanToLine(fileset, astFile, lineNum); node != nil { 59 return node, nil 60 } 61 if node := scanToDeferLine(fileset, astFile, lineNum); node != nil { 62 node, err := guessDefer(node) 63 if err != nil || node != nil { 64 return node, err 65 } 66 } 67 return nil, errors.Errorf( 68 "failed to find an expression on line %d in %s", lineNum, filename) 69} 70 71func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node { 72 var matchedNode ast.Node 73 ast.Inspect(node, func(node ast.Node) bool { 74 switch { 75 case node == nil || matchedNode != nil: 76 return false 77 case nodePosition(fileset, node).Line == lineNum: 78 matchedNode = node 79 return false 80 } 81 return true 82 }) 83 return matchedNode 84} 85 86// In golang 1.9 the line number changed from being the line where the statement 87// ended to the line where the statement began. 88func nodePosition(fileset *token.FileSet, node ast.Node) token.Position { 89 if goVersionBefore19 { 90 return fileset.Position(node.End()) 91 } 92 return fileset.Position(node.Pos()) 93} 94 95// GoVersionLessThan returns true if runtime.Version() is semantically less than 96// version 1.minor. 97func GoVersionLessThan(minor int64) bool { 98 version := runtime.Version() 99 // not a release version 100 if !strings.HasPrefix(version, "go") { 101 return false 102 } 103 version = strings.TrimPrefix(version, "go") 104 parts := strings.Split(version, ".") 105 if len(parts) < 2 { 106 return false 107 } 108 actual, err := strconv.ParseInt(parts[1], 10, 32) 109 return err == nil && parts[0] == "1" && actual < minor 110} 111 112var goVersionBefore19 = GoVersionLessThan(9) 113 114func getCallExprArgs(node ast.Node) ([]ast.Expr, error) { 115 visitor := &callExprVisitor{} 116 ast.Walk(visitor, node) 117 if visitor.expr == nil { 118 return nil, errors.New("failed to find call expression") 119 } 120 debug("callExpr: %s", debugFormatNode{visitor.expr}) 121 return visitor.expr.Args, nil 122} 123 124type callExprVisitor struct { 125 expr *ast.CallExpr 126} 127 128func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor { 129 if v.expr != nil || node == nil { 130 return nil 131 } 132 debug("visit: %s", debugFormatNode{node}) 133 134 switch typed := node.(type) { 135 case *ast.CallExpr: 136 v.expr = typed 137 return nil 138 case *ast.DeferStmt: 139 ast.Walk(v, typed.Call.Fun) 140 return nil 141 } 142 return v 143} 144 145// FormatNode using go/format.Node and return the result as a string 146func FormatNode(node ast.Node) (string, error) { 147 buf := new(bytes.Buffer) 148 err := format.Node(buf, token.NewFileSet(), node) 149 return buf.String(), err 150} 151 152var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != "" 153 154func debug(format string, args ...interface{}) { 155 if debugEnabled { 156 fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...) 157 } 158} 159 160type debugFormatNode struct { 161 ast.Node 162} 163 164func (n debugFormatNode) String() string { 165 out, err := FormatNode(n.Node) 166 if err != nil { 167 return fmt.Sprintf("failed to format %s: %s", n.Node, err) 168 } 169 return fmt.Sprintf("(%T) %s", n.Node, out) 170} 171