1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gofmt
6
7import (
8	"fmt"
9	"go/ast"
10	"go/parser"
11	"go/token"
12	"os"
13	"reflect"
14	"strings"
15	"unicode"
16	"unicode/utf8"
17)
18
19func initRewrite() {
20	if *rewriteRule == "" {
21		rewrite = nil // disable any previous rewrite
22		return
23	}
24	f := strings.Split(*rewriteRule, "->")
25	if len(f) != 2 {
26		fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
27		os.Exit(2)
28	}
29	pattern := parseExpr(f[0], "pattern")
30	replace := parseExpr(f[1], "replacement")
31	rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
32}
33
34// parseExpr parses s as an expression.
35// It might make sense to expand this to allow statement patterns,
36// but there are problems with preserving formatting and also
37// with what a wildcard for a statement looks like.
38func parseExpr(s, what string) ast.Expr {
39	x, err := parser.ParseExpr(s)
40	if err != nil {
41		fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
42		os.Exit(2)
43	}
44	return x
45}
46
47// Keep this function for debugging.
48/*
49func dump(msg string, val reflect.Value) {
50	fmt.Printf("%s:\n", msg)
51	ast.Print(fileSet, val.Interface())
52	fmt.Println()
53}
54*/
55
56// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
57func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
58	cmap := ast.NewCommentMap(fileSet, p, p.Comments)
59	m := make(map[string]reflect.Value)
60	pat := reflect.ValueOf(pattern)
61	repl := reflect.ValueOf(replace)
62
63	var rewriteVal func(val reflect.Value) reflect.Value
64	rewriteVal = func(val reflect.Value) reflect.Value {
65		// don't bother if val is invalid to start with
66		if !val.IsValid() {
67			return reflect.Value{}
68		}
69		val = apply(rewriteVal, val)
70		for k := range m {
71			delete(m, k)
72		}
73		if match(m, pat, val) {
74			val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
75		}
76		return val
77	}
78
79	r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
80	r.Comments = cmap.Filter(r).Comments() // recreate comments list
81	return r
82}
83
84// set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
85func set(x, y reflect.Value) {
86	// don't bother if x cannot be set or y is invalid
87	if !x.CanSet() || !y.IsValid() {
88		return
89	}
90	defer func() {
91		if x := recover(); x != nil {
92			if s, ok := x.(string); ok &&
93				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
94				// x cannot be set to y - ignore this rewrite
95				return
96			}
97			panic(x)
98		}
99	}()
100	x.Set(y)
101}
102
103// Values/types for special cases.
104var (
105	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
106	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
107
108	identType     = reflect.TypeOf((*ast.Ident)(nil))
109	objectPtrType = reflect.TypeOf((*ast.Object)(nil))
110	positionType  = reflect.TypeOf(token.NoPos)
111	callExprType  = reflect.TypeOf((*ast.CallExpr)(nil))
112	scopePtrType  = reflect.TypeOf((*ast.Scope)(nil))
113)
114
115// apply replaces each AST field x in val with f(x), returning val.
116// To avoid extra conversions, f operates on the reflect.Value form.
117func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
118	if !val.IsValid() {
119		return reflect.Value{}
120	}
121
122	// *ast.Objects introduce cycles and are likely incorrect after
123	// rewrite; don't follow them but replace with nil instead
124	if val.Type() == objectPtrType {
125		return objectPtrNil
126	}
127
128	// similarly for scopes: they are likely incorrect after a rewrite;
129	// replace them with nil
130	if val.Type() == scopePtrType {
131		return scopePtrNil
132	}
133
134	switch v := reflect.Indirect(val); v.Kind() {
135	case reflect.Slice:
136		for i := 0; i < v.Len(); i++ {
137			e := v.Index(i)
138			set(e, f(e))
139		}
140	case reflect.Struct:
141		for i := 0; i < v.NumField(); i++ {
142			e := v.Field(i)
143			set(e, f(e))
144		}
145	case reflect.Interface:
146		e := v.Elem()
147		set(v, f(e))
148	}
149	return val
150}
151
152func isWildcard(s string) bool {
153	rune, size := utf8.DecodeRuneInString(s)
154	return size == len(s) && unicode.IsLower(rune)
155}
156
157// match reports whether pattern matches val,
158// recording wildcard submatches in m.
159// If m == nil, match checks whether pattern == val.
160func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
161	// Wildcard matches any expression. If it appears multiple
162	// times in the pattern, it must match the same expression
163	// each time.
164	if m != nil && pattern.IsValid() && pattern.Type() == identType {
165		name := pattern.Interface().(*ast.Ident).Name
166		if isWildcard(name) && val.IsValid() {
167			// wildcards only match valid (non-nil) expressions.
168			if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
169				if old, ok := m[name]; ok {
170					return match(nil, old, val)
171				}
172				m[name] = val
173				return true
174			}
175		}
176	}
177
178	// Otherwise, pattern and val must match recursively.
179	if !pattern.IsValid() || !val.IsValid() {
180		return !pattern.IsValid() && !val.IsValid()
181	}
182	if pattern.Type() != val.Type() {
183		return false
184	}
185
186	// Special cases.
187	switch pattern.Type() {
188	case identType:
189		// For identifiers, only the names need to match
190		// (and none of the other *ast.Object information).
191		// This is a common case, handle it all here instead
192		// of recursing down any further via reflection.
193		p := pattern.Interface().(*ast.Ident)
194		v := val.Interface().(*ast.Ident)
195		return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
196	case objectPtrType, positionType:
197		// object pointers and token positions always match
198		return true
199	case callExprType:
200		// For calls, the Ellipsis fields (token.Position) must
201		// match since that is how f(x) and f(x...) are different.
202		// Check them here but fall through for the remaining fields.
203		p := pattern.Interface().(*ast.CallExpr)
204		v := val.Interface().(*ast.CallExpr)
205		if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
206			return false
207		}
208	}
209
210	p := reflect.Indirect(pattern)
211	v := reflect.Indirect(val)
212	if !p.IsValid() || !v.IsValid() {
213		return !p.IsValid() && !v.IsValid()
214	}
215
216	switch p.Kind() {
217	case reflect.Slice:
218		if p.Len() != v.Len() {
219			return false
220		}
221		for i := 0; i < p.Len(); i++ {
222			if !match(m, p.Index(i), v.Index(i)) {
223				return false
224			}
225		}
226		return true
227
228	case reflect.Struct:
229		for i := 0; i < p.NumField(); i++ {
230			if !match(m, p.Field(i), v.Field(i)) {
231				return false
232			}
233		}
234		return true
235
236	case reflect.Interface:
237		return match(m, p.Elem(), v.Elem())
238	}
239
240	// Handle token integers, etc.
241	return p.Interface() == v.Interface()
242}
243
244// subst returns a copy of pattern with values from m substituted in place
245// of wildcards and pos used as the position of tokens from the pattern.
246// if m == nil, subst returns a copy of pattern and doesn't change the line
247// number information.
248func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
249	if !pattern.IsValid() {
250		return reflect.Value{}
251	}
252
253	// Wildcard gets replaced with map value.
254	if m != nil && pattern.Type() == identType {
255		name := pattern.Interface().(*ast.Ident).Name
256		if isWildcard(name) {
257			if old, ok := m[name]; ok {
258				return subst(nil, old, reflect.Value{})
259			}
260		}
261	}
262
263	if pos.IsValid() && pattern.Type() == positionType {
264		// use new position only if old position was valid in the first place
265		if old := pattern.Interface().(token.Pos); !old.IsValid() {
266			return pattern
267		}
268		return pos
269	}
270
271	// Otherwise copy.
272	switch p := pattern; p.Kind() {
273	case reflect.Slice:
274		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
275		for i := 0; i < p.Len(); i++ {
276			v.Index(i).Set(subst(m, p.Index(i), pos))
277		}
278		return v
279
280	case reflect.Struct:
281		v := reflect.New(p.Type()).Elem()
282		for i := 0; i < p.NumField(); i++ {
283			v.Field(i).Set(subst(m, p.Field(i), pos))
284		}
285		return v
286
287	case reflect.Ptr:
288		v := reflect.New(p.Type()).Elem()
289		if elem := p.Elem(); elem.IsValid() {
290			v.Set(subst(m, elem, pos).Addr())
291		}
292		return v
293
294	case reflect.Interface:
295		v := reflect.New(p.Type()).Elem()
296		if elem := p.Elem(); elem.IsValid() {
297			v.Set(subst(m, elem, pos))
298		}
299		return v
300	}
301
302	return pattern
303}
304