1package main
2
3import (
4	"flag"
5	"fmt"
6	"os"
7	"path/filepath"
8	"strings"
9	"text/template"
10)
11
12func BuildGenerateCommand() *Command {
13	var agouti, noDot, internal bool
14	flagSet := flag.NewFlagSet("generate", flag.ExitOnError)
15	flagSet.BoolVar(&agouti, "agouti", false, "If set, generate will generate a test file for writing Agouti tests")
16	flagSet.BoolVar(&noDot, "nodot", false, "If set, generate will generate a test file that does not . import ginkgo and gomega")
17	flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name")
18
19	return &Command{
20		Name:         "generate",
21		FlagSet:      flagSet,
22		UsageCommand: "ginkgo generate <filename(s)>",
23		Usage: []string{
24			"Generate a test file named filename_test.go",
25			"If the optional <filenames> argument is omitted, a file named after the package in the current directory will be created.",
26			"Accepts the following flags:",
27		},
28		Command: func(args []string, additionalArgs []string) {
29			generateSpec(args, agouti, noDot, internal)
30		},
31	}
32}
33
34var specText = `package {{.Package}}
35
36import (
37	{{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}}
38	{{if .IncludeImports}}. "github.com/onsi/gomega"{{end}}
39
40	{{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}}
41)
42
43var _ = Describe("{{.Subject}}", func() {
44
45})
46`
47
48var agoutiSpecText = `package {{.Package}}
49
50import (
51	{{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}}
52	{{if .IncludeImports}}. "github.com/onsi/gomega"{{end}}
53	"github.com/sclevine/agouti"
54	. "github.com/sclevine/agouti/matchers"
55
56	{{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}}
57)
58
59var _ = Describe("{{.Subject}}", func() {
60	var page *agouti.Page
61
62	BeforeEach(func() {
63		var err error
64		page, err = agoutiDriver.NewPage()
65		Expect(err).NotTo(HaveOccurred())
66	})
67
68	AfterEach(func() {
69		Expect(page.Destroy()).To(Succeed())
70	})
71})
72`
73
74type specData struct {
75	Package           string
76	Subject           string
77	PackageImportPath string
78	IncludeImports    bool
79	DotImportPackage  bool
80}
81
82func generateSpec(args []string, agouti, noDot, internal bool) {
83	if len(args) == 0 {
84		err := generateSpecForSubject("", agouti, noDot, internal)
85		if err != nil {
86			fmt.Println(err.Error())
87			fmt.Println("")
88			os.Exit(1)
89		}
90		fmt.Println("")
91		return
92	}
93
94	var failed bool
95	for _, arg := range args {
96		err := generateSpecForSubject(arg, agouti, noDot, internal)
97		if err != nil {
98			failed = true
99			fmt.Println(err.Error())
100		}
101	}
102	fmt.Println("")
103	if failed {
104		os.Exit(1)
105	}
106}
107
108func generateSpecForSubject(subject string, agouti, noDot, internal bool) error {
109	packageName, specFilePrefix, formattedName := getPackageAndFormattedName()
110	if subject != "" {
111		subject = strings.Split(subject, ".go")[0]
112		subject = strings.Split(subject, "_test")[0]
113		specFilePrefix = subject
114		formattedName = prettifyPackageName(subject)
115	}
116
117	data := specData{
118		Package:           determinePackageName(packageName, internal),
119		Subject:           formattedName,
120		PackageImportPath: getPackageImportPath(),
121		IncludeImports:    !noDot,
122		DotImportPackage:  !internal,
123	}
124
125	targetFile := fmt.Sprintf("%s_test.go", specFilePrefix)
126	if fileExists(targetFile) {
127		return fmt.Errorf("%s already exists.", targetFile)
128	} else {
129		fmt.Printf("Generating ginkgo test for %s in:\n  %s\n", data.Subject, targetFile)
130	}
131
132	f, err := os.Create(targetFile)
133	if err != nil {
134		return err
135	}
136	defer f.Close()
137
138	var templateText string
139	if agouti {
140		templateText = agoutiSpecText
141	} else {
142		templateText = specText
143	}
144
145	specTemplate, err := template.New("spec").Parse(templateText)
146	if err != nil {
147		return err
148	}
149
150	specTemplate.Execute(f, data)
151	goFmt(targetFile)
152	return nil
153}
154
155func getPackageImportPath() string {
156	workingDir, err := os.Getwd()
157	if err != nil {
158		panic(err.Error())
159	}
160	sep := string(filepath.Separator)
161	paths := strings.Split(workingDir, sep+"src"+sep)
162	if len(paths) == 1 {
163		fmt.Printf("\nCouldn't identify package import path.\n\n\tginkgo generate\n\nMust be run within a package directory under $GOPATH/src/...\nYou're going to have to change UNKNOWN_PACKAGE_PATH in the generated file...\n\n")
164		return "UNKNOWN_PACKAGE_PATH"
165	}
166	return filepath.ToSlash(paths[len(paths)-1])
167}
168