1package convert
2
3import (
4	"bytes"
5	"fmt"
6	"go/ast"
7	"go/format"
8	"go/parser"
9	"go/token"
10	"io/ioutil"
11	"os"
12)
13
14/*
15 * Given a file path, rewrites any tests in the Ginkgo format.
16 * First, we parse the AST, and update the imports declaration.
17 * Then, we walk the first child elements in the file, returning tests to rewrite.
18 * A top level init func is declared, with a single Describe func inside.
19 * Then the test functions to rewrite are inserted as It statements inside the Describe.
20 * Finally we walk the rest of the file, replacing other usages of *testing.T
21 * Once that is complete, we write the AST back out again to its file.
22 */
23func rewriteTestsInFile(pathToFile string) {
24	fileSet := token.NewFileSet()
25	rootNode, err := parser.ParseFile(fileSet, pathToFile, nil, 0)
26	if err != nil {
27		panic(fmt.Sprintf("Error parsing test file '%s':\n%s\n", pathToFile, err.Error()))
28	}
29
30	addGinkgoImports(rootNode)
31	removeTestingImport(rootNode)
32
33	varUnderscoreBlock := createVarUnderscoreBlock()
34	describeBlock := createDescribeBlock()
35	varUnderscoreBlock.Values = []ast.Expr{describeBlock}
36
37	for _, testFunc := range findTestFuncs(rootNode) {
38		rewriteTestFuncAsItStatement(testFunc, rootNode, describeBlock)
39	}
40
41	underscoreDecl := &ast.GenDecl{
42		Tok:    85, // gah, magick numbers are needed to make this work
43		TokPos: 14, // this tricks Go into writing "var _ = Describe"
44		Specs:  []ast.Spec{varUnderscoreBlock},
45	}
46
47	imports := rootNode.Decls[0]
48	tail := rootNode.Decls[1:]
49	rootNode.Decls = append(append([]ast.Decl{imports}, underscoreDecl), tail...)
50	rewriteOtherFuncsToUseGinkgoT(rootNode.Decls)
51	walkNodesInRootNodeReplacingTestingT(rootNode)
52
53	var buffer bytes.Buffer
54	if err = format.Node(&buffer, fileSet, rootNode); err != nil {
55		panic(fmt.Sprintf("Error formatting ast node after rewriting tests.\n%s\n", err.Error()))
56	}
57
58	fileInfo, err := os.Stat(pathToFile)
59	if err != nil {
60		panic(fmt.Sprintf("Error stat'ing file: %s\n", pathToFile))
61	}
62
63	ioutil.WriteFile(pathToFile, buffer.Bytes(), fileInfo.Mode())
64	return
65}
66
67/*
68 * Given a test func named TestDoesSomethingNeat, rewrites it as
69 * It("does something neat", func() { __test_body_here__ }) and adds it
70 * to the Describe's list of statements
71 */
72func rewriteTestFuncAsItStatement(testFunc *ast.FuncDecl, rootNode *ast.File, describe *ast.CallExpr) {
73	var funcIndex int = -1
74	for index, child := range rootNode.Decls {
75		if child == testFunc {
76			funcIndex = index
77			break
78		}
79	}
80
81	if funcIndex < 0 {
82		panic(fmt.Sprintf("Assert failed: Error finding index for test node %s\n", testFunc.Name.Name))
83	}
84
85	var block *ast.BlockStmt = blockStatementFromDescribe(describe)
86	block.List = append(block.List, createItStatementForTestFunc(testFunc))
87	replaceTestingTsWithGinkgoT(block, namedTestingTArg(testFunc))
88
89	// remove the old test func from the root node's declarations
90	rootNode.Decls = append(rootNode.Decls[:funcIndex], rootNode.Decls[funcIndex+1:]...)
91	return
92}
93
94/*
95 * walks nodes inside of a test func's statements and replaces the usage of
96 * it's named *testing.T param with GinkgoT's
97 */
98func replaceTestingTsWithGinkgoT(statementsBlock *ast.BlockStmt, testingT string) {
99	ast.Inspect(statementsBlock, func(node ast.Node) bool {
100		if node == nil {
101			return false
102		}
103
104		keyValueExpr, ok := node.(*ast.KeyValueExpr)
105		if ok {
106			replaceNamedTestingTsInKeyValueExpression(keyValueExpr, testingT)
107			return true
108		}
109
110		funcLiteral, ok := node.(*ast.FuncLit)
111		if ok {
112			replaceTypeDeclTestingTsInFuncLiteral(funcLiteral)
113			return true
114		}
115
116		callExpr, ok := node.(*ast.CallExpr)
117		if !ok {
118			return true
119		}
120		replaceTestingTsInArgsLists(callExpr, testingT)
121
122		funCall, ok := callExpr.Fun.(*ast.SelectorExpr)
123		if ok {
124			replaceTestingTsMethodCalls(funCall, testingT)
125		}
126
127		return true
128	})
129}
130
131/*
132 * rewrite t.Fail() or any other *testing.T method by replacing with T().Fail()
133 * This function receives a selector expression (eg: t.Fail()) and
134 * the name of the *testing.T param from the function declaration. Rewrites the
135 * selector expression in place if the target was a *testing.T
136 */
137func replaceTestingTsMethodCalls(selectorExpr *ast.SelectorExpr, testingT string) {
138	ident, ok := selectorExpr.X.(*ast.Ident)
139	if !ok {
140		return
141	}
142
143	if ident.Name == testingT {
144		selectorExpr.X = newGinkgoTFromIdent(ident)
145	}
146}
147
148/*
149 * replaces usages of a named *testing.T param inside of a call expression
150 * with a new GinkgoT object
151 */
152func replaceTestingTsInArgsLists(callExpr *ast.CallExpr, testingT string) {
153	for index, arg := range callExpr.Args {
154		ident, ok := arg.(*ast.Ident)
155		if !ok {
156			continue
157		}
158
159		if ident.Name == testingT {
160			callExpr.Args[index] = newGinkgoTFromIdent(ident)
161		}
162	}
163}
164