1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"cmd/internal/src"
9	"fmt"
10	"math"
11)
12
13type branch int
14
15const (
16	unknown branch = iota
17	positive
18	negative
19)
20
21// relation represents the set of possible relations between
22// pairs of variables (v, w). Without a priori knowledge the
23// mask is lt | eq | gt meaning v can be less than, equal to or
24// greater than w. When the execution path branches on the condition
25// `v op w` the set of relations is updated to exclude any
26// relation not possible due to `v op w` being true (or false).
27//
28// E.g.
29//
30// r := relation(...)
31//
32// if v < w {
33//   newR := r & lt
34// }
35// if v >= w {
36//   newR := r & (eq|gt)
37// }
38// if v != w {
39//   newR := r & (lt|gt)
40// }
41type relation uint
42
43const (
44	lt relation = 1 << iota
45	eq
46	gt
47)
48
49var relationStrings = [...]string{
50	0: "none", lt: "<", eq: "==", lt | eq: "<=",
51	gt: ">", gt | lt: "!=", gt | eq: ">=", gt | eq | lt: "any",
52}
53
54func (r relation) String() string {
55	if r < relation(len(relationStrings)) {
56		return relationStrings[r]
57	}
58	return fmt.Sprintf("relation(%d)", uint(r))
59}
60
61// domain represents the domain of a variable pair in which a set
62// of relations is known. For example, relations learned for unsigned
63// pairs cannot be transferred to signed pairs because the same bit
64// representation can mean something else.
65type domain uint
66
67const (
68	signed domain = 1 << iota
69	unsigned
70	pointer
71	boolean
72)
73
74var domainStrings = [...]string{
75	"signed", "unsigned", "pointer", "boolean",
76}
77
78func (d domain) String() string {
79	s := ""
80	for i, ds := range domainStrings {
81		if d&(1<<uint(i)) != 0 {
82			if len(s) != 0 {
83				s += "|"
84			}
85			s += ds
86			d &^= 1 << uint(i)
87		}
88	}
89	if d != 0 {
90		if len(s) != 0 {
91			s += "|"
92		}
93		s += fmt.Sprintf("0x%x", uint(d))
94	}
95	return s
96}
97
98type pair struct {
99	v, w *Value // a pair of values, ordered by ID.
100	// v can be nil, to mean the zero value.
101	// for booleans the zero value (v == nil) is false.
102	d domain
103}
104
105// fact is a pair plus a relation for that pair.
106type fact struct {
107	p pair
108	r relation
109}
110
111// a limit records known upper and lower bounds for a value.
112type limit struct {
113	min, max   int64  // min <= value <= max, signed
114	umin, umax uint64 // umin <= value <= umax, unsigned
115}
116
117func (l limit) String() string {
118	return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax)
119}
120
121func (l limit) intersect(l2 limit) limit {
122	if l.min < l2.min {
123		l.min = l2.min
124	}
125	if l.umin < l2.umin {
126		l.umin = l2.umin
127	}
128	if l.max > l2.max {
129		l.max = l2.max
130	}
131	if l.umax > l2.umax {
132		l.umax = l2.umax
133	}
134	return l
135}
136
137var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64}
138
139// a limitFact is a limit known for a particular value.
140type limitFact struct {
141	vid   ID
142	limit limit
143}
144
145// factsTable keeps track of relations between pairs of values.
146//
147// The fact table logic is sound, but incomplete. Outside of a few
148// special cases, it performs no deduction or arithmetic. While there
149// are known decision procedures for this, the ad hoc approach taken
150// by the facts table is effective for real code while remaining very
151// efficient.
152type factsTable struct {
153	// unsat is true if facts contains a contradiction.
154	//
155	// Note that the factsTable logic is incomplete, so if unsat
156	// is false, the assertions in factsTable could be satisfiable
157	// *or* unsatisfiable.
158	unsat      bool // true if facts contains a contradiction
159	unsatDepth int  // number of unsat checkpoints
160
161	facts map[pair]relation // current known set of relation
162	stack []fact            // previous sets of relations
163
164	// order is a couple of partial order sets that record information
165	// about relations between SSA values in the signed and unsigned
166	// domain.
167	orderS *poset
168	orderU *poset
169
170	// known lower and upper bounds on individual values.
171	limits     map[ID]limit
172	limitStack []limitFact // previous entries
173
174	// For each slice s, a map from s to a len(s)/cap(s) value (if any)
175	// TODO: check if there are cases that matter where we have
176	// more than one len(s) for a slice. We could keep a list if necessary.
177	lens map[ID]*Value
178	caps map[ID]*Value
179
180	// zero is a zero-valued constant
181	zero *Value
182}
183
184// checkpointFact is an invalid value used for checkpointing
185// and restoring factsTable.
186var checkpointFact = fact{}
187var checkpointBound = limitFact{}
188
189func newFactsTable(f *Func) *factsTable {
190	ft := &factsTable{}
191	ft.orderS = f.newPoset()
192	ft.orderU = f.newPoset()
193	ft.orderS.SetUnsigned(false)
194	ft.orderU.SetUnsigned(true)
195	ft.facts = make(map[pair]relation)
196	ft.stack = make([]fact, 4)
197	ft.limits = make(map[ID]limit)
198	ft.limitStack = make([]limitFact, 4)
199	ft.zero = f.ConstInt64(f.Config.Types.Int64, 0)
200	return ft
201}
202
203// update updates the set of relations between v and w in domain d
204// restricting it to r.
205func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) {
206	if parent.Func.pass.debug > 2 {
207		parent.Func.Warnl(parent.Pos, "parent=%s, update %s %s %s", parent, v, w, r)
208	}
209	// No need to do anything else if we already found unsat.
210	if ft.unsat {
211		return
212	}
213
214	// Self-fact. It's wasteful to register it into the facts
215	// table, so just note whether it's satisfiable
216	if v == w {
217		if r&eq == 0 {
218			ft.unsat = true
219		}
220		return
221	}
222
223	if d == signed || d == unsigned {
224		var ok bool
225		order := ft.orderS
226		if d == unsigned {
227			order = ft.orderU
228		}
229		switch r {
230		case lt:
231			ok = order.SetOrder(v, w)
232		case gt:
233			ok = order.SetOrder(w, v)
234		case lt | eq:
235			ok = order.SetOrderOrEqual(v, w)
236		case gt | eq:
237			ok = order.SetOrderOrEqual(w, v)
238		case eq:
239			ok = order.SetEqual(v, w)
240		case lt | gt:
241			ok = order.SetNonEqual(v, w)
242		default:
243			panic("unknown relation")
244		}
245		if !ok {
246			if parent.Func.pass.debug > 2 {
247				parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r)
248			}
249			ft.unsat = true
250			return
251		}
252	} else {
253		if lessByID(w, v) {
254			v, w = w, v
255			r = reverseBits[r]
256		}
257
258		p := pair{v, w, d}
259		oldR, ok := ft.facts[p]
260		if !ok {
261			if v == w {
262				oldR = eq
263			} else {
264				oldR = lt | eq | gt
265			}
266		}
267		// No changes compared to information already in facts table.
268		if oldR == r {
269			return
270		}
271		ft.stack = append(ft.stack, fact{p, oldR})
272		ft.facts[p] = oldR & r
273		// If this relation is not satisfiable, mark it and exit right away
274		if oldR&r == 0 {
275			if parent.Func.pass.debug > 2 {
276				parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r)
277			}
278			ft.unsat = true
279			return
280		}
281	}
282
283	// Extract bounds when comparing against constants
284	if v.isGenericIntConst() {
285		v, w = w, v
286		r = reverseBits[r]
287	}
288	if v != nil && w.isGenericIntConst() {
289		// Note: all the +1/-1 below could overflow/underflow. Either will
290		// still generate correct results, it will just lead to imprecision.
291		// In fact if there is overflow/underflow, the corresponding
292		// code is unreachable because the known range is outside the range
293		// of the value's type.
294		old, ok := ft.limits[v.ID]
295		if !ok {
296			old = noLimit
297			if v.isGenericIntConst() {
298				switch d {
299				case signed:
300					old.min, old.max = v.AuxInt, v.AuxInt
301					if v.AuxInt >= 0 {
302						old.umin, old.umax = uint64(v.AuxInt), uint64(v.AuxInt)
303					}
304				case unsigned:
305					old.umin = v.AuxUnsigned()
306					old.umax = old.umin
307					if int64(old.umin) >= 0 {
308						old.min, old.max = int64(old.umin), int64(old.umin)
309					}
310				}
311			}
312		}
313		lim := noLimit
314		switch d {
315		case signed:
316			c := w.AuxInt
317			switch r {
318			case lt:
319				lim.max = c - 1
320			case lt | eq:
321				lim.max = c
322			case gt | eq:
323				lim.min = c
324			case gt:
325				lim.min = c + 1
326			case lt | gt:
327				lim = old
328				if c == lim.min {
329					lim.min++
330				}
331				if c == lim.max {
332					lim.max--
333				}
334			case eq:
335				lim.min = c
336				lim.max = c
337			}
338			if lim.min >= 0 {
339				// int(x) >= 0 && int(x) >= N  ⇒  uint(x) >= N
340				lim.umin = uint64(lim.min)
341			}
342			if lim.max != noLimit.max && old.min >= 0 && lim.max >= 0 {
343				// 0 <= int(x) <= N  ⇒  0 <= uint(x) <= N
344				// This is for a max update, so the lower bound
345				// comes from what we already know (old).
346				lim.umax = uint64(lim.max)
347			}
348		case unsigned:
349			uc := w.AuxUnsigned()
350			switch r {
351			case lt:
352				lim.umax = uc - 1
353			case lt | eq:
354				lim.umax = uc
355			case gt | eq:
356				lim.umin = uc
357			case gt:
358				lim.umin = uc + 1
359			case lt | gt:
360				lim = old
361				if uc == lim.umin {
362					lim.umin++
363				}
364				if uc == lim.umax {
365					lim.umax--
366				}
367			case eq:
368				lim.umin = uc
369				lim.umax = uc
370			}
371			// We could use the contrapositives of the
372			// signed implications to derive signed facts,
373			// but it turns out not to matter.
374		}
375		ft.limitStack = append(ft.limitStack, limitFact{v.ID, old})
376		lim = old.intersect(lim)
377		ft.limits[v.ID] = lim
378		if v.Block.Func.pass.debug > 2 {
379			v.Block.Func.Warnl(parent.Pos, "parent=%s, new limits %s %s %s %s", parent, v, w, r, lim.String())
380		}
381		if lim.min > lim.max || lim.umin > lim.umax {
382			ft.unsat = true
383			return
384		}
385	}
386
387	// Derived facts below here are only about numbers.
388	if d != signed && d != unsigned {
389		return
390	}
391
392	// Additional facts we know given the relationship between len and cap.
393	//
394	// TODO: Since prove now derives transitive relations, it
395	// should be sufficient to learn that len(w) <= cap(w) at the
396	// beginning of prove where we look for all len/cap ops.
397	if v.Op == OpSliceLen && r&lt == 0 && ft.caps[v.Args[0].ID] != nil {
398		// len(s) > w implies cap(s) > w
399		// len(s) >= w implies cap(s) >= w
400		// len(s) == w implies cap(s) >= w
401		ft.update(parent, ft.caps[v.Args[0].ID], w, d, r|gt)
402	}
403	if w.Op == OpSliceLen && r&gt == 0 && ft.caps[w.Args[0].ID] != nil {
404		// same, length on the RHS.
405		ft.update(parent, v, ft.caps[w.Args[0].ID], d, r|lt)
406	}
407	if v.Op == OpSliceCap && r&gt == 0 && ft.lens[v.Args[0].ID] != nil {
408		// cap(s) < w implies len(s) < w
409		// cap(s) <= w implies len(s) <= w
410		// cap(s) == w implies len(s) <= w
411		ft.update(parent, ft.lens[v.Args[0].ID], w, d, r|lt)
412	}
413	if w.Op == OpSliceCap && r&lt == 0 && ft.lens[w.Args[0].ID] != nil {
414		// same, capacity on the RHS.
415		ft.update(parent, v, ft.lens[w.Args[0].ID], d, r|gt)
416	}
417
418	// Process fence-post implications.
419	//
420	// First, make the condition > or >=.
421	if r == lt || r == lt|eq {
422		v, w = w, v
423		r = reverseBits[r]
424	}
425	switch r {
426	case gt:
427		if x, delta := isConstDelta(v); x != nil && delta == 1 {
428			// x+1 > w  ⇒  x >= w
429			//
430			// This is useful for eliminating the
431			// growslice branch of append.
432			ft.update(parent, x, w, d, gt|eq)
433		} else if x, delta := isConstDelta(w); x != nil && delta == -1 {
434			// v > x-1  ⇒  v >= x
435			ft.update(parent, v, x, d, gt|eq)
436		}
437	case gt | eq:
438		if x, delta := isConstDelta(v); x != nil && delta == -1 {
439			// x-1 >= w && x > min  ⇒  x > w
440			//
441			// Useful for i > 0; s[i-1].
442			lim, ok := ft.limits[x.ID]
443			if ok && ((d == signed && lim.min > opMin[v.Op]) || (d == unsigned && lim.umin > 0)) {
444				ft.update(parent, x, w, d, gt)
445			}
446		} else if x, delta := isConstDelta(w); x != nil && delta == 1 {
447			// v >= x+1 && x < max  ⇒  v > x
448			lim, ok := ft.limits[x.ID]
449			if ok && ((d == signed && lim.max < opMax[w.Op]) || (d == unsigned && lim.umax < opUMax[w.Op])) {
450				ft.update(parent, v, x, d, gt)
451			}
452		}
453	}
454
455	// Process: x+delta > w (with delta constant)
456	// Only signed domain for now (useful for accesses to slices in loops).
457	if r == gt || r == gt|eq {
458		if x, delta := isConstDelta(v); x != nil && d == signed {
459			if parent.Func.pass.debug > 1 {
460				parent.Func.Warnl(parent.Pos, "x+d %s w; x:%v %v delta:%v w:%v d:%v", r, x, parent.String(), delta, w.AuxInt, d)
461			}
462			if !w.isGenericIntConst() {
463				// If we know that x+delta > w but w is not constant, we can derive:
464				//    if delta < 0 and x > MinInt - delta, then x > w (because x+delta cannot underflow)
465				// This is useful for loops with bounds "len(slice)-K" (delta = -K)
466				if l, has := ft.limits[x.ID]; has && delta < 0 {
467					if (x.Type.Size() == 8 && l.min >= math.MinInt64-delta) ||
468						(x.Type.Size() == 4 && l.min >= math.MinInt32-delta) {
469						ft.update(parent, x, w, signed, r)
470					}
471				}
472			} else {
473				// With w,delta constants, we want to derive: x+delta > w  ⇒  x > w-delta
474				//
475				// We compute (using integers of the correct size):
476				//    min = w - delta
477				//    max = MaxInt - delta
478				//
479				// And we prove that:
480				//    if min<max: min < x AND x <= max
481				//    if min>max: min < x OR  x <= max
482				//
483				// This is always correct, even in case of overflow.
484				//
485				// If the initial fact is x+delta >= w instead, the derived conditions are:
486				//    if min<max: min <= x AND x <= max
487				//    if min>max: min <= x OR  x <= max
488				//
489				// Notice the conditions for max are still <=, as they handle overflows.
490				var min, max int64
491				var vmin, vmax *Value
492				switch x.Type.Size() {
493				case 8:
494					min = w.AuxInt - delta
495					max = int64(^uint64(0)>>1) - delta
496
497					vmin = parent.NewValue0I(parent.Pos, OpConst64, parent.Func.Config.Types.Int64, min)
498					vmax = parent.NewValue0I(parent.Pos, OpConst64, parent.Func.Config.Types.Int64, max)
499
500				case 4:
501					min = int64(int32(w.AuxInt) - int32(delta))
502					max = int64(int32(^uint32(0)>>1) - int32(delta))
503
504					vmin = parent.NewValue0I(parent.Pos, OpConst32, parent.Func.Config.Types.Int32, min)
505					vmax = parent.NewValue0I(parent.Pos, OpConst32, parent.Func.Config.Types.Int32, max)
506
507				default:
508					panic("unimplemented")
509				}
510
511				if min < max {
512					// Record that x > min and max >= x
513					ft.update(parent, x, vmin, d, r)
514					ft.update(parent, vmax, x, d, r|eq)
515				} else {
516					// We know that either x>min OR x<=max. factsTable cannot record OR conditions,
517					// so let's see if we can already prove that one of them is false, in which case
518					// the other must be true
519					if l, has := ft.limits[x.ID]; has {
520						if l.max <= min {
521							if r&eq == 0 || l.max < min {
522								// x>min (x>=min) is impossible, so it must be x<=max
523								ft.update(parent, vmax, x, d, r|eq)
524							}
525						} else if l.min > max {
526							// x<=max is impossible, so it must be x>min
527							ft.update(parent, x, vmin, d, r)
528						}
529					}
530				}
531			}
532		}
533	}
534
535	// Look through value-preserving extensions.
536	// If the domain is appropriate for the pre-extension Type,
537	// repeat the update with the pre-extension Value.
538	if isCleanExt(v) {
539		switch {
540		case d == signed && v.Args[0].Type.IsSigned():
541			fallthrough
542		case d == unsigned && !v.Args[0].Type.IsSigned():
543			ft.update(parent, v.Args[0], w, d, r)
544		}
545	}
546	if isCleanExt(w) {
547		switch {
548		case d == signed && w.Args[0].Type.IsSigned():
549			fallthrough
550		case d == unsigned && !w.Args[0].Type.IsSigned():
551			ft.update(parent, v, w.Args[0], d, r)
552		}
553	}
554}
555
556var opMin = map[Op]int64{
557	OpAdd64: math.MinInt64, OpSub64: math.MinInt64,
558	OpAdd32: math.MinInt32, OpSub32: math.MinInt32,
559}
560
561var opMax = map[Op]int64{
562	OpAdd64: math.MaxInt64, OpSub64: math.MaxInt64,
563	OpAdd32: math.MaxInt32, OpSub32: math.MaxInt32,
564}
565
566var opUMax = map[Op]uint64{
567	OpAdd64: math.MaxUint64, OpSub64: math.MaxUint64,
568	OpAdd32: math.MaxUint32, OpSub32: math.MaxUint32,
569}
570
571// isNonNegative reports whether v is known to be non-negative.
572func (ft *factsTable) isNonNegative(v *Value) bool {
573	if isNonNegative(v) {
574		return true
575	}
576
577	var max int64
578	switch v.Type.Size() {
579	case 1:
580		max = math.MaxInt8
581	case 2:
582		max = math.MaxInt16
583	case 4:
584		max = math.MaxInt32
585	case 8:
586		max = math.MaxInt64
587	default:
588		panic("unexpected integer size")
589	}
590
591	// Check if the recorded limits can prove that the value is positive
592
593	if l, has := ft.limits[v.ID]; has && (l.min >= 0 || l.umax <= uint64(max)) {
594		return true
595	}
596
597	// Check if v = x+delta, and we can use x's limits to prove that it's positive
598	if x, delta := isConstDelta(v); x != nil {
599		if l, has := ft.limits[x.ID]; has {
600			if delta > 0 && l.min >= -delta && l.max <= max-delta {
601				return true
602			}
603			if delta < 0 && l.min >= -delta {
604				return true
605			}
606		}
607	}
608
609	// Check if v is a value-preserving extension of a non-negative value.
610	if isCleanExt(v) && ft.isNonNegative(v.Args[0]) {
611		return true
612	}
613
614	// Check if the signed poset can prove that the value is >= 0
615	return ft.orderS.OrderedOrEqual(ft.zero, v)
616}
617
618// checkpoint saves the current state of known relations.
619// Called when descending on a branch.
620func (ft *factsTable) checkpoint() {
621	if ft.unsat {
622		ft.unsatDepth++
623	}
624	ft.stack = append(ft.stack, checkpointFact)
625	ft.limitStack = append(ft.limitStack, checkpointBound)
626	ft.orderS.Checkpoint()
627	ft.orderU.Checkpoint()
628}
629
630// restore restores known relation to the state just
631// before the previous checkpoint.
632// Called when backing up on a branch.
633func (ft *factsTable) restore() {
634	if ft.unsatDepth > 0 {
635		ft.unsatDepth--
636	} else {
637		ft.unsat = false
638	}
639	for {
640		old := ft.stack[len(ft.stack)-1]
641		ft.stack = ft.stack[:len(ft.stack)-1]
642		if old == checkpointFact {
643			break
644		}
645		if old.r == lt|eq|gt {
646			delete(ft.facts, old.p)
647		} else {
648			ft.facts[old.p] = old.r
649		}
650	}
651	for {
652		old := ft.limitStack[len(ft.limitStack)-1]
653		ft.limitStack = ft.limitStack[:len(ft.limitStack)-1]
654		if old.vid == 0 { // checkpointBound
655			break
656		}
657		if old.limit == noLimit {
658			delete(ft.limits, old.vid)
659		} else {
660			ft.limits[old.vid] = old.limit
661		}
662	}
663	ft.orderS.Undo()
664	ft.orderU.Undo()
665}
666
667func lessByID(v, w *Value) bool {
668	if v == nil && w == nil {
669		// Should not happen, but just in case.
670		return false
671	}
672	if v == nil {
673		return true
674	}
675	return w != nil && v.ID < w.ID
676}
677
678var (
679	reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7}
680
681	// maps what we learn when the positive branch is taken.
682	// For example:
683	//      OpLess8:   {signed, lt},
684	//	v1 = (OpLess8 v2 v3).
685	// If v1 branch is taken then we learn that the rangeMask
686	// can be at most lt.
687	domainRelationTable = map[Op]struct {
688		d domain
689		r relation
690	}{
691		OpEq8:   {signed | unsigned, eq},
692		OpEq16:  {signed | unsigned, eq},
693		OpEq32:  {signed | unsigned, eq},
694		OpEq64:  {signed | unsigned, eq},
695		OpEqPtr: {pointer, eq},
696
697		OpNeq8:   {signed | unsigned, lt | gt},
698		OpNeq16:  {signed | unsigned, lt | gt},
699		OpNeq32:  {signed | unsigned, lt | gt},
700		OpNeq64:  {signed | unsigned, lt | gt},
701		OpNeqPtr: {pointer, lt | gt},
702
703		OpLess8:   {signed, lt},
704		OpLess8U:  {unsigned, lt},
705		OpLess16:  {signed, lt},
706		OpLess16U: {unsigned, lt},
707		OpLess32:  {signed, lt},
708		OpLess32U: {unsigned, lt},
709		OpLess64:  {signed, lt},
710		OpLess64U: {unsigned, lt},
711
712		OpLeq8:   {signed, lt | eq},
713		OpLeq8U:  {unsigned, lt | eq},
714		OpLeq16:  {signed, lt | eq},
715		OpLeq16U: {unsigned, lt | eq},
716		OpLeq32:  {signed, lt | eq},
717		OpLeq32U: {unsigned, lt | eq},
718		OpLeq64:  {signed, lt | eq},
719		OpLeq64U: {unsigned, lt | eq},
720
721		OpGeq8:   {signed, eq | gt},
722		OpGeq8U:  {unsigned, eq | gt},
723		OpGeq16:  {signed, eq | gt},
724		OpGeq16U: {unsigned, eq | gt},
725		OpGeq32:  {signed, eq | gt},
726		OpGeq32U: {unsigned, eq | gt},
727		OpGeq64:  {signed, eq | gt},
728		OpGeq64U: {unsigned, eq | gt},
729
730		OpGreater8:   {signed, gt},
731		OpGreater8U:  {unsigned, gt},
732		OpGreater16:  {signed, gt},
733		OpGreater16U: {unsigned, gt},
734		OpGreater32:  {signed, gt},
735		OpGreater32U: {unsigned, gt},
736		OpGreater64:  {signed, gt},
737		OpGreater64U: {unsigned, gt},
738
739		// For these ops, the negative branch is different: we can only
740		// prove signed/GE (signed/GT) if we can prove that arg0 is non-negative.
741		// See the special case in addBranchRestrictions.
742		OpIsInBounds:      {signed | unsigned, lt},      // 0 <= arg0 < arg1
743		OpIsSliceInBounds: {signed | unsigned, lt | eq}, // 0 <= arg0 <= arg1
744	}
745)
746
747// prove removes redundant BlockIf branches that can be inferred
748// from previous dominating comparisons.
749//
750// By far, the most common redundant pair are generated by bounds checking.
751// For example for the code:
752//
753//    a[i] = 4
754//    foo(a[i])
755//
756// The compiler will generate the following code:
757//
758//    if i >= len(a) {
759//        panic("not in bounds")
760//    }
761//    a[i] = 4
762//    if i >= len(a) {
763//        panic("not in bounds")
764//    }
765//    foo(a[i])
766//
767// The second comparison i >= len(a) is clearly redundant because if the
768// else branch of the first comparison is executed, we already know that i < len(a).
769// The code for the second panic can be removed.
770//
771// prove works by finding contradictions and trimming branches whose
772// conditions are unsatisfiable given the branches leading up to them.
773// It tracks a "fact table" of branch conditions. For each branching
774// block, it asserts the branch conditions that uniquely dominate that
775// block, and then separately asserts the block's branch condition and
776// its negation. If either leads to a contradiction, it can trim that
777// successor.
778func prove(f *Func) {
779	ft := newFactsTable(f)
780	ft.checkpoint()
781
782	var lensVars map[*Block][]*Value
783
784	// Find length and capacity ops.
785	for _, b := range f.Blocks {
786		for _, v := range b.Values {
787			if v.Uses == 0 {
788				// We don't care about dead values.
789				// (There can be some that are CSEd but not removed yet.)
790				continue
791			}
792			switch v.Op {
793			case OpStringLen:
794				ft.update(b, v, ft.zero, signed, gt|eq)
795			case OpSliceLen:
796				if ft.lens == nil {
797					ft.lens = map[ID]*Value{}
798				}
799				ft.lens[v.Args[0].ID] = v
800				ft.update(b, v, ft.zero, signed, gt|eq)
801				if v.Args[0].Op == OpSliceMake {
802					if lensVars == nil {
803						lensVars = make(map[*Block][]*Value)
804					}
805					lensVars[b] = append(lensVars[b], v)
806				}
807			case OpSliceCap:
808				if ft.caps == nil {
809					ft.caps = map[ID]*Value{}
810				}
811				ft.caps[v.Args[0].ID] = v
812				ft.update(b, v, ft.zero, signed, gt|eq)
813				if v.Args[0].Op == OpSliceMake {
814					if lensVars == nil {
815						lensVars = make(map[*Block][]*Value)
816					}
817					lensVars[b] = append(lensVars[b], v)
818				}
819			}
820		}
821	}
822
823	// Find induction variables. Currently, findIndVars
824	// is limited to one induction variable per block.
825	var indVars map[*Block]indVar
826	for _, v := range findIndVar(f) {
827		if indVars == nil {
828			indVars = make(map[*Block]indVar)
829		}
830		indVars[v.entry] = v
831	}
832
833	// current node state
834	type walkState int
835	const (
836		descend walkState = iota
837		simplify
838	)
839	// work maintains the DFS stack.
840	type bp struct {
841		block *Block    // current handled block
842		state walkState // what's to do
843	}
844	work := make([]bp, 0, 256)
845	work = append(work, bp{
846		block: f.Entry,
847		state: descend,
848	})
849
850	idom := f.Idom()
851	sdom := f.Sdom()
852
853	// DFS on the dominator tree.
854	//
855	// For efficiency, we consider only the dominator tree rather
856	// than the entire flow graph. On the way down, we consider
857	// incoming branches and accumulate conditions that uniquely
858	// dominate the current block. If we discover a contradiction,
859	// we can eliminate the entire block and all of its children.
860	// On the way back up, we consider outgoing branches that
861	// haven't already been considered. This way we consider each
862	// branch condition only once.
863	for len(work) > 0 {
864		node := work[len(work)-1]
865		work = work[:len(work)-1]
866		parent := idom[node.block.ID]
867		branch := getBranch(sdom, parent, node.block)
868
869		switch node.state {
870		case descend:
871			ft.checkpoint()
872
873			// Entering the block, add the block-depending facts that we collected
874			// at the beginning: induction variables and lens/caps of slices.
875			if iv, ok := indVars[node.block]; ok {
876				addIndVarRestrictions(ft, parent, iv)
877			}
878			if lens, ok := lensVars[node.block]; ok {
879				for _, v := range lens {
880					switch v.Op {
881					case OpSliceLen:
882						ft.update(node.block, v, v.Args[0].Args[1], signed, eq)
883					case OpSliceCap:
884						ft.update(node.block, v, v.Args[0].Args[2], signed, eq)
885					}
886				}
887			}
888
889			if branch != unknown {
890				addBranchRestrictions(ft, parent, branch)
891				if ft.unsat {
892					// node.block is unreachable.
893					// Remove it and don't visit
894					// its children.
895					removeBranch(parent, branch)
896					ft.restore()
897					break
898				}
899				// Otherwise, we can now commit to
900				// taking this branch. We'll restore
901				// ft when we unwind.
902			}
903
904			// Add inductive facts for phis in this block.
905			addLocalInductiveFacts(ft, node.block)
906
907			work = append(work, bp{
908				block: node.block,
909				state: simplify,
910			})
911			for s := sdom.Child(node.block); s != nil; s = sdom.Sibling(s) {
912				work = append(work, bp{
913					block: s,
914					state: descend,
915				})
916			}
917
918		case simplify:
919			simplifyBlock(sdom, ft, node.block)
920			ft.restore()
921		}
922	}
923
924	ft.restore()
925
926	// Return the posets to the free list
927	for _, po := range []*poset{ft.orderS, ft.orderU} {
928		// Make sure it's empty as it should be. A non-empty poset
929		// might cause errors and miscompilations if reused.
930		if checkEnabled {
931			if err := po.CheckEmpty(); err != nil {
932				f.Fatalf("prove poset not empty after function %s: %v", f.Name, err)
933			}
934		}
935		f.retPoset(po)
936	}
937}
938
939// getBranch returns the range restrictions added by p
940// when reaching b. p is the immediate dominator of b.
941func getBranch(sdom SparseTree, p *Block, b *Block) branch {
942	if p == nil || p.Kind != BlockIf {
943		return unknown
944	}
945	// If p and p.Succs[0] are dominators it means that every path
946	// from entry to b passes through p and p.Succs[0]. We care that
947	// no path from entry to b passes through p.Succs[1]. If p.Succs[0]
948	// has one predecessor then (apart from the degenerate case),
949	// there is no path from entry that can reach b through p.Succs[1].
950	// TODO: how about p->yes->b->yes, i.e. a loop in yes.
951	if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 {
952		return positive
953	}
954	if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 {
955		return negative
956	}
957	return unknown
958}
959
960// addIndVarRestrictions updates the factsTables ft with the facts
961// learned from the induction variable indVar which drives the loop
962// starting in Block b.
963func addIndVarRestrictions(ft *factsTable, b *Block, iv indVar) {
964	d := signed
965	if ft.isNonNegative(iv.min) && ft.isNonNegative(iv.max) {
966		d |= unsigned
967	}
968
969	if iv.flags&indVarMinExc == 0 {
970		addRestrictions(b, ft, d, iv.min, iv.ind, lt|eq)
971	} else {
972		addRestrictions(b, ft, d, iv.min, iv.ind, lt)
973	}
974
975	if iv.flags&indVarMaxInc == 0 {
976		addRestrictions(b, ft, d, iv.ind, iv.max, lt)
977	} else {
978		addRestrictions(b, ft, d, iv.ind, iv.max, lt|eq)
979	}
980}
981
982// addBranchRestrictions updates the factsTables ft with the facts learned when
983// branching from Block b in direction br.
984func addBranchRestrictions(ft *factsTable, b *Block, br branch) {
985	c := b.Controls[0]
986	switch br {
987	case negative:
988		addRestrictions(b, ft, boolean, nil, c, eq)
989	case positive:
990		addRestrictions(b, ft, boolean, nil, c, lt|gt)
991	default:
992		panic("unknown branch")
993	}
994	if tr, has := domainRelationTable[c.Op]; has {
995		// When we branched from parent we learned a new set of
996		// restrictions. Update the factsTable accordingly.
997		d := tr.d
998		if d == signed && ft.isNonNegative(c.Args[0]) && ft.isNonNegative(c.Args[1]) {
999			d |= unsigned
1000		}
1001		switch c.Op {
1002		case OpIsInBounds, OpIsSliceInBounds:
1003			// 0 <= a0 < a1 (or 0 <= a0 <= a1)
1004			//
1005			// On the positive branch, we learn:
1006			//   signed: 0 <= a0 < a1 (or 0 <= a0 <= a1)
1007			//   unsigned:    a0 < a1 (or a0 <= a1)
1008			//
1009			// On the negative branch, we learn (0 > a0 ||
1010			// a0 >= a1). In the unsigned domain, this is
1011			// simply a0 >= a1 (which is the reverse of the
1012			// positive branch, so nothing surprising).
1013			// But in the signed domain, we can't express the ||
1014			// condition, so check if a0 is non-negative instead,
1015			// to be able to learn something.
1016			switch br {
1017			case negative:
1018				d = unsigned
1019				if ft.isNonNegative(c.Args[0]) {
1020					d |= signed
1021				}
1022				addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq))
1023			case positive:
1024				addRestrictions(b, ft, signed, ft.zero, c.Args[0], lt|eq)
1025				addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
1026			}
1027		default:
1028			switch br {
1029			case negative:
1030				addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq))
1031			case positive:
1032				addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r)
1033			}
1034		}
1035
1036	}
1037}
1038
1039// addRestrictions updates restrictions from the immediate
1040// dominating block (p) using r.
1041func addRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r relation) {
1042	if t == 0 {
1043		// Trivial case: nothing to do.
1044		// Shoult not happen, but just in case.
1045		return
1046	}
1047	for i := domain(1); i <= t; i <<= 1 {
1048		if t&i == 0 {
1049			continue
1050		}
1051		ft.update(parent, v, w, i, r)
1052	}
1053}
1054
1055// addLocalInductiveFacts adds inductive facts when visiting b, where
1056// b is a join point in a loop. In contrast with findIndVar, this
1057// depends on facts established for b, which is why it happens when
1058// visiting b. addLocalInductiveFacts specifically targets the pattern
1059// created by OFORUNTIL, which isn't detected by findIndVar.
1060//
1061// TODO: It would be nice to combine this with findIndVar.
1062func addLocalInductiveFacts(ft *factsTable, b *Block) {
1063	// This looks for a specific pattern of induction:
1064	//
1065	// 1. i1 = OpPhi(min, i2) in b
1066	// 2. i2 = i1 + 1
1067	// 3. i2 < max at exit from b.Preds[1]
1068	// 4. min < max
1069	//
1070	// If all of these conditions are true, then i1 < max and i1 >= min.
1071
1072	for _, i1 := range b.Values {
1073		if i1.Op != OpPhi {
1074			continue
1075		}
1076
1077		// Check for conditions 1 and 2. This is easy to do
1078		// and will throw out most phis.
1079		min, i2 := i1.Args[0], i1.Args[1]
1080		if i1q, delta := isConstDelta(i2); i1q != i1 || delta != 1 {
1081			continue
1082		}
1083
1084		// Try to prove condition 3. We can't just query the
1085		// fact table for this because we don't know what the
1086		// facts of b.Preds[1] are (in general, b.Preds[1] is
1087		// a loop-back edge, so we haven't even been there
1088		// yet). As a conservative approximation, we look for
1089		// this condition in the predecessor chain until we
1090		// hit a join point.
1091		uniquePred := func(b *Block) *Block {
1092			if len(b.Preds) == 1 {
1093				return b.Preds[0].b
1094			}
1095			return nil
1096		}
1097		pred, child := b.Preds[1].b, b
1098		for ; pred != nil; pred = uniquePred(pred) {
1099			if pred.Kind != BlockIf {
1100				continue
1101			}
1102			control := pred.Controls[0]
1103
1104			br := unknown
1105			if pred.Succs[0].b == child {
1106				br = positive
1107			}
1108			if pred.Succs[1].b == child {
1109				if br != unknown {
1110					continue
1111				}
1112				br = negative
1113			}
1114
1115			tr, has := domainRelationTable[control.Op]
1116			if !has {
1117				continue
1118			}
1119			r := tr.r
1120			if br == negative {
1121				// Negative branch taken to reach b.
1122				// Complement the relations.
1123				r = (lt | eq | gt) ^ r
1124			}
1125
1126			// Check for i2 < max or max > i2.
1127			var max *Value
1128			if r == lt && control.Args[0] == i2 {
1129				max = control.Args[1]
1130			} else if r == gt && control.Args[1] == i2 {
1131				max = control.Args[0]
1132			} else {
1133				continue
1134			}
1135
1136			// Check condition 4 now that we have a
1137			// candidate max. For this we can query the
1138			// fact table. We "prove" min < max by showing
1139			// that min >= max is unsat. (This may simply
1140			// compare two constants; that's fine.)
1141			ft.checkpoint()
1142			ft.update(b, min, max, tr.d, gt|eq)
1143			proved := ft.unsat
1144			ft.restore()
1145
1146			if proved {
1147				// We know that min <= i1 < max.
1148				if b.Func.pass.debug > 0 {
1149					printIndVar(b, i1, min, max, 1, 0)
1150				}
1151				ft.update(b, min, i1, tr.d, lt|eq)
1152				ft.update(b, i1, max, tr.d, lt)
1153			}
1154		}
1155	}
1156}
1157
1158var ctzNonZeroOp = map[Op]Op{OpCtz8: OpCtz8NonZero, OpCtz16: OpCtz16NonZero, OpCtz32: OpCtz32NonZero, OpCtz64: OpCtz64NonZero}
1159var mostNegativeDividend = map[Op]int64{
1160	OpDiv16: -1 << 15,
1161	OpMod16: -1 << 15,
1162	OpDiv32: -1 << 31,
1163	OpMod32: -1 << 31,
1164	OpDiv64: -1 << 63,
1165	OpMod64: -1 << 63}
1166
1167// simplifyBlock simplifies some constant values in b and evaluates
1168// branches to non-uniquely dominated successors of b.
1169func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) {
1170	for _, v := range b.Values {
1171		switch v.Op {
1172		case OpSlicemask:
1173			// Replace OpSlicemask operations in b with constants where possible.
1174			x, delta := isConstDelta(v.Args[0])
1175			if x == nil {
1176				continue
1177			}
1178			// slicemask(x + y)
1179			// if x is larger than -y (y is negative), then slicemask is -1.
1180			lim, ok := ft.limits[x.ID]
1181			if !ok {
1182				continue
1183			}
1184			if lim.umin > uint64(-delta) {
1185				if v.Args[0].Op == OpAdd64 {
1186					v.reset(OpConst64)
1187				} else {
1188					v.reset(OpConst32)
1189				}
1190				if b.Func.pass.debug > 0 {
1191					b.Func.Warnl(v.Pos, "Proved slicemask not needed")
1192				}
1193				v.AuxInt = -1
1194			}
1195		case OpCtz8, OpCtz16, OpCtz32, OpCtz64:
1196			// On some architectures, notably amd64, we can generate much better
1197			// code for CtzNN if we know that the argument is non-zero.
1198			// Capture that information here for use in arch-specific optimizations.
1199			x := v.Args[0]
1200			lim, ok := ft.limits[x.ID]
1201			if !ok {
1202				continue
1203			}
1204			if lim.umin > 0 || lim.min > 0 || lim.max < 0 {
1205				if b.Func.pass.debug > 0 {
1206					b.Func.Warnl(v.Pos, "Proved %v non-zero", v.Op)
1207				}
1208				v.Op = ctzNonZeroOp[v.Op]
1209			}
1210
1211		case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64,
1212			OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64,
1213			OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64,
1214			OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64,
1215			OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64,
1216			OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64,
1217			OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64,
1218			OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64,
1219			OpRsh8Ux8, OpRsh8Ux16, OpRsh8Ux32, OpRsh8Ux64,
1220			OpRsh16Ux8, OpRsh16Ux16, OpRsh16Ux32, OpRsh16Ux64,
1221			OpRsh32Ux8, OpRsh32Ux16, OpRsh32Ux32, OpRsh32Ux64,
1222			OpRsh64Ux8, OpRsh64Ux16, OpRsh64Ux32, OpRsh64Ux64:
1223			// Check whether, for a << b, we know that b
1224			// is strictly less than the number of bits in a.
1225			by := v.Args[1]
1226			lim, ok := ft.limits[by.ID]
1227			if !ok {
1228				continue
1229			}
1230			bits := 8 * v.Args[0].Type.Size()
1231			if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) {
1232				v.AuxInt = 1 // see shiftIsBounded
1233				if b.Func.pass.debug > 0 {
1234					b.Func.Warnl(v.Pos, "Proved %v bounded", v.Op)
1235				}
1236			}
1237		case OpDiv16, OpDiv32, OpDiv64, OpMod16, OpMod32, OpMod64:
1238			// On amd64 and 386 fix-up code can be avoided if we know
1239			//  the divisor is not -1 or the dividend > MinIntNN.
1240			divr := v.Args[1]
1241			divrLim, divrLimok := ft.limits[divr.ID]
1242			divd := v.Args[0]
1243			divdLim, divdLimok := ft.limits[divd.ID]
1244			if (divrLimok && (divrLim.max < -1 || divrLim.min > -1)) ||
1245				(divdLimok && divdLim.min > mostNegativeDividend[v.Op]) {
1246				v.AuxInt = 1 // see NeedsFixUp in genericOps - v.AuxInt = 0 means we have not proved
1247				// that the divisor is not -1 and the dividend is not the most negative,
1248				// so we need to add fix-up code.
1249				if b.Func.pass.debug > 0 {
1250					b.Func.Warnl(v.Pos, "Proved %v does not need fix-up", v.Op)
1251				}
1252			}
1253		}
1254	}
1255
1256	if b.Kind != BlockIf {
1257		return
1258	}
1259
1260	// Consider outgoing edges from this block.
1261	parent := b
1262	for i, branch := range [...]branch{positive, negative} {
1263		child := parent.Succs[i].b
1264		if getBranch(sdom, parent, child) != unknown {
1265			// For edges to uniquely dominated blocks, we
1266			// already did this when we visited the child.
1267			continue
1268		}
1269		// For edges to other blocks, this can trim a branch
1270		// even if we couldn't get rid of the child itself.
1271		ft.checkpoint()
1272		addBranchRestrictions(ft, parent, branch)
1273		unsat := ft.unsat
1274		ft.restore()
1275		if unsat {
1276			// This branch is impossible, so remove it
1277			// from the block.
1278			removeBranch(parent, branch)
1279			// No point in considering the other branch.
1280			// (It *is* possible for both to be
1281			// unsatisfiable since the fact table is
1282			// incomplete. We could turn this into a
1283			// BlockExit, but it doesn't seem worth it.)
1284			break
1285		}
1286	}
1287}
1288
1289func removeBranch(b *Block, branch branch) {
1290	c := b.Controls[0]
1291	if b.Func.pass.debug > 0 {
1292		verb := "Proved"
1293		if branch == positive {
1294			verb = "Disproved"
1295		}
1296		if b.Func.pass.debug > 1 {
1297			b.Func.Warnl(b.Pos, "%s %s (%s)", verb, c.Op, c)
1298		} else {
1299			b.Func.Warnl(b.Pos, "%s %s", verb, c.Op)
1300		}
1301	}
1302	if c != nil && c.Pos.IsStmt() == src.PosIsStmt && c.Pos.SameFileAndLine(b.Pos) {
1303		// attempt to preserve statement marker.
1304		b.Pos = b.Pos.WithIsStmt()
1305	}
1306	b.Kind = BlockFirst
1307	b.ResetControls()
1308	if branch == positive {
1309		b.swapSuccessors()
1310	}
1311}
1312
1313// isNonNegative reports whether v is known to be greater or equal to zero.
1314func isNonNegative(v *Value) bool {
1315	switch v.Op {
1316	case OpConst64:
1317		return v.AuxInt >= 0
1318
1319	case OpConst32:
1320		return int32(v.AuxInt) >= 0
1321
1322	case OpStringLen, OpSliceLen, OpSliceCap,
1323		OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64:
1324		return true
1325
1326	case OpRsh64Ux64:
1327		by := v.Args[1]
1328		return by.Op == OpConst64 && by.AuxInt > 0
1329
1330	case OpRsh64x64:
1331		return isNonNegative(v.Args[0])
1332	}
1333	return false
1334}
1335
1336// isConstDelta returns non-nil if v is equivalent to w+delta (signed).
1337func isConstDelta(v *Value) (w *Value, delta int64) {
1338	cop := OpConst64
1339	switch v.Op {
1340	case OpAdd32, OpSub32:
1341		cop = OpConst32
1342	}
1343	switch v.Op {
1344	case OpAdd64, OpAdd32:
1345		if v.Args[0].Op == cop {
1346			return v.Args[1], v.Args[0].AuxInt
1347		}
1348		if v.Args[1].Op == cop {
1349			return v.Args[0], v.Args[1].AuxInt
1350		}
1351	case OpSub64, OpSub32:
1352		if v.Args[1].Op == cop {
1353			aux := v.Args[1].AuxInt
1354			if aux != -aux { // Overflow; too bad
1355				return v.Args[0], -aux
1356			}
1357		}
1358	}
1359	return nil, 0
1360}
1361
1362// isCleanExt reports whether v is the result of a value-preserving
1363// sign or zero extension
1364func isCleanExt(v *Value) bool {
1365	switch v.Op {
1366	case OpSignExt8to16, OpSignExt8to32, OpSignExt8to64,
1367		OpSignExt16to32, OpSignExt16to64, OpSignExt32to64:
1368		// signed -> signed is the only value-preserving sign extension
1369		return v.Args[0].Type.IsSigned() && v.Type.IsSigned()
1370
1371	case OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64,
1372		OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64:
1373		// unsigned -> signed/unsigned are value-preserving zero extensions
1374		return !v.Args[0].Type.IsSigned()
1375	}
1376	return false
1377}
1378