1package main 2 3import ( 4 "flag" 5 "fmt" 6 "os" 7 "path/filepath" 8 "strings" 9 "text/template" 10) 11 12func BuildGenerateCommand() *Command { 13 var agouti, noDot, internal bool 14 flagSet := flag.NewFlagSet("generate", flag.ExitOnError) 15 flagSet.BoolVar(&agouti, "agouti", false, "If set, generate will generate a test file for writing Agouti tests") 16 flagSet.BoolVar(&noDot, "nodot", false, "If set, generate will generate a test file that does not . import ginkgo and gomega") 17 flagSet.BoolVar(&internal, "internal", false, "If set, generate will generate a test file that uses the regular package name") 18 19 return &Command{ 20 Name: "generate", 21 FlagSet: flagSet, 22 UsageCommand: "ginkgo generate <filename(s)>", 23 Usage: []string{ 24 "Generate a test file named filename_test.go", 25 "If the optional <filenames> argument is omitted, a file named after the package in the current directory will be created.", 26 "Accepts the following flags:", 27 }, 28 Command: func(args []string, additionalArgs []string) { 29 generateSpec(args, agouti, noDot, internal) 30 }, 31 } 32} 33 34var specText = `package {{.Package}} 35 36import ( 37 {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} 38 {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} 39 40 {{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}} 41) 42 43var _ = Describe("{{.Subject}}", func() { 44 45}) 46` 47 48var agoutiSpecText = `package {{.Package}} 49 50import ( 51 {{if .IncludeImports}}. "github.com/onsi/ginkgo"{{end}} 52 {{if .IncludeImports}}. "github.com/onsi/gomega"{{end}} 53 "github.com/sclevine/agouti" 54 . "github.com/sclevine/agouti/matchers" 55 56 {{if .DotImportPackage}}. "{{.PackageImportPath}}"{{end}} 57) 58 59var _ = Describe("{{.Subject}}", func() { 60 var page *agouti.Page 61 62 BeforeEach(func() { 63 var err error 64 page, err = agoutiDriver.NewPage() 65 Expect(err).NotTo(HaveOccurred()) 66 }) 67 68 AfterEach(func() { 69 Expect(page.Destroy()).To(Succeed()) 70 }) 71}) 72` 73 74type specData struct { 75 Package string 76 Subject string 77 PackageImportPath string 78 IncludeImports bool 79 DotImportPackage bool 80} 81 82func generateSpec(args []string, agouti, noDot, internal bool) { 83 if len(args) == 0 { 84 err := generateSpecForSubject("", agouti, noDot, internal) 85 if err != nil { 86 fmt.Println(err.Error()) 87 fmt.Println("") 88 os.Exit(1) 89 } 90 fmt.Println("") 91 return 92 } 93 94 var failed bool 95 for _, arg := range args { 96 err := generateSpecForSubject(arg, agouti, noDot, internal) 97 if err != nil { 98 failed = true 99 fmt.Println(err.Error()) 100 } 101 } 102 fmt.Println("") 103 if failed { 104 os.Exit(1) 105 } 106} 107 108func generateSpecForSubject(subject string, agouti, noDot, internal bool) error { 109 packageName, specFilePrefix, formattedName := getPackageAndFormattedName() 110 if subject != "" { 111 specFilePrefix = formatSubject(subject) 112 formattedName = prettifyPackageName(specFilePrefix) 113 } 114 115 data := specData{ 116 Package: determinePackageName(packageName, internal), 117 Subject: formattedName, 118 PackageImportPath: getPackageImportPath(), 119 IncludeImports: !noDot, 120 DotImportPackage: !internal, 121 } 122 123 targetFile := fmt.Sprintf("%s_test.go", specFilePrefix) 124 if fileExists(targetFile) { 125 return fmt.Errorf("%s already exists.", targetFile) 126 } else { 127 fmt.Printf("Generating ginkgo test for %s in:\n %s\n", data.Subject, targetFile) 128 } 129 130 f, err := os.Create(targetFile) 131 if err != nil { 132 return err 133 } 134 defer f.Close() 135 136 var templateText string 137 if agouti { 138 templateText = agoutiSpecText 139 } else { 140 templateText = specText 141 } 142 143 specTemplate, err := template.New("spec").Parse(templateText) 144 if err != nil { 145 return err 146 } 147 148 specTemplate.Execute(f, data) 149 goFmt(targetFile) 150 return nil 151} 152 153func formatSubject(name string) string { 154 name = strings.Replace(name, "-", "_", -1) 155 name = strings.Replace(name, " ", "_", -1) 156 name = strings.Split(name, ".go")[0] 157 name = strings.Split(name, "_test")[0] 158 return name 159} 160 161func getPackageImportPath() string { 162 workingDir, err := os.Getwd() 163 if err != nil { 164 panic(err.Error()) 165 } 166 sep := string(filepath.Separator) 167 paths := strings.Split(workingDir, sep+"src"+sep) 168 if len(paths) == 1 { 169 fmt.Printf("\nCouldn't identify package import path.\n\n\tginkgo generate\n\nMust be run within a package directory under $GOPATH/src/...\nYou're going to have to change UNKNOWN_PACKAGE_PATH in the generated file...\n\n") 170 return "UNKNOWN_PACKAGE_PATH" 171 } 172 return filepath.ToSlash(paths[len(paths)-1]) 173} 174