1package source // import "gotest.tools/v3/internal/source"
2
3import (
4	"bytes"
5	"fmt"
6	"go/ast"
7	"go/format"
8	"go/parser"
9	"go/token"
10	"os"
11	"runtime"
12	"strconv"
13	"strings"
14
15	"github.com/pkg/errors"
16)
17
18const baseStackIndex = 1
19
20// FormattedCallExprArg returns the argument from an ast.CallExpr at the
21// index in the call stack. The argument is formatted using FormatNode.
22func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
23	args, err := CallExprArgs(stackIndex + 1)
24	if err != nil {
25		return "", err
26	}
27	if argPos >= len(args) {
28		return "", errors.New("failed to find expression")
29	}
30	return FormatNode(args[argPos])
31}
32
33// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
34// the index in the call stack.
35func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
36	_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
37	if !ok {
38		return nil, errors.New("failed to get call stack")
39	}
40	debug("call stack position: %s:%d", filename, lineNum)
41
42	node, err := getNodeAtLine(filename, lineNum)
43	if err != nil {
44		return nil, err
45	}
46	debug("found node: %s", debugFormatNode{node})
47
48	return getCallExprArgs(node)
49}
50
51func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
52	fileset := token.NewFileSet()
53	astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
54	if err != nil {
55		return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
56	}
57
58	if node := scanToLine(fileset, astFile, lineNum); node != nil {
59		return node, nil
60	}
61	if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
62		node, err := guessDefer(node)
63		if err != nil || node != nil {
64			return node, err
65		}
66	}
67	return nil, errors.Errorf(
68		"failed to find an expression on line %d in %s", lineNum, filename)
69}
70
71func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
72	var matchedNode ast.Node
73	ast.Inspect(node, func(node ast.Node) bool {
74		switch {
75		case node == nil || matchedNode != nil:
76			return false
77		case nodePosition(fileset, node).Line == lineNum:
78			matchedNode = node
79			return false
80		}
81		return true
82	})
83	return matchedNode
84}
85
86// In golang 1.9 the line number changed from being the line where the statement
87// ended to the line where the statement began.
88func nodePosition(fileset *token.FileSet, node ast.Node) token.Position {
89	if goVersionBefore19 {
90		return fileset.Position(node.End())
91	}
92	return fileset.Position(node.Pos())
93}
94
95// GoVersionLessThan returns true if runtime.Version() is semantically less than
96// version 1.minor.
97func GoVersionLessThan(minor int64) bool {
98	version := runtime.Version()
99	// not a release version
100	if !strings.HasPrefix(version, "go") {
101		return false
102	}
103	version = strings.TrimPrefix(version, "go")
104	parts := strings.Split(version, ".")
105	if len(parts) < 2 {
106		return false
107	}
108	actual, err := strconv.ParseInt(parts[1], 10, 32)
109	return err == nil && parts[0] == "1" && actual < minor
110}
111
112var goVersionBefore19 = GoVersionLessThan(9)
113
114func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
115	visitor := &callExprVisitor{}
116	ast.Walk(visitor, node)
117	if visitor.expr == nil {
118		return nil, errors.New("failed to find call expression")
119	}
120	debug("callExpr: %s", debugFormatNode{visitor.expr})
121	return visitor.expr.Args, nil
122}
123
124type callExprVisitor struct {
125	expr *ast.CallExpr
126}
127
128func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
129	if v.expr != nil || node == nil {
130		return nil
131	}
132	debug("visit: %s", debugFormatNode{node})
133
134	switch typed := node.(type) {
135	case *ast.CallExpr:
136		v.expr = typed
137		return nil
138	case *ast.DeferStmt:
139		ast.Walk(v, typed.Call.Fun)
140		return nil
141	}
142	return v
143}
144
145// FormatNode using go/format.Node and return the result as a string
146func FormatNode(node ast.Node) (string, error) {
147	buf := new(bytes.Buffer)
148	err := format.Node(buf, token.NewFileSet(), node)
149	return buf.String(), err
150}
151
152var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
153
154func debug(format string, args ...interface{}) {
155	if debugEnabled {
156		fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
157	}
158}
159
160type debugFormatNode struct {
161	ast.Node
162}
163
164func (n debugFormatNode) String() string {
165	out, err := FormatNode(n.Node)
166	if err != nil {
167		return fmt.Sprintf("failed to format %s: %s", n.Node, err)
168	}
169	return fmt.Sprintf("(%T) %s", n.Node, out)
170}
171