1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build ignore
6// +build ignore
7
8package main
9
10import (
11	"bytes"
12	"fmt"
13	"go/format"
14	"go/types"
15	"io/ioutil"
16	"log"
17	"reflect"
18	"sort"
19	"strings"
20
21	"golang.org/x/tools/go/packages"
22)
23
24var irPkg *types.Package
25var buf bytes.Buffer
26
27func main() {
28	cfg := &packages.Config{
29		Mode: packages.NeedSyntax | packages.NeedTypes,
30	}
31	pkgs, err := packages.Load(cfg, "cmd/compile/internal/ir")
32	if err != nil {
33		log.Fatal(err)
34	}
35	irPkg = pkgs[0].Types
36
37	fmt.Fprintln(&buf, "// Code generated by mknode.go. DO NOT EDIT.")
38	fmt.Fprintln(&buf)
39	fmt.Fprintln(&buf, "package ir")
40	fmt.Fprintln(&buf)
41	fmt.Fprintln(&buf, `import "fmt"`)
42
43	scope := irPkg.Scope()
44	for _, name := range scope.Names() {
45		if strings.HasPrefix(name, "mini") {
46			continue
47		}
48
49		obj, ok := scope.Lookup(name).(*types.TypeName)
50		if !ok {
51			continue
52		}
53		typ := obj.Type().(*types.Named)
54		if !implementsNode(types.NewPointer(typ)) {
55			continue
56		}
57
58		fmt.Fprintf(&buf, "\n")
59		fmt.Fprintf(&buf, "func (n *%s) Format(s fmt.State, verb rune) { fmtNode(n, s, verb) }\n", name)
60
61		switch name {
62		case "Name", "Func":
63			// Too specialized to automate.
64			continue
65		}
66
67		forNodeFields(typ,
68			"func (n *%[1]s) copy() Node { c := *n\n",
69			"",
70			"c.%[1]s = copy%[2]s(c.%[1]s)",
71			"return &c }\n")
72
73		forNodeFields(typ,
74			"func (n *%[1]s) doChildren(do func(Node) bool) bool {\n",
75			"if n.%[1]s != nil && do(n.%[1]s) { return true }",
76			"if do%[2]s(n.%[1]s, do) { return true }",
77			"return false }\n")
78
79		forNodeFields(typ,
80			"func (n *%[1]s) editChildren(edit func(Node) Node) {\n",
81			"if n.%[1]s != nil { n.%[1]s = edit(n.%[1]s).(%[2]s) }",
82			"edit%[2]s(n.%[1]s, edit)",
83			"}\n")
84	}
85
86	makeHelpers()
87
88	out, err := format.Source(buf.Bytes())
89	if err != nil {
90		// write out mangled source so we can see the bug.
91		out = buf.Bytes()
92	}
93
94	err = ioutil.WriteFile("node_gen.go", out, 0666)
95	if err != nil {
96		log.Fatal(err)
97	}
98}
99
100// needHelper maps needed slice helpers from their base name to their
101// respective slice-element type.
102var needHelper = map[string]string{}
103
104func makeHelpers() {
105	var names []string
106	for name := range needHelper {
107		names = append(names, name)
108	}
109	sort.Strings(names)
110
111	for _, name := range names {
112		fmt.Fprintf(&buf, sliceHelperTmpl, name, needHelper[name])
113	}
114}
115
116const sliceHelperTmpl = `
117func copy%[1]s(list []%[2]s) []%[2]s {
118	if list == nil {
119		return nil
120	}
121	c := make([]%[2]s, len(list))
122	copy(c, list)
123	return c
124}
125func do%[1]s(list []%[2]s, do func(Node) bool) bool {
126	for _, x := range list {
127		if x != nil && do(x) {
128			return true
129		}
130	}
131	return false
132}
133func edit%[1]s(list []%[2]s, edit func(Node) Node) {
134	for i, x := range list {
135		if x != nil {
136			list[i] = edit(x).(%[2]s)
137		}
138	}
139}
140`
141
142func forNodeFields(named *types.Named, prologue, singleTmpl, sliceTmpl, epilogue string) {
143	fmt.Fprintf(&buf, prologue, named.Obj().Name())
144
145	anyField(named.Underlying().(*types.Struct), func(f *types.Var) bool {
146		if f.Embedded() {
147			return false
148		}
149		name, typ := f.Name(), f.Type()
150
151		slice, _ := typ.Underlying().(*types.Slice)
152		if slice != nil {
153			typ = slice.Elem()
154		}
155
156		tmpl, what := singleTmpl, types.TypeString(typ, types.RelativeTo(irPkg))
157		if implementsNode(typ) {
158			if slice != nil {
159				helper := strings.TrimPrefix(what, "*") + "s"
160				needHelper[helper] = what
161				tmpl, what = sliceTmpl, helper
162			}
163		} else if what == "*Field" {
164			// Special case for *Field.
165			tmpl = sliceTmpl
166			if slice != nil {
167				what = "Fields"
168			} else {
169				what = "Field"
170			}
171		} else {
172			return false
173		}
174
175		if tmpl == "" {
176			return false
177		}
178
179		// Allow template to not use all arguments without
180		// upsetting fmt.Printf.
181		s := fmt.Sprintf(tmpl+"\x00 %[1]s %[2]s", name, what)
182		fmt.Fprintln(&buf, s[:strings.LastIndex(s, "\x00")])
183		return false
184	})
185
186	fmt.Fprintf(&buf, epilogue)
187}
188
189func implementsNode(typ types.Type) bool {
190	if _, ok := typ.Underlying().(*types.Interface); ok {
191		// TODO(mdempsky): Check the interface implements Node.
192		// Worst case, node_gen.go will fail to compile if we're wrong.
193		return true
194	}
195
196	if ptr, ok := typ.(*types.Pointer); ok {
197		if str, ok := ptr.Elem().Underlying().(*types.Struct); ok {
198			return anyField(str, func(f *types.Var) bool {
199				return f.Embedded() && f.Name() == "miniNode"
200			})
201		}
202	}
203
204	return false
205}
206
207func anyField(typ *types.Struct, pred func(f *types.Var) bool) bool {
208	for i, n := 0, typ.NumFields(); i < n; i++ {
209		if value, ok := reflect.StructTag(typ.Tag(i)).Lookup("mknode"); ok {
210			if value != "-" {
211				panic(fmt.Sprintf("unexpected tag value: %q", value))
212			}
213			continue
214		}
215
216		f := typ.Field(i)
217		if pred(f) {
218			return true
219		}
220		if f.Embedded() {
221			if typ, ok := f.Type().Underlying().(*types.Struct); ok {
222				if anyField(typ, pred) {
223					return true
224				}
225			}
226		}
227	}
228	return false
229}
230