1// Copyright 2010 Google Inc. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// MockGen generates mock implementations of Go interfaces. 16package main 17 18// TODO: This does not support recursive embedded interfaces. 19// TODO: This does not support embedding package-local interfaces in a separate file. 20 21import ( 22 "bytes" 23 "encoding/json" 24 "flag" 25 "fmt" 26 "go/token" 27 "io" 28 "io/ioutil" 29 "log" 30 "os" 31 "os/exec" 32 "path" 33 "path/filepath" 34 "sort" 35 "strconv" 36 "strings" 37 "unicode" 38 39 "github.com/golang/mock/mockgen/model" 40 41 toolsimports "golang.org/x/tools/imports" 42) 43 44const ( 45 gomockImportPath = "github.com/golang/mock/gomock" 46) 47 48var ( 49 version = "" 50 commit = "none" 51 date = "unknown" 52) 53 54var ( 55 source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") 56 destination = flag.String("destination", "", "Output file; defaults to stdout.") 57 mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") 58 packageOut = flag.String("package", "", "Package of the generated code; defaults to the package of the input with a 'mock_' prefix.") 59 selfPackage = flag.String("self_package", "", "The full package import path for the generated code. The purpose of this flag is to prevent import cycles in the generated code by trying to include its own package. This can happen if the mock's package is set to one of its inputs (usually the main one) and the output is stdio so mockgen cannot detect the final output package. Setting this flag will then tell mockgen which import to exclude.") 60 writePkgComment = flag.Bool("write_package_comment", true, "Writes package documentation comment (godoc) if true.") 61 copyrightFile = flag.String("copyright_file", "", "Copyright file used to add copyright header") 62 63 debugParser = flag.Bool("debug_parser", false, "Print out parser results only.") 64 showVersion = flag.Bool("version", false, "Print version.") 65) 66 67func main() { 68 flag.Usage = usage 69 flag.Parse() 70 71 if *showVersion { 72 printVersion() 73 return 74 } 75 76 var pkg *model.Package 77 var err error 78 var packageName string 79 if *source != "" { 80 pkg, err = sourceMode(*source) 81 } else { 82 if flag.NArg() != 2 { 83 usage() 84 log.Fatal("Expected exactly two arguments") 85 } 86 packageName = flag.Arg(0) 87 if packageName == "." { 88 dir, err := os.Getwd() 89 if err != nil { 90 log.Fatalf("Get current directory failed: %v", err) 91 } 92 packageName, err = packageNameOfDir(dir) 93 if err != nil { 94 log.Fatalf("Parse package name failed: %v", err) 95 } 96 } 97 pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ",")) 98 } 99 if err != nil { 100 log.Fatalf("Loading input failed: %v", err) 101 } 102 103 if *debugParser { 104 pkg.Print(os.Stdout) 105 return 106 } 107 108 dst := os.Stdout 109 if len(*destination) > 0 { 110 if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil { 111 log.Fatalf("Unable to create directory: %v", err) 112 } 113 f, err := os.Create(*destination) 114 if err != nil { 115 log.Fatalf("Failed opening destination file: %v", err) 116 } 117 defer f.Close() 118 dst = f 119 } 120 121 outputPackageName := *packageOut 122 if outputPackageName == "" { 123 // pkg.Name in reflect mode is the base name of the import path, 124 // which might have characters that are illegal to have in package names. 125 outputPackageName = "mock_" + sanitize(pkg.Name) 126 } 127 128 // outputPackagePath represents the fully qualified name of the package of 129 // the generated code. Its purposes are to prevent the module from importing 130 // itself and to prevent qualifying type names that come from its own 131 // package (i.e. if there is a type called X then we want to print "X" not 132 // "package.X" since "package" is this package). This can happen if the mock 133 // is output into an already existing package. 134 outputPackagePath := *selfPackage 135 if outputPackagePath == "" && *destination != "" { 136 dstPath, err := filepath.Abs(filepath.Dir(*destination)) 137 if err != nil { 138 log.Fatalf("Unable to determine destination file path: %v", err) 139 } 140 outputPackagePath, err = parsePackageImport(dstPath) 141 if err != nil { 142 log.Fatalf("Unable to determine destination file path: %v", err) 143 } 144 } 145 146 g := new(generator) 147 if *source != "" { 148 g.filename = *source 149 } else { 150 g.srcPackage = packageName 151 g.srcInterfaces = flag.Arg(1) 152 } 153 g.destination = *destination 154 155 if *mockNames != "" { 156 g.mockNames = parseMockNames(*mockNames) 157 } 158 if *copyrightFile != "" { 159 header, err := ioutil.ReadFile(*copyrightFile) 160 if err != nil { 161 log.Fatalf("Failed reading copyright file: %v", err) 162 } 163 164 g.copyrightHeader = string(header) 165 } 166 if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil { 167 log.Fatalf("Failed generating mock: %v", err) 168 } 169 if _, err := dst.Write(g.Output()); err != nil { 170 log.Fatalf("Failed writing to destination: %v", err) 171 } 172} 173 174func parseMockNames(names string) map[string]string { 175 mocksMap := make(map[string]string) 176 for _, kv := range strings.Split(names, ",") { 177 parts := strings.SplitN(kv, "=", 2) 178 if len(parts) != 2 || parts[1] == "" { 179 log.Fatalf("bad mock names spec: %v", kv) 180 } 181 mocksMap[parts[0]] = parts[1] 182 } 183 return mocksMap 184} 185 186func usage() { 187 _, _ = io.WriteString(os.Stderr, usageText) 188 flag.PrintDefaults() 189} 190 191const usageText = `mockgen has two modes of operation: source and reflect. 192 193Source mode generates mock interfaces from a source file. 194It is enabled by using the -source flag. Other flags that 195may be useful in this mode are -imports and -aux_files. 196Example: 197 mockgen -source=foo.go [other options] 198 199Reflect mode generates mock interfaces by building a program 200that uses reflection to understand interfaces. It is enabled 201by passing two non-flag arguments: an import path, and a 202comma-separated list of symbols. 203Example: 204 mockgen database/sql/driver Conn,Driver 205 206` 207 208type generator struct { 209 buf bytes.Buffer 210 indent string 211 mockNames map[string]string // may be empty 212 filename string // may be empty 213 destination string // may be empty 214 srcPackage, srcInterfaces string // may be empty 215 copyrightHeader string 216 217 packageMap map[string]string // map from import path to package name 218} 219 220func (g *generator) p(format string, args ...interface{}) { 221 fmt.Fprintf(&g.buf, g.indent+format+"\n", args...) 222} 223 224func (g *generator) in() { 225 g.indent += "\t" 226} 227 228func (g *generator) out() { 229 if len(g.indent) > 0 { 230 g.indent = g.indent[0 : len(g.indent)-1] 231 } 232} 233 234func removeDot(s string) string { 235 if len(s) > 0 && s[len(s)-1] == '.' { 236 return s[0 : len(s)-1] 237 } 238 return s 239} 240 241// sanitize cleans up a string to make a suitable package name. 242func sanitize(s string) string { 243 t := "" 244 for _, r := range s { 245 if t == "" { 246 if unicode.IsLetter(r) || r == '_' { 247 t += string(r) 248 continue 249 } 250 } else { 251 if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' { 252 t += string(r) 253 continue 254 } 255 } 256 t += "_" 257 } 258 if t == "_" { 259 t = "x" 260 } 261 return t 262} 263 264func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error { 265 if outputPkgName != pkg.Name && *selfPackage == "" { 266 // reset outputPackagePath if it's not passed in through -self_package 267 outputPackagePath = "" 268 } 269 270 if g.copyrightHeader != "" { 271 lines := strings.Split(g.copyrightHeader, "\n") 272 for _, line := range lines { 273 g.p("// %s", line) 274 } 275 g.p("") 276 } 277 278 g.p("// Code generated by MockGen. DO NOT EDIT.") 279 if g.filename != "" { 280 g.p("// Source: %v", g.filename) 281 } else { 282 g.p("// Source: %v (interfaces: %v)", g.srcPackage, g.srcInterfaces) 283 } 284 g.p("") 285 286 // Get all required imports, and generate unique names for them all. 287 im := pkg.Imports() 288 im[gomockImportPath] = true 289 290 // Only import reflect if it's used. We only use reflect in mocked methods 291 // so only import if any of the mocked interfaces have methods. 292 for _, intf := range pkg.Interfaces { 293 if len(intf.Methods) > 0 { 294 im["reflect"] = true 295 break 296 } 297 } 298 299 // Sort keys to make import alias generation predictable 300 sortedPaths := make([]string, len(im)) 301 x := 0 302 for pth := range im { 303 sortedPaths[x] = pth 304 x++ 305 } 306 sort.Strings(sortedPaths) 307 308 packagesName := createPackageMap(sortedPaths) 309 310 g.packageMap = make(map[string]string, len(im)) 311 localNames := make(map[string]bool, len(im)) 312 for _, pth := range sortedPaths { 313 base, ok := packagesName[pth] 314 if !ok { 315 base = sanitize(path.Base(pth)) 316 } 317 318 // Local names for an imported package can usually be the basename of the import path. 319 // A couple of situations don't permit that, such as duplicate local names 320 // (e.g. importing "html/template" and "text/template"), or where the basename is 321 // a keyword (e.g. "foo/case"). 322 // try base0, base1, ... 323 pkgName := base 324 i := 0 325 for localNames[pkgName] || token.Lookup(pkgName).IsKeyword() { 326 pkgName = base + strconv.Itoa(i) 327 i++ 328 } 329 330 // Avoid importing package if source pkg == output pkg 331 if pth == pkg.PkgPath && outputPackagePath == pkg.PkgPath { 332 continue 333 } 334 335 g.packageMap[pth] = pkgName 336 localNames[pkgName] = true 337 } 338 339 if *writePkgComment { 340 g.p("// Package %v is a generated GoMock package.", outputPkgName) 341 } 342 g.p("package %v", outputPkgName) 343 g.p("") 344 g.p("import (") 345 g.in() 346 for pkgPath, pkgName := range g.packageMap { 347 if pkgPath == outputPackagePath { 348 continue 349 } 350 g.p("%v %q", pkgName, pkgPath) 351 } 352 for _, pkgPath := range pkg.DotImports { 353 g.p(". %q", pkgPath) 354 } 355 g.out() 356 g.p(")") 357 358 for _, intf := range pkg.Interfaces { 359 if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil { 360 return err 361 } 362 } 363 364 return nil 365} 366 367// The name of the mock type to use for the given interface identifier. 368func (g *generator) mockName(typeName string) string { 369 if mockName, ok := g.mockNames[typeName]; ok { 370 return mockName 371 } 372 373 return "Mock" + typeName 374} 375 376func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error { 377 mockType := g.mockName(intf.Name) 378 379 g.p("") 380 g.p("// %v is a mock of %v interface.", mockType, intf.Name) 381 g.p("type %v struct {", mockType) 382 g.in() 383 g.p("ctrl *gomock.Controller") 384 g.p("recorder *%vMockRecorder", mockType) 385 g.out() 386 g.p("}") 387 g.p("") 388 389 g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType) 390 g.p("type %vMockRecorder struct {", mockType) 391 g.in() 392 g.p("mock *%v", mockType) 393 g.out() 394 g.p("}") 395 g.p("") 396 397 // TODO: Re-enable this if we can import the interface reliably. 398 // g.p("// Verify that the mock satisfies the interface at compile time.") 399 // g.p("var _ %v = (*%v)(nil)", typeName, mockType) 400 // g.p("") 401 402 g.p("// New%v creates a new mock instance.", mockType) 403 g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType) 404 g.in() 405 g.p("mock := &%v{ctrl: ctrl}", mockType) 406 g.p("mock.recorder = &%vMockRecorder{mock}", mockType) 407 g.p("return mock") 408 g.out() 409 g.p("}") 410 g.p("") 411 412 // XXX: possible name collision here if someone has EXPECT in their interface. 413 g.p("// EXPECT returns an object that allows the caller to indicate expected use.") 414 g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType) 415 g.in() 416 g.p("return m.recorder") 417 g.out() 418 g.p("}") 419 420 g.GenerateMockMethods(mockType, intf, outputPackagePath) 421 422 return nil 423} 424 425type byMethodName []*model.Method 426 427func (b byMethodName) Len() int { return len(b) } 428func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] } 429func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name } 430 431func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) { 432 sort.Sort(byMethodName(intf.Methods)) 433 for _, m := range intf.Methods { 434 g.p("") 435 _ = g.GenerateMockMethod(mockType, m, pkgOverride) 436 g.p("") 437 _ = g.GenerateMockRecorderMethod(mockType, m) 438 } 439} 440 441func makeArgString(argNames, argTypes []string) string { 442 args := make([]string, len(argNames)) 443 for i, name := range argNames { 444 // specify the type only once for consecutive args of the same type 445 if i+1 < len(argTypes) && argTypes[i] == argTypes[i+1] { 446 args[i] = name 447 } else { 448 args[i] = name + " " + argTypes[i] 449 } 450 } 451 return strings.Join(args, ", ") 452} 453 454// GenerateMockMethod generates a mock method implementation. 455// If non-empty, pkgOverride is the package in which unqualified types reside. 456func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error { 457 argNames := g.getArgNames(m) 458 argTypes := g.getArgTypes(m, pkgOverride) 459 argString := makeArgString(argNames, argTypes) 460 461 rets := make([]string, len(m.Out)) 462 for i, p := range m.Out { 463 rets[i] = p.Type.String(g.packageMap, pkgOverride) 464 } 465 retString := strings.Join(rets, ", ") 466 if len(rets) > 1 { 467 retString = "(" + retString + ")" 468 } 469 if retString != "" { 470 retString = " " + retString 471 } 472 473 ia := newIdentifierAllocator(argNames) 474 idRecv := ia.allocateIdentifier("m") 475 476 g.p("// %v mocks base method.", m.Name) 477 g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString) 478 g.in() 479 g.p("%s.ctrl.T.Helper()", idRecv) 480 481 var callArgs string 482 if m.Variadic == nil { 483 if len(argNames) > 0 { 484 callArgs = ", " + strings.Join(argNames, ", ") 485 } 486 } else { 487 // Non-trivial. The generated code must build a []interface{}, 488 // but the variadic argument may be any type. 489 idVarArgs := ia.allocateIdentifier("varargs") 490 idVArg := ia.allocateIdentifier("a") 491 g.p("%s := []interface{}{%s}", idVarArgs, strings.Join(argNames[:len(argNames)-1], ", ")) 492 g.p("for _, %s := range %s {", idVArg, argNames[len(argNames)-1]) 493 g.in() 494 g.p("%s = append(%s, %s)", idVarArgs, idVarArgs, idVArg) 495 g.out() 496 g.p("}") 497 callArgs = ", " + idVarArgs + "..." 498 } 499 if len(m.Out) == 0 { 500 g.p(`%v.ctrl.Call(%v, %q%v)`, idRecv, idRecv, m.Name, callArgs) 501 } else { 502 idRet := ia.allocateIdentifier("ret") 503 g.p(`%v := %v.ctrl.Call(%v, %q%v)`, idRet, idRecv, idRecv, m.Name, callArgs) 504 505 // Go does not allow "naked" type assertions on nil values, so we use the two-value form here. 506 // The value of that is either (x.(T), true) or (Z, false), where Z is the zero value for T. 507 // Happily, this coincides with the semantics we want here. 508 retNames := make([]string, len(rets)) 509 for i, t := range rets { 510 retNames[i] = ia.allocateIdentifier(fmt.Sprintf("ret%d", i)) 511 g.p("%s, _ := %s[%d].(%s)", retNames[i], idRet, i, t) 512 } 513 g.p("return " + strings.Join(retNames, ", ")) 514 } 515 516 g.out() 517 g.p("}") 518 return nil 519} 520 521func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error { 522 argNames := g.getArgNames(m) 523 524 var argString string 525 if m.Variadic == nil { 526 argString = strings.Join(argNames, ", ") 527 } else { 528 argString = strings.Join(argNames[:len(argNames)-1], ", ") 529 } 530 if argString != "" { 531 argString += " interface{}" 532 } 533 534 if m.Variadic != nil { 535 if argString != "" { 536 argString += ", " 537 } 538 argString += fmt.Sprintf("%s ...interface{}", argNames[len(argNames)-1]) 539 } 540 541 ia := newIdentifierAllocator(argNames) 542 idRecv := ia.allocateIdentifier("mr") 543 544 g.p("// %v indicates an expected call of %v.", m.Name, m.Name) 545 g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString) 546 g.in() 547 g.p("%s.mock.ctrl.T.Helper()", idRecv) 548 549 var callArgs string 550 if m.Variadic == nil { 551 if len(argNames) > 0 { 552 callArgs = ", " + strings.Join(argNames, ", ") 553 } 554 } else { 555 if len(argNames) == 1 { 556 // Easy: just use ... to push the arguments through. 557 callArgs = ", " + argNames[0] + "..." 558 } else { 559 // Hard: create a temporary slice. 560 idVarArgs := ia.allocateIdentifier("varargs") 561 g.p("%s := append([]interface{}{%s}, %s...)", 562 idVarArgs, 563 strings.Join(argNames[:len(argNames)-1], ", "), 564 argNames[len(argNames)-1]) 565 callArgs = ", " + idVarArgs + "..." 566 } 567 } 568 g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs) 569 570 g.out() 571 g.p("}") 572 return nil 573} 574 575func (g *generator) getArgNames(m *model.Method) []string { 576 argNames := make([]string, len(m.In)) 577 for i, p := range m.In { 578 name := p.Name 579 if name == "" || name == "_" { 580 name = fmt.Sprintf("arg%d", i) 581 } 582 argNames[i] = name 583 } 584 if m.Variadic != nil { 585 name := m.Variadic.Name 586 if name == "" { 587 name = fmt.Sprintf("arg%d", len(m.In)) 588 } 589 argNames = append(argNames, name) 590 } 591 return argNames 592} 593 594func (g *generator) getArgTypes(m *model.Method, pkgOverride string) []string { 595 argTypes := make([]string, len(m.In)) 596 for i, p := range m.In { 597 argTypes[i] = p.Type.String(g.packageMap, pkgOverride) 598 } 599 if m.Variadic != nil { 600 argTypes = append(argTypes, "..."+m.Variadic.Type.String(g.packageMap, pkgOverride)) 601 } 602 return argTypes 603} 604 605type identifierAllocator map[string]struct{} 606 607func newIdentifierAllocator(taken []string) identifierAllocator { 608 a := make(identifierAllocator, len(taken)) 609 for _, s := range taken { 610 a[s] = struct{}{} 611 } 612 return a 613} 614 615func (o identifierAllocator) allocateIdentifier(want string) string { 616 id := want 617 for i := 2; ; i++ { 618 if _, ok := o[id]; !ok { 619 o[id] = struct{}{} 620 return id 621 } 622 id = want + "_" + strconv.Itoa(i) 623 } 624} 625 626// Output returns the generator's output, formatted in the standard Go style. 627func (g *generator) Output() []byte { 628 src, err := toolsimports.Process(g.destination, g.buf.Bytes(), nil) 629 if err != nil { 630 log.Fatalf("Failed to format generated source code: %s\n%s", err, g.buf.String()) 631 } 632 return src 633} 634 635// createPackageMap returns a map of import path to package name 636// for specified importPaths. 637func createPackageMap(importPaths []string) map[string]string { 638 var pkg struct { 639 Name string 640 ImportPath string 641 } 642 pkgMap := make(map[string]string) 643 b := bytes.NewBuffer(nil) 644 args := []string{"list", "-json"} 645 args = append(args, importPaths...) 646 cmd := exec.Command("go", args...) 647 cmd.Stdout = b 648 cmd.Run() 649 dec := json.NewDecoder(b) 650 for dec.More() { 651 err := dec.Decode(&pkg) 652 if err != nil { 653 log.Printf("failed to decode 'go list' output: %v", err) 654 continue 655 } 656 pkgMap[pkg.ImportPath] = pkg.Name 657 } 658 return pkgMap 659} 660 661func printVersion() { 662 if version != "" { 663 fmt.Printf("v%s\nCommit: %s\nDate: %s\n", version, commit, date) 664 } else { 665 printModuleVersion() 666 } 667} 668