1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// The fiximports command fixes import declarations to use the canonical
6// import path for packages that have an "import comment" as defined by
7// https://golang.org/s/go14customimport.
8//
9//
10// Background
11//
12// The Go 1 custom import path mechanism lets the maintainer of a
13// package give it a stable name by which clients may import and "go
14// get" it, independent of the underlying version control system (such
15// as Git) or server (such as github.com) that hosts it.  Requests for
16// the custom name are redirected to the underlying name.  This allows
17// packages to be migrated from one underlying server or system to
18// another without breaking existing clients.
19//
20// Because this redirect mechanism creates aliases for existing
21// packages, it's possible for a single program to import the same
22// package by its canonical name and by an alias.  The resulting
23// executable will contain two copies of the package, which is wasteful
24// at best and incorrect at worst.
25//
26// To avoid this, "go build" reports an error if it encounters a special
27// comment like the one below, and if the import path in the comment
28// does not match the path of the enclosing package relative to
29// GOPATH/src:
30//
31//      $ grep ^package $GOPATH/src/github.com/bob/vanity/foo/foo.go
32// 	package foo // import "vanity.com/foo"
33//
34// The error from "go build" indicates that the package canonically
35// known as "vanity.com/foo" is locally installed under the
36// non-canonical name "github.com/bob/vanity/foo".
37//
38//
39// Usage
40//
41// When a package that you depend on introduces a custom import comment,
42// and your workspace imports it by the non-canonical name, your build
43// will stop working as soon as you update your copy of that package
44// using "go get -u".
45//
46// The purpose of the fiximports tool is to fix up all imports of the
47// non-canonical path within a Go workspace, replacing them with imports
48// of the canonical path.  Following a run of fiximports, the workspace
49// will no longer depend on the non-canonical copy of the package, so it
50// should be safe to delete.  It may be necessary to run "go get -u"
51// again to ensure that the package is locally installed under its
52// canonical path, if it was not already.
53//
54// The fiximports tool operates locally; it does not make HTTP requests
55// and does not discover new custom import comments.  It only operates
56// on non-canonical packages present in your workspace.
57//
58// The -baddomains flag is a list of domain names that should always be
59// considered non-canonical.  You can use this if you wish to make sure
60// that you no longer have any dependencies on packages from that
61// domain, even those that do not yet provide a canonical import path
62// comment.  For example, the default value of -baddomains includes the
63// moribund code hosting site code.google.com, so fiximports will report
64// an error for each import of a package from this domain remaining
65// after canonicalization.
66//
67// To see the changes fiximports would make without applying them, use
68// the -n flag.
69//
70package main
71
72import (
73	"bytes"
74	"encoding/json"
75	"flag"
76	"fmt"
77	"go/ast"
78	"go/build"
79	"go/format"
80	"go/parser"
81	"go/token"
82	"io"
83	"io/ioutil"
84	"log"
85	"os"
86	"os/exec"
87	"path"
88	"path/filepath"
89	"sort"
90	"strconv"
91	"strings"
92)
93
94// flags
95var (
96	dryrun     = flag.Bool("n", false, "dry run: show changes, but don't apply them")
97	badDomains = flag.String("baddomains", "code.google.com",
98		"a comma-separated list of domains from which packages should not be imported")
99	replaceFlag = flag.String("replace", "",
100		"a comma-separated list of noncanonical=canonical pairs of package paths.  If both items in a pair end with '...', they are treated as path prefixes.")
101)
102
103// seams for testing
104var (
105	stderr    io.Writer = os.Stderr
106	writeFile           = ioutil.WriteFile
107)
108
109const usage = `fiximports: rewrite import paths to use canonical package names.
110
111Usage: fiximports [-n] package...
112
113The package... arguments specify a list of packages
114in the style of the go tool; see "go help packages".
115Hint: use "all" or "..." to match the entire workspace.
116
117For details, see http://godoc.org/golang.org/x/tools/cmd/fiximports.
118
119Flags:
120  -n:	       dry run: show changes, but don't apply them
121  -baddomains  a comma-separated list of domains from which packages
122               should not be imported
123`
124
125func main() {
126	flag.Parse()
127
128	if len(flag.Args()) == 0 {
129		fmt.Fprint(stderr, usage)
130		os.Exit(1)
131	}
132	if !fiximports(flag.Args()...) {
133		os.Exit(1)
134	}
135}
136
137type canonicalName struct{ path, name string }
138
139// fiximports fixes imports in the specified packages.
140// Invariant: a false result implies an error was already printed.
141func fiximports(packages ...string) bool {
142	// importedBy is the transpose of the package import graph.
143	importedBy := make(map[string]map[*build.Package]bool)
144
145	// addEdge adds an edge to the import graph.
146	addEdge := func(from *build.Package, to string) {
147		if to == "C" || to == "unsafe" {
148			return // fake
149		}
150		pkgs := importedBy[to]
151		if pkgs == nil {
152			pkgs = make(map[*build.Package]bool)
153			importedBy[to] = pkgs
154		}
155		pkgs[from] = true
156	}
157
158	// List metadata for all packages in the workspace.
159	pkgs, err := list("...")
160	if err != nil {
161		fmt.Fprintf(stderr, "importfix: %v\n", err)
162		return false
163	}
164
165	// packageName maps each package's path to its name.
166	packageName := make(map[string]string)
167	for _, p := range pkgs {
168		packageName[p.ImportPath] = p.Package.Name
169	}
170
171	// canonical maps each non-canonical package path to
172	// its canonical path and name.
173	// A present nil value indicates that the canonical package
174	// is unknown: hosted on a bad domain with no redirect.
175	canonical := make(map[string]canonicalName)
176	domains := strings.Split(*badDomains, ",")
177
178	type replaceItem struct {
179		old, new    string
180		matchPrefix bool
181	}
182	var replace []replaceItem
183	for _, pair := range strings.Split(*replaceFlag, ",") {
184		if pair == "" {
185			continue
186		}
187		words := strings.Split(pair, "=")
188		if len(words) != 2 {
189			fmt.Fprintf(stderr, "importfix: -replace: %q is not of the form \"canonical=noncanonical\".\n", pair)
190			return false
191		}
192		replace = append(replace, replaceItem{
193			old: strings.TrimSuffix(words[0], "..."),
194			new: strings.TrimSuffix(words[1], "..."),
195			matchPrefix: strings.HasSuffix(words[0], "...") &&
196				strings.HasSuffix(words[1], "..."),
197		})
198	}
199
200	// Find non-canonical packages and populate importedBy graph.
201	for _, p := range pkgs {
202		if p.Error != nil {
203			msg := p.Error.Err
204			if strings.Contains(msg, "code in directory") &&
205				strings.Contains(msg, "expects import") {
206				// don't show the very errors we're trying to fix
207			} else {
208				fmt.Fprintln(stderr, p.Error)
209			}
210		}
211
212		for _, imp := range p.Imports {
213			addEdge(&p.Package, imp)
214		}
215		for _, imp := range p.TestImports {
216			addEdge(&p.Package, imp)
217		}
218		for _, imp := range p.XTestImports {
219			addEdge(&p.Package, imp)
220		}
221
222		// Does package have an explicit import comment?
223		if p.ImportComment != "" {
224			if p.ImportComment != p.ImportPath {
225				canonical[p.ImportPath] = canonicalName{
226					path: p.Package.ImportComment,
227					name: p.Package.Name,
228				}
229			}
230		} else {
231			// Is package matched by a -replace item?
232			var newPath string
233			for _, item := range replace {
234				if item.matchPrefix {
235					if strings.HasPrefix(p.ImportPath, item.old) {
236						newPath = item.new + p.ImportPath[len(item.old):]
237						break
238					}
239				} else if p.ImportPath == item.old {
240					newPath = item.new
241					break
242				}
243			}
244			if newPath != "" {
245				newName := packageName[newPath]
246				if newName == "" {
247					newName = filepath.Base(newPath) // a guess
248				}
249				canonical[p.ImportPath] = canonicalName{
250					path: newPath,
251					name: newName,
252				}
253				continue
254			}
255
256			// Is package matched by a -baddomains item?
257			for _, domain := range domains {
258				slash := strings.Index(p.ImportPath, "/")
259				if slash < 0 {
260					continue // no slash: standard package
261				}
262				if p.ImportPath[:slash] == domain {
263					// Package comes from bad domain and has no import comment.
264					// Report an error each time this package is imported.
265					canonical[p.ImportPath] = canonicalName{}
266
267					// TODO(adonovan): should we make an HTTP request to
268					// see if there's an HTTP redirect, a "go-import" meta tag,
269					// or an import comment in the the latest revision?
270					// It would duplicate a lot of logic from "go get".
271				}
272				break
273			}
274		}
275	}
276
277	// Find all clients (direct importers) of canonical packages.
278	// These are the packages that need fixing up.
279	clients := make(map[*build.Package]bool)
280	for path := range canonical {
281		for client := range importedBy[path] {
282			clients[client] = true
283		}
284	}
285
286	// Restrict rewrites to the set of packages specified by the user.
287	if len(packages) == 1 && (packages[0] == "all" || packages[0] == "...") {
288		// no restriction
289	} else {
290		pkgs, err := list(packages...)
291		if err != nil {
292			fmt.Fprintf(stderr, "importfix: %v\n", err)
293			return false
294		}
295		seen := make(map[string]bool)
296		for _, p := range pkgs {
297			seen[p.ImportPath] = true
298		}
299		for client := range clients {
300			if !seen[client.ImportPath] {
301				delete(clients, client)
302			}
303		}
304	}
305
306	// Rewrite selected client packages.
307	ok := true
308	for client := range clients {
309		if !rewritePackage(client, canonical) {
310			ok = false
311
312			// There were errors.
313			// Show direct and indirect imports of client.
314			seen := make(map[string]bool)
315			var direct, indirect []string
316			for p := range importedBy[client.ImportPath] {
317				direct = append(direct, p.ImportPath)
318				seen[p.ImportPath] = true
319			}
320
321			var visit func(path string)
322			visit = func(path string) {
323				for q := range importedBy[path] {
324					qpath := q.ImportPath
325					if !seen[qpath] {
326						seen[qpath] = true
327						indirect = append(indirect, qpath)
328						visit(qpath)
329					}
330				}
331			}
332
333			if direct != nil {
334				fmt.Fprintf(stderr, "\timported directly by:\n")
335				sort.Strings(direct)
336				for _, path := range direct {
337					fmt.Fprintf(stderr, "\t\t%s\n", path)
338					visit(path)
339				}
340
341				if indirect != nil {
342					fmt.Fprintf(stderr, "\timported indirectly by:\n")
343					sort.Strings(indirect)
344					for _, path := range indirect {
345						fmt.Fprintf(stderr, "\t\t%s\n", path)
346					}
347				}
348			}
349		}
350	}
351
352	return ok
353}
354
355// Invariant: false result => error already printed.
356func rewritePackage(client *build.Package, canonical map[string]canonicalName) bool {
357	ok := true
358
359	used := make(map[string]bool)
360	var filenames []string
361	filenames = append(filenames, client.GoFiles...)
362	filenames = append(filenames, client.TestGoFiles...)
363	filenames = append(filenames, client.XTestGoFiles...)
364	var first bool
365	for _, filename := range filenames {
366		if !first {
367			first = true
368			fmt.Fprintf(stderr, "%s\n", client.ImportPath)
369		}
370		err := rewriteFile(filepath.Join(client.Dir, filename), canonical, used)
371		if err != nil {
372			fmt.Fprintf(stderr, "\tERROR: %v\n", err)
373			ok = false
374		}
375	}
376
377	// Show which imports were renamed in this package.
378	var keys []string
379	for key := range used {
380		keys = append(keys, key)
381	}
382	sort.Strings(keys)
383	for _, key := range keys {
384		if p := canonical[key]; p.path != "" {
385			fmt.Fprintf(stderr, "\tfixed: %s -> %s\n", key, p.path)
386		} else {
387			fmt.Fprintf(stderr, "\tERROR: %s has no import comment\n", key)
388			ok = false
389		}
390	}
391
392	return ok
393}
394
395// rewrite reads, modifies, and writes filename, replacing all imports
396// of packages P in canonical by canonical[P].
397// It records in used which canonical packages were imported.
398// used[P]=="" indicates that P was imported but its canonical path is unknown.
399func rewriteFile(filename string, canonical map[string]canonicalName, used map[string]bool) error {
400	fset := token.NewFileSet()
401	f, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
402	if err != nil {
403		return err
404	}
405	var changed bool
406	for _, imp := range f.Imports {
407		impPath, err := strconv.Unquote(imp.Path.Value)
408		if err != nil {
409			log.Printf("%s: bad import spec %q: %v",
410				fset.Position(imp.Pos()), imp.Path.Value, err)
411			continue
412		}
413		canon, ok := canonical[impPath]
414		if !ok {
415			continue // import path is canonical
416		}
417
418		used[impPath] = true
419
420		if canon.path == "" {
421			// The canonical path is unknown (a -baddomain).
422			// Show the offending import.
423			// TODO(adonovan): should we show the actual source text?
424			fmt.Fprintf(stderr, "\t%s:%d: import %q\n",
425				shortPath(filename),
426				fset.Position(imp.Pos()).Line, impPath)
427			continue
428		}
429
430		changed = true
431
432		imp.Path.Value = strconv.Quote(canon.path)
433
434		// Add a renaming import if necessary.
435		//
436		// This is a guess at best.  We can't see whether a 'go
437		// get' of the canonical import path would have the same
438		// name or not.  Assume it's the last segment.
439		newBase := path.Base(canon.path)
440		if imp.Name == nil && newBase != canon.name {
441			imp.Name = &ast.Ident{Name: canon.name}
442		}
443	}
444
445	if changed && !*dryrun {
446		var buf bytes.Buffer
447		if err := format.Node(&buf, fset, f); err != nil {
448			return fmt.Errorf("%s: couldn't format file: %v", filename, err)
449		}
450		return writeFile(filename, buf.Bytes(), 0644)
451	}
452
453	return nil
454}
455
456// listPackage is a copy of cmd/go/list.Package.
457// It has more fields than build.Package and we need some of them.
458type listPackage struct {
459	build.Package
460	Error *packageError // error loading package
461}
462
463// A packageError describes an error loading information about a package.
464type packageError struct {
465	ImportStack []string // shortest path from package named on command line to this one
466	Pos         string   // position of error
467	Err         string   // the error itself
468}
469
470func (e packageError) Error() string {
471	if e.Pos != "" {
472		return e.Pos + ": " + e.Err
473	}
474	return e.Err
475}
476
477// list runs 'go list' with the specified arguments and returns the
478// metadata for matching packages.
479func list(args ...string) ([]*listPackage, error) {
480	cmd := exec.Command("go", append([]string{"list", "-e", "-json"}, args...)...)
481	cmd.Stdout = new(bytes.Buffer)
482	cmd.Stderr = stderr
483	if err := cmd.Run(); err != nil {
484		return nil, err
485	}
486
487	dec := json.NewDecoder(cmd.Stdout.(io.Reader))
488	var pkgs []*listPackage
489	for {
490		var p listPackage
491		if err := dec.Decode(&p); err == io.EOF {
492			break
493		} else if err != nil {
494			return nil, err
495		}
496		pkgs = append(pkgs, &p)
497	}
498	return pkgs, nil
499}
500
501// cwd contains the current working directory of the tool.
502//
503// It is initialized directly so that its value will be set for any other
504// package variables or init functions that depend on it, such as the gopath
505// variable in main_test.go.
506var cwd string = func() string {
507	cwd, err := os.Getwd()
508	if err != nil {
509		log.Fatalf("os.Getwd: %v", err)
510	}
511	return cwd
512}()
513
514// shortPath returns an absolute or relative name for path, whatever is shorter.
515// Plundered from $GOROOT/src/cmd/go/build.go.
516func shortPath(path string) string {
517	if rel, err := filepath.Rel(cwd, path); err == nil && len(rel) < len(path) {
518		return rel
519	}
520	return path
521}
522