1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Stringer is a tool to automate the creation of methods that satisfy the fmt.Stringer
6// interface. Given the name of a (signed or unsigned) integer type T that has constants
7// defined, stringer will create a new self-contained Go source file implementing
8//	func (t T) String() string
9// The file is created in the same package and directory as the package that defines T.
10// It has helpful defaults designed for use with go generate.
11//
12// Stringer works best with constants that are consecutive values such as created using iota,
13// but creates good code regardless. In the future it might also provide custom support for
14// constant sets that are bit patterns.
15//
16// For example, given this snippet,
17//
18//	package painkiller
19//
20//	type Pill int
21//
22//	const (
23//		Placebo Pill = iota
24//		Aspirin
25//		Ibuprofen
26//		Paracetamol
27//		Acetaminophen = Paracetamol
28//	)
29//
30// running this command
31//
32//	stringer -type=Pill
33//
34// in the same directory will create the file pill_string.go, in package painkiller,
35// containing a definition of
36//
37//	func (Pill) String() string
38//
39// That method will translate the value of a Pill constant to the string representation
40// of the respective constant name, so that the call fmt.Print(painkiller.Aspirin) will
41// print the string "Aspirin".
42//
43// Typically this process would be run using go generate, like this:
44//
45//	//go:generate stringer -type=Pill
46//
47// If multiple constants have the same value, the lexically first matching name will
48// be used (in the example, Acetaminophen will print as "Paracetamol").
49//
50// With no arguments, it processes the package in the current directory.
51// Otherwise, the arguments must name a single directory holding a Go package
52// or a set of Go source files that represent a single Go package.
53//
54// The -type flag accepts a comma-separated list of types so a single run can
55// generate methods for multiple types. The default output file is t_string.go,
56// where t is the lower-cased name of the first type listed. It can be overridden
57// with the -output flag.
58//
59// The -linecomment flag tells stringer to generate the text of any line comment, trimmed
60// of leading spaces, instead of the constant name. For instance, if the constants above had a
61// Pill prefix, one could write
62//
63//	PillAspirin // Aspirin
64//
65// to suppress it in the output.
66package main // import "golang.org/x/tools/cmd/stringer"
67
68import (
69	"bytes"
70	"flag"
71	"fmt"
72	"go/ast"
73	"go/constant"
74	"go/format"
75	"go/token"
76	"go/types"
77	"io/ioutil"
78	"log"
79	"os"
80	"path/filepath"
81	"sort"
82	"strings"
83
84	"golang.org/x/tools/go/packages"
85)
86
87var (
88	typeNames   = flag.String("type", "", "comma-separated list of type names; must be set")
89	output      = flag.String("output", "", "output file name; default srcdir/<type>_string.go")
90	trimprefix  = flag.String("trimprefix", "", "trim the `prefix` from the generated constant names")
91	linecomment = flag.Bool("linecomment", false, "use line comment text as printed text when present")
92	buildTags   = flag.String("tags", "", "comma-separated list of build tags to apply")
93)
94
95// Usage is a replacement usage function for the flags package.
96func Usage() {
97	fmt.Fprintf(os.Stderr, "Usage of stringer:\n")
98	fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T [directory]\n")
99	fmt.Fprintf(os.Stderr, "\tstringer [flags] -type T files... # Must be a single package\n")
100	fmt.Fprintf(os.Stderr, "For more information, see:\n")
101	fmt.Fprintf(os.Stderr, "\thttp://godoc.org/golang.org/x/tools/cmd/stringer\n")
102	fmt.Fprintf(os.Stderr, "Flags:\n")
103	flag.PrintDefaults()
104}
105
106func main() {
107	log.SetFlags(0)
108	log.SetPrefix("stringer: ")
109	flag.Usage = Usage
110	flag.Parse()
111	if len(*typeNames) == 0 {
112		flag.Usage()
113		os.Exit(2)
114	}
115	types := strings.Split(*typeNames, ",")
116	var tags []string
117	if len(*buildTags) > 0 {
118		tags = strings.Split(*buildTags, ",")
119	}
120
121	// We accept either one directory or a list of files. Which do we have?
122	args := flag.Args()
123	if len(args) == 0 {
124		// Default: process whole package in current directory.
125		args = []string{"."}
126	}
127
128	// Parse the package once.
129	var dir string
130	g := Generator{
131		trimPrefix:  *trimprefix,
132		lineComment: *linecomment,
133	}
134	// TODO(suzmue): accept other patterns for packages (directories, list of files, import paths, etc).
135	if len(args) == 1 && isDirectory(args[0]) {
136		dir = args[0]
137	} else {
138		if len(tags) != 0 {
139			log.Fatal("-tags option applies only to directories, not when files are specified")
140		}
141		dir = filepath.Dir(args[0])
142	}
143
144	g.parsePackage(args, tags)
145
146	// Print the header and package clause.
147	g.Printf("// Code generated by \"stringer %s\"; DO NOT EDIT.\n", strings.Join(os.Args[1:], " "))
148	g.Printf("\n")
149	g.Printf("package %s", g.pkg.name)
150	g.Printf("\n")
151	g.Printf("import \"strconv\"\n") // Used by all methods.
152
153	// Run generate for each type.
154	for _, typeName := range types {
155		g.generate(typeName)
156	}
157
158	// Format the output.
159	src := g.format()
160
161	// Write to file.
162	outputName := *output
163	if outputName == "" {
164		baseName := fmt.Sprintf("%s_string.go", types[0])
165		outputName = filepath.Join(dir, strings.ToLower(baseName))
166	}
167	err := ioutil.WriteFile(outputName, src, 0644)
168	if err != nil {
169		log.Fatalf("writing output: %s", err)
170	}
171}
172
173// isDirectory reports whether the named file is a directory.
174func isDirectory(name string) bool {
175	info, err := os.Stat(name)
176	if err != nil {
177		log.Fatal(err)
178	}
179	return info.IsDir()
180}
181
182// Generator holds the state of the analysis. Primarily used to buffer
183// the output for format.Source.
184type Generator struct {
185	buf bytes.Buffer // Accumulated output.
186	pkg *Package     // Package we are scanning.
187
188	trimPrefix  string
189	lineComment bool
190}
191
192func (g *Generator) Printf(format string, args ...interface{}) {
193	fmt.Fprintf(&g.buf, format, args...)
194}
195
196// File holds a single parsed file and associated data.
197type File struct {
198	pkg  *Package  // Package to which this file belongs.
199	file *ast.File // Parsed AST.
200	// These fields are reset for each type being generated.
201	typeName string  // Name of the constant type.
202	values   []Value // Accumulator for constant values of that type.
203
204	trimPrefix  string
205	lineComment bool
206}
207
208type Package struct {
209	name  string
210	defs  map[*ast.Ident]types.Object
211	files []*File
212}
213
214// parsePackage analyzes the single package constructed from the patterns and tags.
215// parsePackage exits if there is an error.
216func (g *Generator) parsePackage(patterns []string, tags []string) {
217	cfg := &packages.Config{
218		Mode: packages.LoadSyntax,
219		// TODO: Need to think about constants in test files. Maybe write type_string_test.go
220		// in a separate pass? For later.
221		Tests:      false,
222		BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(tags, " "))},
223	}
224	pkgs, err := packages.Load(cfg, patterns...)
225	if err != nil {
226		log.Fatal(err)
227	}
228	if len(pkgs) != 1 {
229		log.Fatalf("error: %d packages found", len(pkgs))
230	}
231	g.addPackage(pkgs[0])
232}
233
234// addPackage adds a type checked Package and its syntax files to the generator.
235func (g *Generator) addPackage(pkg *packages.Package) {
236	g.pkg = &Package{
237		name:  pkg.Name,
238		defs:  pkg.TypesInfo.Defs,
239		files: make([]*File, len(pkg.Syntax)),
240	}
241
242	for i, file := range pkg.Syntax {
243		g.pkg.files[i] = &File{
244			file:        file,
245			pkg:         g.pkg,
246			trimPrefix:  g.trimPrefix,
247			lineComment: g.lineComment,
248		}
249	}
250}
251
252// generate produces the String method for the named type.
253func (g *Generator) generate(typeName string) {
254	values := make([]Value, 0, 100)
255	for _, file := range g.pkg.files {
256		// Set the state for this run of the walker.
257		file.typeName = typeName
258		file.values = nil
259		if file.file != nil {
260			ast.Inspect(file.file, file.genDecl)
261			values = append(values, file.values...)
262		}
263	}
264
265	if len(values) == 0 {
266		log.Fatalf("no values defined for type %s", typeName)
267	}
268	// Generate code that will fail if the constants change value.
269	g.Printf("func _() {\n")
270	g.Printf("\t// An \"invalid array index\" compiler error signifies that the constant values have changed.\n")
271	g.Printf("\t// Re-run the stringer command to generate them again.\n")
272	g.Printf("\tvar x [1]struct{}\n")
273	for _, v := range values {
274		g.Printf("\t_ = x[%s - %s]\n", v.originalName, v.str)
275	}
276	g.Printf("}\n")
277	runs := splitIntoRuns(values)
278	// The decision of which pattern to use depends on the number of
279	// runs in the numbers. If there's only one, it's easy. For more than
280	// one, there's a tradeoff between complexity and size of the data
281	// and code vs. the simplicity of a map. A map takes more space,
282	// but so does the code. The decision here (crossover at 10) is
283	// arbitrary, but considers that for large numbers of runs the cost
284	// of the linear scan in the switch might become important, and
285	// rather than use yet another algorithm such as binary search,
286	// we punt and use a map. In any case, the likelihood of a map
287	// being necessary for any realistic example other than bitmasks
288	// is very low. And bitmasks probably deserve their own analysis,
289	// to be done some other day.
290	switch {
291	case len(runs) == 1:
292		g.buildOneRun(runs, typeName)
293	case len(runs) <= 10:
294		g.buildMultipleRuns(runs, typeName)
295	default:
296		g.buildMap(runs, typeName)
297	}
298}
299
300// splitIntoRuns breaks the values into runs of contiguous sequences.
301// For example, given 1,2,3,5,6,7 it returns {1,2,3},{5,6,7}.
302// The input slice is known to be non-empty.
303func splitIntoRuns(values []Value) [][]Value {
304	// We use stable sort so the lexically first name is chosen for equal elements.
305	sort.Stable(byValue(values))
306	// Remove duplicates. Stable sort has put the one we want to print first,
307	// so use that one. The String method won't care about which named constant
308	// was the argument, so the first name for the given value is the only one to keep.
309	// We need to do this because identical values would cause the switch or map
310	// to fail to compile.
311	j := 1
312	for i := 1; i < len(values); i++ {
313		if values[i].value != values[i-1].value {
314			values[j] = values[i]
315			j++
316		}
317	}
318	values = values[:j]
319	runs := make([][]Value, 0, 10)
320	for len(values) > 0 {
321		// One contiguous sequence per outer loop.
322		i := 1
323		for i < len(values) && values[i].value == values[i-1].value+1 {
324			i++
325		}
326		runs = append(runs, values[:i])
327		values = values[i:]
328	}
329	return runs
330}
331
332// format returns the gofmt-ed contents of the Generator's buffer.
333func (g *Generator) format() []byte {
334	src, err := format.Source(g.buf.Bytes())
335	if err != nil {
336		// Should never happen, but can arise when developing this code.
337		// The user can compile the output to see the error.
338		log.Printf("warning: internal error: invalid Go generated: %s", err)
339		log.Printf("warning: compile the package to analyze the error")
340		return g.buf.Bytes()
341	}
342	return src
343}
344
345// Value represents a declared constant.
346type Value struct {
347	originalName string // The name of the constant.
348	name         string // The name with trimmed prefix.
349	// The value is stored as a bit pattern alone. The boolean tells us
350	// whether to interpret it as an int64 or a uint64; the only place
351	// this matters is when sorting.
352	// Much of the time the str field is all we need; it is printed
353	// by Value.String.
354	value  uint64 // Will be converted to int64 when needed.
355	signed bool   // Whether the constant is a signed type.
356	str    string // The string representation given by the "go/constant" package.
357}
358
359func (v *Value) String() string {
360	return v.str
361}
362
363// byValue lets us sort the constants into increasing order.
364// We take care in the Less method to sort in signed or unsigned order,
365// as appropriate.
366type byValue []Value
367
368func (b byValue) Len() int      { return len(b) }
369func (b byValue) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
370func (b byValue) Less(i, j int) bool {
371	if b[i].signed {
372		return int64(b[i].value) < int64(b[j].value)
373	}
374	return b[i].value < b[j].value
375}
376
377// genDecl processes one declaration clause.
378func (f *File) genDecl(node ast.Node) bool {
379	decl, ok := node.(*ast.GenDecl)
380	if !ok || decl.Tok != token.CONST {
381		// We only care about const declarations.
382		return true
383	}
384	// The name of the type of the constants we are declaring.
385	// Can change if this is a multi-element declaration.
386	typ := ""
387	// Loop over the elements of the declaration. Each element is a ValueSpec:
388	// a list of names possibly followed by a type, possibly followed by values.
389	// If the type and value are both missing, we carry down the type (and value,
390	// but the "go/types" package takes care of that).
391	for _, spec := range decl.Specs {
392		vspec := spec.(*ast.ValueSpec) // Guaranteed to succeed as this is CONST.
393		if vspec.Type == nil && len(vspec.Values) > 0 {
394			// "X = 1". With no type but a value. If the constant is untyped,
395			// skip this vspec and reset the remembered type.
396			typ = ""
397
398			// If this is a simple type conversion, remember the type.
399			// We don't mind if this is actually a call; a qualified call won't
400			// be matched (that will be SelectorExpr, not Ident), and only unusual
401			// situations will result in a function call that appears to be
402			// a type conversion.
403			ce, ok := vspec.Values[0].(*ast.CallExpr)
404			if !ok {
405				continue
406			}
407			id, ok := ce.Fun.(*ast.Ident)
408			if !ok {
409				continue
410			}
411			typ = id.Name
412		}
413		if vspec.Type != nil {
414			// "X T". We have a type. Remember it.
415			ident, ok := vspec.Type.(*ast.Ident)
416			if !ok {
417				continue
418			}
419			typ = ident.Name
420		}
421		if typ != f.typeName {
422			// This is not the type we're looking for.
423			continue
424		}
425		// We now have a list of names (from one line of source code) all being
426		// declared with the desired type.
427		// Grab their names and actual values and store them in f.values.
428		for _, name := range vspec.Names {
429			if name.Name == "_" {
430				continue
431			}
432			// This dance lets the type checker find the values for us. It's a
433			// bit tricky: look up the object declared by the name, find its
434			// types.Const, and extract its value.
435			obj, ok := f.pkg.defs[name]
436			if !ok {
437				log.Fatalf("no value for constant %s", name)
438			}
439			info := obj.Type().Underlying().(*types.Basic).Info()
440			if info&types.IsInteger == 0 {
441				log.Fatalf("can't handle non-integer constant type %s", typ)
442			}
443			value := obj.(*types.Const).Val() // Guaranteed to succeed as this is CONST.
444			if value.Kind() != constant.Int {
445				log.Fatalf("can't happen: constant is not an integer %s", name)
446			}
447			i64, isInt := constant.Int64Val(value)
448			u64, isUint := constant.Uint64Val(value)
449			if !isInt && !isUint {
450				log.Fatalf("internal error: value of %s is not an integer: %s", name, value.String())
451			}
452			if !isInt {
453				u64 = uint64(i64)
454			}
455			v := Value{
456				originalName: name.Name,
457				value:        u64,
458				signed:       info&types.IsUnsigned == 0,
459				str:          value.String(),
460			}
461			if c := vspec.Comment; f.lineComment && c != nil && len(c.List) == 1 {
462				v.name = strings.TrimSpace(c.Text())
463			} else {
464				v.name = strings.TrimPrefix(v.originalName, f.trimPrefix)
465			}
466			f.values = append(f.values, v)
467		}
468	}
469	return false
470}
471
472// Helpers
473
474// usize returns the number of bits of the smallest unsigned integer
475// type that will hold n. Used to create the smallest possible slice of
476// integers to use as indexes into the concatenated strings.
477func usize(n int) int {
478	switch {
479	case n < 1<<8:
480		return 8
481	case n < 1<<16:
482		return 16
483	default:
484		// 2^32 is enough constants for anyone.
485		return 32
486	}
487}
488
489// declareIndexAndNameVars declares the index slices and concatenated names
490// strings representing the runs of values.
491func (g *Generator) declareIndexAndNameVars(runs [][]Value, typeName string) {
492	var indexes, names []string
493	for i, run := range runs {
494		index, name := g.createIndexAndNameDecl(run, typeName, fmt.Sprintf("_%d", i))
495		if len(run) != 1 {
496			indexes = append(indexes, index)
497		}
498		names = append(names, name)
499	}
500	g.Printf("const (\n")
501	for _, name := range names {
502		g.Printf("\t%s\n", name)
503	}
504	g.Printf(")\n\n")
505
506	if len(indexes) > 0 {
507		g.Printf("var (")
508		for _, index := range indexes {
509			g.Printf("\t%s\n", index)
510		}
511		g.Printf(")\n\n")
512	}
513}
514
515// declareIndexAndNameVar is the single-run version of declareIndexAndNameVars
516func (g *Generator) declareIndexAndNameVar(run []Value, typeName string) {
517	index, name := g.createIndexAndNameDecl(run, typeName, "")
518	g.Printf("const %s\n", name)
519	g.Printf("var %s\n", index)
520}
521
522// createIndexAndNameDecl returns the pair of declarations for the run. The caller will add "const" and "var".
523func (g *Generator) createIndexAndNameDecl(run []Value, typeName string, suffix string) (string, string) {
524	b := new(bytes.Buffer)
525	indexes := make([]int, len(run))
526	for i := range run {
527		b.WriteString(run[i].name)
528		indexes[i] = b.Len()
529	}
530	nameConst := fmt.Sprintf("_%s_name%s = %q", typeName, suffix, b.String())
531	nameLen := b.Len()
532	b.Reset()
533	fmt.Fprintf(b, "_%s_index%s = [...]uint%d{0, ", typeName, suffix, usize(nameLen))
534	for i, v := range indexes {
535		if i > 0 {
536			fmt.Fprintf(b, ", ")
537		}
538		fmt.Fprintf(b, "%d", v)
539	}
540	fmt.Fprintf(b, "}")
541	return b.String(), nameConst
542}
543
544// declareNameVars declares the concatenated names string representing all the values in the runs.
545func (g *Generator) declareNameVars(runs [][]Value, typeName string, suffix string) {
546	g.Printf("const _%s_name%s = \"", typeName, suffix)
547	for _, run := range runs {
548		for i := range run {
549			g.Printf("%s", run[i].name)
550		}
551	}
552	g.Printf("\"\n")
553}
554
555// buildOneRun generates the variables and String method for a single run of contiguous values.
556func (g *Generator) buildOneRun(runs [][]Value, typeName string) {
557	values := runs[0]
558	g.Printf("\n")
559	g.declareIndexAndNameVar(values, typeName)
560	// The generated code is simple enough to write as a Printf format.
561	lessThanZero := ""
562	if values[0].signed {
563		lessThanZero = "i < 0 || "
564	}
565	if values[0].value == 0 { // Signed or unsigned, 0 is still 0.
566		g.Printf(stringOneRun, typeName, usize(len(values)), lessThanZero)
567	} else {
568		g.Printf(stringOneRunWithOffset, typeName, values[0].String(), usize(len(values)), lessThanZero)
569	}
570}
571
572// Arguments to format are:
573//	[1]: type name
574//	[2]: size of index element (8 for uint8 etc.)
575//	[3]: less than zero check (for signed types)
576const stringOneRun = `func (i %[1]s) String() string {
577	if %[3]si >= %[1]s(len(_%[1]s_index)-1) {
578		return "%[1]s(" + strconv.FormatInt(int64(i), 10) + ")"
579	}
580	return _%[1]s_name[_%[1]s_index[i]:_%[1]s_index[i+1]]
581}
582`
583
584// Arguments to format are:
585//	[1]: type name
586//	[2]: lowest defined value for type, as a string
587//	[3]: size of index element (8 for uint8 etc.)
588//	[4]: less than zero check (for signed types)
589/*
590 */
591const stringOneRunWithOffset = `func (i %[1]s) String() string {
592	i -= %[2]s
593	if %[4]si >= %[1]s(len(_%[1]s_index)-1) {
594		return "%[1]s(" + strconv.FormatInt(int64(i + %[2]s), 10) + ")"
595	}
596	return _%[1]s_name[_%[1]s_index[i] : _%[1]s_index[i+1]]
597}
598`
599
600// buildMultipleRuns generates the variables and String method for multiple runs of contiguous values.
601// For this pattern, a single Printf format won't do.
602func (g *Generator) buildMultipleRuns(runs [][]Value, typeName string) {
603	g.Printf("\n")
604	g.declareIndexAndNameVars(runs, typeName)
605	g.Printf("func (i %s) String() string {\n", typeName)
606	g.Printf("\tswitch {\n")
607	for i, values := range runs {
608		if len(values) == 1 {
609			g.Printf("\tcase i == %s:\n", &values[0])
610			g.Printf("\t\treturn _%s_name_%d\n", typeName, i)
611			continue
612		}
613		if values[0].value == 0 && !values[0].signed {
614			// For an unsigned lower bound of 0, "0 <= i" would be redundant.
615			g.Printf("\tcase i <= %s:\n", &values[len(values)-1])
616		} else {
617			g.Printf("\tcase %s <= i && i <= %s:\n", &values[0], &values[len(values)-1])
618		}
619		if values[0].value != 0 {
620			g.Printf("\t\ti -= %s\n", &values[0])
621		}
622		g.Printf("\t\treturn _%s_name_%d[_%s_index_%d[i]:_%s_index_%d[i+1]]\n",
623			typeName, i, typeName, i, typeName, i)
624	}
625	g.Printf("\tdefault:\n")
626	g.Printf("\t\treturn \"%s(\" + strconv.FormatInt(int64(i), 10) + \")\"\n", typeName)
627	g.Printf("\t}\n")
628	g.Printf("}\n")
629}
630
631// buildMap handles the case where the space is so sparse a map is a reasonable fallback.
632// It's a rare situation but has simple code.
633func (g *Generator) buildMap(runs [][]Value, typeName string) {
634	g.Printf("\n")
635	g.declareNameVars(runs, typeName, "")
636	g.Printf("\nvar _%s_map = map[%s]string{\n", typeName, typeName)
637	n := 0
638	for _, values := range runs {
639		for _, value := range values {
640			g.Printf("\t%s: _%s_name[%d:%d],\n", &value, typeName, n, n+len(value.name))
641			n += len(value.name)
642		}
643	}
644	g.Printf("}\n\n")
645	g.Printf(stringMap, typeName)
646}
647
648// Argument to format is the type name.
649const stringMap = `func (i %[1]s) String() string {
650	if str, ok := _%[1]s_map[i]; ok {
651		return str
652	}
653	return "%[1]s(" + strconv.FormatInt(int64(i), 10) + ")"
654}
655`
656