1// Copyright 2012 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package main 16 17// This file contains the model construction by parsing source files. 18 19import ( 20 "flag" 21 "fmt" 22 "go/ast" 23 "go/build" 24 "go/parser" 25 "go/token" 26 "log" 27 "path" 28 "path/filepath" 29 "strconv" 30 "strings" 31 32 "github.com/golang/mock/mockgen/model" 33) 34 35var ( 36 imports = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.") 37 auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.") 38) 39 40// TODO: simplify error reporting 41 42func ParseFile(source string) (*model.Package, error) { 43 srcDir, err := filepath.Abs(filepath.Dir(source)) 44 if err != nil { 45 return nil, fmt.Errorf("failed getting source directory: %v", err) 46 } 47 48 var packageImport string 49 if p, err := build.ImportDir(srcDir, 0); err == nil { 50 packageImport = p.ImportPath 51 } // TODO: should we fail if this returns an error? 52 53 fs := token.NewFileSet() 54 file, err := parser.ParseFile(fs, source, nil, 0) 55 if err != nil { 56 return nil, fmt.Errorf("failed parsing source file %v: %v", source, err) 57 } 58 59 p := &fileParser{ 60 fileSet: fs, 61 imports: make(map[string]string), 62 importedInterfaces: make(map[string]map[string]*ast.InterfaceType), 63 auxInterfaces: make(map[string]map[string]*ast.InterfaceType), 64 srcDir: srcDir, 65 } 66 67 // Handle -imports. 68 dotImports := make(map[string]bool) 69 if *imports != "" { 70 for _, kv := range strings.Split(*imports, ",") { 71 eq := strings.Index(kv, "=") 72 k, v := kv[:eq], kv[eq+1:] 73 if k == "." { 74 // TODO: Catch dupes? 75 dotImports[v] = true 76 } else { 77 // TODO: Catch dupes? 78 p.imports[k] = v 79 } 80 } 81 } 82 83 // Handle -aux_files. 84 if err := p.parseAuxFiles(*auxFiles); err != nil { 85 return nil, err 86 } 87 p.addAuxInterfacesFromFile(packageImport, file) // this file 88 89 pkg, err := p.parseFile(packageImport, file) 90 if err != nil { 91 return nil, err 92 } 93 pkg.DotImports = make([]string, 0, len(dotImports)) 94 for path := range dotImports { 95 pkg.DotImports = append(pkg.DotImports, path) 96 } 97 return pkg, nil 98} 99 100type fileParser struct { 101 fileSet *token.FileSet 102 imports map[string]string // package name => import path 103 importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 104 105 auxFiles []*ast.File 106 auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface 107 108 srcDir string 109} 110 111func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error { 112 ps := p.fileSet.Position(pos) 113 format = "%s:%d:%d: " + format 114 args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...) 115 return fmt.Errorf(format, args...) 116} 117 118func (p *fileParser) parseAuxFiles(auxFiles string) error { 119 auxFiles = strings.TrimSpace(auxFiles) 120 if auxFiles == "" { 121 return nil 122 } 123 for _, kv := range strings.Split(auxFiles, ",") { 124 parts := strings.SplitN(kv, "=", 2) 125 if len(parts) != 2 { 126 return fmt.Errorf("bad aux file spec: %v", kv) 127 } 128 pkg, fpath := parts[0], parts[1] 129 130 file, err := parser.ParseFile(p.fileSet, fpath, nil, 0) 131 if err != nil { 132 return err 133 } 134 p.auxFiles = append(p.auxFiles, file) 135 p.addAuxInterfacesFromFile(pkg, file) 136 } 137 return nil 138} 139 140func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) { 141 if _, ok := p.auxInterfaces[pkg]; !ok { 142 p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType) 143 } 144 for ni := range iterInterfaces(file) { 145 p.auxInterfaces[pkg][ni.name.Name] = ni.it 146 } 147} 148 149// parseFile loads all file imports and auxiliary files import into the 150// fileParser, parses all file interfaces and returns package model. 151func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) { 152 allImports := importsOfFile(file) 153 // Don't stomp imports provided by -imports. Those should take precedence. 154 for pkg, path := range allImports { 155 if _, ok := p.imports[pkg]; !ok { 156 p.imports[pkg] = path 157 } 158 } 159 // Add imports from auxiliary files, which might be needed for embedded interfaces. 160 // Don't stomp any other imports. 161 for _, f := range p.auxFiles { 162 for pkg, path := range importsOfFile(f) { 163 if _, ok := p.imports[pkg]; !ok { 164 p.imports[pkg] = path 165 } 166 } 167 } 168 169 var is []*model.Interface 170 for ni := range iterInterfaces(file) { 171 i, err := p.parseInterface(ni.name.String(), importPath, ni.it) 172 if err != nil { 173 return nil, err 174 } 175 is = append(is, i) 176 } 177 return &model.Package{ 178 Name: file.Name.String(), 179 Interfaces: is, 180 }, nil 181} 182 183// parsePackage loads package specified by path, parses it and populates 184// corresponding imports and importedInterfaces into the fileParser. 185func (p *fileParser) parsePackage(path string) error { 186 var pkgs map[string]*ast.Package 187 if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil { 188 return err 189 } else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil { 190 return err 191 } 192 for _, pkg := range pkgs { 193 file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates) 194 if _, ok := p.importedInterfaces[path]; !ok { 195 p.importedInterfaces[path] = make(map[string]*ast.InterfaceType) 196 } 197 for ni := range iterInterfaces(file) { 198 p.importedInterfaces[path][ni.name.Name] = ni.it 199 } 200 for pkgName, pkgPath := range importsOfFile(file) { 201 if _, ok := p.imports[pkgName]; !ok { 202 p.imports[pkgName] = pkgPath 203 } 204 } 205 } 206 return nil 207} 208 209func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) { 210 intf := &model.Interface{Name: name} 211 for _, field := range it.Methods.List { 212 switch v := field.Type.(type) { 213 case *ast.FuncType: 214 if nn := len(field.Names); nn != 1 { 215 return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn) 216 } 217 m := &model.Method{ 218 Name: field.Names[0].String(), 219 } 220 var err error 221 m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v) 222 if err != nil { 223 return nil, err 224 } 225 intf.Methods = append(intf.Methods, m) 226 case *ast.Ident: 227 // Embedded interface in this package. 228 ei := p.auxInterfaces[pkg][v.String()] 229 if ei == nil { 230 if ei = p.importedInterfaces[pkg][v.String()]; ei == nil { 231 return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String()) 232 } 233 } 234 eintf, err := p.parseInterface(v.String(), pkg, ei) 235 if err != nil { 236 return nil, err 237 } 238 // Copy the methods. 239 // TODO: apply shadowing rules. 240 for _, m := range eintf.Methods { 241 intf.Methods = append(intf.Methods, m) 242 } 243 case *ast.SelectorExpr: 244 // Embedded interface in another package. 245 fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String() 246 epkg, ok := p.imports[fpkg] 247 if !ok { 248 return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg) 249 } 250 ei := p.auxInterfaces[fpkg][sel] 251 if ei == nil { 252 fpkg = epkg 253 if _, ok = p.importedInterfaces[epkg]; !ok { 254 if err := p.parsePackage(epkg); err != nil { 255 return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err) 256 } 257 } 258 if ei = p.importedInterfaces[epkg][sel]; ei == nil { 259 return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel) 260 } 261 } 262 eintf, err := p.parseInterface(sel, fpkg, ei) 263 if err != nil { 264 return nil, err 265 } 266 // Copy the methods. 267 // TODO: apply shadowing rules. 268 for _, m := range eintf.Methods { 269 intf.Methods = append(intf.Methods, m) 270 } 271 default: 272 return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type) 273 } 274 } 275 return intf, nil 276} 277 278func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) { 279 if f.Params != nil { 280 regParams := f.Params.List 281 if isVariadic(f) { 282 n := len(regParams) 283 varParams := regParams[n-1:] 284 regParams = regParams[:n-1] 285 vp, err := p.parseFieldList(pkg, varParams) 286 if err != nil { 287 return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err) 288 } 289 variadic = vp[0] 290 } 291 in, err = p.parseFieldList(pkg, regParams) 292 if err != nil { 293 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err) 294 } 295 } 296 if f.Results != nil { 297 out, err = p.parseFieldList(pkg, f.Results.List) 298 if err != nil { 299 return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err) 300 } 301 } 302 return 303} 304 305func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) { 306 nf := 0 307 for _, f := range fields { 308 nn := len(f.Names) 309 if nn == 0 { 310 nn = 1 // anonymous parameter 311 } 312 nf += nn 313 } 314 if nf == 0 { 315 return nil, nil 316 } 317 ps := make([]*model.Parameter, nf) 318 i := 0 // destination index 319 for _, f := range fields { 320 t, err := p.parseType(pkg, f.Type) 321 if err != nil { 322 return nil, err 323 } 324 325 if len(f.Names) == 0 { 326 // anonymous arg 327 ps[i] = &model.Parameter{Type: t} 328 i++ 329 continue 330 } 331 for _, name := range f.Names { 332 ps[i] = &model.Parameter{Name: name.Name, Type: t} 333 i++ 334 } 335 } 336 return ps, nil 337} 338 339func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) { 340 switch v := typ.(type) { 341 case *ast.ArrayType: 342 ln := -1 343 if v.Len != nil { 344 x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value) 345 if err != nil { 346 return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err) 347 } 348 ln = x 349 } 350 t, err := p.parseType(pkg, v.Elt) 351 if err != nil { 352 return nil, err 353 } 354 return &model.ArrayType{Len: ln, Type: t}, nil 355 case *ast.ChanType: 356 t, err := p.parseType(pkg, v.Value) 357 if err != nil { 358 return nil, err 359 } 360 var dir model.ChanDir 361 if v.Dir == ast.SEND { 362 dir = model.SendDir 363 } 364 if v.Dir == ast.RECV { 365 dir = model.RecvDir 366 } 367 return &model.ChanType{Dir: dir, Type: t}, nil 368 case *ast.Ellipsis: 369 // assume we're parsing a variadic argument 370 return p.parseType(pkg, v.Elt) 371 case *ast.FuncType: 372 in, variadic, out, err := p.parseFunc(pkg, v) 373 if err != nil { 374 return nil, err 375 } 376 return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil 377 case *ast.Ident: 378 if v.IsExported() { 379 // `pkg` may be an aliased imported pkg 380 // if so, patch the import w/ the fully qualified import 381 maybeImportedPkg, ok := p.imports[pkg] 382 if ok { 383 pkg = maybeImportedPkg 384 } 385 // assume type in this package 386 return &model.NamedType{Package: pkg, Type: v.Name}, nil 387 } else { 388 // assume predeclared type 389 return model.PredeclaredType(v.Name), nil 390 } 391 case *ast.InterfaceType: 392 if v.Methods != nil && len(v.Methods.List) > 0 { 393 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types") 394 } 395 return model.PredeclaredType("interface{}"), nil 396 case *ast.MapType: 397 key, err := p.parseType(pkg, v.Key) 398 if err != nil { 399 return nil, err 400 } 401 value, err := p.parseType(pkg, v.Value) 402 if err != nil { 403 return nil, err 404 } 405 return &model.MapType{Key: key, Value: value}, nil 406 case *ast.SelectorExpr: 407 pkgName := v.X.(*ast.Ident).String() 408 pkg, ok := p.imports[pkgName] 409 if !ok { 410 return nil, p.errorf(v.Pos(), "unknown package %q", pkgName) 411 } 412 return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil 413 case *ast.StarExpr: 414 t, err := p.parseType(pkg, v.X) 415 if err != nil { 416 return nil, err 417 } 418 return &model.PointerType{Type: t}, nil 419 case *ast.StructType: 420 if v.Fields != nil && len(v.Fields.List) > 0 { 421 return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types") 422 } 423 return model.PredeclaredType("struct{}"), nil 424 } 425 426 return nil, fmt.Errorf("don't know how to parse type %T", typ) 427} 428 429// importsOfFile returns a map of package name to import path 430// of the imports in file. 431func importsOfFile(file *ast.File) map[string]string { 432 m := make(map[string]string) 433 for _, is := range file.Imports { 434 var pkgName string 435 importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes 436 437 if is.Name != nil { 438 // Named imports are always certain. 439 if is.Name.Name == "_" { 440 continue 441 } 442 pkgName = removeDot(is.Name.Name) 443 } else { 444 pkg, err := build.Import(importPath, "", 0) 445 if err != nil { 446 // Fallback to import path suffix. Note that this is uncertain. 447 _, last := path.Split(importPath) 448 // If the last path component has dots, the first dot-delimited 449 // field is used as the name. 450 pkgName = strings.SplitN(last, ".", 2)[0] 451 } else { 452 pkgName = pkg.Name 453 } 454 } 455 456 if _, ok := m[pkgName]; ok { 457 log.Fatalf("imported package collision: %q imported twice", pkgName) 458 } 459 m[pkgName] = importPath 460 } 461 return m 462} 463 464type namedInterface struct { 465 name *ast.Ident 466 it *ast.InterfaceType 467} 468 469// Create an iterator over all interfaces in file. 470func iterInterfaces(file *ast.File) <-chan namedInterface { 471 ch := make(chan namedInterface) 472 go func() { 473 for _, decl := range file.Decls { 474 gd, ok := decl.(*ast.GenDecl) 475 if !ok || gd.Tok != token.TYPE { 476 continue 477 } 478 for _, spec := range gd.Specs { 479 ts, ok := spec.(*ast.TypeSpec) 480 if !ok { 481 continue 482 } 483 it, ok := ts.Type.(*ast.InterfaceType) 484 if !ok { 485 continue 486 } 487 488 ch <- namedInterface{ts.Name, it} 489 } 490 } 491 close(ch) 492 }() 493 return ch 494} 495 496// isVariadic returns whether the function is variadic. 497func isVariadic(f *ast.FuncType) bool { 498 nargs := len(f.Params.List) 499 if nargs == 0 { 500 return false 501 } 502 _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) 503 return ok 504} 505