1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build ignore
6
7// mkmerge.go parses generated source files and merges common
8// consts, funcs, and types into a common source file, per GOOS.
9//
10// Usage:
11//     $ go run mkmerge.go -out MERGED FILE [FILE ...]
12//
13// Example:
14//     # Remove all common consts, funcs, and types from zerrors_linux_*.go
15//     # and write the common code into zerrors_linux.go
16//     $ go run mkmerge.go -out zerrors_linux.go zerrors_linux_*.go
17//
18// mkmerge.go performs the merge in the following steps:
19// 1. Construct the set of common code that is idential in all
20//    architecture-specific files.
21// 2. Write this common code to the merged file.
22// 3. Remove the common code from all architecture-specific files.
23package main
24
25import (
26	"bufio"
27	"bytes"
28	"flag"
29	"fmt"
30	"go/ast"
31	"go/format"
32	"go/parser"
33	"go/token"
34	"io"
35	"io/ioutil"
36	"log"
37	"os"
38	"path"
39	"path/filepath"
40	"regexp"
41	"strconv"
42	"strings"
43)
44
45const validGOOS = "aix|darwin|dragonfly|freebsd|linux|netbsd|openbsd|solaris"
46
47// getValidGOOS returns GOOS, true if filename ends with a valid "_GOOS.go"
48func getValidGOOS(filename string) (string, bool) {
49	matches := regexp.MustCompile(`_(` + validGOOS + `)\.go$`).FindStringSubmatch(filename)
50	if len(matches) != 2 {
51		return "", false
52	}
53	return matches[1], true
54}
55
56// codeElem represents an ast.Decl in a comparable way.
57type codeElem struct {
58	tok token.Token // e.g. token.CONST, token.TYPE, or token.FUNC
59	src string      // the declaration formatted as source code
60}
61
62// newCodeElem returns a codeElem based on tok and node, or an error is returned.
63func newCodeElem(tok token.Token, node ast.Node) (codeElem, error) {
64	var b strings.Builder
65	err := format.Node(&b, token.NewFileSet(), node)
66	if err != nil {
67		return codeElem{}, err
68	}
69	return codeElem{tok, b.String()}, nil
70}
71
72// codeSet is a set of codeElems
73type codeSet struct {
74	set map[codeElem]bool // true for all codeElems in the set
75}
76
77// newCodeSet returns a new codeSet
78func newCodeSet() *codeSet { return &codeSet{make(map[codeElem]bool)} }
79
80// add adds elem to c
81func (c *codeSet) add(elem codeElem) { c.set[elem] = true }
82
83// has returns true if elem is in c
84func (c *codeSet) has(elem codeElem) bool { return c.set[elem] }
85
86// isEmpty returns true if the set is empty
87func (c *codeSet) isEmpty() bool { return len(c.set) == 0 }
88
89// intersection returns a new set which is the intersection of c and a
90func (c *codeSet) intersection(a *codeSet) *codeSet {
91	res := newCodeSet()
92
93	for elem := range c.set {
94		if a.has(elem) {
95			res.add(elem)
96		}
97	}
98	return res
99}
100
101// keepCommon is a filterFn for filtering the merged file with common declarations.
102func (c *codeSet) keepCommon(elem codeElem) bool {
103	switch elem.tok {
104	case token.VAR:
105		// Remove all vars from the merged file
106		return false
107	case token.CONST, token.TYPE, token.FUNC, token.COMMENT:
108		// Remove arch-specific consts, types, functions, and file-level comments from the merged file
109		return c.has(elem)
110	case token.IMPORT:
111		// Keep imports, they are handled by filterImports
112		return true
113	}
114
115	log.Fatalf("keepCommon: invalid elem %v", elem)
116	return true
117}
118
119// keepArchSpecific is a filterFn for filtering the GOARC-specific files.
120func (c *codeSet) keepArchSpecific(elem codeElem) bool {
121	switch elem.tok {
122	case token.CONST, token.TYPE, token.FUNC:
123		// Remove common consts, types, or functions from the arch-specific file
124		return !c.has(elem)
125	}
126	return true
127}
128
129// srcFile represents a source file
130type srcFile struct {
131	name string
132	src  []byte
133}
134
135// filterFn is a helper for filter
136type filterFn func(codeElem) bool
137
138// filter parses and filters Go source code from src, removing top
139// level declarations using keep as predicate.
140// For src parameter, please see docs for parser.ParseFile.
141func filter(src interface{}, keep filterFn) ([]byte, error) {
142	// Parse the src into an ast
143	fset := token.NewFileSet()
144	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
145	if err != nil {
146		return nil, err
147	}
148	cmap := ast.NewCommentMap(fset, f, f.Comments)
149
150	// Group const/type specs on adjacent lines
151	var groups specGroups = make(map[string]int)
152	var groupID int
153
154	decls := f.Decls
155	f.Decls = f.Decls[:0]
156	for _, decl := range decls {
157		switch decl := decl.(type) {
158		case *ast.GenDecl:
159			// Filter imports, consts, types, vars
160			specs := decl.Specs
161			decl.Specs = decl.Specs[:0]
162			for i, spec := range specs {
163				elem, err := newCodeElem(decl.Tok, spec)
164				if err != nil {
165					return nil, err
166				}
167
168				// Create new group if there are empty lines between this and the previous spec
169				if i > 0 && fset.Position(specs[i-1].End()).Line < fset.Position(spec.Pos()).Line-1 {
170					groupID++
171				}
172
173				// Check if we should keep this spec
174				if keep(elem) {
175					decl.Specs = append(decl.Specs, spec)
176					groups.add(elem.src, groupID)
177				}
178			}
179			// Check if we should keep this decl
180			if len(decl.Specs) > 0 {
181				f.Decls = append(f.Decls, decl)
182			}
183		case *ast.FuncDecl:
184			// Filter funcs
185			elem, err := newCodeElem(token.FUNC, decl)
186			if err != nil {
187				return nil, err
188			}
189			if keep(elem) {
190				f.Decls = append(f.Decls, decl)
191			}
192		}
193	}
194
195	// Filter file level comments
196	if cmap[f] != nil {
197		commentGroups := cmap[f]
198		cmap[f] = cmap[f][:0]
199		for _, cGrp := range commentGroups {
200			if keep(codeElem{token.COMMENT, cGrp.Text()}) {
201				cmap[f] = append(cmap[f], cGrp)
202			}
203		}
204	}
205	f.Comments = cmap.Filter(f).Comments()
206
207	// Generate code for the filtered ast
208	var buf bytes.Buffer
209	if err = format.Node(&buf, fset, f); err != nil {
210		return nil, err
211	}
212
213	groupedSrc, err := groups.filterEmptyLines(&buf)
214	if err != nil {
215		return nil, err
216	}
217
218	return filterImports(groupedSrc)
219}
220
221// getCommonSet returns the set of consts, types, and funcs that are present in every file.
222func getCommonSet(files []srcFile) (*codeSet, error) {
223	if len(files) == 0 {
224		return nil, fmt.Errorf("no files provided")
225	}
226	// Use the first architecture file as the baseline
227	baseSet, err := getCodeSet(files[0].src)
228	if err != nil {
229		return nil, err
230	}
231
232	// Compare baseline set with other architecture files: discard any element,
233	// that doesn't exist in other architecture files.
234	for _, f := range files[1:] {
235		set, err := getCodeSet(f.src)
236		if err != nil {
237			return nil, err
238		}
239
240		baseSet = baseSet.intersection(set)
241	}
242	return baseSet, nil
243}
244
245// getCodeSet returns the set of all top-level consts, types, and funcs from src.
246// src must be string, []byte, or io.Reader (see go/parser.ParseFile docs)
247func getCodeSet(src interface{}) (*codeSet, error) {
248	set := newCodeSet()
249
250	fset := token.NewFileSet()
251	f, err := parser.ParseFile(fset, "", src, parser.ParseComments)
252	if err != nil {
253		return nil, err
254	}
255
256	for _, decl := range f.Decls {
257		switch decl := decl.(type) {
258		case *ast.GenDecl:
259			// Add const, and type declarations
260			if !(decl.Tok == token.CONST || decl.Tok == token.TYPE) {
261				break
262			}
263
264			for _, spec := range decl.Specs {
265				elem, err := newCodeElem(decl.Tok, spec)
266				if err != nil {
267					return nil, err
268				}
269
270				set.add(elem)
271			}
272		case *ast.FuncDecl:
273			// Add func declarations
274			elem, err := newCodeElem(token.FUNC, decl)
275			if err != nil {
276				return nil, err
277			}
278
279			set.add(elem)
280		}
281	}
282
283	// Add file level comments
284	cmap := ast.NewCommentMap(fset, f, f.Comments)
285	for _, cGrp := range cmap[f] {
286		set.add(codeElem{token.COMMENT, cGrp.Text()})
287	}
288
289	return set, nil
290}
291
292// importName returns the identifier (PackageName) for an imported package
293func importName(iSpec *ast.ImportSpec) (string, error) {
294	if iSpec.Name == nil {
295		name, err := strconv.Unquote(iSpec.Path.Value)
296		if err != nil {
297			return "", err
298		}
299		return path.Base(name), nil
300	}
301	return iSpec.Name.Name, nil
302}
303
304// specGroups tracks grouped const/type specs with a map of line: groupID pairs
305type specGroups map[string]int
306
307// add spec source to group
308func (s specGroups) add(src string, groupID int) error {
309	srcBytes, err := format.Source(bytes.TrimSpace([]byte(src)))
310	if err != nil {
311		return err
312	}
313	s[string(srcBytes)] = groupID
314	return nil
315}
316
317// filterEmptyLines removes empty lines within groups of const/type specs.
318// Returns the filtered source.
319func (s specGroups) filterEmptyLines(src io.Reader) ([]byte, error) {
320	scanner := bufio.NewScanner(src)
321	var out bytes.Buffer
322
323	var emptyLines bytes.Buffer
324	prevGroupID := -1 // Initialize to invalid group
325	for scanner.Scan() {
326		line := bytes.TrimSpace(scanner.Bytes())
327
328		if len(line) == 0 {
329			fmt.Fprintf(&emptyLines, "%s\n", scanner.Bytes())
330			continue
331		}
332
333		// Discard emptyLines if previous non-empty line belonged to the same
334		// group as this line
335		if src, err := format.Source(line); err == nil {
336			groupID, ok := s[string(src)]
337			if ok && groupID == prevGroupID {
338				emptyLines.Reset()
339			}
340			prevGroupID = groupID
341		}
342
343		emptyLines.WriteTo(&out)
344		fmt.Fprintf(&out, "%s\n", scanner.Bytes())
345	}
346	if err := scanner.Err(); err != nil {
347		return nil, err
348	}
349	return out.Bytes(), nil
350}
351
352// filterImports removes unused imports from fileSrc, and returns a formatted src.
353func filterImports(fileSrc []byte) ([]byte, error) {
354	fset := token.NewFileSet()
355	file, err := parser.ParseFile(fset, "", fileSrc, parser.ParseComments)
356	if err != nil {
357		return nil, err
358	}
359	cmap := ast.NewCommentMap(fset, file, file.Comments)
360
361	// create set of references to imported identifiers
362	keepImport := make(map[string]bool)
363	for _, u := range file.Unresolved {
364		keepImport[u.Name] = true
365	}
366
367	// filter import declarations
368	decls := file.Decls
369	file.Decls = file.Decls[:0]
370	for _, decl := range decls {
371		importDecl, ok := decl.(*ast.GenDecl)
372
373		// Keep non-import declarations
374		if !ok || importDecl.Tok != token.IMPORT {
375			file.Decls = append(file.Decls, decl)
376			continue
377		}
378
379		// Filter the import specs
380		specs := importDecl.Specs
381		importDecl.Specs = importDecl.Specs[:0]
382		for _, spec := range specs {
383			iSpec := spec.(*ast.ImportSpec)
384			name, err := importName(iSpec)
385			if err != nil {
386				return nil, err
387			}
388
389			if keepImport[name] {
390				importDecl.Specs = append(importDecl.Specs, iSpec)
391			}
392		}
393		if len(importDecl.Specs) > 0 {
394			file.Decls = append(file.Decls, importDecl)
395		}
396	}
397
398	// filter file.Imports
399	imports := file.Imports
400	file.Imports = file.Imports[:0]
401	for _, spec := range imports {
402		name, err := importName(spec)
403		if err != nil {
404			return nil, err
405		}
406
407		if keepImport[name] {
408			file.Imports = append(file.Imports, spec)
409		}
410	}
411	file.Comments = cmap.Filter(file).Comments()
412
413	var buf bytes.Buffer
414	err = format.Node(&buf, fset, file)
415	if err != nil {
416		return nil, err
417	}
418
419	return buf.Bytes(), nil
420}
421
422// merge extracts duplicate code from archFiles and merges it to mergeFile.
423// 1. Construct commonSet: the set of code that is idential in all archFiles.
424// 2. Write the code in commonSet to mergedFile.
425// 3. Remove the commonSet code from all archFiles.
426func merge(mergedFile string, archFiles ...string) error {
427	// extract and validate the GOOS part of the merged filename
428	goos, ok := getValidGOOS(mergedFile)
429	if !ok {
430		return fmt.Errorf("invalid GOOS in merged file name %s", mergedFile)
431	}
432
433	// Read architecture files
434	var inSrc []srcFile
435	for _, file := range archFiles {
436		src, err := ioutil.ReadFile(file)
437		if err != nil {
438			return fmt.Errorf("cannot read archfile %s: %w", file, err)
439		}
440
441		inSrc = append(inSrc, srcFile{file, src})
442	}
443
444	// 1. Construct the set of top-level declarations common for all files
445	commonSet, err := getCommonSet(inSrc)
446	if err != nil {
447		return err
448	}
449	if commonSet.isEmpty() {
450		// No common code => do not modify any files
451		return nil
452	}
453
454	// 2. Write the merged file
455	mergedSrc, err := filter(inSrc[0].src, commonSet.keepCommon)
456	if err != nil {
457		return err
458	}
459
460	f, err := os.Create(mergedFile)
461	if err != nil {
462		return err
463	}
464
465	buf := bufio.NewWriter(f)
466	fmt.Fprintln(buf, "// Code generated by mkmerge.go; DO NOT EDIT.")
467	fmt.Fprintln(buf)
468	fmt.Fprintf(buf, "// +build %s\n", goos)
469	fmt.Fprintln(buf)
470	buf.Write(mergedSrc)
471
472	err = buf.Flush()
473	if err != nil {
474		return err
475	}
476	err = f.Close()
477	if err != nil {
478		return err
479	}
480
481	// 3. Remove duplicate declarations from the architecture files
482	for _, inFile := range inSrc {
483		src, err := filter(inFile.src, commonSet.keepArchSpecific)
484		if err != nil {
485			return err
486		}
487		err = ioutil.WriteFile(inFile.name, src, 0644)
488		if err != nil {
489			return err
490		}
491	}
492	return nil
493}
494
495func main() {
496	var mergedFile string
497	flag.StringVar(&mergedFile, "out", "", "Write merged code to `FILE`")
498	flag.Parse()
499
500	// Expand wildcards
501	var filenames []string
502	for _, arg := range flag.Args() {
503		matches, err := filepath.Glob(arg)
504		if err != nil {
505			fmt.Fprintf(os.Stderr, "Invalid command line argument %q: %v\n", arg, err)
506			os.Exit(1)
507		}
508		filenames = append(filenames, matches...)
509	}
510
511	if len(filenames) < 2 {
512		// No need to merge
513		return
514	}
515
516	err := merge(mergedFile, filenames...)
517	if err != nil {
518		fmt.Fprintf(os.Stderr, "Merge failed with error: %v\n", err)
519		os.Exit(1)
520	}
521}
522