1// Copyright 2012 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17// This file contains the model construction by parsing source files.
18
19import (
20	"errors"
21	"flag"
22	"fmt"
23	"go/ast"
24	"go/build"
25	"go/parser"
26	"go/token"
27	"io/ioutil"
28	"log"
29	"os"
30	"path"
31	"path/filepath"
32	"strconv"
33	"strings"
34
35	"github.com/golang/mock/mockgen/model"
36	"golang.org/x/mod/modfile"
37)
38
39var (
40	imports  = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
41	auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
42)
43
44// TODO: simplify error reporting
45
46// sourceMode generates mocks via source file.
47func sourceMode(source string) (*model.Package, error) {
48	srcDir, err := filepath.Abs(filepath.Dir(source))
49	if err != nil {
50		return nil, fmt.Errorf("failed getting source directory: %v", err)
51	}
52
53	packageImport, err := parsePackageImport(srcDir)
54	if err != nil {
55		return nil, err
56	}
57
58	fs := token.NewFileSet()
59	file, err := parser.ParseFile(fs, source, nil, 0)
60	if err != nil {
61		return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
62	}
63
64	p := &fileParser{
65		fileSet:            fs,
66		imports:            make(map[string]importedPackage),
67		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
68		auxInterfaces:      make(map[string]map[string]*ast.InterfaceType),
69		srcDir:             srcDir,
70	}
71
72	// Handle -imports.
73	dotImports := make(map[string]bool)
74	if *imports != "" {
75		for _, kv := range strings.Split(*imports, ",") {
76			eq := strings.Index(kv, "=")
77			k, v := kv[:eq], kv[eq+1:]
78			if k == "." {
79				// TODO: Catch dupes?
80				dotImports[v] = true
81			} else {
82				// TODO: Catch dupes?
83				p.imports[k] = importedPkg{path: v}
84			}
85		}
86	}
87
88	// Handle -aux_files.
89	if err := p.parseAuxFiles(*auxFiles); err != nil {
90		return nil, err
91	}
92	p.addAuxInterfacesFromFile(packageImport, file) // this file
93
94	pkg, err := p.parseFile(packageImport, file)
95	if err != nil {
96		return nil, err
97	}
98	for pkgPath := range dotImports {
99		pkg.DotImports = append(pkg.DotImports, pkgPath)
100	}
101	return pkg, nil
102}
103
104type importedPackage interface {
105	Path() string
106	Parser() *fileParser
107}
108
109type importedPkg struct {
110	path   string
111	parser *fileParser
112}
113
114func (i importedPkg) Path() string        { return i.path }
115func (i importedPkg) Parser() *fileParser { return i.parser }
116
117// duplicateImport is a bit of a misnomer. Currently the parser can't
118// handle cases of multi-file packages importing different packages
119// under the same name. Often these imports would not be problematic,
120// so this type lets us defer raising an error unless the package name
121// is actually used.
122type duplicateImport struct {
123	name       string
124	duplicates []string
125}
126
127func (d duplicateImport) Error() string {
128	return fmt.Sprintf("%q is ambigous because of duplicate imports: %v", d.name, d.duplicates)
129}
130
131func (d duplicateImport) Path() string        { log.Fatal(d.Error()); return "" }
132func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil }
133
134type fileParser struct {
135	fileSet            *token.FileSet
136	imports            map[string]importedPackage               // package name => imported package
137	importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
138
139	auxFiles      []*ast.File
140	auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
141
142	srcDir string
143}
144
145func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
146	ps := p.fileSet.Position(pos)
147	format = "%s:%d:%d: " + format
148	args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
149	return fmt.Errorf(format, args...)
150}
151
152func (p *fileParser) parseAuxFiles(auxFiles string) error {
153	auxFiles = strings.TrimSpace(auxFiles)
154	if auxFiles == "" {
155		return nil
156	}
157	for _, kv := range strings.Split(auxFiles, ",") {
158		parts := strings.SplitN(kv, "=", 2)
159		if len(parts) != 2 {
160			return fmt.Errorf("bad aux file spec: %v", kv)
161		}
162		pkg, fpath := parts[0], parts[1]
163
164		file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
165		if err != nil {
166			return err
167		}
168		p.auxFiles = append(p.auxFiles, file)
169		p.addAuxInterfacesFromFile(pkg, file)
170	}
171	return nil
172}
173
174func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
175	if _, ok := p.auxInterfaces[pkg]; !ok {
176		p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
177	}
178	for ni := range iterInterfaces(file) {
179		p.auxInterfaces[pkg][ni.name.Name] = ni.it
180	}
181}
182
183// parseFile loads all file imports and auxiliary files import into the
184// fileParser, parses all file interfaces and returns package model.
185func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
186	allImports, dotImports := importsOfFile(file)
187	// Don't stomp imports provided by -imports. Those should take precedence.
188	for pkg, pkgI := range allImports {
189		if _, ok := p.imports[pkg]; !ok {
190			p.imports[pkg] = pkgI
191		}
192	}
193	// Add imports from auxiliary files, which might be needed for embedded interfaces.
194	// Don't stomp any other imports.
195	for _, f := range p.auxFiles {
196		auxImports, _ := importsOfFile(f)
197		for pkg, pkgI := range auxImports {
198			if _, ok := p.imports[pkg]; !ok {
199				p.imports[pkg] = pkgI
200			}
201		}
202	}
203
204	var is []*model.Interface
205	for ni := range iterInterfaces(file) {
206		i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
207		if err != nil {
208			return nil, err
209		}
210		is = append(is, i)
211	}
212	return &model.Package{
213		Name:       file.Name.String(),
214		PkgPath:    importPath,
215		Interfaces: is,
216		DotImports: dotImports,
217	}, nil
218}
219
220// parsePackage loads package specified by path, parses it and returns
221// a new fileParser with the parsed imports and interfaces.
222func (p *fileParser) parsePackage(path string) (*fileParser, error) {
223	newP := &fileParser{
224		fileSet:            token.NewFileSet(),
225		imports:            make(map[string]importedPackage),
226		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
227		auxInterfaces:      make(map[string]map[string]*ast.InterfaceType),
228		srcDir:             p.srcDir,
229	}
230
231	var pkgs map[string]*ast.Package
232	if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil {
233		return nil, err
234	} else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil {
235		return nil, err
236	}
237
238	for _, pkg := range pkgs {
239		file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
240		if _, ok := newP.importedInterfaces[path]; !ok {
241			newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
242		}
243		for ni := range iterInterfaces(file) {
244			newP.importedInterfaces[path][ni.name.Name] = ni.it
245		}
246		imports, _ := importsOfFile(file)
247		for pkgName, pkgI := range imports {
248			newP.imports[pkgName] = pkgI
249		}
250	}
251	return newP, nil
252}
253
254func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
255	intf := &model.Interface{Name: name}
256	for _, field := range it.Methods.List {
257		switch v := field.Type.(type) {
258		case *ast.FuncType:
259			if nn := len(field.Names); nn != 1 {
260				return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
261			}
262			m := &model.Method{
263				Name: field.Names[0].String(),
264			}
265			var err error
266			m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
267			if err != nil {
268				return nil, err
269			}
270			intf.AddMethod(m)
271		case *ast.Ident:
272			// Embedded interface in this package.
273			ei := p.auxInterfaces[pkg][v.String()]
274			if ei == nil {
275				ei = p.importedInterfaces[pkg][v.String()]
276			}
277
278			var eintf *model.Interface
279			if ei != nil {
280				var err error
281				eintf, err = p.parseInterface(v.String(), pkg, ei)
282				if err != nil {
283					return nil, err
284				}
285			} else {
286				// This is built-in error interface.
287				if v.String() == model.ErrorInterface.Name {
288					eintf = &model.ErrorInterface
289				} else {
290					return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
291				}
292			}
293			// Copy the methods.
294			for _, m := range eintf.Methods {
295				intf.AddMethod(m)
296			}
297		case *ast.SelectorExpr:
298			// Embedded interface in another package.
299			fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
300			epkg, ok := p.imports[fpkg]
301			if !ok {
302				return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
303			}
304
305			var eintf *model.Interface
306			var err error
307			ei := p.auxInterfaces[fpkg][sel]
308			if ei != nil {
309				eintf, err = p.parseInterface(sel, fpkg, ei)
310				if err != nil {
311					return nil, err
312				}
313			} else {
314				path := epkg.Path()
315				parser := epkg.Parser()
316				if parser == nil {
317					ip, err := p.parsePackage(path)
318					if err != nil {
319						return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err)
320					}
321					parser = ip
322					p.imports[fpkg] = importedPkg{
323						path:   epkg.Path(),
324						parser: parser,
325					}
326				}
327				if ei = parser.importedInterfaces[path][sel]; ei == nil {
328					return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel)
329				}
330				eintf, err = parser.parseInterface(sel, path, ei)
331				if err != nil {
332					return nil, err
333				}
334			}
335			// Copy the methods.
336			// TODO: apply shadowing rules.
337			for _, m := range eintf.Methods {
338				intf.AddMethod(m)
339			}
340		default:
341			return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
342		}
343	}
344	return intf, nil
345}
346
347func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
348	if f.Params != nil {
349		regParams := f.Params.List
350		if isVariadic(f) {
351			n := len(regParams)
352			varParams := regParams[n-1:]
353			regParams = regParams[:n-1]
354			vp, err := p.parseFieldList(pkg, varParams)
355			if err != nil {
356				return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
357			}
358			variadic = vp[0]
359		}
360		in, err = p.parseFieldList(pkg, regParams)
361		if err != nil {
362			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
363		}
364	}
365	if f.Results != nil {
366		out, err = p.parseFieldList(pkg, f.Results.List)
367		if err != nil {
368			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
369		}
370	}
371	return
372}
373
374func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
375	nf := 0
376	for _, f := range fields {
377		nn := len(f.Names)
378		if nn == 0 {
379			nn = 1 // anonymous parameter
380		}
381		nf += nn
382	}
383	if nf == 0 {
384		return nil, nil
385	}
386	ps := make([]*model.Parameter, nf)
387	i := 0 // destination index
388	for _, f := range fields {
389		t, err := p.parseType(pkg, f.Type)
390		if err != nil {
391			return nil, err
392		}
393
394		if len(f.Names) == 0 {
395			// anonymous arg
396			ps[i] = &model.Parameter{Type: t}
397			i++
398			continue
399		}
400		for _, name := range f.Names {
401			ps[i] = &model.Parameter{Name: name.Name, Type: t}
402			i++
403		}
404	}
405	return ps, nil
406}
407
408func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
409	switch v := typ.(type) {
410	case *ast.ArrayType:
411		ln := -1
412		if v.Len != nil {
413			x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
414			if err != nil {
415				return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
416			}
417			ln = x
418		}
419		t, err := p.parseType(pkg, v.Elt)
420		if err != nil {
421			return nil, err
422		}
423		return &model.ArrayType{Len: ln, Type: t}, nil
424	case *ast.ChanType:
425		t, err := p.parseType(pkg, v.Value)
426		if err != nil {
427			return nil, err
428		}
429		var dir model.ChanDir
430		if v.Dir == ast.SEND {
431			dir = model.SendDir
432		}
433		if v.Dir == ast.RECV {
434			dir = model.RecvDir
435		}
436		return &model.ChanType{Dir: dir, Type: t}, nil
437	case *ast.Ellipsis:
438		// assume we're parsing a variadic argument
439		return p.parseType(pkg, v.Elt)
440	case *ast.FuncType:
441		in, variadic, out, err := p.parseFunc(pkg, v)
442		if err != nil {
443			return nil, err
444		}
445		return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
446	case *ast.Ident:
447		if v.IsExported() {
448			// `pkg` may be an aliased imported pkg
449			// if so, patch the import w/ the fully qualified import
450			maybeImportedPkg, ok := p.imports[pkg]
451			if ok {
452				pkg = maybeImportedPkg.Path()
453			}
454			// assume type in this package
455			return &model.NamedType{Package: pkg, Type: v.Name}, nil
456		}
457
458		// assume predeclared type
459		return model.PredeclaredType(v.Name), nil
460	case *ast.InterfaceType:
461		if v.Methods != nil && len(v.Methods.List) > 0 {
462			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
463		}
464		return model.PredeclaredType("interface{}"), nil
465	case *ast.MapType:
466		key, err := p.parseType(pkg, v.Key)
467		if err != nil {
468			return nil, err
469		}
470		value, err := p.parseType(pkg, v.Value)
471		if err != nil {
472			return nil, err
473		}
474		return &model.MapType{Key: key, Value: value}, nil
475	case *ast.SelectorExpr:
476		pkgName := v.X.(*ast.Ident).String()
477		pkg, ok := p.imports[pkgName]
478		if !ok {
479			return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
480		}
481		return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil
482	case *ast.StarExpr:
483		t, err := p.parseType(pkg, v.X)
484		if err != nil {
485			return nil, err
486		}
487		return &model.PointerType{Type: t}, nil
488	case *ast.StructType:
489		if v.Fields != nil && len(v.Fields.List) > 0 {
490			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
491		}
492		return model.PredeclaredType("struct{}"), nil
493	case *ast.ParenExpr:
494		return p.parseType(pkg, v.X)
495	}
496
497	return nil, fmt.Errorf("don't know how to parse type %T", typ)
498}
499
500// importsOfFile returns a map of package name to import path
501// of the imports in file.
502func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
503	var importPaths []string
504	for _, is := range file.Imports {
505		if is.Name != nil {
506			continue
507		}
508		importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
509		importPaths = append(importPaths, importPath)
510	}
511	packagesName := createPackageMap(importPaths)
512	normalImports = make(map[string]importedPackage)
513	dotImports = make([]string, 0)
514	for _, is := range file.Imports {
515		var pkgName string
516		importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
517
518		if is.Name != nil {
519			// Named imports are always certain.
520			if is.Name.Name == "_" {
521				continue
522			}
523			pkgName = is.Name.Name
524		} else {
525			pkg, ok := packagesName[importPath]
526			if !ok {
527				// Fallback to import path suffix. Note that this is uncertain.
528				_, last := path.Split(importPath)
529				// If the last path component has dots, the first dot-delimited
530				// field is used as the name.
531				pkgName = strings.SplitN(last, ".", 2)[0]
532			} else {
533				pkgName = pkg
534			}
535		}
536
537		if pkgName == "." {
538			dotImports = append(dotImports, importPath)
539		} else {
540			if pkg, ok := normalImports[pkgName]; ok {
541				switch p := pkg.(type) {
542				case duplicateImport:
543					normalImports[pkgName] = duplicateImport{
544						name:       p.name,
545						duplicates: append([]string{importPath}, p.duplicates...),
546					}
547				case importedPkg:
548					normalImports[pkgName] = duplicateImport{
549						name:       pkgName,
550						duplicates: []string{p.path, importPath},
551					}
552				}
553			} else {
554				normalImports[pkgName] = importedPkg{path: importPath}
555			}
556		}
557	}
558	return
559}
560
561type namedInterface struct {
562	name *ast.Ident
563	it   *ast.InterfaceType
564}
565
566// Create an iterator over all interfaces in file.
567func iterInterfaces(file *ast.File) <-chan namedInterface {
568	ch := make(chan namedInterface)
569	go func() {
570		for _, decl := range file.Decls {
571			gd, ok := decl.(*ast.GenDecl)
572			if !ok || gd.Tok != token.TYPE {
573				continue
574			}
575			for _, spec := range gd.Specs {
576				ts, ok := spec.(*ast.TypeSpec)
577				if !ok {
578					continue
579				}
580				it, ok := ts.Type.(*ast.InterfaceType)
581				if !ok {
582					continue
583				}
584
585				ch <- namedInterface{ts.Name, it}
586			}
587		}
588		close(ch)
589	}()
590	return ch
591}
592
593// isVariadic returns whether the function is variadic.
594func isVariadic(f *ast.FuncType) bool {
595	nargs := len(f.Params.List)
596	if nargs == 0 {
597		return false
598	}
599	_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
600	return ok
601}
602
603// packageNameOfDir get package import path via dir
604func packageNameOfDir(srcDir string) (string, error) {
605	files, err := ioutil.ReadDir(srcDir)
606	if err != nil {
607		log.Fatal(err)
608	}
609
610	var goFilePath string
611	for _, file := range files {
612		if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") {
613			goFilePath = file.Name()
614			break
615		}
616	}
617	if goFilePath == "" {
618		return "", fmt.Errorf("go source file not found %s", srcDir)
619	}
620
621	packageImport, err := parsePackageImport(srcDir)
622	if err != nil {
623		return "", err
624	}
625	return packageImport, nil
626}
627
628var errOutsideGoPath = errors.New("Source directory is outside GOPATH")
629
630// parseImportPackage get package import path via source file
631// an alternative implementation is to use:
632// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir}
633// pkgs, err := packages.Load(cfg, "file="+source)
634// However, it will call "go list" and slow down the performance
635func parsePackageImport(srcDir string) (string, error) {
636	moduleMode := os.Getenv("GO111MODULE")
637	// trying to find the module
638	if moduleMode != "off" {
639		currentDir := srcDir
640		for {
641			dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod"))
642			if os.IsNotExist(err) {
643				if currentDir == filepath.Dir(currentDir) {
644					// at the root
645					break
646				}
647				currentDir = filepath.Dir(currentDir)
648				continue
649			} else if err != nil {
650				return "", err
651			}
652			modulePath := modfile.ModulePath(dat)
653			return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil
654		}
655	}
656	// fall back to GOPATH mode
657	goPaths := os.Getenv("GOPATH")
658	if goPaths == "" {
659		return "", fmt.Errorf("GOPATH is not set")
660	}
661	goPathList := strings.Split(goPaths, string(os.PathListSeparator))
662	for _, goPath := range goPathList {
663		sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator)
664		if strings.HasPrefix(srcDir, sourceRoot) {
665			return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil
666		}
667	}
668	return "", errOutsideGoPath
669}
670