1// Copyright 2012 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17// This file contains the model construction by parsing source files. 18 19import ( 20 "errors" 21 "flag" 22 "fmt" 23 "go/ast" 24 "go/build" 25 "go/importer" 26 "go/parser" 27 "go/token" 28 "go/types" 29 "io/ioutil" 30 "log" 31 "path" 32 "path/filepath" 33 "strconv" 34 "strings" 35 36 "github.com/golang/mock/mockgen/model" 37) 38 39var ( 40 imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") 41 auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") 42) 43 44// sourceMode generates mocks via source file. 45func sourceMode(source string) (*model.Package, error) { 46 srcDir, err := filepath.Abs(filepath.Dir(source)) 47 if err != nil { 48 return nil, fmt.Errorf("failed getting source directory: %v", err) 49 } 50 51 packageImport, err := parsePackageImport(srcDir) 52 if err != nil { 53 return nil, err 54 } 55 56 fs := token.NewFileSet() 57 file, err := parser.ParseFile(fs, source, nil, 0) 58 if err != nil { 59 return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) 60 } 61 62 p := &fileParser{ 63 fileSet: fs, 64 imports: make(map[string]importedPackage), 65 importedInterfaces: make(map[string]map[string]*ast.InterfaceType), 66 auxInterfaces: make(map[string]map[string]*ast.InterfaceType), 67 srcDir: srcDir, 68 } 69 70 // Handle -imports. 71 dotImports := make(map[string]bool) 72 if *imports != "" { 73 for _, kv := range strings.Split(*imports, ",") { 74 eq := strings.Index(kv, "=") 75 k, v := kv[:eq], kv[eq+1:] 76 if k == "." { 77 dotImports[v] = true 78 } else { 79 p.imports[k] = importedPkg{path: v} 80 } 81 } 82 } 83 84 // Handle -aux_files. 85 if err := p.parseAuxFiles(*auxFiles); err != nil { 86 return nil, err 87 } 88 p.addAuxInterfacesFromFile(packageImport, file) // this file 89 90 pkg, err := p.parseFile(packageImport, file) 91 if err != nil { 92 return nil, err 93 } 94 for pkgPath := range dotImports { 95 pkg.DotImports = append(pkg.DotImports, pkgPath) 96 } 97 return pkg, nil 98} 99 100type importedPackage interface { 101 Path() string 102 Parser() *fileParser 103} 104 105type importedPkg struct { 106 path string 107 parser *fileParser 108} 109 110func (i importedPkg) Path() string { return i.path } 111func (i importedPkg) Parser() *fileParser { return i.parser } 112 113// duplicateImport is a bit of a misnomer. Currently the parser can't 114// handle cases of multi-file packages importing different packages 115// under the same name. Often these imports would not be problematic, 116// so this type lets us defer raising an error unless the package name 117// is actually used. 118type duplicateImport struct { 119 name string 120 duplicates []string 121} 122 123func (d duplicateImport) Error() string { 124 return fmt.Sprintf("%q is ambiguous because of duplicate imports: %v", d.name, d.duplicates) 125} 126 127func (d duplicateImport) Path() string { log.Fatal(d.Error()); return "" } 128func (d duplicateImport) Parser() *fileParser { log.Fatal(d.Error()); return nil } 129 130type fileParser struct { 131 fileSet *token.FileSet 132 imports map[string]importedPackage // package name => imported package 133 importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 134 135 auxFiles []*ast.File 136 auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 137 138 srcDir string 139} 140 141func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { 142 ps := p.fileSet.Position(pos) 143 format = "%s:%d:%d: " + format 144 args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) 145 return fmt.Errorf(format, args...) 146} 147 148func (p *fileParser) parseAuxFiles(auxFiles string) error { 149 auxFiles = strings.TrimSpace(auxFiles) 150 if auxFiles == "" { 151 return nil 152 } 153 for _, kv := range strings.Split(auxFiles, ",") { 154 parts := strings.SplitN(kv, "=", 2) 155 if len(parts) != 2 { 156 return fmt.Errorf("bad aux file spec: %v", kv) 157 } 158 pkg, fpath := parts[0], parts[1] 159 160 file, err := parser.ParseFile(p.fileSet, fpath, nil, 0) 161 if err != nil { 162 return err 163 } 164 p.auxFiles = append(p.auxFiles, file) 165 p.addAuxInterfacesFromFile(pkg, file) 166 } 167 return nil 168} 169 170func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { 171 if _, ok := p.auxInterfaces[pkg]; !ok { 172 p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) 173 } 174 for ni := range iterInterfaces(file) { 175 p.auxInterfaces[pkg][ni.name.Name] = ni.it 176 } 177} 178 179// parseFile loads all file imports and auxiliary files import into the 180// fileParser, parses all file interfaces and returns package model. 181func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) { 182 allImports, dotImports := importsOfFile(file) 183 // Don't stomp imports provided by -imports. Those should take precedence. 184 for pkg, pkgI := range allImports { 185 if _, ok := p.imports[pkg]; !ok { 186 p.imports[pkg] = pkgI 187 } 188 } 189 // Add imports from auxiliary files, which might be needed for embedded interfaces. 190 // Don't stomp any other imports. 191 for _, f := range p.auxFiles { 192 auxImports, _ := importsOfFile(f) 193 for pkg, pkgI := range auxImports { 194 if _, ok := p.imports[pkg]; !ok { 195 p.imports[pkg] = pkgI 196 } 197 } 198 } 199 200 var is []*model.Interface 201 for ni := range iterInterfaces(file) { 202 i, err := p.parseInterface(ni.name.String(), importPath, ni.it) 203 if err != nil { 204 return nil, err 205 } 206 is = append(is, i) 207 } 208 return &model.Package{ 209 Name: file.Name.String(), 210 PkgPath: importPath, 211 Interfaces: is, 212 DotImports: dotImports, 213 }, nil 214} 215 216// parsePackage loads package specified by path, parses it and returns 217// a new fileParser with the parsed imports and interfaces. 218func (p *fileParser) parsePackage(path string) (*fileParser, error) { 219 newP := &fileParser{ 220 fileSet: token.NewFileSet(), 221 imports: make(map[string]importedPackage), 222 importedInterfaces: make(map[string]map[string]*ast.InterfaceType), 223 auxInterfaces: make(map[string]map[string]*ast.InterfaceType), 224 srcDir: p.srcDir, 225 } 226 227 var pkgs map[string]*ast.Package 228 if imp, err := build.Import(path, newP.srcDir, build.FindOnly); err != nil { 229 return nil, err 230 } else if pkgs, err = parser.ParseDir(newP.fileSet, imp.Dir, nil, 0); err != nil { 231 return nil, err 232 } 233 234 for _, pkg := range pkgs { 235 file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) 236 if _, ok := newP.importedInterfaces[path]; !ok { 237 newP.importedInterfaces[path] = make(map[string]*ast.InterfaceType) 238 } 239 for ni := range iterInterfaces(file) { 240 newP.importedInterfaces[path][ni.name.Name] = ni.it 241 } 242 imports, _ := importsOfFile(file) 243 for pkgName, pkgI := range imports { 244 newP.imports[pkgName] = pkgI 245 } 246 } 247 return newP, nil 248} 249 250func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { 251 iface := &model.Interface{Name: name} 252 for _, field := range it.Methods.List { 253 switch v := field.Type.(type) { 254 case *ast.FuncType: 255 if nn := len(field.Names); nn != 1 { 256 return nil, fmt.Errorf("expected one name for interface %v, got %d", iface.Name, nn) 257 } 258 m := &model.Method{ 259 Name: field.Names[0].String(), 260 } 261 var err error 262 m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) 263 if err != nil { 264 return nil, err 265 } 266 iface.AddMethod(m) 267 case *ast.Ident: 268 // Embedded interface in this package. 269 embeddedIfaceType := p.auxInterfaces[pkg][v.String()] 270 if embeddedIfaceType == nil { 271 embeddedIfaceType = p.importedInterfaces[pkg][v.String()] 272 } 273 274 var embeddedIface *model.Interface 275 if embeddedIfaceType != nil { 276 var err error 277 embeddedIface, err = p.parseInterface(v.String(), pkg, embeddedIfaceType) 278 if err != nil { 279 return nil, err 280 } 281 } else { 282 // This is built-in error interface. 283 if v.String() == model.ErrorInterface.Name { 284 embeddedIface = &model.ErrorInterface 285 } else { 286 return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) 287 } 288 } 289 // Copy the methods. 290 for _, m := range embeddedIface.Methods { 291 iface.AddMethod(m) 292 } 293 case *ast.SelectorExpr: 294 // Embedded interface in another package. 295 filePkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() 296 embeddedPkg, ok := p.imports[filePkg] 297 if !ok { 298 return nil, p.errorf(v.X.Pos(), "unknown package %s", filePkg) 299 } 300 301 var embeddedIface *model.Interface 302 var err error 303 embeddedIfaceType := p.auxInterfaces[filePkg][sel] 304 if embeddedIfaceType != nil { 305 embeddedIface, err = p.parseInterface(sel, filePkg, embeddedIfaceType) 306 if err != nil { 307 return nil, err 308 } 309 } else { 310 path := embeddedPkg.Path() 311 parser := embeddedPkg.Parser() 312 if parser == nil { 313 ip, err := p.parsePackage(path) 314 if err != nil { 315 return nil, p.errorf(v.Pos(), "could not parse package %s: %v", path, err) 316 } 317 parser = ip 318 p.imports[filePkg] = importedPkg{ 319 path: embeddedPkg.Path(), 320 parser: parser, 321 } 322 } 323 if embeddedIfaceType = parser.importedInterfaces[path][sel]; embeddedIfaceType == nil { 324 return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", path, sel) 325 } 326 embeddedIface, err = parser.parseInterface(sel, path, embeddedIfaceType) 327 if err != nil { 328 return nil, err 329 } 330 } 331 // Copy the methods. 332 // TODO: apply shadowing rules. 333 for _, m := range embeddedIface.Methods { 334 iface.AddMethod(m) 335 } 336 default: 337 return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) 338 } 339 } 340 return iface, nil 341} 342 343func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (inParam []*model.Parameter, variadic *model.Parameter, outParam []*model.Parameter, err error) { 344 if f.Params != nil { 345 regParams := f.Params.List 346 if isVariadic(f) { 347 n := len(regParams) 348 varParams := regParams[n-1:] 349 regParams = regParams[:n-1] 350 vp, err := p.parseFieldList(pkg, varParams) 351 if err != nil { 352 return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) 353 } 354 variadic = vp[0] 355 } 356 inParam, err = p.parseFieldList(pkg, regParams) 357 if err != nil { 358 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) 359 } 360 } 361 if f.Results != nil { 362 outParam, err = p.parseFieldList(pkg, f.Results.List) 363 if err != nil { 364 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) 365 } 366 } 367 return 368} 369 370func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { 371 nf := 0 372 for _, f := range fields { 373 nn := len(f.Names) 374 if nn == 0 { 375 nn = 1 // anonymous parameter 376 } 377 nf += nn 378 } 379 if nf == 0 { 380 return nil, nil 381 } 382 ps := make([]*model.Parameter, nf) 383 i := 0 // destination index 384 for _, f := range fields { 385 t, err := p.parseType(pkg, f.Type) 386 if err != nil { 387 return nil, err 388 } 389 390 if len(f.Names) == 0 { 391 // anonymous arg 392 ps[i] = &model.Parameter{Type: t} 393 i++ 394 continue 395 } 396 for _, name := range f.Names { 397 ps[i] = &model.Parameter{Name: name.Name, Type: t} 398 i++ 399 } 400 } 401 return ps, nil 402} 403 404func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { 405 switch v := typ.(type) { 406 case *ast.ArrayType: 407 ln := -1 408 if v.Len != nil { 409 var value string 410 switch val := v.Len.(type) { 411 case (*ast.BasicLit): 412 value = val.Value 413 case (*ast.Ident): 414 // when the length is a const defined locally 415 value = val.Obj.Decl.(*ast.ValueSpec).Values[0].(*ast.BasicLit).Value 416 case (*ast.SelectorExpr): 417 // when the length is a const defined in an external package 418 usedPkg, err := importer.Default().Import(fmt.Sprintf("%s", val.X)) 419 if err != nil { 420 return nil, p.errorf(v.Len.Pos(), "unknown package in array length: %v", err) 421 } 422 ev, err := types.Eval(token.NewFileSet(), usedPkg, token.NoPos, val.Sel.Name) 423 if err != nil { 424 return nil, p.errorf(v.Len.Pos(), "unknown constant in array length: %v", err) 425 } 426 value = ev.Value.String() 427 } 428 429 x, err := strconv.Atoi(value) 430 if err != nil { 431 return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) 432 } 433 ln = x 434 } 435 t, err := p.parseType(pkg, v.Elt) 436 if err != nil { 437 return nil, err 438 } 439 return &model.ArrayType{Len: ln, Type: t}, nil 440 case *ast.ChanType: 441 t, err := p.parseType(pkg, v.Value) 442 if err != nil { 443 return nil, err 444 } 445 var dir model.ChanDir 446 if v.Dir == ast.SEND { 447 dir = model.SendDir 448 } 449 if v.Dir == ast.RECV { 450 dir = model.RecvDir 451 } 452 return &model.ChanType{Dir: dir, Type: t}, nil 453 case *ast.Ellipsis: 454 // assume we're parsing a variadic argument 455 return p.parseType(pkg, v.Elt) 456 case *ast.FuncType: 457 in, variadic, out, err := p.parseFunc(pkg, v) 458 if err != nil { 459 return nil, err 460 } 461 return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil 462 case *ast.Ident: 463 if v.IsExported() { 464 // `pkg` may be an aliased imported pkg 465 // if so, patch the import w/ the fully qualified import 466 maybeImportedPkg, ok := p.imports[pkg] 467 if ok { 468 pkg = maybeImportedPkg.Path() 469 } 470 // assume type in this package 471 return &model.NamedType{Package: pkg, Type: v.Name}, nil 472 } 473 474 // assume predeclared type 475 return model.PredeclaredType(v.Name), nil 476 case *ast.InterfaceType: 477 if v.Methods != nil && len(v.Methods.List) > 0 { 478 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") 479 } 480 return model.PredeclaredType("interface{}"), nil 481 case *ast.MapType: 482 key, err := p.parseType(pkg, v.Key) 483 if err != nil { 484 return nil, err 485 } 486 value, err := p.parseType(pkg, v.Value) 487 if err != nil { 488 return nil, err 489 } 490 return &model.MapType{Key: key, Value: value}, nil 491 case *ast.SelectorExpr: 492 pkgName := v.X.(*ast.Ident).String() 493 pkg, ok := p.imports[pkgName] 494 if !ok { 495 return nil, p.errorf(v.Pos(), "unknown package %q", pkgName) 496 } 497 return &model.NamedType{Package: pkg.Path(), Type: v.Sel.String()}, nil 498 case *ast.StarExpr: 499 t, err := p.parseType(pkg, v.X) 500 if err != nil { 501 return nil, err 502 } 503 return &model.PointerType{Type: t}, nil 504 case *ast.StructType: 505 if v.Fields != nil && len(v.Fields.List) > 0 { 506 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types") 507 } 508 return model.PredeclaredType("struct{}"), nil 509 case *ast.ParenExpr: 510 return p.parseType(pkg, v.X) 511 } 512 513 return nil, fmt.Errorf("don't know how to parse type %T", typ) 514} 515 516// importsOfFile returns a map of package name to import path 517// of the imports in file. 518func importsOfFile(file *ast.File) (normalImports map[string]importedPackage, dotImports []string) { 519 var importPaths []string 520 for _, is := range file.Imports { 521 if is.Name != nil { 522 continue 523 } 524 importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes 525 importPaths = append(importPaths, importPath) 526 } 527 packagesName := createPackageMap(importPaths) 528 normalImports = make(map[string]importedPackage) 529 dotImports = make([]string, 0) 530 for _, is := range file.Imports { 531 var pkgName string 532 importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes 533 534 if is.Name != nil { 535 // Named imports are always certain. 536 if is.Name.Name == "_" { 537 continue 538 } 539 pkgName = is.Name.Name 540 } else { 541 pkg, ok := packagesName[importPath] 542 if !ok { 543 // Fallback to import path suffix. Note that this is uncertain. 544 _, last := path.Split(importPath) 545 // If the last path component has dots, the first dot-delimited 546 // field is used as the name. 547 pkgName = strings.SplitN(last, ".", 2)[0] 548 } else { 549 pkgName = pkg 550 } 551 } 552 553 if pkgName == "." { 554 dotImports = append(dotImports, importPath) 555 } else { 556 if pkg, ok := normalImports[pkgName]; ok { 557 switch p := pkg.(type) { 558 case duplicateImport: 559 normalImports[pkgName] = duplicateImport{ 560 name: p.name, 561 duplicates: append([]string{importPath}, p.duplicates...), 562 } 563 case importedPkg: 564 normalImports[pkgName] = duplicateImport{ 565 name: pkgName, 566 duplicates: []string{p.path, importPath}, 567 } 568 } 569 } else { 570 normalImports[pkgName] = importedPkg{path: importPath} 571 } 572 } 573 } 574 return 575} 576 577type namedInterface struct { 578 name *ast.Ident 579 it *ast.InterfaceType 580} 581 582// Create an iterator over all interfaces in file. 583func iterInterfaces(file *ast.File) <-chan namedInterface { 584 ch := make(chan namedInterface) 585 go func() { 586 for _, decl := range file.Decls { 587 gd, ok := decl.(*ast.GenDecl) 588 if !ok || gd.Tok != token.TYPE { 589 continue 590 } 591 for _, spec := range gd.Specs { 592 ts, ok := spec.(*ast.TypeSpec) 593 if !ok { 594 continue 595 } 596 it, ok := ts.Type.(*ast.InterfaceType) 597 if !ok { 598 continue 599 } 600 601 ch <- namedInterface{ts.Name, it} 602 } 603 } 604 close(ch) 605 }() 606 return ch 607} 608 609// isVariadic returns whether the function is variadic. 610func isVariadic(f *ast.FuncType) bool { 611 nargs := len(f.Params.List) 612 if nargs == 0 { 613 return false 614 } 615 _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) 616 return ok 617} 618 619// packageNameOfDir get package import path via dir 620func packageNameOfDir(srcDir string) (string, error) { 621 files, err := ioutil.ReadDir(srcDir) 622 if err != nil { 623 log.Fatal(err) 624 } 625 626 var goFilePath string 627 for _, file := range files { 628 if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { 629 goFilePath = file.Name() 630 break 631 } 632 } 633 if goFilePath == "" { 634 return "", fmt.Errorf("go source file not found %s", srcDir) 635 } 636 637 packageImport, err := parsePackageImport(srcDir) 638 if err != nil { 639 return "", err 640 } 641 return packageImport, nil 642} 643 644var errOutsideGoPath = errors.New("source directory is outside GOPATH") 645