1// Copyright 2012 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package main
16
17// This file contains the model construction by parsing source files.
18
19import (
20	"flag"
21	"fmt"
22	"go/ast"
23	"go/build"
24	"go/parser"
25	"go/token"
26	"log"
27	"path"
28	"path/filepath"
29	"strconv"
30	"strings"
31
32	"github.com/golang/mock/mockgen/model"
33)
34
35var (
36	imports  = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
37	auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
38)
39
40// TODO: simplify error reporting
41
42func ParseFile(source string) (*model.Package, error) {
43	srcDir, err := filepath.Abs(filepath.Dir(source))
44	if err != nil {
45		return nil, fmt.Errorf("failed getting source directory: %v", err)
46	}
47
48	var packageImport string
49	if p, err := build.ImportDir(srcDir, 0); err == nil {
50		packageImport = p.ImportPath
51	} // TODO: should we fail if this returns an error?
52
53	fs := token.NewFileSet()
54	file, err := parser.ParseFile(fs, source, nil, 0)
55	if err != nil {
56		return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
57	}
58
59	p := &fileParser{
60		fileSet:            fs,
61		imports:            make(map[string]string),
62		importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
63		auxInterfaces:      make(map[string]map[string]*ast.InterfaceType),
64		srcDir:             srcDir,
65	}
66
67	// Handle -imports.
68	dotImports := make(map[string]bool)
69	if *imports != "" {
70		for _, kv := range strings.Split(*imports, ",") {
71			eq := strings.Index(kv, "=")
72			k, v := kv[:eq], kv[eq+1:]
73			if k == "." {
74				// TODO: Catch dupes?
75				dotImports[v] = true
76			} else {
77				// TODO: Catch dupes?
78				p.imports[k] = v
79			}
80		}
81	}
82
83	// Handle -aux_files.
84	if err := p.parseAuxFiles(*auxFiles); err != nil {
85		return nil, err
86	}
87	p.addAuxInterfacesFromFile(packageImport, file) // this file
88
89	pkg, err := p.parseFile(packageImport, file)
90	if err != nil {
91		return nil, err
92	}
93	pkg.DotImports = make([]string, 0, len(dotImports))
94	for path := range dotImports {
95		pkg.DotImports = append(pkg.DotImports, path)
96	}
97	return pkg, nil
98}
99
100type fileParser struct {
101	fileSet            *token.FileSet
102	imports            map[string]string                        // package name => import path
103	importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
104
105	auxFiles      []*ast.File
106	auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
107
108	srcDir string
109}
110
111func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
112	ps := p.fileSet.Position(pos)
113	format = "%s:%d:%d: " + format
114	args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
115	return fmt.Errorf(format, args...)
116}
117
118func (p *fileParser) parseAuxFiles(auxFiles string) error {
119	auxFiles = strings.TrimSpace(auxFiles)
120	if auxFiles == "" {
121		return nil
122	}
123	for _, kv := range strings.Split(auxFiles, ",") {
124		parts := strings.SplitN(kv, "=", 2)
125		if len(parts) != 2 {
126			return fmt.Errorf("bad aux file spec: %v", kv)
127		}
128		pkg, fpath := parts[0], parts[1]
129
130		file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
131		if err != nil {
132			return err
133		}
134		p.auxFiles = append(p.auxFiles, file)
135		p.addAuxInterfacesFromFile(pkg, file)
136	}
137	return nil
138}
139
140func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
141	if _, ok := p.auxInterfaces[pkg]; !ok {
142		p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
143	}
144	for ni := range iterInterfaces(file) {
145		p.auxInterfaces[pkg][ni.name.Name] = ni.it
146	}
147}
148
149// parseFile loads all file imports and auxiliary files import into the
150// fileParser, parses all file interfaces and returns package model.
151func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
152	allImports := importsOfFile(file)
153	// Don't stomp imports provided by -imports. Those should take precedence.
154	for pkg, path := range allImports {
155		if _, ok := p.imports[pkg]; !ok {
156			p.imports[pkg] = path
157		}
158	}
159	// Add imports from auxiliary files, which might be needed for embedded interfaces.
160	// Don't stomp any other imports.
161	for _, f := range p.auxFiles {
162		for pkg, path := range importsOfFile(f) {
163			if _, ok := p.imports[pkg]; !ok {
164				p.imports[pkg] = path
165			}
166		}
167	}
168
169	var is []*model.Interface
170	for ni := range iterInterfaces(file) {
171		i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
172		if err != nil {
173			return nil, err
174		}
175		is = append(is, i)
176	}
177	return &model.Package{
178		Name:       file.Name.String(),
179		Interfaces: is,
180	}, nil
181}
182
183// parsePackage loads package specified by path, parses it and populates
184// corresponding imports and importedInterfaces into the fileParser.
185func (p *fileParser) parsePackage(path string) error {
186	var pkgs map[string]*ast.Package
187	if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil {
188		return err
189	} else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil {
190		return err
191	}
192	for _, pkg := range pkgs {
193		file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
194		if _, ok := p.importedInterfaces[path]; !ok {
195			p.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
196		}
197		for ni := range iterInterfaces(file) {
198			p.importedInterfaces[path][ni.name.Name] = ni.it
199		}
200		for pkgName, pkgPath := range importsOfFile(file) {
201			if _, ok := p.imports[pkgName]; !ok {
202				p.imports[pkgName] = pkgPath
203			}
204		}
205	}
206	return nil
207}
208
209func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
210	intf := &model.Interface{Name: name}
211	for _, field := range it.Methods.List {
212		switch v := field.Type.(type) {
213		case *ast.FuncType:
214			if nn := len(field.Names); nn != 1 {
215				return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
216			}
217			m := &model.Method{
218				Name: field.Names[0].String(),
219			}
220			var err error
221			m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
222			if err != nil {
223				return nil, err
224			}
225			intf.Methods = append(intf.Methods, m)
226		case *ast.Ident:
227			// Embedded interface in this package.
228			ei := p.auxInterfaces[pkg][v.String()]
229			if ei == nil {
230				if ei = p.importedInterfaces[pkg][v.String()]; ei == nil {
231					return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
232				}
233			}
234			eintf, err := p.parseInterface(v.String(), pkg, ei)
235			if err != nil {
236				return nil, err
237			}
238			// Copy the methods.
239			// TODO: apply shadowing rules.
240			for _, m := range eintf.Methods {
241				intf.Methods = append(intf.Methods, m)
242			}
243		case *ast.SelectorExpr:
244			// Embedded interface in another package.
245			fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
246			epkg, ok := p.imports[fpkg]
247			if !ok {
248				return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
249			}
250			ei := p.auxInterfaces[fpkg][sel]
251			if ei == nil {
252				fpkg = epkg
253				if _, ok = p.importedInterfaces[epkg]; !ok {
254					if err := p.parsePackage(epkg); err != nil {
255						return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err)
256					}
257				}
258				if ei = p.importedInterfaces[epkg][sel]; ei == nil {
259					return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
260				}
261			}
262			eintf, err := p.parseInterface(sel, fpkg, ei)
263			if err != nil {
264				return nil, err
265			}
266			// Copy the methods.
267			// TODO: apply shadowing rules.
268			for _, m := range eintf.Methods {
269				intf.Methods = append(intf.Methods, m)
270			}
271		default:
272			return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
273		}
274	}
275	return intf, nil
276}
277
278func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
279	if f.Params != nil {
280		regParams := f.Params.List
281		if isVariadic(f) {
282			n := len(regParams)
283			varParams := regParams[n-1:]
284			regParams = regParams[:n-1]
285			vp, err := p.parseFieldList(pkg, varParams)
286			if err != nil {
287				return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
288			}
289			variadic = vp[0]
290		}
291		in, err = p.parseFieldList(pkg, regParams)
292		if err != nil {
293			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
294		}
295	}
296	if f.Results != nil {
297		out, err = p.parseFieldList(pkg, f.Results.List)
298		if err != nil {
299			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
300		}
301	}
302	return
303}
304
305func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
306	nf := 0
307	for _, f := range fields {
308		nn := len(f.Names)
309		if nn == 0 {
310			nn = 1 // anonymous parameter
311		}
312		nf += nn
313	}
314	if nf == 0 {
315		return nil, nil
316	}
317	ps := make([]*model.Parameter, nf)
318	i := 0 // destination index
319	for _, f := range fields {
320		t, err := p.parseType(pkg, f.Type)
321		if err != nil {
322			return nil, err
323		}
324
325		if len(f.Names) == 0 {
326			// anonymous arg
327			ps[i] = &model.Parameter{Type: t}
328			i++
329			continue
330		}
331		for _, name := range f.Names {
332			ps[i] = &model.Parameter{Name: name.Name, Type: t}
333			i++
334		}
335	}
336	return ps, nil
337}
338
339func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
340	switch v := typ.(type) {
341	case *ast.ArrayType:
342		ln := -1
343		if v.Len != nil {
344			x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
345			if err != nil {
346				return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
347			}
348			ln = x
349		}
350		t, err := p.parseType(pkg, v.Elt)
351		if err != nil {
352			return nil, err
353		}
354		return &model.ArrayType{Len: ln, Type: t}, nil
355	case *ast.ChanType:
356		t, err := p.parseType(pkg, v.Value)
357		if err != nil {
358			return nil, err
359		}
360		var dir model.ChanDir
361		if v.Dir == ast.SEND {
362			dir = model.SendDir
363		}
364		if v.Dir == ast.RECV {
365			dir = model.RecvDir
366		}
367		return &model.ChanType{Dir: dir, Type: t}, nil
368	case *ast.Ellipsis:
369		// assume we're parsing a variadic argument
370		return p.parseType(pkg, v.Elt)
371	case *ast.FuncType:
372		in, variadic, out, err := p.parseFunc(pkg, v)
373		if err != nil {
374			return nil, err
375		}
376		return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
377	case *ast.Ident:
378		if v.IsExported() {
379			// `pkg` may be an aliased imported pkg
380			// if so, patch the import w/ the fully qualified import
381			maybeImportedPkg, ok := p.imports[pkg]
382			if ok {
383				pkg = maybeImportedPkg
384			}
385			// assume type in this package
386			return &model.NamedType{Package: pkg, Type: v.Name}, nil
387		} else {
388			// assume predeclared type
389			return model.PredeclaredType(v.Name), nil
390		}
391	case *ast.InterfaceType:
392		if v.Methods != nil && len(v.Methods.List) > 0 {
393			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
394		}
395		return model.PredeclaredType("interface{}"), nil
396	case *ast.MapType:
397		key, err := p.parseType(pkg, v.Key)
398		if err != nil {
399			return nil, err
400		}
401		value, err := p.parseType(pkg, v.Value)
402		if err != nil {
403			return nil, err
404		}
405		return &model.MapType{Key: key, Value: value}, nil
406	case *ast.SelectorExpr:
407		pkgName := v.X.(*ast.Ident).String()
408		pkg, ok := p.imports[pkgName]
409		if !ok {
410			return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
411		}
412		return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
413	case *ast.StarExpr:
414		t, err := p.parseType(pkg, v.X)
415		if err != nil {
416			return nil, err
417		}
418		return &model.PointerType{Type: t}, nil
419	case *ast.StructType:
420		if v.Fields != nil && len(v.Fields.List) > 0 {
421			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
422		}
423		return model.PredeclaredType("struct{}"), nil
424	}
425
426	return nil, fmt.Errorf("don't know how to parse type %T", typ)
427}
428
429// importsOfFile returns a map of package name to import path
430// of the imports in file.
431func importsOfFile(file *ast.File) map[string]string {
432	m := make(map[string]string)
433	for _, is := range file.Imports {
434		var pkgName string
435		importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
436
437		if is.Name != nil {
438			// Named imports are always certain.
439			if is.Name.Name == "_" {
440				continue
441			}
442			pkgName = removeDot(is.Name.Name)
443		} else {
444			pkg, err := build.Import(importPath, "", 0)
445			if err != nil {
446				// Fallback to import path suffix. Note that this is uncertain.
447				_, last := path.Split(importPath)
448				// If the last path component has dots, the first dot-delimited
449				// field is used as the name.
450				pkgName = strings.SplitN(last, ".", 2)[0]
451			} else {
452				pkgName = pkg.Name
453			}
454		}
455
456		if _, ok := m[pkgName]; ok {
457			log.Fatalf("imported package collision: %q imported twice", pkgName)
458		}
459		m[pkgName] = importPath
460	}
461	return m
462}
463
464type namedInterface struct {
465	name *ast.Ident
466	it   *ast.InterfaceType
467}
468
469// Create an iterator over all interfaces in file.
470func iterInterfaces(file *ast.File) <-chan namedInterface {
471	ch := make(chan namedInterface)
472	go func() {
473		for _, decl := range file.Decls {
474			gd, ok := decl.(*ast.GenDecl)
475			if !ok || gd.Tok != token.TYPE {
476				continue
477			}
478			for _, spec := range gd.Specs {
479				ts, ok := spec.(*ast.TypeSpec)
480				if !ok {
481					continue
482				}
483				it, ok := ts.Type.(*ast.InterfaceType)
484				if !ok {
485					continue
486				}
487
488				ch <- namedInterface{ts.Name, it}
489			}
490		}
491		close(ch)
492	}()
493	return ch
494}
495
496// isVariadic returns whether the function is variadic.
497func isVariadic(f *ast.FuncType) bool {
498	nargs := len(f.Params.List)
499	if nargs == 0 {
500		return false
501	}
502	_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
503	return ok
504}
505