1package convert
2
3import (
4	"fmt"
5	"go/ast"
6)
7
8/*
9 * Given the root node of an AST, returns the node containing the
10 * import statements for the file.
11 */
12func importsForRootNode(rootNode *ast.File) (imports *ast.GenDecl, err error) {
13	for _, declaration := range rootNode.Decls {
14		decl, ok := declaration.(*ast.GenDecl)
15		if !ok || len(decl.Specs) == 0 {
16			continue
17		}
18
19		_, ok = decl.Specs[0].(*ast.ImportSpec)
20		if ok {
21			imports = decl
22			return
23		}
24	}
25
26	err = fmt.Errorf("Could not find imports for root node:\n\t%#v\n", rootNode)
27	return
28}
29
30/*
31 * Removes "testing" import, if present
32 */
33func removeTestingImport(rootNode *ast.File) {
34	importDecl, err := importsForRootNode(rootNode)
35	if err != nil {
36		panic(err.Error())
37	}
38
39	var index int
40	for i, importSpec := range importDecl.Specs {
41		importSpec := importSpec.(*ast.ImportSpec)
42		if importSpec.Path.Value == "\"testing\"" {
43			index = i
44			break
45		}
46	}
47
48	importDecl.Specs = append(importDecl.Specs[:index], importDecl.Specs[index+1:]...)
49}
50
51/*
52 * Adds import statements for onsi/ginkgo, if missing
53 */
54func addGinkgoImports(rootNode *ast.File) {
55	importDecl, err := importsForRootNode(rootNode)
56	if err != nil {
57		panic(err.Error())
58	}
59
60	if len(importDecl.Specs) == 0 {
61		// TODO: might need to create a import decl here
62		panic("unimplemented : expected to find an imports block")
63	}
64
65	needsGinkgo := true
66	for _, importSpec := range importDecl.Specs {
67		importSpec, ok := importSpec.(*ast.ImportSpec)
68		if !ok {
69			continue
70		}
71
72		if importSpec.Path.Value == "\"github.com/onsi/ginkgo\"" {
73			needsGinkgo = false
74		}
75	}
76
77	if needsGinkgo {
78		importDecl.Specs = append(importDecl.Specs, createImport(".", "\"github.com/onsi/ginkgo\""))
79	}
80}
81
82/*
83 * convenience function to create an import statement
84 */
85func createImport(name, path string) *ast.ImportSpec {
86	return &ast.ImportSpec{
87		Name: &ast.Ident{Name: name},
88		Path: &ast.BasicLit{Kind: 9, Value: path},
89	}
90}
91