1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build ignore
6// +build ignore
7
8// mkmerge.go parses generated source files and merges common
9// consts, funcs, and types into a common source file, per GOOS.
10//
11// Usage:
12//     $ go run mkmerge.go -out MERGED FILE [FILE ...]
13//
14// Example:
15//     # Remove all common consts, funcs, and types from zerrors_linux_*.go
16//     # and write the common code into zerrors_linux.go
17//     $ go run mkmerge.go -out zerrors_linux.go zerrors_linux_*.go
18//
19// mkmerge.go performs the merge in the following steps:
20// 1. Construct the set of common code that is idential in all
21//    architecture-specific files.
22// 2. Write this common code to the merged file.
23// 3. Remove the common code from all architecture-specific files.
24package main
25
26import (
27	"bufio"
28	"bytes"
29	"flag"
30	"fmt"
31	"go/ast"
32	"go/format"
33	"go/parser"
34	"go/token"
35	"io"
36	"io/ioutil"
37	"log"
38	"os"
39	"path"
40	"path/filepath"
41	"regexp"
42	"strconv"
43	"strings"
44)
45
46const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
47
48// getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
49func getValidGOOS(filename string) (string, bool) {
50	matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
51	if len(matches) != 2 {
52		return "", false
53	}
54	return matches[1], true
55}
56
57// codeElem represents an ast.Decl in a comparable way.
58type codeElem struct {
59	tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
60	src string      // the declaration formatted as source code
61}
62
63// newCodeElem returns a codeElem based on tok and node, or an error is returned.
64func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
65	var b strings.Builder
66	err := format.Node(&b, token.NewFileSet(), node)
67	if err != nil {
68		return codeElem{}, err
69	}
70	return codeElem{tok, b.String()}, nil
71}
72
73// codeSet is a set of codeElems
74type codeSet struct {
75	set map[codeElem]bool // true for all codeElems in the set
76}
77
78// newCodeSet returns a new codeSet
79func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
80
81// add adds elem to c
82func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
83
84// has returns true if elem is in c
85func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
86
87// isEmpty returns true if the set is empty
88func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
89
90// intersection returns a new set which is the intersection of c and a
91func (c *codeSet) intersection(a *codeSet) *codeSet {
92	res := newCodeSet()
93
94	for elem := range c.set {
95		if a.has(elem) {
96			res.add(elem)
97		}
98	}
99	return res
100}
101
102// keepCommon is a filterFn for filtering the merged file with common declarations.
103func (c *codeSet) keepCommon(elem codeElem) bool {
104	switch elem.tok {
105	case token.VAR:
106		// Remove all vars from the merged file
107		return false
108	case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
109		// Remove arch-specific consts, types, functions, and file-level comments from the merged file
110		return c.has(elem)
111	case token.IMPORT:
112		// Keep imports, they are handled by filterImports
113		return true
114	}
115
116	log.Fatalf("keepCommon: invalid elem %v", elem)
117	return true
118}
119
120// keepArchSpecific is a filterFn for filtering the GOARC-specific files.
121func (c *codeSet) keepArchSpecific(elem codeElem) bool {
122	switch elem.tok {
123	case token.CONST, token.TYPE, token.FUNC:
124		// Remove common consts, types, or functions from the arch-specific file
125		return !c.has(elem)
126	}
127	return true
128}
129
130// srcFile represents a source file
131type srcFile struct {
132	name string
133	src  []byte
134}
135
136// filterFn is a helper for filter
137type filterFn func(codeElem) bool
138
139// filter parses and filters Go source code from src, removing top
140// level declarations using keep as predicate.
141// For src parameter, please see docs for parser.ParseFile.
142func filter(src interface{}, keep filterFn) ([]byte, error) {
143	// Parse the src into an ast
144	fset := token.NewFileSet()
145	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
146	if err != nil {
147		return nil, err
148	}
149	cmap := ast.NewCommentMap(fset, f, f.Comments)
150
151	// Group const/type specs on adjacent lines
152	var groups specGroups = make(map[string]int)
153	var groupID int
154
155	decls := f.Decls
156	f.Decls = f.Decls[:0]
157	for _, decl := range decls {
158		switch decl := decl.(type) {
159		case *ast.GenDecl:
160			// Filter imports, consts, types, vars
161			specs := decl.Specs
162			decl.Specs = decl.Specs[:0]
163			for i, spec := range specs {
164				elem, err := newCodeElem(decl.Tok, spec)
165				if err != nil {
166					return nil, err
167				}
168
169				// Create new group if there are empty lines between this and the previous spec
170				if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
171					groupID++
172				}
173
174				// Check if we should keep this spec
175				if keep(elem) {
176					decl.Specs = append(decl.Specs, spec)
177					groups.add(elem.src, groupID)
178				}
179			}
180			// Check if we should keep this decl
181			if len(decl.Specs) > 0 {
182				f.Decls = append(f.Decls, decl)
183			}
184		case *ast.FuncDecl:
185			// Filter funcs
186			elem, err := newCodeElem(token.FUNC, decl)
187			if err != nil {
188				return nil, err
189			}
190			if keep(elem) {
191				f.Decls = append(f.Decls, decl)
192			}
193		}
194	}
195
196	// Filter file level comments
197	if cmap[f] != nil {
198		commentGroups := cmap[f]
199		cmap[f] = cmap[f][:0]
200		for _, cGrp := range commentGroups {
201			if keep(codeElem{token.COMMENT, cGrp.Text()}) {
202				cmap[f] = append(cmap[f], cGrp)
203			}
204		}
205	}
206	f.Comments = cmap.Filter(f).Comments()
207
208	// Generate code for the filtered ast
209	var buf bytes.Buffer
210	if err = format.Node(&buf, fset, f); err != nil {
211		return nil, err
212	}
213
214	groupedSrc, err := groups.filterEmptyLines(&buf)
215	if err != nil {
216		return nil, err
217	}
218
219	return filterImports(groupedSrc)
220}
221
222// getCommonSet returns the set of consts, types, and funcs that are present in every file.
223func getCommonSet(files []srcFile) (*codeSet, error) {
224	if len(files) == 0 {
225		return nil, fmt.Errorf("no files provided")
226	}
227	// Use the first architecture file as the baseline
228	baseSet, err := getCodeSet(files[0].src)
229	if err != nil {
230		return nil, err
231	}
232
233	// Compare baseline set with other architecture files: discard any element,
234	// that doesn't exist in other architecture files.
235	for _, f := range files[1:] {
236		set, err := getCodeSet(f.src)
237		if err != nil {
238			return nil, err
239		}
240
241		baseSet = baseSet.intersection(set)
242	}
243	return baseSet, nil
244}
245
246// getCodeSet returns the set of all top-level consts, types, and funcs from src.
247// src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
248func getCodeSet(src interface{}) (*codeSet, error) {
249	set := newCodeSet()
250
251	fset := token.NewFileSet()
252	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
253	if err != nil {
254		return nil, err
255	}
256
257	for _, decl := range f.Decls {
258		switch decl := decl.(type) {
259		case *ast.GenDecl:
260			// Add const, and type declarations
261			if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
262				break
263			}
264
265			for _, spec := range decl.Specs {
266				elem, err := newCodeElem(decl.Tok, spec)
267				if err != nil {
268					return nil, err
269				}
270
271				set.add(elem)
272			}
273		case *ast.FuncDecl:
274			// Add func declarations
275			elem, err := newCodeElem(token.FUNC, decl)
276			if err != nil {
277				return nil, err
278			}
279
280			set.add(elem)
281		}
282	}
283
284	// Add file level comments
285	cmap := ast.NewCommentMap(fset, f, f.Comments)
286	for _, cGrp := range cmap[f] {
287		set.add(codeElem{token.COMMENT, cGrp.Text()})
288	}
289
290	return set, nil
291}
292
293// importName returns the identifier (PackageName) for an imported package
294func importName(iSpec *ast.ImportSpec) (string, error) {
295	if iSpec.Name == nil {
296		name, err := strconv.Unquote(iSpec.Path.Value)
297		if err != nil {
298			return "", err
299		}
300		return path.Base(name), nil
301	}
302	return iSpec.Name.Name, nil
303}
304
305// specGroups tracks grouped const/type specs with a map of line: groupID pairs
306type specGroups map[string]int
307
308// add spec source to group
309func (s specGroups) add(src string, groupID int) error {
310	srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
311	if err != nil {
312		return err
313	}
314	s[string(srcBytes)] = groupID
315	return nil
316}
317
318// filterEmptyLines removes empty lines within groups of const/type specs.
319// Returns the filtered source.
320func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
321	scanner := bufio.NewScanner(src)
322	var out bytes.Buffer
323
324	var emptyLines bytes.Buffer
325	prevGroupID := -1 // Initialize to invalid group
326	for scanner.Scan() {
327		line := bytes.TrimSpace(scanner.Bytes())
328
329		if len(line) == 0 {
330			fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
331			continue
332		}
333
334		// Discard emptyLines if previous non-empty line belonged to the same
335		// group as this line
336		if src, err := format.Source(line); err == nil {
337			groupID, ok := s[string(src)]
338			if ok && groupID == prevGroupID {
339				emptyLines.Reset()
340			}
341			prevGroupID = groupID
342		}
343
344		emptyLines.WriteTo(&out)
345		fmt.Fprintf(&out, "%s\n", scanner.Bytes())
346	}
347	if err := scanner.Err(); err != nil {
348		return nil, err
349	}
350	return out.Bytes(), nil
351}
352
353// filterImports removes unused imports from fileSrc, and returns a formatted src.
354func filterImports(fileSrc []byte) ([]byte, error) {
355	fset := token.NewFileSet()
356	file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
357	if err != nil {
358		return nil, err
359	}
360	cmap := ast.NewCommentMap(fset, file, file.Comments)
361
362	// create set of references to imported identifiers
363	keepImport := make(map[string]bool)
364	for _, u := range file.Unresolved {
365		keepImport[u.Name] = true
366	}
367
368	// filter import declarations
369	decls := file.Decls
370	file.Decls = file.Decls[:0]
371	for _, decl := range decls {
372		importDecl, ok := decl.(*ast.GenDecl)
373
374		// Keep non-import declarations
375		if !ok || importDecl.Tok != token.IMPORT {
376			file.Decls = append(file.Decls, decl)
377			continue
378		}
379
380		// Filter the import specs
381		specs := importDecl.Specs
382		importDecl.Specs = importDecl.Specs[:0]
383		for _, spec := range specs {
384			iSpec := spec.(*ast.ImportSpec)
385			name, err := importName(iSpec)
386			if err != nil {
387				return nil, err
388			}
389
390			if keepImport[name] {
391				importDecl.Specs = append(importDecl.Specs, iSpec)
392			}
393		}
394		if len(importDecl.Specs) > 0 {
395			file.Decls = append(file.Decls, importDecl)
396		}
397	}
398
399	// filter file.Imports
400	imports := file.Imports
401	file.Imports = file.Imports[:0]
402	for _, spec := range imports {
403		name, err := importName(spec)
404		if err != nil {
405			return nil, err
406		}
407
408		if keepImport[name] {
409			file.Imports = append(file.Imports, spec)
410		}
411	}
412	file.Comments = cmap.Filter(file).Comments()
413
414	var buf bytes.Buffer
415	err = format.Node(&buf, fset, file)
416	if err != nil {
417		return nil, err
418	}
419
420	return buf.Bytes(), nil
421}
422
423// merge extracts duplicate code from archFiles and merges it to mergeFile.
424// 1. Construct commonSet: the set of code that is idential in all archFiles.
425// 2. Write the code in commonSet to mergedFile.
426// 3. Remove the commonSet code from all archFiles.
427func merge(mergedFile string, archFiles ...string) error {
428	// extract and validate the GOOS part of the merged filename
429	goos, ok := getValidGOOS(mergedFile)
430	if !ok {
431		return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
432	}
433
434	// Read architecture files
435	var inSrc []srcFile
436	for _, file := range archFiles {
437		src, err := ioutil.ReadFile(file)
438		if err != nil {
439			return fmt.Errorf("cannot read archfile %s: %w", file, err)
440		}
441
442		inSrc = append(inSrc, srcFile{file, src})
443	}
444
445	// 1. Construct the set of top-level declarations common for all files
446	commonSet, err := getCommonSet(inSrc)
447	if err != nil {
448		return err
449	}
450	if commonSet.isEmpty() {
451		// No common code => do not modify any files
452		return nil
453	}
454
455	// 2. Write the merged file
456	mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
457	if err != nil {
458		return err
459	}
460
461	f, err := os.Create(mergedFile)
462	if err != nil {
463		return err
464	}
465
466	buf := bufio.NewWriter(f)
467	fmt.Fprintln(buf, "// Code generated by mkmerge.go; DO NOT EDIT.")
468	fmt.Fprintln(buf)
469	fmt.Fprintf(buf, "//go:build %s\n", goos)
470	fmt.Fprintf(buf, "// +build %s\n", goos)
471	fmt.Fprintln(buf)
472	buf.Write(mergedSrc)
473
474	err = buf.Flush()
475	if err != nil {
476		return err
477	}
478	err = f.Close()
479	if err != nil {
480		return err
481	}
482
483	// 3. Remove duplicate declarations from the architecture files
484	for _, inFile := range inSrc {
485		src, err := filter(inFile.src, commonSet.keepArchSpecific)
486		if err != nil {
487			return err
488		}
489		err = ioutil.WriteFile(inFile.name, src, 0644)
490		if err != nil {
491			return err
492		}
493	}
494	return nil
495}
496
497func main() {
498	var mergedFile string
499	flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
500	flag.Parse()
501
502	// Expand wildcards
503	var filenames []string
504	for _, arg := range flag.Args() {
505		matches, err := filepath.Glob(arg)
506		if err != nil {
507			fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
508			os.Exit(1)
509		}
510		filenames = append(filenames, matches...)
511	}
512
513	if len(filenames) < 2 {
514		// No need to merge
515		return
516	}
517
518	err := merge(mergedFile, filenames...)
519	if err != nil {
520		fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
521		os.Exit(1)
522	}
523}
524