1package nodot
2
3import (
4	"fmt"
5	"go/ast"
6	"go/build"
7	"go/parser"
8	"go/token"
9	"path/filepath"
10	"strings"
11)
12
13func ApplyNoDot(data []byte) ([]byte, error) {
14	sections, err := generateNodotSections()
15	if err != nil {
16		return nil, err
17	}
18
19	for _, section := range sections {
20		data = section.createOrUpdateIn(data)
21	}
22
23	return data, nil
24}
25
26type nodotSection struct {
27	name         string
28	pkg          string
29	declarations []string
30	types        []string
31}
32
33func (s nodotSection) createOrUpdateIn(data []byte) []byte {
34	renames := map[string]string{}
35
36	contents := string(data)
37
38	lines := strings.Split(contents, "\n")
39
40	comment := "// Declarations for " + s.name
41
42	newLines := []string{}
43	for _, line := range lines {
44		if line == comment {
45			continue
46		}
47
48		words := strings.Split(line, " ")
49		lastWord := words[len(words)-1]
50
51		if s.containsDeclarationOrType(lastWord) {
52			renames[lastWord] = words[1]
53			continue
54		}
55
56		newLines = append(newLines, line)
57	}
58
59	if len(newLines[len(newLines)-1]) > 0 {
60		newLines = append(newLines, "")
61	}
62
63	newLines = append(newLines, comment)
64
65	for _, typ := range s.types {
66		name, ok := renames[s.prefix(typ)]
67		if !ok {
68			name = typ
69		}
70		newLines = append(newLines, fmt.Sprintf("type %s %s", name, s.prefix(typ)))
71	}
72
73	for _, decl := range s.declarations {
74		name, ok := renames[s.prefix(decl)]
75		if !ok {
76			name = decl
77		}
78		newLines = append(newLines, fmt.Sprintf("var %s = %s", name, s.prefix(decl)))
79	}
80
81	newLines = append(newLines, "")
82
83	newContents := strings.Join(newLines, "\n")
84
85	return []byte(newContents)
86}
87
88func (s nodotSection) prefix(declOrType string) string {
89	return s.pkg + "." + declOrType
90}
91
92func (s nodotSection) containsDeclarationOrType(word string) bool {
93	for _, declaration := range s.declarations {
94		if s.prefix(declaration) == word {
95			return true
96		}
97	}
98
99	for _, typ := range s.types {
100		if s.prefix(typ) == word {
101			return true
102		}
103	}
104
105	return false
106}
107
108func generateNodotSections() ([]nodotSection, error) {
109	sections := []nodotSection{}
110
111	declarations, err := getExportedDeclerationsForPackage("github.com/onsi/ginkgo", "ginkgo_dsl.go", "GINKGO_VERSION", "GINKGO_PANIC")
112	if err != nil {
113		return nil, err
114	}
115	sections = append(sections, nodotSection{
116		name:         "Ginkgo DSL",
117		pkg:          "ginkgo",
118		declarations: declarations,
119		types:        []string{"Done", "Benchmarker"},
120	})
121
122	declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "gomega_dsl.go", "GOMEGA_VERSION")
123	if err != nil {
124		return nil, err
125	}
126	sections = append(sections, nodotSection{
127		name:         "Gomega DSL",
128		pkg:          "gomega",
129		declarations: declarations,
130	})
131
132	declarations, err = getExportedDeclerationsForPackage("github.com/onsi/gomega", "matchers.go")
133	if err != nil {
134		return nil, err
135	}
136	sections = append(sections, nodotSection{
137		name:         "Gomega Matchers",
138		pkg:          "gomega",
139		declarations: declarations,
140	})
141
142	return sections, nil
143}
144
145func getExportedDeclerationsForPackage(pkgPath string, filename string, blacklist ...string) ([]string, error) {
146	pkg, err := build.Import(pkgPath, ".", 0)
147	if err != nil {
148		return []string{}, err
149	}
150
151	declarations, err := getExportedDeclarationsForFile(filepath.Join(pkg.Dir, filename))
152	if err != nil {
153		return []string{}, err
154	}
155
156	blacklistLookup := map[string]bool{}
157	for _, declaration := range blacklist {
158		blacklistLookup[declaration] = true
159	}
160
161	filteredDeclarations := []string{}
162	for _, declaration := range declarations {
163		if blacklistLookup[declaration] {
164			continue
165		}
166		filteredDeclarations = append(filteredDeclarations, declaration)
167	}
168
169	return filteredDeclarations, nil
170}
171
172func getExportedDeclarationsForFile(path string) ([]string, error) {
173	fset := token.NewFileSet()
174	tree, err := parser.ParseFile(fset, path, nil, 0)
175	if err != nil {
176		return []string{}, err
177	}
178
179	declarations := []string{}
180	ast.FileExports(tree)
181	for _, decl := range tree.Decls {
182		switch x := decl.(type) {
183		case *ast.GenDecl:
184			switch s := x.Specs[0].(type) {
185			case *ast.ValueSpec:
186				declarations = append(declarations, s.Names[0].Name)
187			}
188		case *ast.FuncDecl:
189			declarations = append(declarations, x.Name.Name)
190		}
191	}
192
193	return declarations, nil
194}
195