1package main
2
3import (
4	"flag"
5	"fmt"
6	"os"
7	"path/filepath"
8	"strings"
9	"text/template"
10)
11
12func BuildGenerateCommand() *Command {
13	var agouti, noDot, internal bool
14	flagSet := flag.NewFlagSet("generate", flag.ExitOnError)
15	flagSet.BoolVar(&agouti, "agouti", false, "If set, generate will generate a test file for writing Agouti tests")
16	flagSet.BoolVar(&noDot, "nodot", false, "If set, generate will generate a test file that does not . import ginkgo and gomega")
17	flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name")
18
19	return &Command{
20		Name:         "generate",
21		FlagSet:      flagSet,
22		UsageCommand: "ginkgo generate <filename(s)>",
23		Usage: []string{
24			"Generate a test file named filename_test.go",
25			"If the optional <filenames> argument is omitted, a file named after the package in the current directory will be created.",
26			"Accepts the following flags:",
27		},
28		Command: func(args []string, additionalArgs []string) {
29			generateSpec(args, agouti, noDot, internal)
30		},
31	}
32}
33
34var specText = `package {{.Package}}
35
36import (
37	{{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}}
38	{{if .IncludeImports}}. "github.com/onsi/gomega"{{end}}
39
40	{{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}}
41)
42
43var _ = Describe("{{.Subject}}", func() {
44
45})
46`
47
48var agoutiSpecText = `package {{.Package}}
49
50import (
51	{{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}}
52	{{if .IncludeImports}}. "github.com/onsi/gomega"{{end}}
53	"github.com/sclevine/agouti"
54	. "github.com/sclevine/agouti/matchers"
55
56	{{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}}
57)
58
59var _ = Describe("{{.Subject}}", func() {
60	var page *agouti.Page
61
62	BeforeEach(func() {
63		var err error
64		page, err = agoutiDriver.NewPage()
65		Expect(err).NotTo(HaveOccurred())
66	})
67
68	AfterEach(func() {
69		Expect(page.Destroy()).To(Succeed())
70	})
71})
72`
73
74type specData struct {
75	Package           string
76	Subject           string
77	PackageImportPath string
78	IncludeImports    bool
79	DotImportPackage  bool
80}
81
82func generateSpec(args []string, agouti, noDot, internal bool) {
83	if len(args) == 0 {
84		err := generateSpecForSubject("", agouti, noDot, internal)
85		if err != nil {
86			fmt.Println(err.Error())
87			fmt.Println("")
88			os.Exit(1)
89		}
90		fmt.Println("")
91		return
92	}
93
94	var failed bool
95	for _, arg := range args {
96		err := generateSpecForSubject(arg, agouti, noDot, internal)
97		if err != nil {
98			failed = true
99			fmt.Println(err.Error())
100		}
101	}
102	fmt.Println("")
103	if failed {
104		os.Exit(1)
105	}
106}
107
108func generateSpecForSubject(subject string, agouti, noDot, internal bool) error {
109	packageName, specFilePrefix, formattedName := getPackageAndFormattedName()
110	if subject != "" {
111		specFilePrefix = formatSubject(subject)
112		formattedName = prettifyPackageName(specFilePrefix)
113	}
114
115	data := specData{
116		Package:           determinePackageName(packageName, internal),
117		Subject:           formattedName,
118		PackageImportPath: getPackageImportPath(),
119		IncludeImports:    !noDot,
120		DotImportPackage:  !internal,
121	}
122
123	targetFile := fmt.Sprintf("%s_test.go", specFilePrefix)
124	if fileExists(targetFile) {
125		return fmt.Errorf("%s already exists.", targetFile)
126	} else {
127		fmt.Printf("Generating ginkgo test for %s in:\n  %s\n", data.Subject, targetFile)
128	}
129
130	f, err := os.Create(targetFile)
131	if err != nil {
132		return err
133	}
134	defer f.Close()
135
136	var templateText string
137	if agouti {
138		templateText = agoutiSpecText
139	} else {
140		templateText = specText
141	}
142
143	specTemplate, err := template.New("spec").Parse(templateText)
144	if err != nil {
145		return err
146	}
147
148	specTemplate.Execute(f, data)
149	goFmt(targetFile)
150	return nil
151}
152
153func formatSubject(name string) string {
154	name = strings.Replace(name, "-", "_", -1)
155	name = strings.Replace(name, " ", "_", -1)
156	name = strings.Split(name, ".go")[0]
157	name = strings.Split(name, "_test")[0]
158	return name
159}
160
161func getPackageImportPath() string {
162	workingDir, err := os.Getwd()
163	if err != nil {
164		panic(err.Error())
165	}
166	sep := string(filepath.Separator)
167	paths := strings.Split(workingDir, sep+"src"+sep)
168	if len(paths) == 1 {
169		fmt.Printf("\nCouldn't identify package import path.\n\n\tginkgo generate\n\nMust be run within a package directory under $GOPATH/src/...\nYou're going to have to change UNKNOWN_PACKAGE_PATH in the generated file...\n\n")
170		return "UNKNOWN_PACKAGE_PATH"
171	}
172	return filepath.ToSlash(paths[len(paths)-1])
173}
174