1package arguments
2
3import (
4	"go/build"
5	"log"
6	"path"
7	"path/filepath"
8	"regexp"
9	"strings"
10	"unicode"
11)
12
13//go:generate counterfeiter . ArgumentParser
14type ArgumentParser interface {
15	ParseArguments(...string) ParsedArguments
16}
17
18func NewArgumentParser(
19	failHandler FailHandler,
20	currentWorkingDir CurrentWorkingDir,
21	symlinkEvaler SymlinkEvaler,
22	fileStatReader FileStatReader,
23) ArgumentParser {
24	return &argumentParser{
25		failHandler:       failHandler,
26		currentWorkingDir: currentWorkingDir,
27		symlinkEvaler:     symlinkEvaler,
28		fileStatReader:    fileStatReader,
29	}
30}
31
32func (argParser *argumentParser) ParseArguments(args ...string) ParsedArguments {
33	if *packageFlag {
34		return argParser.parsePackageArgs(args...)
35	} else {
36		return argParser.parseInterfaceArgs(args...)
37	}
38}
39
40func (argParser *argumentParser) parseInterfaceArgs(args ...string) ParsedArguments {
41	var interfaceName string
42	var outputPathFlagValue string
43	var rootDestinationDir string
44	var sourcePackageDir string
45	var packagePath string
46
47	if outputPathFlag != nil {
48		outputPathFlagValue = *outputPathFlag
49	}
50
51	if len(args) > 1 {
52		interfaceName = args[1]
53		sourcePackageDir = argParser.getSourceDir(args[0])
54		rootDestinationDir = sourcePackageDir
55	} else {
56		fullyQualifiedInterface := strings.Split(args[0], ".")
57		interfaceName = fullyQualifiedInterface[len(fullyQualifiedInterface)-1]
58		rootDestinationDir = argParser.currentWorkingDir()
59		packagePath = strings.Join(fullyQualifiedInterface[:len(fullyQualifiedInterface)-1], ".")
60	}
61
62	fakeImplName := getFakeName(interfaceName, *fakeNameFlag)
63
64	outputPath := argParser.getOutputPath(
65		rootDestinationDir,
66		fakeImplName,
67		outputPathFlagValue,
68	)
69
70	packageName := restrictToValidPackageName(filepath.Base(filepath.Dir(outputPath)))
71	if packagePath == "" {
72		packagePath = sourcePackageDir
73	}
74	if strings.HasPrefix(packagePath, build.Default.GOPATH) {
75		packagePath = strings.Replace(packagePath, build.Default.GOPATH+"/src/", "", -1)
76	}
77
78	log.Printf("Parsed Arguments:\nInterface Name: %s\nPackage Path: %s\nDestination Package Name: %s", interfaceName, packagePath, packageName)
79	return ParsedArguments{
80		GenerateInterfaceAndShimFromPackageDirectory: false,
81		SourcePackageDir: sourcePackageDir,
82		OutputPath:       outputPath,
83		PackagePath:      packagePath,
84
85		InterfaceName:          interfaceName,
86		DestinationPackageName: packageName,
87		FakeImplName:           fakeImplName,
88
89		PrintToStdOut: any(args, "-"),
90	}
91}
92
93func (argParser *argumentParser) parsePackageArgs(args ...string) ParsedArguments {
94	packagePath := args[0]
95	packageName := path.Base(packagePath) + "shim"
96
97	var outputPath string
98	if *outputPathFlag != "" {
99		// TODO: sensible checking of dirs and symlinks
100		outputPath = *outputPathFlag
101	} else {
102		outputPath = path.Join(argParser.currentWorkingDir(), packageName)
103	}
104
105	log.Printf("Parsed Arguments:\nPackage Name: %s\nDestination Package Name: %s", packagePath, packageName)
106	return ParsedArguments{
107		GenerateInterfaceAndShimFromPackageDirectory: true,
108		SourcePackageDir:       packagePath,
109		OutputPath:             outputPath,
110		PackagePath:            packagePath,
111		DestinationPackageName: packageName,
112		FakeImplName:           strings.ToUpper(path.Base(packagePath))[:1] + path.Base(packagePath)[1:],
113		PrintToStdOut:          any(args, "-"),
114	}
115}
116
117type argumentParser struct {
118	failHandler       FailHandler
119	currentWorkingDir CurrentWorkingDir
120	symlinkEvaler     SymlinkEvaler
121	fileStatReader    FileStatReader
122}
123
124type ParsedArguments struct {
125	GenerateInterfaceAndShimFromPackageDirectory bool
126
127	SourcePackageDir string // abs path to the dir containing the interface to fake
128	PackagePath      string // package path to the package containing the interface to fake
129	OutputPath       string // path to write the fake file to
130
131	DestinationPackageName string // often the base-dir for OutputPath but must be a valid package name
132
133	InterfaceName string // the interface to counterfeit
134	FakeImplName  string // the name of the struct implementing the given interface
135
136	PrintToStdOut bool
137}
138
139func fixupUnexportedNames(interfaceName string) string {
140	asRunes := []rune(interfaceName)
141	if len(asRunes) == 0 || !unicode.IsLower(asRunes[0]) {
142		return interfaceName
143	}
144	asRunes[0] = unicode.ToUpper(asRunes[0])
145	return string(asRunes)
146}
147
148func getFakeName(interfaceName, arg string) string {
149	if arg == "" {
150		interfaceName = fixupUnexportedNames(interfaceName)
151		return "Fake" + interfaceName
152	} else {
153		return arg
154	}
155}
156
157var camelRegexp = regexp.MustCompile("([a-z])([A-Z])")
158
159func (argParser *argumentParser) getOutputPath(rootDestinationDir, fakeName, outputPathFlagValue string) string {
160	if outputPathFlagValue == "" {
161		snakeCaseName := strings.ToLower(camelRegexp.ReplaceAllString(fakeName, "${1}_${2}"))
162		return filepath.Join(rootDestinationDir, packageNameForPath(rootDestinationDir), snakeCaseName+".go")
163	} else {
164		if !filepath.IsAbs(outputPathFlagValue) {
165			outputPathFlagValue = filepath.Join(argParser.currentWorkingDir(), outputPathFlagValue)
166		}
167		return outputPathFlagValue
168	}
169}
170
171func packageNameForPath(pathToPackage string) string {
172	_, packageName := filepath.Split(pathToPackage)
173	return packageName + "fakes"
174}
175
176func (argParser *argumentParser) getSourceDir(path string) string {
177	if !filepath.IsAbs(path) {
178		path = filepath.Join(argParser.currentWorkingDir(), path)
179	}
180
181	evaluatedPath, err := argParser.symlinkEvaler(path)
182	if err != nil {
183		argParser.failHandler("No such file/directory/package: '%s'", path)
184	}
185
186	stat, err := argParser.fileStatReader(evaluatedPath)
187	if err != nil {
188		argParser.failHandler("No such file/directory/package: '%s'", path)
189	}
190
191	if !stat.IsDir() {
192		return filepath.Dir(path)
193	} else {
194		return path
195	}
196}
197
198func any(slice []string, needle string) bool {
199	for _, str := range slice {
200		if str == needle {
201			return true
202		}
203	}
204
205	return false
206}
207
208func restrictToValidPackageName(input string) string {
209	return strings.Map(func(r rune) rune {
210		if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' {
211			return r
212		} else {
213			return -1
214		}
215	}, input)
216}
217
218type FailHandler func(string, ...interface{})
219