1// Copyright 2012 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17// This file contains the model construction by parsing source files.
18
19import (
20	"errors"
21	"flag"
22	"fmt"
23	"go/ast"
24	"go/build"
25	"go/importer"
26	"go/parser"
27	"go/token"
28	"go/types"
29	"io/ioutil"
30	"log"
31	"path"
32	"path/filepath"
33	"strconv"
34	"strings"
35
36	"github.com/golang/mock/mockgen/model"
37)
38
39var (
40	imports  = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
41	auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
42)
43
44// sourceMode generates mocks via source file.
45func sourceMode(source string) (*model.Package, error) {
46	srcDir, err := filepath.Abs(filepath.Dir(source))
47	if err != nil {
48		return nil, fmt.Errorf("failed getting source directory: %v", err)
49	}
50
51	packageImport, err := parsePackageImport(srcDir)
52	if err != nil {
53		return nil, err
54	}
55
56	fs := token.NewFileSet()
57	file, err := parser.ParseFile(fs, source, nil, 0)
58	if err != nil {
59		return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
60	}
61
62	p := &fileParser{
63		fileSet:            fs,
64		imports:            make(map[string]importedPackage),
65		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
66		auxInterfaces:      make(map[string]map[string]*ast.InterfaceType),
67		srcDir:             srcDir,
68	}
69
70	// Handle -imports.
71	dotImports := make(map[string]bool)
72	if *imports != "" {
73		for _, kv := range strings.Split(*imports, ",") {
74			eq := strings.Index(kv, "=")
75			k, v := kv[:eq], kv[eq+1:]
76			if k == "." {
77				dotImports[v] = true
78			} else {
79				p.imports[k] = importedPkg{path: v}
80			}
81		}
82	}
83
84	// Handle -aux_files.
85	if err := p.parseAuxFiles(*auxFiles); err != nil {
86		return nil, err
87	}
88	p.addAuxInterfacesFromFile(packageImport, file) // this file
89
90	pkg, err := p.parseFile(packageImport, file)
91	if err != nil {
92		return nil, err
93	}
94	for pkgPath := range dotImports {
95		pkg.DotImports = append(pkg.DotImports, pkgPath)
96	}
97	return pkg, nil
98}
99
100type importedPackage interface {
101	Path() string
102	Parser() *fileParser
103}
104
105type importedPkg struct {
106	path   string
107	parser *fileParser
108}
109
110func (i importedPkg) Path() string        { return i.path }
111func (i importedPkg) Parser() *fileParser { return i.parser }
112
113// duplicateImport is a bit of a misnomer. Currently the parser can't
114// handle cases of multi-file packages importing different packages
115// under the same name. Often these imports would not be problematic,
116// so this type lets us defer raising an error unless the package name
117// is actually used.
118type duplicateImport struct {
119	name       string
120	duplicates []string
121}
122
123func (d duplicateImport) Error() string {
124	return fmt.Sprintf("%q is ambiguous because of duplicate imports: %v", d.name, d.duplicates)
125}
126
127func (d duplicateImport) Path() string        { log.Fatal(d.Error()); return "" }
128func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil }
129
130type fileParser struct {
131	fileSet            *token.FileSet
132	imports            map[string]importedPackage               // package name => imported package
133	importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
134
135	auxFiles      []*ast.File
136	auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
137
138	srcDir string
139}
140
141func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
142	ps := p.fileSet.Position(pos)
143	format = "%s:%d:%d: " + format
144	args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
145	return fmt.Errorf(format, args...)
146}
147
148func (p *fileParser) parseAuxFiles(auxFiles string) error {
149	auxFiles = strings.TrimSpace(auxFiles)
150	if auxFiles == "" {
151		return nil
152	}
153	for _, kv := range strings.Split(auxFiles, ",") {
154		parts := strings.SplitN(kv, "=", 2)
155		if len(parts) != 2 {
156			return fmt.Errorf("bad aux file spec: %v", kv)
157		}
158		pkg, fpath := parts[0], parts[1]
159
160		file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
161		if err != nil {
162			return err
163		}
164		p.auxFiles = append(p.auxFiles, file)
165		p.addAuxInterfacesFromFile(pkg, file)
166	}
167	return nil
168}
169
170func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
171	if _, ok := p.auxInterfaces[pkg]; !ok {
172		p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
173	}
174	for ni := range iterInterfaces(file) {
175		p.auxInterfaces[pkg][ni.name.Name] = ni.it
176	}
177}
178
179// parseFile loads all file imports and auxiliary files import into the
180// fileParser, parses all file interfaces and returns package model.
181func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
182	allImports, dotImports := importsOfFile(file)
183	// Don't stomp imports provided by -imports. Those should take precedence.
184	for pkg, pkgI := range allImports {
185		if _, ok := p.imports[pkg]; !ok {
186			p.imports[pkg] = pkgI
187		}
188	}
189	// Add imports from auxiliary files, which might be needed for embedded interfaces.
190	// Don't stomp any other imports.
191	for _, f := range p.auxFiles {
192		auxImports, _ := importsOfFile(f)
193		for pkg, pkgI := range auxImports {
194			if _, ok := p.imports[pkg]; !ok {
195				p.imports[pkg] = pkgI
196			}
197		}
198	}
199
200	var is []*model.Interface
201	for ni := range iterInterfaces(file) {
202		i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
203		if err != nil {
204			return nil, err
205		}
206		is = append(is, i)
207	}
208	return &model.Package{
209		Name:       file.Name.String(),
210		PkgPath:    importPath,
211		Interfaces: is,
212		DotImports: dotImports,
213	}, nil
214}
215
216// parsePackage loads package specified by path, parses it and returns
217// a new fileParser with the parsed imports and interfaces.
218func (p *fileParser) parsePackage(path string) (*fileParser, error) {
219	newP := &fileParser{
220		fileSet:            token.NewFileSet(),
221		imports:            make(map[string]importedPackage),
222		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
223		auxInterfaces:      make(map[string]map[string]*ast.InterfaceType),
224		srcDir:             p.srcDir,
225	}
226
227	var pkgs map[string]*ast.Package
228	if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil {
229		return nil, err
230	} else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil {
231		return nil, err
232	}
233
234	for _, pkg := range pkgs {
235		file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
236		if _, ok := newP.importedInterfaces[path]; !ok {
237			newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
238		}
239		for ni := range iterInterfaces(file) {
240			newP.importedInterfaces[path][ni.name.Name] = ni.it
241		}
242		imports, _ := importsOfFile(file)
243		for pkgName, pkgI := range imports {
244			newP.imports[pkgName] = pkgI
245		}
246	}
247	return newP, nil
248}
249
250func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
251	iface := &model.Interface{Name: name}
252	for _, field := range it.Methods.List {
253		switch v := field.Type.(type) {
254		case *ast.FuncType:
255			if nn := len(field.Names); nn != 1 {
256				return nil, fmt.Errorf("expected one name for interface %v, got %d", iface.Name, nn)
257			}
258			m := &model.Method{
259				Name: field.Names[0].String(),
260			}
261			var err error
262			m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
263			if err != nil {
264				return nil, err
265			}
266			iface.AddMethod(m)
267		case *ast.Ident:
268			// Embedded interface in this package.
269			embeddedIfaceType := p.auxInterfaces[pkg][v.String()]
270			if embeddedIfaceType == nil {
271				embeddedIfaceType = p.importedInterfaces[pkg][v.String()]
272			}
273
274			var embeddedIface *model.Interface
275			if embeddedIfaceType != nil {
276				var err error
277				embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType)
278				if err != nil {
279					return nil, err
280				}
281			} else {
282				// This is built-in error interface.
283				if v.String() == model.ErrorInterface.Name {
284					embeddedIface = &model.ErrorInterface
285				} else {
286					return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
287				}
288			}
289			// Copy the methods.
290			for _, m := range embeddedIface.Methods {
291				iface.AddMethod(m)
292			}
293		case *ast.SelectorExpr:
294			// Embedded interface in another package.
295			filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
296			embeddedPkg, ok := p.imports[filePkg]
297			if !ok {
298				return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg)
299			}
300
301			var embeddedIface *model.Interface
302			var err error
303			embeddedIfaceType := p.auxInterfaces[filePkg][sel]
304			if embeddedIfaceType != nil {
305				embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType)
306				if err != nil {
307					return nil, err
308				}
309			} else {
310				path := embeddedPkg.Path()
311				parser := embeddedPkg.Parser()
312				if parser == nil {
313					ip, err := p.parsePackage(path)
314					if err != nil {
315						return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err)
316					}
317					parser = ip
318					p.imports[filePkg] = importedPkg{
319						path:   embeddedPkg.Path(),
320						parser: parser,
321					}
322				}
323				if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil {
324					return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel)
325				}
326				embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType)
327				if err != nil {
328					return nil, err
329				}
330			}
331			// Copy the methods.
332			// TODO: apply shadowing rules.
333			for _, m := range embeddedIface.Methods {
334				iface.AddMethod(m)
335			}
336		default:
337			return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
338		}
339	}
340	return iface, nil
341}
342
343func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) {
344	if f.Params != nil {
345		regParams := f.Params.List
346		if isVariadic(f) {
347			n := len(regParams)
348			varParams := regParams[n-1:]
349			regParams = regParams[:n-1]
350			vp, err := p.parseFieldList(pkg, varParams)
351			if err != nil {
352				return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
353			}
354			variadic = vp[0]
355		}
356		inParam, err = p.parseFieldList(pkg, regParams)
357		if err != nil {
358			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
359		}
360	}
361	if f.Results != nil {
362		outParam, err = p.parseFieldList(pkg, f.Results.List)
363		if err != nil {
364			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
365		}
366	}
367	return
368}
369
370func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
371	nf := 0
372	for _, f := range fields {
373		nn := len(f.Names)
374		if nn == 0 {
375			nn = 1 // anonymous parameter
376		}
377		nf += nn
378	}
379	if nf == 0 {
380		return nil, nil
381	}
382	ps := make([]*model.Parameter, nf)
383	i := 0 // destination index
384	for _, f := range fields {
385		t, err := p.parseType(pkg, f.Type)
386		if err != nil {
387			return nil, err
388		}
389
390		if len(f.Names) == 0 {
391			// anonymous arg
392			ps[i] = &model.Parameter{Type: t}
393			i++
394			continue
395		}
396		for _, name := range f.Names {
397			ps[i] = &model.Parameter{Name: name.Name, Type: t}
398			i++
399		}
400	}
401	return ps, nil
402}
403
404func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
405	switch v := typ.(type) {
406	case *ast.ArrayType:
407		ln := -1
408		if v.Len != nil {
409			var value string
410			switch val := v.Len.(type) {
411			case (*ast.BasicLit):
412				value = val.Value
413			case (*ast.Ident):
414				// when the length is a const defined locally
415				value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value
416			case (*ast.SelectorExpr):
417				// when the length is a const defined in an external package
418				usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X))
419				if err != nil {
420					return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err)
421				}
422				ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name)
423				if err != nil {
424					return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err)
425				}
426				value = ev.Value.String()
427			}
428
429			x, err := strconv.Atoi(value)
430			if err != nil {
431				return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
432			}
433			ln = x
434		}
435		t, err := p.parseType(pkg, v.Elt)
436		if err != nil {
437			return nil, err
438		}
439		return &model.ArrayType{Len: ln, Type: t}, nil
440	case *ast.ChanType:
441		t, err := p.parseType(pkg, v.Value)
442		if err != nil {
443			return nil, err
444		}
445		var dir model.ChanDir
446		if v.Dir == ast.SEND {
447			dir = model.SendDir
448		}
449		if v.Dir == ast.RECV {
450			dir = model.RecvDir
451		}
452		return &model.ChanType{Dir: dir, Type: t}, nil
453	case *ast.Ellipsis:
454		// assume we're parsing a variadic argument
455		return p.parseType(pkg, v.Elt)
456	case *ast.FuncType:
457		in, variadic, out, err := p.parseFunc(pkg, v)
458		if err != nil {
459			return nil, err
460		}
461		return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
462	case *ast.Ident:
463		if v.IsExported() {
464			// `pkg` may be an aliased imported pkg
465			// if so, patch the import w/ the fully qualified import
466			maybeImportedPkg, ok := p.imports[pkg]
467			if ok {
468				pkg = maybeImportedPkg.Path()
469			}
470			// assume type in this package
471			return &model.NamedType{Package: pkg, Type: v.Name}, nil
472		}
473
474		// assume predeclared type
475		return model.PredeclaredType(v.Name), nil
476	case *ast.InterfaceType:
477		if v.Methods != nil && len(v.Methods.List) > 0 {
478			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
479		}
480		return model.PredeclaredType("interface{}"), nil
481	case *ast.MapType:
482		key, err := p.parseType(pkg, v.Key)
483		if err != nil {
484			return nil, err
485		}
486		value, err := p.parseType(pkg, v.Value)
487		if err != nil {
488			return nil, err
489		}
490		return &model.MapType{Key: key, Value: value}, nil
491	case *ast.SelectorExpr:
492		pkgName := v.X.(*ast.Ident).String()
493		pkg, ok := p.imports[pkgName]
494		if !ok {
495			return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
496		}
497		return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil
498	case *ast.StarExpr:
499		t, err := p.parseType(pkg, v.X)
500		if err != nil {
501			return nil, err
502		}
503		return &model.PointerType{Type: t}, nil
504	case *ast.StructType:
505		if v.Fields != nil && len(v.Fields.List) > 0 {
506			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
507		}
508		return model.PredeclaredType("struct{}"), nil
509	case *ast.ParenExpr:
510		return p.parseType(pkg, v.X)
511	}
512
513	return nil, fmt.Errorf("don't know how to parse type %T", typ)
514}
515
516// importsOfFile returns a map of package name to import path
517// of the imports in file.
518func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) {
519	var importPaths []string
520	for _, is := range file.Imports {
521		if is.Name != nil {
522			continue
523		}
524		importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
525		importPaths = append(importPaths, importPath)
526	}
527	packagesName := createPackageMap(importPaths)
528	normalImports = make(map[string]importedPackage)
529	dotImports = make([]string, 0)
530	for _, is := range file.Imports {
531		var pkgName string
532		importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
533
534		if is.Name != nil {
535			// Named imports are always certain.
536			if is.Name.Name == "_" {
537				continue
538			}
539			pkgName = is.Name.Name
540		} else {
541			pkg, ok := packagesName[importPath]
542			if !ok {
543				// Fallback to import path suffix. Note that this is uncertain.
544				_, last := path.Split(importPath)
545				// If the last path component has dots, the first dot-delimited
546				// field is used as the name.
547				pkgName = strings.SplitN(last, ".", 2)[0]
548			} else {
549				pkgName = pkg
550			}
551		}
552
553		if pkgName == "." {
554			dotImports = append(dotImports, importPath)
555		} else {
556			if pkg, ok := normalImports[pkgName]; ok {
557				switch p := pkg.(type) {
558				case duplicateImport:
559					normalImports[pkgName] = duplicateImport{
560						name:       p.name,
561						duplicates: append([]string{importPath}, p.duplicates...),
562					}
563				case importedPkg:
564					normalImports[pkgName] = duplicateImport{
565						name:       pkgName,
566						duplicates: []string{p.path, importPath},
567					}
568				}
569			} else {
570				normalImports[pkgName] = importedPkg{path: importPath}
571			}
572		}
573	}
574	return
575}
576
577type namedInterface struct {
578	name *ast.Ident
579	it   *ast.InterfaceType
580}
581
582// Create an iterator over all interfaces in file.
583func iterInterfaces(file *ast.File) <-chan namedInterface {
584	ch := make(chan namedInterface)
585	go func() {
586		for _, decl := range file.Decls {
587			gd, ok := decl.(*ast.GenDecl)
588			if !ok || gd.Tok != token.TYPE {
589				continue
590			}
591			for _, spec := range gd.Specs {
592				ts, ok := spec.(*ast.TypeSpec)
593				if !ok {
594					continue
595				}
596				it, ok := ts.Type.(*ast.InterfaceType)
597				if !ok {
598					continue
599				}
600
601				ch <- namedInterface{ts.Name, it}
602			}
603		}
604		close(ch)
605	}()
606	return ch
607}
608
609// isVariadic returns whether the function is variadic.
610func isVariadic(f *ast.FuncType) bool {
611	nargs := len(f.Params.List)
612	if nargs == 0 {
613		return false
614	}
615	_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
616	return ok
617}
618
619// packageNameOfDir get package import path via dir
620func packageNameOfDir(srcDir string) (string, error) {
621	files, err := ioutil.ReadDir(srcDir)
622	if err != nil {
623		log.Fatal(err)
624	}
625
626	var goFilePath string
627	for _, file := range files {
628		if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") {
629			goFilePath = file.Name()
630			break
631		}
632	}
633	if goFilePath == "" {
634		return "", fmt.Errorf("go source file not found %s", srcDir)
635	}
636
637	packageImport, err := parsePackageImport(srcDir)
638	if err != nil {
639		return "", err
640	}
641	return packageImport, nil
642}
643
644var errOutsideGoPath = errors.New("source directory is outside GOPATH")
645