1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package eg
6
7// This file defines the AST rewriting pass.
8// Most of it was plundered directly from
9// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
10
11import (
12	"fmt"
13	"go/ast"
14	"go/token"
15	"go/types"
16	"os"
17	"reflect"
18	"sort"
19	"strconv"
20	"strings"
21
22	"golang.org/x/tools/go/ast/astutil"
23)
24
25// transformItem takes a reflect.Value representing a variable of type ast.Node
26// transforms its child elements recursively with apply, and then transforms the
27// actual element if it contains an expression.
28func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
29	// don't bother if val is invalid to start with
30	if !rv.IsValid() {
31		return reflect.Value{}, false, nil
32	}
33
34	rv, changed, newEnv := tr.apply(tr.transformItem, rv)
35
36	e := rvToExpr(rv)
37	if e == nil {
38		return rv, changed, newEnv
39	}
40
41	savedEnv := tr.env
42	tr.env = make(map[string]ast.Expr) // inefficient!  Use a slice of k/v pairs
43
44	if tr.matchExpr(tr.before, e) {
45		if tr.verbose {
46			fmt.Fprintf(os.Stderr, "%s matches %s",
47				astString(tr.fset, tr.before), astString(tr.fset, e))
48			if len(tr.env) > 0 {
49				fmt.Fprintf(os.Stderr, " with:")
50				for name, ast := range tr.env {
51					fmt.Fprintf(os.Stderr, " %s->%s",
52						name, astString(tr.fset, ast))
53				}
54			}
55			fmt.Fprintf(os.Stderr, "\n")
56		}
57		tr.nsubsts++
58
59		// Clone the replacement tree, performing parameter substitution.
60		// We update all positions to n.Pos() to aid comment placement.
61		rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
62			reflect.ValueOf(e.Pos()))
63		changed = true
64		newEnv = tr.env
65	}
66	tr.env = savedEnv
67
68	return rv, changed, newEnv
69}
70
71// Transform applies the transformation to the specified parsed file,
72// whose type information is supplied in info, and returns the number
73// of replacements that were made.
74//
75// It mutates the AST in place (the identity of the root node is
76// unchanged), and may add nodes for which no type information is
77// available in info.
78//
79// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
80//
81func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
82	if !tr.seenInfos[info] {
83		tr.seenInfos[info] = true
84		mergeTypeInfo(tr.info, info)
85	}
86	tr.currentPkg = pkg
87	tr.nsubsts = 0
88
89	if tr.verbose {
90		fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
91		fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
92		fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
93	}
94
95	o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
96	if changed {
97		panic("BUG")
98	}
99	file2 := o.Interface().(*ast.File)
100
101	// By construction, the root node is unchanged.
102	if file != file2 {
103		panic("BUG")
104	}
105
106	// Add any necessary imports.
107	// TODO(adonovan): remove no-longer needed imports too.
108	if tr.nsubsts > 0 {
109		pkgs := make(map[string]*types.Package)
110		for obj := range tr.importedObjs {
111			pkgs[obj.Pkg().Path()] = obj.Pkg()
112		}
113
114		for _, imp := range file.Imports {
115			path, _ := strconv.Unquote(imp.Path.Value)
116			delete(pkgs, path)
117		}
118		delete(pkgs, pkg.Path()) // don't import self
119
120		// NB: AddImport may completely replace the AST!
121		// It thus renders info and tr.info no longer relevant to file.
122		var paths []string
123		for path := range pkgs {
124			paths = append(paths, path)
125		}
126		sort.Strings(paths)
127		for _, path := range paths {
128			astutil.AddImport(tr.fset, file, path)
129		}
130	}
131
132	tr.currentPkg = nil
133
134	return tr.nsubsts
135}
136
137// setValue is a wrapper for x.SetValue(y); it protects
138// the caller from panics if x cannot be changed to y.
139func setValue(x, y reflect.Value) {
140	// don't bother if y is invalid to start with
141	if !y.IsValid() {
142		return
143	}
144	defer func() {
145		if x := recover(); x != nil {
146			if s, ok := x.(string); ok &&
147				(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
148				// x cannot be set to y - ignore this rewrite
149				return
150			}
151			panic(x)
152		}
153	}()
154	x.Set(y)
155}
156
157// Values/types for special cases.
158var (
159	objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
160	scopePtrNil  = reflect.ValueOf((*ast.Scope)(nil))
161
162	identType        = reflect.TypeOf((*ast.Ident)(nil))
163	selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
164	objectPtrType    = reflect.TypeOf((*ast.Object)(nil))
165	statementType    = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
166	positionType     = reflect.TypeOf(token.NoPos)
167	scopePtrType     = reflect.TypeOf((*ast.Scope)(nil))
168)
169
170// apply replaces each AST field x in val with f(x), returning val.
171// To avoid extra conversions, f operates on the reflect.Value form.
172// f takes a reflect.Value representing the variable to modify of type ast.Node.
173// It returns a reflect.Value containing the transformed value of type ast.Node,
174// whether any change was made, and a map of identifiers to ast.Expr (so we can
175// do contextually correct substitutions in the parent statements).
176func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
177	if !val.IsValid() {
178		return reflect.Value{}, false, nil
179	}
180
181	// *ast.Objects introduce cycles and are likely incorrect after
182	// rewrite; don't follow them but replace with nil instead
183	if val.Type() == objectPtrType {
184		return objectPtrNil, false, nil
185	}
186
187	// similarly for scopes: they are likely incorrect after a rewrite;
188	// replace them with nil
189	if val.Type() == scopePtrType {
190		return scopePtrNil, false, nil
191	}
192
193	switch v := reflect.Indirect(val); v.Kind() {
194	case reflect.Slice:
195		// no possible rewriting of statements.
196		if v.Type().Elem() != statementType {
197			changed := false
198			var envp map[string]ast.Expr
199			for i := 0; i < v.Len(); i++ {
200				e := v.Index(i)
201				o, localchanged, env := f(e)
202				if localchanged {
203					changed = true
204					// we clobber envp here,
205					// which means if we have two successive
206					// replacements inside the same statement
207					// we will only generate the setup for one of them.
208					envp = env
209				}
210				setValue(e, o)
211			}
212			return val, changed, envp
213		}
214
215		// statements are rewritten.
216		var out []ast.Stmt
217		for i := 0; i < v.Len(); i++ {
218			e := v.Index(i)
219			o, changed, env := f(e)
220			if changed {
221				for _, s := range tr.afterStmts {
222					t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
223					out = append(out, t.(ast.Stmt))
224				}
225			}
226			setValue(e, o)
227			out = append(out, e.Interface().(ast.Stmt))
228		}
229		return reflect.ValueOf(out), false, nil
230	case reflect.Struct:
231		changed := false
232		var envp map[string]ast.Expr
233		for i := 0; i < v.NumField(); i++ {
234			e := v.Field(i)
235			o, localchanged, env := f(e)
236			if localchanged {
237				changed = true
238				envp = env
239			}
240			setValue(e, o)
241		}
242		return val, changed, envp
243	case reflect.Interface:
244		e := v.Elem()
245		o, changed, env := f(e)
246		setValue(v, o)
247		return val, changed, env
248	}
249	return val, false, nil
250}
251
252// subst returns a copy of (replacement) pattern with values from env
253// substituted in place of wildcards and pos used as the position of
254// tokens from the pattern.  if env == nil, subst returns a copy of
255// pattern and doesn't change the line number information.
256func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
257	if !pattern.IsValid() {
258		return reflect.Value{}
259	}
260
261	// *ast.Objects introduce cycles and are likely incorrect after
262	// rewrite; don't follow them but replace with nil instead
263	if pattern.Type() == objectPtrType {
264		return objectPtrNil
265	}
266
267	// similarly for scopes: they are likely incorrect after a rewrite;
268	// replace them with nil
269	if pattern.Type() == scopePtrType {
270		return scopePtrNil
271	}
272
273	// Wildcard gets replaced with map value.
274	if env != nil && pattern.Type() == identType {
275		id := pattern.Interface().(*ast.Ident)
276		if old, ok := env[id.Name]; ok {
277			return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
278		}
279	}
280
281	// Emit qualified identifiers in the pattern by appropriate
282	// (possibly qualified) identifier in the input.
283	//
284	// The template cannot contain dot imports, so all identifiers
285	// for imported objects are explicitly qualified.
286	//
287	// We assume (unsoundly) that there are no dot or named
288	// imports in the input code, nor are any imported package
289	// names shadowed, so the usual normal qualified identifier
290	// syntax may be used.
291	// TODO(adonovan): fix: avoid this assumption.
292	//
293	// A refactoring may be applied to a package referenced by the
294	// template.  Objects belonging to the current package are
295	// denoted by unqualified identifiers.
296	//
297	if tr.importedObjs != nil && pattern.Type() == selectorExprType {
298		obj := isRef(pattern.Interface().(*ast.SelectorExpr), tr.info)
299		if obj != nil {
300			if sel, ok := tr.importedObjs[obj]; ok {
301				var id ast.Expr
302				if obj.Pkg() == tr.currentPkg {
303					id = sel.Sel // unqualified
304				} else {
305					id = sel // pkg-qualified
306				}
307
308				// Return a clone of id.
309				saved := tr.importedObjs
310				tr.importedObjs = nil // break cycle
311				r := tr.subst(nil, reflect.ValueOf(id), pos)
312				tr.importedObjs = saved
313				return r
314			}
315		}
316	}
317
318	if pos.IsValid() && pattern.Type() == positionType {
319		// use new position only if old position was valid in the first place
320		if old := pattern.Interface().(token.Pos); !old.IsValid() {
321			return pattern
322		}
323		return pos
324	}
325
326	// Otherwise copy.
327	switch p := pattern; p.Kind() {
328	case reflect.Slice:
329		v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
330		for i := 0; i < p.Len(); i++ {
331			v.Index(i).Set(tr.subst(env, p.Index(i), pos))
332		}
333		return v
334
335	case reflect.Struct:
336		v := reflect.New(p.Type()).Elem()
337		for i := 0; i < p.NumField(); i++ {
338			v.Field(i).Set(tr.subst(env, p.Field(i), pos))
339		}
340		return v
341
342	case reflect.Ptr:
343		v := reflect.New(p.Type()).Elem()
344		if elem := p.Elem(); elem.IsValid() {
345			v.Set(tr.subst(env, elem, pos).Addr())
346		}
347
348		// Duplicate type information for duplicated ast.Expr.
349		// All ast.Node implementations are *structs,
350		// so this case catches them all.
351		if e := rvToExpr(v); e != nil {
352			updateTypeInfo(tr.info, e, p.Interface().(ast.Expr))
353		}
354		return v
355
356	case reflect.Interface:
357		v := reflect.New(p.Type()).Elem()
358		if elem := p.Elem(); elem.IsValid() {
359			v.Set(tr.subst(env, elem, pos))
360		}
361		return v
362	}
363
364	return pattern
365}
366
367// -- utilities -------------------------------------------------------
368
369func rvToExpr(rv reflect.Value) ast.Expr {
370	if rv.CanInterface() {
371		if e, ok := rv.Interface().(ast.Expr); ok {
372			return e
373		}
374	}
375	return nil
376}
377
378// updateTypeInfo duplicates type information for the existing AST old
379// so that it also applies to duplicated AST new.
380func updateTypeInfo(info *types.Info, new, old ast.Expr) {
381	switch new := new.(type) {
382	case *ast.Ident:
383		orig := old.(*ast.Ident)
384		if obj, ok := info.Defs[orig]; ok {
385			info.Defs[new] = obj
386		}
387		if obj, ok := info.Uses[orig]; ok {
388			info.Uses[new] = obj
389		}
390
391	case *ast.SelectorExpr:
392		orig := old.(*ast.SelectorExpr)
393		if sel, ok := info.Selections[orig]; ok {
394			info.Selections[new] = sel
395		}
396	}
397
398	if tv, ok := info.Types[old]; ok {
399		info.Types[new] = tv
400	}
401}
402