1package ssa
2
3type indVar struct {
4	ind   *Value // induction variable
5	inc   *Value // increment, a constant
6	nxt   *Value // ind+inc variable
7	min   *Value // minimum value. inclusive,
8	max   *Value // maximum value. exclusive.
9	entry *Block // entry block in the loop.
10	// Invariants: for all blocks dominated by entry:
11	//	min <= ind < max
12	//	min <= nxt <= max
13}
14
15// findIndVar finds induction variables in a function.
16//
17// Look for variables and blocks that satisfy the following
18//
19// loop:
20//   ind = (Phi min nxt),
21//   if ind < max
22//     then goto enter_loop
23//     else goto exit_loop
24//
25//   enter_loop:
26//	do something
27//      nxt = inc + ind
28//	goto loop
29//
30// exit_loop:
31//
32//
33// TODO: handle 32 bit operations
34func findIndVar(f *Func) []indVar {
35	var iv []indVar
36	sdom := f.sdom()
37
38nextb:
39	for _, b := range f.Blocks {
40		if b.Kind != BlockIf || len(b.Preds) != 2 {
41			continue
42		}
43
44		var ind, max *Value // induction, and maximum
45		entry := -1         // which successor of b enters the loop
46
47		// Check thet the control if it either ind < max or max > ind.
48		// TODO: Handle Leq64, Geq64.
49		switch b.Control.Op {
50		case OpLess64:
51			entry = 0
52			ind, max = b.Control.Args[0], b.Control.Args[1]
53		case OpGreater64:
54			entry = 0
55			ind, max = b.Control.Args[1], b.Control.Args[0]
56		default:
57			continue nextb
58		}
59
60		// Check that the induction variable is a phi that depends on itself.
61		if ind.Op != OpPhi {
62			continue
63		}
64
65		// Extract min and nxt knowing that nxt is an addition (e.g. Add64).
66		var min, nxt *Value // minimum, and next value
67		if n := ind.Args[0]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
68			min, nxt = ind.Args[1], n
69		} else if n := ind.Args[1]; n.Op == OpAdd64 && (n.Args[0] == ind || n.Args[1] == ind) {
70			min, nxt = ind.Args[0], n
71		} else {
72			// Not a recognized induction variable.
73			continue
74		}
75
76		var inc *Value
77		if nxt.Args[0] == ind { // nxt = ind + inc
78			inc = nxt.Args[1]
79		} else if nxt.Args[1] == ind { // nxt = inc + ind
80			inc = nxt.Args[0]
81		} else {
82			panic("unreachable") // one of the cases must be true from the above.
83		}
84
85		// Expect the increment to be a positive constant.
86		// TODO: handle negative increment.
87		if inc.Op != OpConst64 || inc.AuxInt <= 0 {
88			continue
89		}
90
91		// Up to now we extracted the induction variable (ind),
92		// the increment delta (inc), the temporary sum (nxt),
93		// the mininum value (min) and the maximum value (max).
94		//
95		// We also know that ind has the form (Phi min nxt) where
96		// nxt is (Add inc nxt) which means: 1) inc dominates nxt
97		// and 2) there is a loop starting at inc and containing nxt.
98		//
99		// We need to prove that the induction variable is incremented
100		// only when it's smaller than the maximum value.
101		// Two conditions must happen listed below to accept ind
102		// as an induction variable.
103
104		// First condition: loop entry has a single predecessor, which
105		// is the header block.  This implies that b.Succs[entry] is
106		// reached iff ind < max.
107		if len(b.Succs[entry].b.Preds) != 1 {
108			// b.Succs[1-entry] must exit the loop.
109			continue
110		}
111
112		// Second condition: b.Succs[entry] dominates nxt so that
113		// nxt is computed when inc < max, meaning nxt <= max.
114		if !sdom.isAncestorEq(b.Succs[entry].b, nxt.Block) {
115			// inc+ind can only be reached through the branch that enters the loop.
116			continue
117		}
118
119		// If max is c + SliceLen with c <= 0 then we drop c.
120		// Makes sure c + SliceLen doesn't overflow when SliceLen == 0.
121		// TODO: save c as an offset from max.
122		if w, c := dropAdd64(max); (w.Op == OpStringLen || w.Op == OpSliceLen) && 0 >= c && -c >= 0 {
123			max = w
124		}
125
126		// We can only guarantee that the loops runs within limits of induction variable
127		// if the increment is 1 or when the limits are constants.
128		if inc.AuxInt != 1 {
129			ok := false
130			if min.Op == OpConst64 && max.Op == OpConst64 {
131				if max.AuxInt > min.AuxInt && max.AuxInt%inc.AuxInt == min.AuxInt%inc.AuxInt { // handle overflow
132					ok = true
133				}
134			}
135			if !ok {
136				continue
137			}
138		}
139
140		if f.pass.debug > 1 {
141			if min.Op == OpConst64 {
142				b.Func.Warnl(b.Pos, "Induction variable with minimum %d and increment %d", min.AuxInt, inc.AuxInt)
143			} else {
144				b.Func.Warnl(b.Pos, "Induction variable with non-const minimum and increment %d", inc.AuxInt)
145			}
146		}
147
148		iv = append(iv, indVar{
149			ind:   ind,
150			inc:   inc,
151			nxt:   nxt,
152			min:   min,
153			max:   max,
154			entry: b.Succs[entry].b,
155		})
156		b.Logf("found induction variable %v (inc = %v, min = %v, max = %v)\n", ind, inc, min, max)
157	}
158
159	return iv
160}
161
162// loopbce performs loop based bounds check elimination.
163func loopbce(f *Func) {
164	ivList := findIndVar(f)
165
166	m := make(map[*Value]indVar)
167	for _, iv := range ivList {
168		m[iv.ind] = iv
169	}
170
171	removeBoundsChecks(f, m)
172}
173
174// removesBoundsChecks remove IsInBounds and IsSliceInBounds based on the induction variables.
175func removeBoundsChecks(f *Func, m map[*Value]indVar) {
176	sdom := f.sdom()
177	for _, b := range f.Blocks {
178		if b.Kind != BlockIf {
179			continue
180		}
181
182		v := b.Control
183
184		// Simplify:
185		// (IsInBounds ind max) where 0 <= const == min <= ind < max.
186		// (IsSliceInBounds ind max) where 0 <= const == min <= ind < max.
187		// Found in:
188		//	for i := range a {
189		//		use a[i]
190		//		use a[i:]
191		//		use a[:i]
192		//	}
193		if v.Op == OpIsInBounds || v.Op == OpIsSliceInBounds {
194			ind, add := dropAdd64(v.Args[0])
195			if ind.Op != OpPhi {
196				goto skip1
197			}
198			if v.Op == OpIsInBounds && add != 0 {
199				goto skip1
200			}
201			if v.Op == OpIsSliceInBounds && (0 > add || add > 1) {
202				goto skip1
203			}
204
205			if iv, has := m[ind]; has && sdom.isAncestorEq(iv.entry, b) && isNonNegative(iv.min) {
206				if v.Args[1] == iv.max {
207					if f.pass.debug > 0 {
208						f.Warnl(b.Pos, "Found redundant %s", v.Op)
209					}
210					goto simplify
211				}
212			}
213		}
214	skip1:
215
216		// Simplify:
217		// (IsSliceInBounds ind (SliceCap a)) where 0 <= min <= ind < max == (SliceLen a)
218		// Found in:
219		//	for i := range a {
220		//		use a[:i]
221		//		use a[:i+1]
222		//	}
223		if v.Op == OpIsSliceInBounds {
224			ind, add := dropAdd64(v.Args[0])
225			if ind.Op != OpPhi {
226				goto skip2
227			}
228			if 0 > add || add > 1 {
229				goto skip2
230			}
231
232			if iv, has := m[ind]; has && sdom.isAncestorEq(iv.entry, b) && isNonNegative(iv.min) {
233				if v.Args[1].Op == OpSliceCap && iv.max.Op == OpSliceLen && v.Args[1].Args[0] == iv.max.Args[0] {
234					if f.pass.debug > 0 {
235						f.Warnl(b.Pos, "Found redundant %s (len promoted to cap)", v.Op)
236					}
237					goto simplify
238				}
239			}
240		}
241	skip2:
242
243		// Simplify
244		// (IsInBounds (Add64 ind) (Const64 [c])) where 0 <= min <= ind < max <= (Const64 [c])
245		// (IsSliceInBounds ind (Const64 [c])) where 0 <= min <= ind < max <= (Const64 [c])
246		if v.Op == OpIsInBounds || v.Op == OpIsSliceInBounds {
247			ind, add := dropAdd64(v.Args[0])
248			if ind.Op != OpPhi {
249				goto skip3
250			}
251
252			// ind + add >= 0 <-> min + add >= 0 <-> min >= -add
253			if iv, has := m[ind]; has && sdom.isAncestorEq(iv.entry, b) && isGreaterOrEqualThan(iv.min, -add) {
254				if !v.Args[1].isGenericIntConst() || !iv.max.isGenericIntConst() {
255					goto skip3
256				}
257
258				limit := v.Args[1].AuxInt
259				if v.Op == OpIsSliceInBounds {
260					// If limit++ overflows signed integer then 0 <= max && max <= limit will be false.
261					limit++
262				}
263
264				if max := iv.max.AuxInt + add; 0 <= max && max <= limit { // handle overflow
265					if f.pass.debug > 0 {
266						f.Warnl(b.Pos, "Found redundant (%s ind %d), ind < %d", v.Op, v.Args[1].AuxInt, iv.max.AuxInt+add)
267					}
268					goto simplify
269				}
270			}
271		}
272	skip3:
273
274		continue
275
276	simplify:
277		f.Logf("removing bounds check %v at %v in %s\n", b.Control, b, f.Name)
278		b.Kind = BlockFirst
279		b.SetControl(nil)
280	}
281}
282
283func dropAdd64(v *Value) (*Value, int64) {
284	if v.Op == OpAdd64 && v.Args[0].Op == OpConst64 {
285		return v.Args[1], v.Args[0].AuxInt
286	}
287	if v.Op == OpAdd64 && v.Args[1].Op == OpConst64 {
288		return v.Args[0], v.Args[1].AuxInt
289	}
290	return v, 0
291}
292
293func isGreaterOrEqualThan(v *Value, c int64) bool {
294	if c == 0 {
295		return isNonNegative(v)
296	}
297	if v.isGenericIntConst() && v.AuxInt >= c {
298		return true
299	}
300	return false
301}
302