1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build ignore
6
7// This program is run via "go generate" (via a directive in sort.go)
8// to generate zfuncversion.go.
9//
10// It copies sort.go to zfuncversion.go, only retaining funcs which
11// take a "data Interface" parameter, and renaming each to have a
12// "_func" suffix and taking a "data lessSwap" instead. It then rewrites
13// each internal function call to the appropriate _func variants.
14
15package main
16
17import (
18	"bytes"
19	"go/ast"
20	"go/format"
21	"go/parser"
22	"go/token"
23	"io/ioutil"
24	"log"
25	"regexp"
26)
27
28var fset = token.NewFileSet()
29
30func main() {
31	af, err := parser.ParseFile(fset, "sort.go", nil, 0)
32	if err != nil {
33		log.Fatal(err)
34	}
35	af.Doc = nil
36	af.Imports = nil
37	af.Comments = nil
38
39	var newDecl []ast.Decl
40	for _, d := range af.Decls {
41		fd, ok := d.(*ast.FuncDecl)
42		if !ok {
43			continue
44		}
45		if fd.Recv != nil || fd.Name.IsExported() {
46			continue
47		}
48		typ := fd.Type
49		if len(typ.Params.List) < 1 {
50			continue
51		}
52		arg0 := typ.Params.List[0]
53		arg0Name := arg0.Names[0].Name
54		arg0Type := arg0.Type.(*ast.Ident)
55		if arg0Name != "data" || arg0Type.Name != "Interface" {
56			continue
57		}
58		arg0Type.Name = "lessSwap"
59
60		newDecl = append(newDecl, fd)
61	}
62	af.Decls = newDecl
63	ast.Walk(visitFunc(rewriteCalls), af)
64
65	var out bytes.Buffer
66	if err := format.Node(&out, fset, af); err != nil {
67		log.Fatalf("format.Node: %v", err)
68	}
69
70	// Get rid of blank lines after removal of comments.
71	src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
72
73	// Add comments to each func, for the lost reader.
74	// This is so much easier than adding comments via the AST
75	// and trying to get position info correct.
76	src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
77
78	// Final gofmt.
79	src, err = format.Source(src)
80	if err != nil {
81		log.Fatalf("format.Source: %v on\n%s", err, src)
82	}
83
84	out.Reset()
85	out.WriteString(`// DO NOT EDIT; AUTO-GENERATED from sort.go using genzfunc.go
86
87// Copyright 2016 The Go Authors. All rights reserved.
88// Use of this source code is governed by a BSD-style
89// license that can be found in the LICENSE file.
90
91`)
92	out.Write(src)
93
94	const target = "zfuncversion.go"
95	if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
96		log.Fatal(err)
97	}
98}
99
100type visitFunc func(ast.Node) ast.Visitor
101
102func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
103
104func rewriteCalls(n ast.Node) ast.Visitor {
105	ce, ok := n.(*ast.CallExpr)
106	if ok {
107		rewriteCall(ce)
108	}
109	return visitFunc(rewriteCalls)
110}
111
112func rewriteCall(ce *ast.CallExpr) {
113	ident, ok := ce.Fun.(*ast.Ident)
114	if !ok {
115		// e.g. skip SelectorExpr (data.Less(..) calls)
116		return
117	}
118	// skip casts
119	if ident.Name == "int" || ident.Name == "uint" {
120		return
121	}
122	if len(ce.Args) < 1 {
123		return
124	}
125	ident.Name += "_func"
126}
127