1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssagen
6
7import (
8	"bytes"
9	"fmt"
10
11	"cmd/compile/internal/base"
12	"cmd/compile/internal/ir"
13	"cmd/compile/internal/typecheck"
14	"cmd/compile/internal/types"
15	"cmd/internal/obj"
16	"cmd/internal/src"
17)
18
19func EnableNoWriteBarrierRecCheck() {
20	nowritebarrierrecCheck = newNowritebarrierrecChecker()
21}
22
23func NoWriteBarrierRecCheck() {
24	// Write barriers are now known. Check the
25	// call graph.
26	nowritebarrierrecCheck.check()
27	nowritebarrierrecCheck = nil
28}
29
30var nowritebarrierrecCheck *nowritebarrierrecChecker
31
32type nowritebarrierrecChecker struct {
33	// extraCalls contains extra function calls that may not be
34	// visible during later analysis. It maps from the ODCLFUNC of
35	// the caller to a list of callees.
36	extraCalls map[*ir.Func][]nowritebarrierrecCall
37
38	// curfn is the current function during AST walks.
39	curfn *ir.Func
40}
41
42type nowritebarrierrecCall struct {
43	target *ir.Func // caller or callee
44	lineno src.XPos // line of call
45}
46
47// newNowritebarrierrecChecker creates a nowritebarrierrecChecker. It
48// must be called before walk
49func newNowritebarrierrecChecker() *nowritebarrierrecChecker {
50	c := &nowritebarrierrecChecker{
51		extraCalls: make(map[*ir.Func][]nowritebarrierrecCall),
52	}
53
54	// Find all systemstack calls and record their targets. In
55	// general, flow analysis can't see into systemstack, but it's
56	// important to handle it for this check, so we model it
57	// directly. This has to happen before transforming closures in walk since
58	// it's a lot harder to work out the argument after.
59	for _, n := range typecheck.Target.Decls {
60		if n.Op() != ir.ODCLFUNC {
61			continue
62		}
63		c.curfn = n.(*ir.Func)
64		if c.curfn.ABIWrapper() {
65			// We only want "real" calls to these
66			// functions, not the generated ones within
67			// their own ABI wrappers.
68			continue
69		}
70		ir.Visit(n, c.findExtraCalls)
71	}
72	c.curfn = nil
73	return c
74}
75
76func (c *nowritebarrierrecChecker) findExtraCalls(nn ir.Node) {
77	if nn.Op() != ir.OCALLFUNC {
78		return
79	}
80	n := nn.(*ir.CallExpr)
81	if n.X == nil || n.X.Op() != ir.ONAME {
82		return
83	}
84	fn := n.X.(*ir.Name)
85	if fn.Class != ir.PFUNC || fn.Defn == nil {
86		return
87	}
88	if !types.IsRuntimePkg(fn.Sym().Pkg) || fn.Sym().Name != "systemstack" {
89		return
90	}
91
92	var callee *ir.Func
93	arg := n.Args[0]
94	switch arg.Op() {
95	case ir.ONAME:
96		arg := arg.(*ir.Name)
97		callee = arg.Defn.(*ir.Func)
98	case ir.OCLOSURE:
99		arg := arg.(*ir.ClosureExpr)
100		callee = arg.Func
101	default:
102		base.Fatalf("expected ONAME or OCLOSURE node, got %+v", arg)
103	}
104	if callee.Op() != ir.ODCLFUNC {
105		base.Fatalf("expected ODCLFUNC node, got %+v", callee)
106	}
107	c.extraCalls[c.curfn] = append(c.extraCalls[c.curfn], nowritebarrierrecCall{callee, n.Pos()})
108}
109
110// recordCall records a call from ODCLFUNC node "from", to function
111// symbol "to" at position pos.
112//
113// This should be done as late as possible during compilation to
114// capture precise call graphs. The target of the call is an LSym
115// because that's all we know after we start SSA.
116//
117// This can be called concurrently for different from Nodes.
118func (c *nowritebarrierrecChecker) recordCall(fn *ir.Func, to *obj.LSym, pos src.XPos) {
119	// We record this information on the *Func so this is concurrent-safe.
120	if fn.NWBRCalls == nil {
121		fn.NWBRCalls = new([]ir.SymAndPos)
122	}
123	*fn.NWBRCalls = append(*fn.NWBRCalls, ir.SymAndPos{Sym: to, Pos: pos})
124}
125
126func (c *nowritebarrierrecChecker) check() {
127	// We walk the call graph as late as possible so we can
128	// capture all calls created by lowering, but this means we
129	// only get to see the obj.LSyms of calls. symToFunc lets us
130	// get back to the ODCLFUNCs.
131	symToFunc := make(map[*obj.LSym]*ir.Func)
132	// funcs records the back-edges of the BFS call graph walk. It
133	// maps from the ODCLFUNC of each function that must not have
134	// write barriers to the call that inhibits them. Functions
135	// that are directly marked go:nowritebarrierrec are in this
136	// map with a zero-valued nowritebarrierrecCall. This also
137	// acts as the set of marks for the BFS of the call graph.
138	funcs := make(map[*ir.Func]nowritebarrierrecCall)
139	// q is the queue of ODCLFUNC Nodes to visit in BFS order.
140	var q ir.NameQueue
141
142	for _, n := range typecheck.Target.Decls {
143		if n.Op() != ir.ODCLFUNC {
144			continue
145		}
146		fn := n.(*ir.Func)
147
148		symToFunc[fn.LSym] = fn
149
150		// Make nowritebarrierrec functions BFS roots.
151		if fn.Pragma&ir.Nowritebarrierrec != 0 {
152			funcs[fn] = nowritebarrierrecCall{}
153			q.PushRight(fn.Nname)
154		}
155		// Check go:nowritebarrier functions.
156		if fn.Pragma&ir.Nowritebarrier != 0 && fn.WBPos.IsKnown() {
157			base.ErrorfAt(fn.WBPos, "write barrier prohibited")
158		}
159	}
160
161	// Perform a BFS of the call graph from all
162	// go:nowritebarrierrec functions.
163	enqueue := func(src, target *ir.Func, pos src.XPos) {
164		if target.Pragma&ir.Yeswritebarrierrec != 0 {
165			// Don't flow into this function.
166			return
167		}
168		if _, ok := funcs[target]; ok {
169			// Already found a path to target.
170			return
171		}
172
173		// Record the path.
174		funcs[target] = nowritebarrierrecCall{target: src, lineno: pos}
175		q.PushRight(target.Nname)
176	}
177	for !q.Empty() {
178		fn := q.PopLeft().Func
179
180		// Check fn.
181		if fn.WBPos.IsKnown() {
182			var err bytes.Buffer
183			call := funcs[fn]
184			for call.target != nil {
185				fmt.Fprintf(&err, "\n\t%v: called by %v", base.FmtPos(call.lineno), call.target.Nname)
186				call = funcs[call.target]
187			}
188			base.ErrorfAt(fn.WBPos, "write barrier prohibited by caller; %v%s", fn.Nname, err.String())
189			continue
190		}
191
192		// Enqueue fn's calls.
193		for _, callee := range c.extraCalls[fn] {
194			enqueue(fn, callee.target, callee.lineno)
195		}
196		if fn.NWBRCalls == nil {
197			continue
198		}
199		for _, callee := range *fn.NWBRCalls {
200			target := symToFunc[callee.Sym]
201			if target != nil {
202				enqueue(fn, target, callee.Pos)
203			}
204		}
205	}
206}
207