1// Copyright 2019 The Go Cloud Development Kit Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Command gatherexamples extracts examples in a Go module into a JSON-formatted
16// object. This is used as input for building the Go CDK Hugo website.
17//
18// Examples must include a comment
19// "// PRAGMA: This example is used on gocloud.dev; PRAGMA comments adjust how it is shown and can be ignored."
20// somewhere in the function body in order to be included in this tool's output.
21//
22// gatherexamples does some minimal rewriting of the example source code for
23// presentation:
24//
25//   - Any imports the example uses will be prepended to the code.
26//   - log.Fatal(err) -> return err
27//   - A comment line "// PRAGMA: On gocloud.dev, hide lines until the next blank line." will
28//     remove any code up to the next blank line. This is intended for
29//     compiler-mandated setup like `ctx := context.Background()`.
30//   - A comment line "// PRAGMA: On gocloud.dev, hide the rest of the function." will
31//     remove any code until the end of the function. This is intended for
32//     compiler-mandated assignments like `_ = bucket`.
33//   - A comment line "// PRAGMA: On gocloud.dev, add a blank import: _ "example.com/foo""
34//     will add the blank import to the example's imports.
35//
36// The key of each JSON object entry will be the import path of the package,
37// followed by a dot ("."), followed by the name of the example function. The
38// value of each JSON object entry is an object like
39// {"imports": "import (\n\t\"fmt\"\n)", "code": "/* ... */"}. These are
40// separated so that templating can format or show them separately.
41package main
42
43import (
44	"encoding/json"
45	"flag"
46	"fmt"
47	"go/ast"
48	"go/format"
49	"go/printer"
50	"go/types"
51	"os"
52	"sort"
53	"strconv"
54	"strings"
55
56	"golang.org/x/tools/go/packages"
57)
58
59func main() {
60	flag.Usage = func() {
61		out := flag.CommandLine.Output()
62		fmt.Fprintln(out, "usage: gatherexamples [options] DIR [...]")
63		fmt.Fprintln(out)
64		fmt.Fprintln(out, "Options:")
65		flag.PrintDefaults()
66	}
67	pattern := flag.String("pattern", "./...", "Go package pattern to use at each directory argument")
68	flag.Parse()
69	if flag.NArg() == 0 {
70		flag.Usage()
71		os.Exit(2) // matches with flag package
72	}
73
74	// Load packages in each module named on the command line and find
75	// all examples.
76	allExamples := make(map[string]example)
77	for _, dir := range flag.Args() {
78		cfg := &packages.Config{
79			Mode:  gatherLoadMode,
80			Dir:   dir,
81			Tests: true,
82		}
83		pkgs, err := packages.Load(cfg, *pattern)
84		if err != nil {
85			fmt.Fprintf(os.Stderr, "gatherexamples: load %s: %v\n", dir, err)
86			os.Exit(1)
87		}
88		examples, err := gather(pkgs)
89		if err != nil {
90			fmt.Fprintf(os.Stderr, "gatherexamples: gather: %v", err)
91			os.Exit(1)
92		}
93		for exampleName, ex := range examples {
94			allExamples[exampleName] = ex
95		}
96	}
97
98	// Write all examples as a JSON object.
99	data, err := json.MarshalIndent(allExamples, "", "\t")
100	if err != nil {
101		fmt.Fprintf(os.Stderr, "gatherexamples: generate JSON: %v\n", err)
102		os.Exit(1)
103	}
104	data = append(data, '\n')
105	if _, err := os.Stdout.Write(data); err != nil {
106		fmt.Fprintf(os.Stderr, "gatherexamples: write output: %v\n", err)
107		os.Exit(1)
108	}
109}
110
111const gatherLoadMode packages.LoadMode = packages.NeedName |
112	packages.NeedFiles |
113	packages.NeedTypes |
114	packages.NeedSyntax |
115	packages.NeedTypesInfo |
116	packages.NeedImports |
117	// TODO(light): We really only need name from deps, but there's no way to
118	// specify that in the current go/packages API. This sadly makes this program
119	// 10x slower. Reported as https://github.com/golang/go/issues/31699.
120	packages.NeedDeps
121
122// pragmaPrefix is the prefix for all comments in examples that are used as
123// directives for formatting.
124const pragmaPrefix = "// PRAGMA: "
125
126// inclusionComment is the comment used to signify whether the example should be
127// included in the output.
128const inclusionComment = pragmaPrefix + "This example is used on gocloud.dev; PRAGMA comments adjust how it is shown and can be ignored."
129
130type example struct {
131	Imports string `json:"imports"`
132	Code    string `json:"code"`
133}
134
135// gather extracts the code from the example functions in the given packages
136// and returns a map like the one described in the package documentation.
137func gather(pkgs []*packages.Package) (map[string]example, error) {
138	examples := make(map[string]example)
139	for _, pkg := range pkgs {
140		for _, file := range pkg.Syntax {
141			for _, decl := range file.Decls {
142				// Determine whether this declaration is an example function.
143				fn, ok := decl.(*ast.FuncDecl)
144				if !ok || !strings.HasPrefix(fn.Name.Name, "Example") || len(fn.Type.Params.List) > 0 || len(fn.Type.Params.List) > 0 {
145					continue
146				}
147
148				// Format example into string.
149				sb := new(strings.Builder)
150				err := format.Node(sb, pkg.Fset, &printer.CommentedNode{
151					Node:     fn.Body,
152					Comments: file.Comments,
153				})
154				if err != nil {
155					return nil, err // will only occur for bad invocations of Fprint
156				}
157				original := sb.String()
158				if !strings.Contains(original, inclusionComment) {
159					// Does not contain the inclusion comment. Skip it, but first verify
160					// that it doesn't contain any PRAGMA comments; only examples with
161					// the inclusion comment should include pragmas.
162					if strings.Contains(original, pragmaPrefix) {
163						return nil, fmt.Errorf("%s in package %s has PRAGMA(s) for gatherexamples, but is not marked for inclusion with %q", fn.Name.Name, pkg.PkgPath, inclusionComment)
164					}
165					continue
166				}
167				exampleCode, blankImports := rewriteBlock(original)
168
169				// Gather map of imported packages to overridden identifier.
170				usedPackages := make(map[string]string)
171				for _, path := range blankImports {
172					usedPackages[path] = "_"
173				}
174				ast.Inspect(fn.Body, func(node ast.Node) bool {
175					id, ok := node.(*ast.Ident)
176					if !ok {
177						return true
178					}
179					refPkg, ok := pkg.TypesInfo.ObjectOf(id).(*types.PkgName)
180					if !ok {
181						return true
182					}
183					overrideName := ""
184					if id.Name != refPkg.Imported().Name() {
185						overrideName = id.Name
186					}
187					usedPackages[refPkg.Imported().Path()] = overrideName
188					return true
189				})
190				// Remove "log" import since it's almost always used for log.Fatal(err).
191				delete(usedPackages, "log")
192
193				pkgPath := strings.TrimSuffix(pkg.PkgPath, "_test")
194				exampleName := pkgPath + "." + fn.Name.Name
195				examples[exampleName] = example{
196					Imports: formatImports(usedPackages),
197					Code:    exampleCode,
198				}
199			}
200		}
201	}
202	return examples, nil
203}
204
205// rewriteBlock reformats a Go block statement for display as an example.
206// It also extracts any blank imports found
207func rewriteBlock(block string) (_ string, blankImports []string) {
208	// Trim block.
209	block = strings.TrimPrefix(block, "{")
210	block = strings.TrimSuffix(block, "}")
211
212	// Rewrite line-by-line.
213	sb := new(strings.Builder)
214rewrite:
215	for len(block) > 0 {
216		var line string
217		line, block = nextLine(block)
218
219		// Dedent line.
220		// TODO(light): In the case of a multi-line raw string literal,
221		// this can produce incorrect rewrites.
222		line = strings.TrimPrefix(line, "\t")
223
224		// Write the line to sb, performing textual substitutions as needed.
225		start := strings.IndexFunc(line, func(r rune) bool { return r != ' ' && r != '\t' })
226		if start == -1 {
227			// Blank.
228			sb.WriteString(line)
229			sb.WriteByte('\n')
230			continue
231		}
232		const importBlankPrefix = pragmaPrefix + "On gocloud.dev, add a blank import: _ "
233		indent, lineContent := line[:start], line[start:]
234		switch {
235		case lineContent == pragmaPrefix+"On gocloud.dev, hide lines until the next blank line.":
236			// Skip lines until we hit a blank line.
237			for len(block) > 0 {
238				var next string
239				next, block = nextLine(block)
240				if strings.TrimSpace(next) == "" {
241					break
242				}
243			}
244		case lineContent == pragmaPrefix+"On gocloud.dev, hide the rest of the function.":
245			// Ignore remaining lines.
246			break rewrite
247		case lineContent == "log.Fatal(err)":
248			sb.WriteString(indent)
249			sb.WriteString("return err")
250			sb.WriteByte('\n')
251		case strings.HasPrefix(lineContent, importBlankPrefix):
252			// Blank import.
253			path, err := strconv.Unquote(lineContent[len(importBlankPrefix):])
254			if err == nil {
255				blankImports = append(blankImports, path)
256			}
257		case strings.Contains(lineContent, inclusionComment):
258			// inclusion comment. Skip it.
259		default:
260			// Ordinary line, write as-is.
261			sb.WriteString(line)
262			sb.WriteByte('\n')
263		}
264	}
265	return strings.TrimSpace(sb.String()), blankImports
266}
267
268// nextLine splits the string at the next linefeed.
269func nextLine(s string) (line, tail string) {
270	i := strings.IndexByte(s, '\n')
271	if i == -1 {
272		return s, ""
273	}
274	return s[:i], s[i+1:]
275}
276
277// formatImports formats a map of imports to their package identifiers into a
278// Go import declaration.
279func formatImports(usedPackages map[string]string) string {
280	if len(usedPackages) == 0 {
281		return ""
282	}
283	if len(usedPackages) == 1 {
284		// Special case: one-line import.
285		for path, id := range usedPackages {
286			if id != "" {
287				return fmt.Sprintf("import %s %q", id, path)
288			}
289			return fmt.Sprintf("import %q", path)
290		}
291	}
292	// Typical case: multiple imports in factored declaration form.
293	// Group into standard library imports then third-party imports.
294	sortedStdlib := make([]string, 0, len(usedPackages))
295	sortedThirdParty := make([]string, 0, len(usedPackages))
296	for path := range usedPackages {
297		if strings.ContainsRune(path, '.') {
298			// Third-party imports almost always contain a dot for a domain name,
299			// especially in GOPATH/Go modules workspaces.
300			sortedThirdParty = append(sortedThirdParty, path)
301		} else {
302			sortedStdlib = append(sortedStdlib, path)
303		}
304	}
305	sort.Strings(sortedStdlib)
306	sort.Strings(sortedThirdParty)
307	sb := new(strings.Builder)
308	sb.WriteString("import (\n")
309	printImports := func(paths []string) {
310		for _, path := range paths {
311			id := usedPackages[path]
312			if id == "" {
313				fmt.Fprintf(sb, "\t%q\n", path)
314			} else {
315				fmt.Fprintf(sb, "\t%s %q\n", id, path)
316			}
317		}
318	}
319	printImports(sortedStdlib)
320	if len(sortedStdlib) > 0 && len(sortedThirdParty) > 0 {
321		// Insert blank line to separate.
322		sb.WriteByte('\n')
323	}
324	printImports(sortedThirdParty)
325	sb.WriteString(")")
326	return sb.String()
327}
328