1// Copyright 2016 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// +build ignore 6 7// This program is run via "go generate" (via a directive in sort.go) 8// to generate zfuncversion.go. 9// 10// It copies sort.go to zfuncversion.go, only retaining funcs which 11// take a "data Interface" parameter, and renaming each to have a 12// "_func" suffix and taking a "data lessSwap" instead. It then rewrites 13// each internal function call to the appropriate _func variants. 14 15package main 16 17import ( 18 "bytes" 19 "go/ast" 20 "go/format" 21 "go/parser" 22 "go/token" 23 "log" 24 "os" 25 "regexp" 26) 27 28var fset = token.NewFileSet() 29 30func main() { 31 af, err := parser.ParseFile(fset, "sort.go", nil, 0) 32 if err != nil { 33 log.Fatal(err) 34 } 35 af.Doc = nil 36 af.Imports = nil 37 af.Comments = nil 38 39 var newDecl []ast.Decl 40 for _, d := range af.Decls { 41 fd, ok := d.(*ast.FuncDecl) 42 if !ok { 43 continue 44 } 45 if fd.Recv != nil || fd.Name.IsExported() { 46 continue 47 } 48 typ := fd.Type 49 if len(typ.Params.List) < 1 { 50 continue 51 } 52 arg0 := typ.Params.List[0] 53 arg0Name := arg0.Names[0].Name 54 arg0Type := arg0.Type.(*ast.Ident) 55 if arg0Name != "data" || arg0Type.Name != "Interface" { 56 continue 57 } 58 arg0Type.Name = "lessSwap" 59 60 newDecl = append(newDecl, fd) 61 } 62 af.Decls = newDecl 63 ast.Walk(visitFunc(rewriteCalls), af) 64 65 var out bytes.Buffer 66 if err := format.Node(&out, fset, af); err != nil { 67 log.Fatalf("format.Node: %v", err) 68 } 69 70 // Get rid of blank lines after removal of comments. 71 src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n")) 72 73 // Add comments to each func, for the lost reader. 74 // This is so much easier than adding comments via the AST 75 // and trying to get position info correct. 76 src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func")) 77 78 // Final gofmt. 79 src, err = format.Source(src) 80 if err != nil { 81 log.Fatalf("format.Source: %v on\n%s", err, src) 82 } 83 84 out.Reset() 85 out.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT. 86 87// Copyright 2016 The Go Authors. All rights reserved. 88// Use of this source code is governed by a BSD-style 89// license that can be found in the LICENSE file. 90 91`) 92 out.Write(src) 93 94 const target = "zfuncversion.go" 95 if err := os.WriteFile(target, out.Bytes(), 0644); err != nil { 96 log.Fatal(err) 97 } 98} 99 100type visitFunc func(ast.Node) ast.Visitor 101 102func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) } 103 104func rewriteCalls(n ast.Node) ast.Visitor { 105 ce, ok := n.(*ast.CallExpr) 106 if ok { 107 rewriteCall(ce) 108 } 109 return visitFunc(rewriteCalls) 110} 111 112func rewriteCall(ce *ast.CallExpr) { 113 ident, ok := ce.Fun.(*ast.Ident) 114 if !ok { 115 // e.g. skip SelectorExpr (data.Less(..) calls) 116 return 117 } 118 // skip casts 119 if ident.Name == "int" || ident.Name == "uint" { 120 return 121 } 122 if len(ce.Args) < 1 { 123 return 124 } 125 ident.Name += "_func" 126} 127