1// Copyright 2012 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17// This file contains the model construction by parsing source files. 18 19import ( 20 "errors" 21 "flag" 22 "fmt" 23 "go/ast" 24 "go/build" 25 "go/parser" 26 "go/token" 27 "io/ioutil" 28 "log" 29 "os" 30 "path" 31 "path/filepath" 32 "strconv" 33 "strings" 34 35 "github.com/golang/mock/mockgen/model" 36 "golang.org/x/mod/modfile" 37) 38 39var ( 40 imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") 41 auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") 42) 43 44// TODO: simplify error reporting 45 46// sourceMode generates mocks via source file. 47func sourceMode(source string) (*model.Package, error) { 48 srcDir, err := filepath.Abs(filepath.Dir(source)) 49 if err != nil { 50 return nil, fmt.Errorf("failed getting source directory: %v", err) 51 } 52 53 packageImport, err := parsePackageImport(srcDir) 54 if err != nil { 55 return nil, err 56 } 57 58 fs := token.NewFileSet() 59 file, err := parser.ParseFile(fs, source, nil, 0) 60 if err != nil { 61 return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) 62 } 63 64 p := &fileParser{ 65 fileSet: fs, 66 imports: make(map[string]importedPackage), 67 importedInterfaces: make(map[string]map[string]*ast.InterfaceType), 68 auxInterfaces: make(map[string]map[string]*ast.InterfaceType), 69 srcDir: srcDir, 70 } 71 72 // Handle -imports. 73 dotImports := make(map[string]bool) 74 if *imports != "" { 75 for _, kv := range strings.Split(*imports, ",") { 76 eq := strings.Index(kv, "=") 77 k, v := kv[:eq], kv[eq+1:] 78 if k == "." { 79 // TODO: Catch dupes? 80 dotImports[v] = true 81 } else { 82 // TODO: Catch dupes? 83 p.imports[k] = importedPkg{path: v} 84 } 85 } 86 } 87 88 // Handle -aux_files. 89 if err := p.parseAuxFiles(*auxFiles); err != nil { 90 return nil, err 91 } 92 p.addAuxInterfacesFromFile(packageImport, file) // this file 93 94 pkg, err := p.parseFile(packageImport, file) 95 if err != nil { 96 return nil, err 97 } 98 for pkgPath := range dotImports { 99 pkg.DotImports = append(pkg.DotImports, pkgPath) 100 } 101 return pkg, nil 102} 103 104type importedPackage interface { 105 Path() string 106 Parser() *fileParser 107} 108 109type importedPkg struct { 110 path string 111 parser *fileParser 112} 113 114func (i importedPkg) Path() string { return i.path } 115func (i importedPkg) Parser() *fileParser { return i.parser } 116 117// duplicateImport is a bit of a misnomer. Currently the parser can't 118// handle cases of multi-file packages importing different packages 119// under the same name. Often these imports would not be problematic, 120// so this type lets us defer raising an error unless the package name 121// is actually used. 122type duplicateImport struct { 123 name string 124 duplicates []string 125} 126 127func (d duplicateImport) Error() string { 128 return fmt.Sprintf("%q is ambigous because of duplicate imports: %v", d.name, d.duplicates) 129} 130 131func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } 132func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil } 133 134type fileParser struct { 135 fileSet *token.FileSet 136 imports map[string]importedPackage // package name => imported package 137 importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 138 139 auxFiles []*ast.File 140 auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 141 142 srcDir string 143} 144 145func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { 146 ps := p.fileSet.Position(pos) 147 format = "%s:%d:%d: " + format 148 args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) 149 return fmt.Errorf(format, args...) 150} 151 152func (p *fileParser) parseAuxFiles(auxFiles string) error { 153 auxFiles = strings.TrimSpace(auxFiles) 154 if auxFiles == "" { 155 return nil 156 } 157 for _, kv := range strings.Split(auxFiles, ",") { 158 parts := strings.SplitN(kv, "=", 2) 159 if len(parts) != 2 { 160 return fmt.Errorf("bad aux file spec: %v", kv) 161 } 162 pkg, fpath := parts[0], parts[1] 163 164 file, err := parser.ParseFile(p.fileSet, fpath, nil, 0) 165 if err != nil { 166 return err 167 } 168 p.auxFiles = append(p.auxFiles, file) 169 p.addAuxInterfacesFromFile(pkg, file) 170 } 171 return nil 172} 173 174func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { 175 if _, ok := p.auxInterfaces[pkg]; !ok { 176 p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) 177 } 178 for ni := range iterInterfaces(file) { 179 p.auxInterfaces[pkg][ni.name.Name] = ni.it 180 } 181} 182 183// parseFile loads all file imports and auxiliary files import into the 184// fileParser, parses all file interfaces and returns package model. 185func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) { 186 allImports, dotImports := importsOfFile(file) 187 // Don't stomp imports provided by -imports. Those should take precedence. 188 for pkg, pkgI := range allImports { 189 if _, ok := p.imports[pkg]; !ok { 190 p.imports[pkg] = pkgI 191 } 192 } 193 // Add imports from auxiliary files, which might be needed for embedded interfaces. 194 // Don't stomp any other imports. 195 for _, f := range p.auxFiles { 196 auxImports, _ := importsOfFile(f) 197 for pkg, pkgI := range auxImports { 198 if _, ok := p.imports[pkg]; !ok { 199 p.imports[pkg] = pkgI 200 } 201 } 202 } 203 204 var is []*model.Interface 205 for ni := range iterInterfaces(file) { 206 i, err := p.parseInterface(ni.name.String(), importPath, ni.it) 207 if err != nil { 208 return nil, err 209 } 210 is = append(is, i) 211 } 212 return &model.Package{ 213 Name: file.Name.String(), 214 PkgPath: importPath, 215 Interfaces: is, 216 DotImports: dotImports, 217 }, nil 218} 219 220// parsePackage loads package specified by path, parses it and returns 221// a new fileParser with the parsed imports and interfaces. 222func (p *fileParser) parsePackage(path string) (*fileParser, error) { 223 newP := &fileParser{ 224 fileSet: token.NewFileSet(), 225 imports: make(map[string]importedPackage), 226 importedInterfaces: make(map[string]map[string]*ast.InterfaceType), 227 auxInterfaces: make(map[string]map[string]*ast.InterfaceType), 228 srcDir: p.srcDir, 229 } 230 231 var pkgs map[string]*ast.Package 232 if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil { 233 return nil, err 234 } else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil { 235 return nil, err 236 } 237 238 for _, pkg := range pkgs { 239 file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) 240 if _, ok := newP.importedInterfaces[path]; !ok { 241 newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType) 242 } 243 for ni := range iterInterfaces(file) { 244 newP.importedInterfaces[path][ni.name.Name] = ni.it 245 } 246 imports, _ := importsOfFile(file) 247 for pkgName, pkgI := range imports { 248 newP.imports[pkgName] = pkgI 249 } 250 } 251 return newP, nil 252} 253 254func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { 255 intf := &model.Interface{Name: name} 256 for _, field := range it.Methods.List { 257 switch v := field.Type.(type) { 258 case *ast.FuncType: 259 if nn := len(field.Names); nn != 1 { 260 return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn) 261 } 262 m := &model.Method{ 263 Name: field.Names[0].String(), 264 } 265 var err error 266 m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) 267 if err != nil { 268 return nil, err 269 } 270 intf.AddMethod(m) 271 case *ast.Ident: 272 // Embedded interface in this package. 273 ei := p.auxInterfaces[pkg][v.String()] 274 if ei == nil { 275 ei = p.importedInterfaces[pkg][v.String()] 276 } 277 278 var eintf *model.Interface 279 if ei != nil { 280 var err error 281 eintf, err = p.parseInterface(v.String(), pkg, ei) 282 if err != nil { 283 return nil, err 284 } 285 } else { 286 // This is built-in error interface. 287 if v.String() == model.ErrorInterface.Name { 288 eintf = &model.ErrorInterface 289 } else { 290 return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) 291 } 292 } 293 // Copy the methods. 294 for _, m := range eintf.Methods { 295 intf.AddMethod(m) 296 } 297 case *ast.SelectorExpr: 298 // Embedded interface in another package. 299 fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() 300 epkg, ok := p.imports[fpkg] 301 if !ok { 302 return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg) 303 } 304 305 var eintf *model.Interface 306 var err error 307 ei := p.auxInterfaces[fpkg][sel] 308 if ei != nil { 309 eintf, err = p.parseInterface(sel, fpkg, ei) 310 if err != nil { 311 return nil, err 312 } 313 } else { 314 path := epkg.Path() 315 parser := epkg.Parser() 316 if parser == nil { 317 ip, err := p.parsePackage(path) 318 if err != nil { 319 return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) 320 } 321 parser = ip 322 p.imports[fpkg] = importedPkg{ 323 path: epkg.Path(), 324 parser: parser, 325 } 326 } 327 if ei = parser.importedInterfaces[path][sel]; ei == nil { 328 return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) 329 } 330 eintf, err = parser.parseInterface(sel, path, ei) 331 if err != nil { 332 return nil, err 333 } 334 } 335 // Copy the methods. 336 // TODO: apply shadowing rules. 337 for _, m := range eintf.Methods { 338 intf.AddMethod(m) 339 } 340 default: 341 return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) 342 } 343 } 344 return intf, nil 345} 346 347func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) { 348 if f.Params != nil { 349 regParams := f.Params.List 350 if isVariadic(f) { 351 n := len(regParams) 352 varParams := regParams[n-1:] 353 regParams = regParams[:n-1] 354 vp, err := p.parseFieldList(pkg, varParams) 355 if err != nil { 356 return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) 357 } 358 variadic = vp[0] 359 } 360 in, err = p.parseFieldList(pkg, regParams) 361 if err != nil { 362 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) 363 } 364 } 365 if f.Results != nil { 366 out, err = p.parseFieldList(pkg, f.Results.List) 367 if err != nil { 368 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) 369 } 370 } 371 return 372} 373 374func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { 375 nf := 0 376 for _, f := range fields { 377 nn := len(f.Names) 378 if nn == 0 { 379 nn = 1 // anonymous parameter 380 } 381 nf += nn 382 } 383 if nf == 0 { 384 return nil, nil 385 } 386 ps := make([]*model.Parameter, nf) 387 i := 0 // destination index 388 for _, f := range fields { 389 t, err := p.parseType(pkg, f.Type) 390 if err != nil { 391 return nil, err 392 } 393 394 if len(f.Names) == 0 { 395 // anonymous arg 396 ps[i] = &model.Parameter{Type: t} 397 i++ 398 continue 399 } 400 for _, name := range f.Names { 401 ps[i] = &model.Parameter{Name: name.Name, Type: t} 402 i++ 403 } 404 } 405 return ps, nil 406} 407 408func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { 409 switch v := typ.(type) { 410 case *ast.ArrayType: 411 ln := -1 412 if v.Len != nil { 413 x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value) 414 if err != nil { 415 return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) 416 } 417 ln = x 418 } 419 t, err := p.parseType(pkg, v.Elt) 420 if err != nil { 421 return nil, err 422 } 423 return &model.ArrayType{Len: ln, Type: t}, nil 424 case *ast.ChanType: 425 t, err := p.parseType(pkg, v.Value) 426 if err != nil { 427 return nil, err 428 } 429 var dir model.ChanDir 430 if v.Dir == ast.SEND { 431 dir = model.SendDir 432 } 433 if v.Dir == ast.RECV { 434 dir = model.RecvDir 435 } 436 return &model.ChanType{Dir: dir, Type: t}, nil 437 case *ast.Ellipsis: 438 // assume we're parsing a variadic argument 439 return p.parseType(pkg, v.Elt) 440 case *ast.FuncType: 441 in, variadic, out, err := p.parseFunc(pkg, v) 442 if err != nil { 443 return nil, err 444 } 445 return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil 446 case *ast.Ident: 447 if v.IsExported() { 448 // `pkg` may be an aliased imported pkg 449 // if so, patch the import w/ the fully qualified import 450 maybeImportedPkg, ok := p.imports[pkg] 451 if ok { 452 pkg = maybeImportedPkg.Path() 453 } 454 // assume type in this package 455 return &model.NamedType{Package: pkg, Type: v.Name}, nil 456 } 457 458 // assume predeclared type 459 return model.PredeclaredType(v.Name), nil 460 case *ast.InterfaceType: 461 if v.Methods != nil && len(v.Methods.List) > 0 { 462 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") 463 } 464 return model.PredeclaredType("interface{}"), nil 465 case *ast.MapType: 466 key, err := p.parseType(pkg, v.Key) 467 if err != nil { 468 return nil, err 469 } 470 value, err := p.parseType(pkg, v.Value) 471 if err != nil { 472 return nil, err 473 } 474 return &model.MapType{Key: key, Value: value}, nil 475 case *ast.SelectorExpr: 476 pkgName := v.X.(*ast.Ident).String() 477 pkg, ok := p.imports[pkgName] 478 if !ok { 479 return nil, p.errorf(v.Pos(), "unknown package %q", pkgName) 480 } 481 return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil 482 case *ast.StarExpr: 483 t, err := p.parseType(pkg, v.X) 484 if err != nil { 485 return nil, err 486 } 487 return &model.PointerType{Type: t}, nil 488 case *ast.StructType: 489 if v.Fields != nil && len(v.Fields.List) > 0 { 490 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types") 491 } 492 return model.PredeclaredType("struct{}"), nil 493 case *ast.ParenExpr: 494 return p.parseType(pkg, v.X) 495 } 496 497 return nil, fmt.Errorf("don't know how to parse type %T", typ) 498} 499 500// importsOfFile returns a map of package name to import path 501// of the imports in file. 502func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) { 503 var importPaths []string 504 for _, is := range file.Imports { 505 if is.Name != nil { 506 continue 507 } 508 importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes 509 importPaths = append(importPaths, importPath) 510 } 511 packagesName := createPackageMap(importPaths) 512 normalImports = make(map[string]importedPackage) 513 dotImports = make([]string, 0) 514 for _, is := range file.Imports { 515 var pkgName string 516 importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes 517 518 if is.Name != nil { 519 // Named imports are always certain. 520 if is.Name.Name == "_" { 521 continue 522 } 523 pkgName = is.Name.Name 524 } else { 525 pkg, ok := packagesName[importPath] 526 if !ok { 527 // Fallback to import path suffix. Note that this is uncertain. 528 _, last := path.Split(importPath) 529 // If the last path component has dots, the first dot-delimited 530 // field is used as the name. 531 pkgName = strings.SplitN(last, ".", 2)[0] 532 } else { 533 pkgName = pkg 534 } 535 } 536 537 if pkgName == "." { 538 dotImports = append(dotImports, importPath) 539 } else { 540 if pkg, ok := normalImports[pkgName]; ok { 541 switch p := pkg.(type) { 542 case duplicateImport: 543 normalImports[pkgName] = duplicateImport{ 544 name: p.name, 545 duplicates: append([]string{importPath}, p.duplicates...), 546 } 547 case importedPkg: 548 normalImports[pkgName] = duplicateImport{ 549 name: pkgName, 550 duplicates: []string{p.path, importPath}, 551 } 552 } 553 } else { 554 normalImports[pkgName] = importedPkg{path: importPath} 555 } 556 } 557 } 558 return 559} 560 561type namedInterface struct { 562 name *ast.Ident 563 it *ast.InterfaceType 564} 565 566// Create an iterator over all interfaces in file. 567func iterInterfaces(file *ast.File) <-chan namedInterface { 568 ch := make(chan namedInterface) 569 go func() { 570 for _, decl := range file.Decls { 571 gd, ok := decl.(*ast.GenDecl) 572 if !ok || gd.Tok != token.TYPE { 573 continue 574 } 575 for _, spec := range gd.Specs { 576 ts, ok := spec.(*ast.TypeSpec) 577 if !ok { 578 continue 579 } 580 it, ok := ts.Type.(*ast.InterfaceType) 581 if !ok { 582 continue 583 } 584 585 ch <- namedInterface{ts.Name, it} 586 } 587 } 588 close(ch) 589 }() 590 return ch 591} 592 593// isVariadic returns whether the function is variadic. 594func isVariadic(f *ast.FuncType) bool { 595 nargs := len(f.Params.List) 596 if nargs == 0 { 597 return false 598 } 599 _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) 600 return ok 601} 602 603// packageNameOfDir get package import path via dir 604func packageNameOfDir(srcDir string) (string, error) { 605 files, err := ioutil.ReadDir(srcDir) 606 if err != nil { 607 log.Fatal(err) 608 } 609 610 var goFilePath string 611 for _, file := range files { 612 if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { 613 goFilePath = file.Name() 614 break 615 } 616 } 617 if goFilePath == "" { 618 return "", fmt.Errorf("go source file not found %s", srcDir) 619 } 620 621 packageImport, err := parsePackageImport(srcDir) 622 if err != nil { 623 return "", err 624 } 625 return packageImport, nil 626} 627 628var errOutsideGoPath = errors.New("Source directory is outside GOPATH") 629 630// parseImportPackage get package import path via source file 631// an alternative implementation is to use: 632// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir} 633// pkgs, err := packages.Load(cfg, "file="+source) 634// However, it will call "go list" and slow down the performance 635func parsePackageImport(srcDir string) (string, error) { 636 moduleMode := os.Getenv("GO111MODULE") 637 // trying to find the module 638 if moduleMode != "off" { 639 currentDir := srcDir 640 for { 641 dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod")) 642 if os.IsNotExist(err) { 643 if currentDir == filepath.Dir(currentDir) { 644 // at the root 645 break 646 } 647 currentDir = filepath.Dir(currentDir) 648 continue 649 } else if err != nil { 650 return "", err 651 } 652 modulePath := modfile.ModulePath(dat) 653 return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil 654 } 655 } 656 // fall back to GOPATH mode 657 goPaths := os.Getenv("GOPATH") 658 if goPaths == "" { 659 return "", fmt.Errorf("GOPATH is not set") 660 } 661 goPathList := strings.Split(goPaths, string(os.PathListSeparator)) 662 for _, goPath := range goPathList { 663 sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator) 664 if strings.HasPrefix(srcDir, sourceRoot) { 665 return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil 666 } 667 } 668 return "", errOutsideGoPath 669} 670