1// Copyright 2010 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// MockGen generates mock implementations of Go interfaces.
16package main
17
18// TODO: This does not support recursive embedded interfaces.
19// TODO: This does not support embedding package-local interfaces in a separate file.
20
21import (
22	"bytes"
23	"encoding/json"
24	"flag"
25	"fmt"
26	"go/token"
27	"io"
28	"io/ioutil"
29	"log"
30	"os"
31	"os/exec"
32	"path"
33	"path/filepath"
34	"sort"
35	"strconv"
36	"strings"
37	"unicode"
38
39	"github.com/golang/mock/mockgen/model"
40
41	toolsimports "golang.org/x/tools/imports"
42)
43
44const (
45	gomockImportPath = "github.com/golang/mock/gomock"
46)
47
48var (
49	version = ""
50	commit  = "none"
51	date    = "unknown"
52)
53
54var (
55	source          = flag.String("source", "", "(source mode) Input Go source file; enables source mode.")
56	destination     = flag.String("destination", "", "Output file; defaults to stdout.")
57	mockNames       = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.")
58	packageOut      = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.")
59	selfPackage     = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.")
60	writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.")
61	copyrightFile   = flag.String("copyright_file", "", "Copyright file used to add copyright header")
62
63	debugParser = flag.Bool("debug_parser", false, "Print out parser results only.")
64	showVersion = flag.Bool("version", false, "Print version.")
65)
66
67func main() {
68	flag.Usage = usage
69	flag.Parse()
70
71	if *showVersion {
72		printVersion()
73		return
74	}
75
76	var pkg *model.Package
77	var err error
78	var packageName string
79	if *source != "" {
80		pkg, err = sourceMode(*source)
81	} else {
82		if flag.NArg() != 2 {
83			usage()
84			log.Fatal("Expected exactly two arguments")
85		}
86		packageName = flag.Arg(0)
87		if packageName == "." {
88			dir, err := os.Getwd()
89			if err != nil {
90				log.Fatalf("Get current directory failed: %v", err)
91			}
92			packageName, err = packageNameOfDir(dir)
93			if err != nil {
94				log.Fatalf("Parse package name failed: %v", err)
95			}
96		}
97		pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
98	}
99	if err != nil {
100		log.Fatalf("Loading input failed: %v", err)
101	}
102
103	if *debugParser {
104		pkg.Print(os.Stdout)
105		return
106	}
107
108	dst := os.Stdout
109	if len(*destination) > 0 {
110		if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
111			log.Fatalf("Unable to create directory: %v", err)
112		}
113		f, err := os.Create(*destination)
114		if err != nil {
115			log.Fatalf("Failed opening destination file: %v", err)
116		}
117		defer f.Close()
118		dst = f
119	}
120
121	outputPackageName := *packageOut
122	if outputPackageName == "" {
123		// pkg.Name in reflect mode is the base name of the import path,
124		// which might have characters that are illegal to have in package names.
125		outputPackageName = "mock_" + sanitize(pkg.Name)
126	}
127
128	// outputPackagePath represents the fully qualified name of the package of
129	// the generated code. Its purposes are to prevent the module from importing
130	// itself and to prevent qualifying type names that come from its own
131	// package (i.e. if there is a type called X then we want to print "X" not
132	// "package.X" since "package" is this package). This can happen if the mock
133	// is output into an already existing package.
134	outputPackagePath := *selfPackage
135	if outputPackagePath == "" && *destination != "" {
136		dstPath, err := filepath.Abs(filepath.Dir(*destination))
137		if err != nil {
138			log.Fatalf("Unable to determine destination file path: %v", err)
139		}
140		outputPackagePath, err = parsePackageImport(dstPath)
141		if err != nil {
142			log.Fatalf("Unable to determine destination file path: %v", err)
143		}
144	}
145
146	g := new(generator)
147	if *source != "" {
148		g.filename = *source
149	} else {
150		g.srcPackage = packageName
151		g.srcInterfaces = flag.Arg(1)
152	}
153	g.destination = *destination
154
155	if *mockNames != "" {
156		g.mockNames = parseMockNames(*mockNames)
157	}
158	if *copyrightFile != "" {
159		header, err := ioutil.ReadFile(*copyrightFile)
160		if err != nil {
161			log.Fatalf("Failed reading copyright file: %v", err)
162		}
163
164		g.copyrightHeader = string(header)
165	}
166	if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil {
167		log.Fatalf("Failed generating mock: %v", err)
168	}
169	if _, err := dst.Write(g.Output()); err != nil {
170		log.Fatalf("Failed writing to destination: %v", err)
171	}
172}
173
174func parseMockNames(names string) map[string]string {
175	mocksMap := make(map[string]string)
176	for _, kv := range strings.Split(names, ",") {
177		parts := strings.SplitN(kv, "=", 2)
178		if len(parts) != 2 || parts[1] == "" {
179			log.Fatalf("bad mock names spec: %v", kv)
180		}
181		mocksMap[parts[0]] = parts[1]
182	}
183	return mocksMap
184}
185
186func usage() {
187	_, _ = io.WriteString(os.Stderr, usageText)
188	flag.PrintDefaults()
189}
190
191const usageText = `mockgen has two modes of operation: source and reflect.
192
193Source mode generates mock interfaces from a source file.
194It is enabled by using the -source flag. Other flags that
195may be useful in this mode are -imports and -aux_files.
196Example:
197	mockgen -source=foo.go [other options]
198
199Reflect mode generates mock interfaces by building a program
200that uses reflection to understand interfaces. It is enabled
201by passing two non-flag arguments: an import path, and a
202comma-separated list of symbols.
203Example:
204	mockgen database/sql/driver Conn,Driver
205
206`
207
208type generator struct {
209	buf                       bytes.Buffer
210	indent                    string
211	mockNames                 map[string]string // may be empty
212	filename                  string            // may be empty
213	destination               string            // may be empty
214	srcPackage, srcInterfaces string            // may be empty
215	copyrightHeader           string
216
217	packageMap map[string]string // map from import path to package name
218}
219
220func (g *generator) p(format string, args ...interface{}) {
221	fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
222}
223
224func (g *generator) in() {
225	g.indent += "\t"
226}
227
228func (g *generator) out() {
229	if len(g.indent) > 0 {
230		g.indent = g.indent[0 : len(g.indent)-1]
231	}
232}
233
234func removeDot(s string) string {
235	if len(s) > 0 && s[len(s)-1] == '.' {
236		return s[0 : len(s)-1]
237	}
238	return s
239}
240
241// sanitize cleans up a string to make a suitable package name.
242func sanitize(s string) string {
243	t := ""
244	for _, r := range s {
245		if t == "" {
246			if unicode.IsLetter(r) || r == '_' {
247				t += string(r)
248				continue
249			}
250		} else {
251			if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
252				t += string(r)
253				continue
254			}
255		}
256		t += "_"
257	}
258	if t == "_" {
259		t = "x"
260	}
261	return t
262}
263
264func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
265	if outputPkgName != pkg.Name && *selfPackage == "" {
266		// reset outputPackagePath if it's not passed in through -self_package
267		outputPackagePath = ""
268	}
269
270	if g.copyrightHeader != "" {
271		lines := strings.Split(g.copyrightHeader, "\n")
272		for _, line := range lines {
273			g.p("// %s", line)
274		}
275		g.p("")
276	}
277
278	g.p("// Code generated by MockGen. DO NOT EDIT.")
279	if g.filename != "" {
280		g.p("// Source: %v", g.filename)
281	} else {
282		g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces)
283	}
284	g.p("")
285
286	// Get all required imports, and generate unique names for them all.
287	im := pkg.Imports()
288	im[gomockImportPath] = true
289
290	// Only import reflect if it's used. We only use reflect in mocked methods
291	// so only import if any of the mocked interfaces have methods.
292	for _, intf := range pkg.Interfaces {
293		if len(intf.Methods) > 0 {
294			im["reflect"] = true
295			break
296		}
297	}
298
299	// Sort keys to make import alias generation predictable
300	sortedPaths := make([]string, len(im))
301	x := 0
302	for pth := range im {
303		sortedPaths[x] = pth
304		x++
305	}
306	sort.Strings(sortedPaths)
307
308	packagesName := createPackageMap(sortedPaths)
309
310	g.packageMap = make(map[string]string, len(im))
311	localNames := make(map[string]bool, len(im))
312	for _, pth := range sortedPaths {
313		base, ok := packagesName[pth]
314		if !ok {
315			base = sanitize(path.Base(pth))
316		}
317
318		// Local names for an imported package can usually be the basename of the import path.
319		// A couple of situations don't permit that, such as duplicate local names
320		// (e.g. importing "html/template" and "text/template"), or where the basename is
321		// a keyword (e.g. "foo/case").
322		// try base0, base1, ...
323		pkgName := base
324		i := 0
325		for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() {
326			pkgName = base + strconv.Itoa(i)
327			i++
328		}
329
330		// Avoid importing package if source pkg == output pkg
331		if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath {
332			continue
333		}
334
335		g.packageMap[pth] = pkgName
336		localNames[pkgName] = true
337	}
338
339	if *writePkgComment {
340		g.p("// Package %v is a generated GoMock package.", outputPkgName)
341	}
342	g.p("package %v", outputPkgName)
343	g.p("")
344	g.p("import (")
345	g.in()
346	for pkgPath, pkgName := range g.packageMap {
347		if pkgPath == outputPackagePath {
348			continue
349		}
350		g.p("%v %q", pkgName, pkgPath)
351	}
352	for _, pkgPath := range pkg.DotImports {
353		g.p(". %q", pkgPath)
354	}
355	g.out()
356	g.p(")")
357
358	for _, intf := range pkg.Interfaces {
359		if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
360			return err
361		}
362	}
363
364	return nil
365}
366
367// The name of the mock type to use for the given interface identifier.
368func (g *generator) mockName(typeName string) string {
369	if mockName, ok := g.mockNames[typeName]; ok {
370		return mockName
371	}
372
373	return "Mock" + typeName
374}
375
376func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
377	mockType := g.mockName(intf.Name)
378
379	g.p("")
380	g.p("// %v is a mock of %v interface.", mockType, intf.Name)
381	g.p("type %v struct {", mockType)
382	g.in()
383	g.p("ctrl     *gomock.Controller")
384	g.p("recorder *%vMockRecorder", mockType)
385	g.out()
386	g.p("}")
387	g.p("")
388
389	g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType)
390	g.p("type %vMockRecorder struct {", mockType)
391	g.in()
392	g.p("mock *%v", mockType)
393	g.out()
394	g.p("}")
395	g.p("")
396
397	// TODO: Re-enable this if we can import the interface reliably.
398	// g.p("// Verify that the mock satisfies the interface at compile time.")
399	// g.p("var _ %v = (*%v)(nil)", typeName, mockType)
400	// g.p("")
401
402	g.p("// New%v creates a new mock instance.", mockType)
403	g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
404	g.in()
405	g.p("mock := &%v{ctrl: ctrl}", mockType)
406	g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
407	g.p("return mock")
408	g.out()
409	g.p("}")
410	g.p("")
411
412	// XXX: possible name collision here if someone has EXPECT in their interface.
413	g.p("// EXPECT returns an object that allows the caller to indicate expected use.")
414	g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
415	g.in()
416	g.p("return m.recorder")
417	g.out()
418	g.p("}")
419
420	g.GenerateMockMethods(mockType, intf, outputPackagePath)
421
422	return nil
423}
424
425type byMethodName []*model.Method
426
427func (b byMethodName) Len() int           { return len(b) }
428func (b byMethodName) Swap(i, j int)      { b[i], b[j] = b[j], b[i] }
429func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name }
430
431func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
432	sort.Sort(byMethodName(intf.Methods))
433	for _, m := range intf.Methods {
434		g.p("")
435		_ = g.GenerateMockMethod(mockType, m, pkgOverride)
436		g.p("")
437		_ = g.GenerateMockRecorderMethod(mockType, m)
438	}
439}
440
441func makeArgString(argNames, argTypes []string) string {
442	args := make([]string, len(argNames))
443	for i, name := range argNames {
444		// specify the type only once for consecutive args of the same type
445		if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] {
446			args[i] = name
447		} else {
448			args[i] = name + " " + argTypes[i]
449		}
450	}
451	return strings.Join(args, ", ")
452}
453
454// GenerateMockMethod generates a mock method implementation.
455// If non-empty, pkgOverride is the package in which unqualified types reside.
456func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
457	argNames := g.getArgNames(m)
458	argTypes := g.getArgTypes(m, pkgOverride)
459	argString := makeArgString(argNames, argTypes)
460
461	rets := make([]string, len(m.Out))
462	for i, p := range m.Out {
463		rets[i] = p.Type.String(g.packageMap, pkgOverride)
464	}
465	retString := strings.Join(rets, ", ")
466	if len(rets) > 1 {
467		retString = "(" + retString + ")"
468	}
469	if retString != "" {
470		retString = " " + retString
471	}
472
473	ia := newIdentifierAllocator(argNames)
474	idRecv := ia.allocateIdentifier("m")
475
476	g.p("// %v mocks base method.", m.Name)
477	g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
478	g.in()
479	g.p("%s.ctrl.T.Helper()", idRecv)
480
481	var callArgs string
482	if m.Variadic == nil {
483		if len(argNames) > 0 {
484			callArgs = ", " + strings.Join(argNames, ", ")
485		}
486	} else {
487		// Non-trivial. The generated code must build a []interface{},
488		// but the variadic argument may be any type.
489		idVarArgs := ia.allocateIdentifier("varargs")
490		idVArg := ia.allocateIdentifier("a")
491		g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", "))
492		g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1])
493		g.in()
494		g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg)
495		g.out()
496		g.p("}")
497		callArgs = ", " + idVarArgs + "..."
498	}
499	if len(m.Out) == 0 {
500		g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs)
501	} else {
502		idRet := ia.allocateIdentifier("ret")
503		g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs)
504
505		// Go does not allow "naked" type assertions on nil values, so we use the two-value form here.
506		// The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T.
507		// Happily, this coincides with the semantics we want here.
508		retNames := make([]string, len(rets))
509		for i, t := range rets {
510			retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i))
511			g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t)
512		}
513		g.p("return " + strings.Join(retNames, ", "))
514	}
515
516	g.out()
517	g.p("}")
518	return nil
519}
520
521func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
522	argNames := g.getArgNames(m)
523
524	var argString string
525	if m.Variadic == nil {
526		argString = strings.Join(argNames, ", ")
527	} else {
528		argString = strings.Join(argNames[:len(argNames)-1], ", ")
529	}
530	if argString != "" {
531		argString += " interface{}"
532	}
533
534	if m.Variadic != nil {
535		if argString != "" {
536			argString += ", "
537		}
538		argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1])
539	}
540
541	ia := newIdentifierAllocator(argNames)
542	idRecv := ia.allocateIdentifier("mr")
543
544	g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
545	g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
546	g.in()
547	g.p("%s.mock.ctrl.T.Helper()", idRecv)
548
549	var callArgs string
550	if m.Variadic == nil {
551		if len(argNames) > 0 {
552			callArgs = ", " + strings.Join(argNames, ", ")
553		}
554	} else {
555		if len(argNames) == 1 {
556			// Easy: just use ... to push the arguments through.
557			callArgs = ", " + argNames[0] + "..."
558		} else {
559			// Hard: create a temporary slice.
560			idVarArgs := ia.allocateIdentifier("varargs")
561			g.p("%s := append([]interface{}{%s}, %s...)",
562				idVarArgs,
563				strings.Join(argNames[:len(argNames)-1], ", "),
564				argNames[len(argNames)-1])
565			callArgs = ", " + idVarArgs + "..."
566		}
567	}
568	g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
569
570	g.out()
571	g.p("}")
572	return nil
573}
574
575func (g *generator) getArgNames(m *model.Method) []string {
576	argNames := make([]string, len(m.In))
577	for i, p := range m.In {
578		name := p.Name
579		if name == "" || name == "_" {
580			name = fmt.Sprintf("arg%d", i)
581		}
582		argNames[i] = name
583	}
584	if m.Variadic != nil {
585		name := m.Variadic.Name
586		if name == "" {
587			name = fmt.Sprintf("arg%d", len(m.In))
588		}
589		argNames = append(argNames, name)
590	}
591	return argNames
592}
593
594func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string {
595	argTypes := make([]string, len(m.In))
596	for i, p := range m.In {
597		argTypes[i] = p.Type.String(g.packageMap, pkgOverride)
598	}
599	if m.Variadic != nil {
600		argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride))
601	}
602	return argTypes
603}
604
605type identifierAllocator map[string]struct{}
606
607func newIdentifierAllocator(taken []string) identifierAllocator {
608	a := make(identifierAllocator, len(taken))
609	for _, s := range taken {
610		a[s] = struct{}{}
611	}
612	return a
613}
614
615func (o identifierAllocator) allocateIdentifier(want string) string {
616	id := want
617	for i := 2; ; i++ {
618		if _, ok := o[id]; !ok {
619			o[id] = struct{}{}
620			return id
621		}
622		id = want + "_" + strconv.Itoa(i)
623	}
624}
625
626// Output returns the generator's output, formatted in the standard Go style.
627func (g *generator) Output() []byte {
628	src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil)
629	if err != nil {
630		log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String())
631	}
632	return src
633}
634
635// createPackageMap returns a map of import path to package name
636// for specified importPaths.
637func createPackageMap(importPaths []string) map[string]string {
638	var pkg struct {
639		Name       string
640		ImportPath string
641	}
642	pkgMap := make(map[string]string)
643	b := bytes.NewBuffer(nil)
644	args := []string{"list", "-json"}
645	args = append(args, importPaths...)
646	cmd := exec.Command("go", args...)
647	cmd.Stdout = b
648	cmd.Run()
649	dec := json.NewDecoder(b)
650	for dec.More() {
651		err := dec.Decode(&pkg)
652		if err != nil {
653			log.Printf("failed to decode 'go list' output: %v", err)
654			continue
655		}
656		pkgMap[pkg.ImportPath] = pkg.Name
657	}
658	return pkgMap
659}
660
661func printVersion() {
662	if version != "" {
663		fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date)
664	} else {
665		printModuleVersion()
666	}
667}
668