1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ssa
6
7import (
8	"fmt"
9	"os"
10)
11
12// If true, check poset integrity after every mutation
13var debugPoset = false
14
15const uintSize = 32 << (^uint(0) >> 32 & 1) // 32 or 64
16
17// bitset is a bit array for dense indexes.
18type bitset []uint
19
20func newBitset(n int) bitset {
21	return make(bitset, (n+uintSize-1)/uintSize)
22}
23
24func (bs bitset) Reset() {
25	for i := range bs {
26		bs[i] = 0
27	}
28}
29
30func (bs bitset) Set(idx uint32) {
31	bs[idx/uintSize] |= 1 << (idx % uintSize)
32}
33
34func (bs bitset) Clear(idx uint32) {
35	bs[idx/uintSize] &^= 1 << (idx % uintSize)
36}
37
38func (bs bitset) Test(idx uint32) bool {
39	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
40}
41
42type undoType uint8
43
44const (
45	undoInvalid     undoType = iota
46	undoCheckpoint           // a checkpoint to group undo passes
47	undoSetChl               // change back left child of undo.idx to undo.edge
48	undoSetChr               // change back right child of undo.idx to undo.edge
49	undoNonEqual             // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
50	undoNewNode              // remove new node created for SSA value undo.ID
51	undoNewConstant          // remove the constant node idx from the constants map
52	undoAliasNode            // unalias SSA value undo.ID so that it points back to node index undo.idx
53	undoNewRoot              // remove node undo.idx from root list
54	undoChangeRoot           // remove node undo.idx from root list, and put back undo.edge.Target instead
55	undoMergeRoot            // remove node undo.idx from root list, and put back its children instead
56)
57
58// posetUndo represents an undo pass to be performed.
59// It's an union of fields that can be used to store information,
60// and typ is the discriminant, that specifies which kind
61// of operation must be performed. Not all fields are always used.
62type posetUndo struct {
63	typ  undoType
64	idx  uint32
65	ID   ID
66	edge posetEdge
67}
68
69const (
70	// Make poset handle constants as unsigned numbers.
71	posetFlagUnsigned = 1 << iota
72)
73
74// A poset edge. The zero value is the null/empty edge.
75// Packs target node index (31 bits) and strict flag (1 bit).
76type posetEdge uint32
77
78func newedge(t uint32, strict bool) posetEdge {
79	s := uint32(0)
80	if strict {
81		s = 1
82	}
83	return posetEdge(t<<1 | s)
84}
85func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
86func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
87func (e posetEdge) String() string {
88	s := fmt.Sprint(e.Target())
89	if e.Strict() {
90		s += "*"
91	}
92	return s
93}
94
95// posetNode is a node of a DAG within the poset.
96type posetNode struct {
97	l, r posetEdge
98}
99
100// poset is a union-find data structure that can represent a partially ordered set
101// of SSA values. Given a binary relation that creates a partial order (eg: '<'),
102// clients can record relations between SSA values using SetOrder, and later
103// check relations (in the transitive closure) with Ordered. For instance,
104// if SetOrder is called to record that A<B and B<C, Ordered will later confirm
105// that A<C.
106//
107// It is possible to record equality relations between SSA values with SetEqual and check
108// equality with Equal. Equality propagates into the transitive closure for the partial
109// order so that if we know that A<B<C and later learn that A==D, Ordered will return
110// true for D<C.
111//
112// It is also possible to record inequality relations between nodes with SetNonEqual;
113// non-equality relations are not transitive, but they can still be useful: for instance
114// if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
115// NonEqual can be used to check whether it is known that the nodes are different, either
116// because SetNonEqual was called before, or because we know that they are strictly ordered.
117//
118// poset will refuse to record new relations that contradict existing relations:
119// for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
120// calling SetEqual for C==A will fail.
121//
122// poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
123// from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
124// two SSA values to the same DAG node; when a new equality relation is recorded
125// between two existing nodes,the nodes are merged, adjusting incoming and outgoing edges.
126//
127// Constants are specially treated. When a constant is added to the poset, it is
128// immediately linked to other constants already present; so for instance if the
129// poset knows that x<=3, and then x is tested against 5, 5 is first added and linked
130// 3 (using 3<5), so that the poset knows that x<=3<5; at that point, it is able
131// to answer x<5 correctly. This means that all constants are always within the same
132// DAG; as an implementation detail, we enfoce that the DAG containtining the constants
133// is always the first in the forest.
134//
135// poset is designed to be memory efficient and do little allocations during normal usage.
136// Most internal data structures are pre-allocated and flat, so for instance adding a
137// new relation does not cause any allocation. For performance reasons,
138// each node has only up to two outgoing edges (like a binary tree), so intermediate
139// "dummy" nodes are required to represent more than two relations. For instance,
140// to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
141// following DAG:
142//
143//         A
144//        / \
145//       I  dummy
146//           /  \
147//          J    K
148//
149type poset struct {
150	lastidx   uint32            // last generated dense index
151	flags     uint8             // internal flags
152	values    map[ID]uint32     // map SSA values to dense indexes
153	constants map[int64]uint32  // record SSA constants together with their value
154	nodes     []posetNode       // nodes (in all DAGs)
155	roots     []uint32          // list of root nodes (forest)
156	noneq     map[uint32]bitset // non-equal relations
157	undo      []posetUndo       // undo chain
158}
159
160func newPoset() *poset {
161	return &poset{
162		values:    make(map[ID]uint32),
163		constants: make(map[int64]uint32, 8),
164		nodes:     make([]posetNode, 1, 16),
165		roots:     make([]uint32, 0, 4),
166		noneq:     make(map[uint32]bitset),
167		undo:      make([]posetUndo, 0, 4),
168	}
169}
170
171func (po *poset) SetUnsigned(uns bool) {
172	if uns {
173		po.flags |= posetFlagUnsigned
174	} else {
175		po.flags &^= posetFlagUnsigned
176	}
177}
178
179// Handle children
180func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
181func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
182func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
183func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
184func (po *poset) children(i uint32) (posetEdge, posetEdge) {
185	return po.nodes[i].l, po.nodes[i].r
186}
187
188// upush records a new undo step. It can be used for simple
189// undo passes that record up to one index and one edge.
190func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
191	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
192}
193
194// upushnew pushes an undo pass for a new node
195func (po *poset) upushnew(id ID, idx uint32) {
196	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
197}
198
199// upushneq pushes a new undo pass for a nonequal relation
200func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
201	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
202}
203
204// upushalias pushes a new undo pass for aliasing two nodes
205func (po *poset) upushalias(id ID, i2 uint32) {
206	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
207}
208
209// upushconst pushes a new undo pass for a new constant
210func (po *poset) upushconst(idx uint32, old uint32) {
211	po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
212}
213
214// addchild adds i2 as direct child of i1.
215func (po *poset) addchild(i1, i2 uint32, strict bool) {
216	i1l, i1r := po.children(i1)
217	e2 := newedge(i2, strict)
218
219	if i1l == 0 {
220		po.setchl(i1, e2)
221		po.upush(undoSetChl, i1, 0)
222	} else if i1r == 0 {
223		po.setchr(i1, e2)
224		po.upush(undoSetChr, i1, 0)
225	} else {
226		// If n1 already has two children, add an intermediate dummy
227		// node to record the relation correctly (without relating
228		// n2 to other existing nodes). Use a non-deterministic value
229		// to decide whether to append on the left or the right, to avoid
230		// creating degenerated chains.
231		//
232		//      n1
233		//     /  \
234		//   i1l  dummy
235		//        /   \
236		//      i1r   n2
237		//
238		dummy := po.newnode(nil)
239		if (i1^i2)&1 != 0 { // non-deterministic
240			po.setchl(dummy, i1r)
241			po.setchr(dummy, e2)
242			po.setchr(i1, newedge(dummy, false))
243			po.upush(undoSetChr, i1, i1r)
244		} else {
245			po.setchl(dummy, i1l)
246			po.setchr(dummy, e2)
247			po.setchl(i1, newedge(dummy, false))
248			po.upush(undoSetChl, i1, i1l)
249		}
250	}
251}
252
253// newnode allocates a new node bound to SSA value n.
254// If n is nil, this is a dummy node (= only used internally).
255func (po *poset) newnode(n *Value) uint32 {
256	i := po.lastidx + 1
257	po.lastidx++
258	po.nodes = append(po.nodes, posetNode{})
259	if n != nil {
260		if po.values[n.ID] != 0 {
261			panic("newnode for Value already inserted")
262		}
263		po.values[n.ID] = i
264		po.upushnew(n.ID, i)
265	} else {
266		po.upushnew(0, i)
267	}
268	return i
269}
270
271// lookup searches for a SSA value into the forest of DAGS, and return its node.
272// Constants are materialized on the fly during lookup.
273func (po *poset) lookup(n *Value) (uint32, bool) {
274	i, f := po.values[n.ID]
275	if !f && n.isGenericIntConst() {
276		po.newconst(n)
277		i, f = po.values[n.ID]
278	}
279	return i, f
280}
281
282// newconst creates a node for a constant. It links it to other constants, so
283// that n<=5 is detected true when n<=3 is known to be true.
284// TODO: this is O(N), fix it.
285func (po *poset) newconst(n *Value) {
286	if !n.isGenericIntConst() {
287		panic("newconst on non-constant")
288	}
289
290	// If the same constant is already present in the poset through a different
291	// Value, just alias to it without allocating a new node.
292	val := n.AuxInt
293	if po.flags&posetFlagUnsigned != 0 {
294		val = int64(n.AuxUnsigned())
295	}
296	if c, found := po.constants[val]; found {
297		po.values[n.ID] = c
298		po.upushalias(n.ID, 0)
299		return
300	}
301
302	// Create the new node for this constant
303	i := po.newnode(n)
304
305	// If this is the first constant, put it as a new root, as
306	// we can't record an existing connection so we don't have
307	// a specific DAG to add it to. Notice that we want all
308	// constants to be in root #0, so make sure the new root
309	// goes there.
310	if len(po.constants) == 0 {
311		idx := len(po.roots)
312		po.roots = append(po.roots, i)
313		po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
314		po.upush(undoNewRoot, i, 0)
315		po.constants[val] = i
316		po.upushconst(i, 0)
317		return
318	}
319
320	// Find the lower and upper bound among existing constants. That is,
321	// find the higher constant that is lower than the one that we're adding,
322	// and the lower constant that is higher.
323	// The loop is duplicated to handle signed and unsigned comparison,
324	// depending on how the poset was configured.
325	var lowerptr, higherptr uint32
326
327	if po.flags&posetFlagUnsigned != 0 {
328		var lower, higher uint64
329		val1 := n.AuxUnsigned()
330		for val2, ptr := range po.constants {
331			val2 := uint64(val2)
332			if val1 == val2 {
333				panic("unreachable")
334			}
335			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
336				lower = val2
337				lowerptr = ptr
338			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
339				higher = val2
340				higherptr = ptr
341			}
342		}
343	} else {
344		var lower, higher int64
345		val1 := n.AuxInt
346		for val2, ptr := range po.constants {
347			if val1 == val2 {
348				panic("unreachable")
349			}
350			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
351				lower = val2
352				lowerptr = ptr
353			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
354				higher = val2
355				higherptr = ptr
356			}
357		}
358	}
359
360	if lowerptr == 0 && higherptr == 0 {
361		// This should not happen, as at least one
362		// other constant must exist if we get here.
363		panic("no constant found")
364	}
365
366	// Create the new node and connect it to the bounds, so that
367	// lower < n < higher. We could have found both bounds or only one
368	// of them, depending on what other constants are present in the poset.
369	// Notice that we always link constants together, so they
370	// are always part of the same DAG.
371	switch {
372	case lowerptr != 0 && higherptr != 0:
373		// Both bounds are present, record lower < n < higher.
374		po.addchild(lowerptr, i, true)
375		po.addchild(i, higherptr, true)
376
377	case lowerptr != 0:
378		// Lower bound only, record lower < n.
379		po.addchild(lowerptr, i, true)
380
381	case higherptr != 0:
382		// Higher bound only. To record n < higher, we need
383		// a dummy root:
384		//
385		//        dummy
386		//        /   \
387		//      root   \
388		//       /      n
389		//     ....    /
390		//       \    /
391		//       higher
392		//
393		i2 := higherptr
394		r2 := po.findroot(i2)
395		if r2 != po.roots[0] { // all constants should be in root #0
396			panic("constant not in root #0")
397		}
398		dummy := po.newnode(nil)
399		po.changeroot(r2, dummy)
400		po.upush(undoChangeRoot, dummy, newedge(r2, false))
401		po.addchild(dummy, r2, false)
402		po.addchild(dummy, i, false)
403		po.addchild(i, i2, true)
404	}
405
406	po.constants[val] = i
407	po.upushconst(i, 0)
408}
409
410// aliasnewnode records that a single node n2 (not in the poset yet) is an alias
411// of the master node n1.
412func (po *poset) aliasnewnode(n1, n2 *Value) {
413	i1, i2 := po.values[n1.ID], po.values[n2.ID]
414	if i1 == 0 || i2 != 0 {
415		panic("aliasnewnode invalid arguments")
416	}
417
418	po.values[n2.ID] = i1
419	po.upushalias(n2.ID, 0)
420}
421
422// aliasnodes records that all the nodes i2s are aliases of a single master node n1.
423// aliasnodes takes care of rearranging the DAG, changing references of parent/children
424// of nodes in i2s, so that they point to n1 instead.
425// Complexity is O(n) (with n being the total number of nodes in the poset, not just
426// the number of nodes being aliased).
427func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
428	i1 := po.values[n1.ID]
429	if i1 == 0 {
430		panic("aliasnode for non-existing node")
431	}
432	if i2s.Test(i1) {
433		panic("aliasnode i2s contains n1 node")
434	}
435
436	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
437	for idx, n := range po.nodes {
438		// Do not touch i1 itself, otherwise we can create useless self-loops
439		if uint32(idx) == i1 {
440			continue
441		}
442		l, r := n.l, n.r
443
444		// Rename all references to i2s into i1
445		if i2s.Test(l.Target()) {
446			po.setchl(uint32(idx), newedge(i1, l.Strict()))
447			po.upush(undoSetChl, uint32(idx), l)
448		}
449		if i2s.Test(r.Target()) {
450			po.setchr(uint32(idx), newedge(i1, r.Strict()))
451			po.upush(undoSetChr, uint32(idx), r)
452		}
453
454		// Connect all chidren of i2s to i1 (unless those children
455		// are in i2s as well, in which case it would be useless)
456		if i2s.Test(uint32(idx)) {
457			if l != 0 && !i2s.Test(l.Target()) {
458				po.addchild(i1, l.Target(), l.Strict())
459			}
460			if r != 0 && !i2s.Test(r.Target()) {
461				po.addchild(i1, r.Target(), r.Strict())
462			}
463			po.setchl(uint32(idx), 0)
464			po.setchr(uint32(idx), 0)
465			po.upush(undoSetChl, uint32(idx), l)
466			po.upush(undoSetChr, uint32(idx), r)
467		}
468	}
469
470	// Reassign all existing IDs that point to i2 to i1.
471	// This includes n2.ID.
472	for k, v := range po.values {
473		if i2s.Test(v) {
474			po.values[k] = i1
475			po.upushalias(k, v)
476		}
477	}
478
479	// If one of the aliased nodes is a constant, then make sure
480	// po.constants is updated to point to the master node.
481	for val, idx := range po.constants {
482		if i2s.Test(idx) {
483			po.constants[val] = i1
484			po.upushconst(i1, idx)
485		}
486	}
487}
488
489func (po *poset) isroot(r uint32) bool {
490	for i := range po.roots {
491		if po.roots[i] == r {
492			return true
493		}
494	}
495	return false
496}
497
498func (po *poset) changeroot(oldr, newr uint32) {
499	for i := range po.roots {
500		if po.roots[i] == oldr {
501			po.roots[i] = newr
502			return
503		}
504	}
505	panic("changeroot on non-root")
506}
507
508func (po *poset) removeroot(r uint32) {
509	for i := range po.roots {
510		if po.roots[i] == r {
511			po.roots = append(po.roots[:i], po.roots[i+1:]...)
512			return
513		}
514	}
515	panic("removeroot on non-root")
516}
517
518// dfs performs a depth-first search within the DAG whose root is r.
519// f is the visit function called for each node; if it returns true,
520// the search is aborted and true is returned. The root node is
521// visited too.
522// If strict, ignore edges across a path until at least one
523// strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
524// a strict walk visits D,E,F.
525// If the visit ends, false is returned.
526func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
527	closed := newBitset(int(po.lastidx + 1))
528	open := make([]uint32, 1, 64)
529	open[0] = r
530
531	if strict {
532		// Do a first DFS; walk all paths and stop when we find a strict
533		// edge, building a "next" list of nodes reachable through strict
534		// edges. This will be the bootstrap open list for the real DFS.
535		next := make([]uint32, 0, 64)
536
537		for len(open) > 0 {
538			i := open[len(open)-1]
539			open = open[:len(open)-1]
540
541			// Don't visit the same node twice. Notice that all nodes
542			// across non-strict paths are still visited at least once, so
543			// a non-strict path can never obscure a strict path to the
544			// same node.
545			if !closed.Test(i) {
546				closed.Set(i)
547
548				l, r := po.children(i)
549				if l != 0 {
550					if l.Strict() {
551						next = append(next, l.Target())
552					} else {
553						open = append(open, l.Target())
554					}
555				}
556				if r != 0 {
557					if r.Strict() {
558						next = append(next, r.Target())
559					} else {
560						open = append(open, r.Target())
561					}
562				}
563			}
564		}
565		open = next
566		closed.Reset()
567	}
568
569	for len(open) > 0 {
570		i := open[len(open)-1]
571		open = open[:len(open)-1]
572
573		if !closed.Test(i) {
574			if f(i) {
575				return true
576			}
577			closed.Set(i)
578			l, r := po.children(i)
579			if l != 0 {
580				open = append(open, l.Target())
581			}
582			if r != 0 {
583				open = append(open, r.Target())
584			}
585		}
586	}
587	return false
588}
589
590// Returns true if there is a path from i1 to i2.
591// If strict ==  true: if the function returns true, then i1 <  i2.
592// If strict == false: if the function returns true, then i1 <= i2.
593// If the function returns false, no relation is known.
594func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
595	return po.dfs(i1, strict, func(n uint32) bool {
596		return n == i2
597	})
598}
599
600// findroot finds i's root, that is which DAG contains i.
601// Returns the root; if i is itself a root, it is returned.
602// Panic if i is not in any DAG.
603func (po *poset) findroot(i uint32) uint32 {
604	// TODO(rasky): if needed, a way to speed up this search is
605	// storing a bitset for each root using it as a mini bloom filter
606	// of nodes present under that root.
607	for _, r := range po.roots {
608		if po.reaches(r, i, false) {
609			return r
610		}
611	}
612	panic("findroot didn't find any root")
613}
614
615// mergeroot merges two DAGs into one DAG by creating a new dummy root
616func (po *poset) mergeroot(r1, r2 uint32) uint32 {
617	// Root #0 is special as it contains all constants. Since mergeroot
618	// discards r2 as root and keeps r1, make sure that r2 is not root #0,
619	// otherwise constants would move to a different root.
620	if r2 == po.roots[0] {
621		r1, r2 = r2, r1
622	}
623	r := po.newnode(nil)
624	po.setchl(r, newedge(r1, false))
625	po.setchr(r, newedge(r2, false))
626	po.changeroot(r1, r)
627	po.removeroot(r2)
628	po.upush(undoMergeRoot, r, 0)
629	return r
630}
631
632// collapsepath marks n1 and n2 as equal and collapses as equal all
633// nodes across all paths between n1 and n2. If a strict edge is
634// found, the function does not modify the DAG and returns false.
635// Complexity is O(n).
636func (po *poset) collapsepath(n1, n2 *Value) bool {
637	i1, i2 := po.values[n1.ID], po.values[n2.ID]
638	if po.reaches(i1, i2, true) {
639		return false
640	}
641
642	// Find all the paths from i1 to i2
643	paths := po.findpaths(i1, i2)
644	// Mark all nodes in all the paths as aliases of n1
645	// (excluding n1 itself)
646	paths.Clear(i1)
647	po.aliasnodes(n1, paths)
648	return true
649}
650
651// findpaths is a recursive function that calculates all paths from cur to dst
652// and return them as a bitset (the index of a node is set in the bitset if
653// that node is on at least one path from cur to dst).
654// We do a DFS from cur (stopping going deep any time we reach dst, if ever),
655// and mark as part of the paths any node that has a children which is already
656// part of the path (or is dst itself).
657func (po *poset) findpaths(cur, dst uint32) bitset {
658	seen := newBitset(int(po.lastidx + 1))
659	path := newBitset(int(po.lastidx + 1))
660	path.Set(dst)
661	po.findpaths1(cur, dst, seen, path)
662	return path
663}
664
665func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
666	if cur == dst {
667		return
668	}
669	seen.Set(cur)
670	l, r := po.chl(cur), po.chr(cur)
671	if !seen.Test(l) {
672		po.findpaths1(l, dst, seen, path)
673	}
674	if !seen.Test(r) {
675		po.findpaths1(r, dst, seen, path)
676	}
677	if path.Test(l) || path.Test(r) {
678		path.Set(cur)
679	}
680}
681
682// Check whether it is recorded that i1!=i2
683func (po *poset) isnoneq(i1, i2 uint32) bool {
684	if i1 == i2 {
685		return false
686	}
687	if i1 < i2 {
688		i1, i2 = i2, i1
689	}
690
691	// Check if we recorded a non-equal relation before
692	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
693		return true
694	}
695	return false
696}
697
698// Record that i1!=i2
699func (po *poset) setnoneq(n1, n2 *Value) {
700	i1, f1 := po.lookup(n1)
701	i2, f2 := po.lookup(n2)
702
703	// If any of the nodes do not exist in the poset, allocate them. Since
704	// we don't know any relation (in the partial order) about them, they must
705	// become independent roots.
706	if !f1 {
707		i1 = po.newnode(n1)
708		po.roots = append(po.roots, i1)
709		po.upush(undoNewRoot, i1, 0)
710	}
711	if !f2 {
712		i2 = po.newnode(n2)
713		po.roots = append(po.roots, i2)
714		po.upush(undoNewRoot, i2, 0)
715	}
716
717	if i1 == i2 {
718		panic("setnoneq on same node")
719	}
720	if i1 < i2 {
721		i1, i2 = i2, i1
722	}
723	bs := po.noneq[i1]
724	if bs == nil {
725		// Given that we record non-equality relations using the
726		// higher index as a key, the bitsize will never change size.
727		// TODO(rasky): if memory is a problem, consider allocating
728		// a small bitset and lazily grow it when higher indices arrive.
729		bs = newBitset(int(i1))
730		po.noneq[i1] = bs
731	} else if bs.Test(i2) {
732		// Already recorded
733		return
734	}
735	bs.Set(i2)
736	po.upushneq(i1, i2)
737}
738
739// CheckIntegrity verifies internal integrity of a poset. It is intended
740// for debugging purposes.
741func (po *poset) CheckIntegrity() {
742	// Record which index is a constant
743	constants := newBitset(int(po.lastidx + 1))
744	for _, c := range po.constants {
745		constants.Set(c)
746	}
747
748	// Verify that each node appears in a single DAG, and that
749	// all constants are within the first DAG
750	seen := newBitset(int(po.lastidx + 1))
751	for ridx, r := range po.roots {
752		if r == 0 {
753			panic("empty root")
754		}
755
756		po.dfs(r, false, func(i uint32) bool {
757			if seen.Test(i) {
758				panic("duplicate node")
759			}
760			seen.Set(i)
761			if constants.Test(i) {
762				if ridx != 0 {
763					panic("constants not in the first DAG")
764				}
765			}
766			return false
767		})
768	}
769
770	// Verify that values contain the minimum set
771	for id, idx := range po.values {
772		if !seen.Test(idx) {
773			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
774		}
775	}
776
777	// Verify that only existing nodes have non-zero children
778	for i, n := range po.nodes {
779		if n.l|n.r != 0 {
780			if !seen.Test(uint32(i)) {
781				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
782			}
783			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
784				panic(fmt.Errorf("self-loop on node %d", i))
785			}
786		}
787	}
788}
789
790// CheckEmpty checks that a poset is completely empty.
791// It can be used for debugging purposes, as a poset is supposed to
792// be empty after it's fully rolled back through Undo.
793func (po *poset) CheckEmpty() error {
794	if len(po.nodes) != 1 {
795		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
796	}
797	if len(po.values) != 0 {
798		return fmt.Errorf("non-empty value map: %v", po.values)
799	}
800	if len(po.roots) != 0 {
801		return fmt.Errorf("non-empty root list: %v", po.roots)
802	}
803	if len(po.constants) != 0 {
804		return fmt.Errorf("non-empty constants: %v", po.constants)
805	}
806	if len(po.undo) != 0 {
807		return fmt.Errorf("non-empty undo list: %v", po.undo)
808	}
809	if po.lastidx != 0 {
810		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
811	}
812	for _, bs := range po.noneq {
813		for _, x := range bs {
814			if x != 0 {
815				return fmt.Errorf("non-empty noneq map")
816			}
817		}
818	}
819	return nil
820}
821
822// DotDump dumps the poset in graphviz format to file fn, with the specified title.
823func (po *poset) DotDump(fn string, title string) error {
824	f, err := os.Create(fn)
825	if err != nil {
826		return err
827	}
828	defer f.Close()
829
830	// Create reverse index mapping (taking aliases into account)
831	names := make(map[uint32]string)
832	for id, i := range po.values {
833		s := names[i]
834		if s == "" {
835			s = fmt.Sprintf("v%d", id)
836		} else {
837			s += fmt.Sprintf(", v%d", id)
838		}
839		names[i] = s
840	}
841
842	// Create reverse constant mapping
843	consts := make(map[uint32]int64)
844	for val, idx := range po.constants {
845		consts[idx] = val
846	}
847
848	fmt.Fprintf(f, "digraph poset {\n")
849	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
850	for ridx, r := range po.roots {
851		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
852		po.dfs(r, false, func(i uint32) bool {
853			if val, ok := consts[i]; ok {
854				// Constant
855				var vals string
856				if po.flags&posetFlagUnsigned != 0 {
857					vals = fmt.Sprint(uint64(val))
858				} else {
859					vals = fmt.Sprint(int64(val))
860				}
861				fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
862					i, vals, names[i], i)
863			} else {
864				// Normal SSA value
865				fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
866			}
867			chl, chr := po.children(i)
868			for _, ch := range []posetEdge{chl, chr} {
869				if ch != 0 {
870					if ch.Strict() {
871						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
872					} else {
873						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
874					}
875				}
876			}
877			return false
878		})
879		fmt.Fprintf(f, "\t}\n")
880	}
881	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
882	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
883	fmt.Fprintf(f, "\tlabel=%q\n", title)
884	fmt.Fprintf(f, "}\n")
885	return nil
886}
887
888// Ordered reports whether n1<n2. It returns false either when it is
889// certain that n1<n2 is false, or if there is not enough information
890// to tell.
891// Complexity is O(n).
892func (po *poset) Ordered(n1, n2 *Value) bool {
893	if debugPoset {
894		defer po.CheckIntegrity()
895	}
896	if n1.ID == n2.ID {
897		panic("should not call Ordered with n1==n2")
898	}
899
900	i1, f1 := po.lookup(n1)
901	i2, f2 := po.lookup(n2)
902	if !f1 || !f2 {
903		return false
904	}
905
906	return i1 != i2 && po.reaches(i1, i2, true)
907}
908
909// Ordered reports whether n1<=n2. It returns false either when it is
910// certain that n1<=n2 is false, or if there is not enough information
911// to tell.
912// Complexity is O(n).
913func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
914	if debugPoset {
915		defer po.CheckIntegrity()
916	}
917	if n1.ID == n2.ID {
918		panic("should not call Ordered with n1==n2")
919	}
920
921	i1, f1 := po.lookup(n1)
922	i2, f2 := po.lookup(n2)
923	if !f1 || !f2 {
924		return false
925	}
926
927	return i1 == i2 || po.reaches(i1, i2, false)
928}
929
930// Equal reports whether n1==n2. It returns false either when it is
931// certain that n1==n2 is false, or if there is not enough information
932// to tell.
933// Complexity is O(1).
934func (po *poset) Equal(n1, n2 *Value) bool {
935	if debugPoset {
936		defer po.CheckIntegrity()
937	}
938	if n1.ID == n2.ID {
939		panic("should not call Equal with n1==n2")
940	}
941
942	i1, f1 := po.lookup(n1)
943	i2, f2 := po.lookup(n2)
944	return f1 && f2 && i1 == i2
945}
946
947// NonEqual reports whether n1!=n2. It returns false either when it is
948// certain that n1!=n2 is false, or if there is not enough information
949// to tell.
950// Complexity is O(n) (because it internally calls Ordered to see if we
951// can infer n1!=n2 from n1<n2 or n2<n1).
952func (po *poset) NonEqual(n1, n2 *Value) bool {
953	if debugPoset {
954		defer po.CheckIntegrity()
955	}
956	if n1.ID == n2.ID {
957		panic("should not call NonEqual with n1==n2")
958	}
959
960	// If we never saw the nodes before, we don't
961	// have a recorded non-equality.
962	i1, f1 := po.lookup(n1)
963	i2, f2 := po.lookup(n2)
964	if !f1 || !f2 {
965		return false
966	}
967
968	// Check if we recored inequality
969	if po.isnoneq(i1, i2) {
970		return true
971	}
972
973	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
974	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
975		return true
976	}
977
978	return false
979}
980
981// setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
982// if this is a contradiction.
983// Implements SetOrder() and SetOrderOrEqual()
984func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
985	i1, f1 := po.lookup(n1)
986	i2, f2 := po.lookup(n2)
987
988	switch {
989	case !f1 && !f2:
990		// Neither n1 nor n2 are in the poset, so they are not related
991		// in any way to existing nodes.
992		// Create a new DAG to record the relation.
993		i1, i2 = po.newnode(n1), po.newnode(n2)
994		po.roots = append(po.roots, i1)
995		po.upush(undoNewRoot, i1, 0)
996		po.addchild(i1, i2, strict)
997
998	case f1 && !f2:
999		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
1000		// of n1.
1001		i2 = po.newnode(n2)
1002		po.addchild(i1, i2, strict)
1003
1004	case !f1 && f2:
1005		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
1006		// n1 in its place as a root; otherwise, we need to create a new
1007		// dummy root to record the relation.
1008		i1 = po.newnode(n1)
1009
1010		if po.isroot(i2) {
1011			po.changeroot(i2, i1)
1012			po.upush(undoChangeRoot, i1, newedge(i2, strict))
1013			po.addchild(i1, i2, strict)
1014			return true
1015		}
1016
1017		// Search for i2's root; this requires a O(n) search on all
1018		// DAGs
1019		r := po.findroot(i2)
1020
1021		// Re-parent as follows:
1022		//
1023		//                  dummy
1024		//     r            /   \
1025		//      \   ===>   r    i1
1026		//      i2          \   /
1027		//                    i2
1028		//
1029		dummy := po.newnode(nil)
1030		po.changeroot(r, dummy)
1031		po.upush(undoChangeRoot, dummy, newedge(r, false))
1032		po.addchild(dummy, r, false)
1033		po.addchild(dummy, i1, false)
1034		po.addchild(i1, i2, strict)
1035
1036	case f1 && f2:
1037		// If the nodes are aliased, fail only if we're setting a strict order
1038		// (that is, we cannot set n1<n2 if n1==n2).
1039		if i1 == i2 {
1040			return !strict
1041		}
1042
1043		// If we are trying to record n1<=n2 but we learned that n1!=n2,
1044		// record n1<n2, as it provides more information.
1045		if !strict && po.isnoneq(i1, i2) {
1046			strict = true
1047		}
1048
1049		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
1050		// as we need to find many different cases and DAG shapes.
1051
1052		// Check if n1 somehow reaches n2
1053		if po.reaches(i1, i2, false) {
1054			// This is the table of all cases we need to handle:
1055			//
1056			//      DAG          New      Action
1057			//      ---------------------------------------------------
1058			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
1059			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
1060			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
1061			// #4:  N1<X<N2   |  N1<N2  | do nothing
1062
1063			// Check if we're in case #2
1064			if strict && !po.reaches(i1, i2, true) {
1065				po.addchild(i1, i2, true)
1066				return true
1067			}
1068
1069			// Case #1, #3 o #4: nothing to do
1070			return true
1071		}
1072
1073		// Check if n2 somehow reaches n1
1074		if po.reaches(i2, i1, false) {
1075			// This is the table of all cases we need to handle:
1076			//
1077			//      DAG           New      Action
1078			//      ---------------------------------------------------
1079			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
1080			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
1081			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
1082			// #8:  N2<X<N1    |  N1<N2  | contradiction
1083
1084			if strict {
1085				// Cases #6 and #8: contradiction
1086				return false
1087			}
1088
1089			// We're in case #5 or #7. Try to collapse path, and that will
1090			// fail if it realizes that we are in case #7.
1091			return po.collapsepath(n2, n1)
1092		}
1093
1094		// We don't know of any existing relation between n1 and n2. They could
1095		// be part of the same DAG or not.
1096		// Find their roots to check whether they are in the same DAG.
1097		r1, r2 := po.findroot(i1), po.findroot(i2)
1098		if r1 != r2 {
1099			// We need to merge the two DAGs to record a relation between the nodes
1100			po.mergeroot(r1, r2)
1101		}
1102
1103		// Connect n1 and n2
1104		po.addchild(i1, i2, strict)
1105	}
1106
1107	return true
1108}
1109
1110// SetOrder records that n1<n2. Returns false if this is a contradiction
1111// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1112func (po *poset) SetOrder(n1, n2 *Value) bool {
1113	if debugPoset {
1114		defer po.CheckIntegrity()
1115	}
1116	if n1.ID == n2.ID {
1117		panic("should not call SetOrder with n1==n2")
1118	}
1119	return po.setOrder(n1, n2, true)
1120}
1121
1122// SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
1123// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1124func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
1125	if debugPoset {
1126		defer po.CheckIntegrity()
1127	}
1128	if n1.ID == n2.ID {
1129		panic("should not call SetOrder with n1==n2")
1130	}
1131	return po.setOrder(n1, n2, false)
1132}
1133
1134// SetEqual records that n1==n2. Returns false if this is a contradiction
1135// (that is, if it is already recorded that n1<n2 or n2<n1).
1136// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
1137func (po *poset) SetEqual(n1, n2 *Value) bool {
1138	if debugPoset {
1139		defer po.CheckIntegrity()
1140	}
1141	if n1.ID == n2.ID {
1142		panic("should not call Add with n1==n2")
1143	}
1144
1145	i1, f1 := po.lookup(n1)
1146	i2, f2 := po.lookup(n2)
1147
1148	switch {
1149	case !f1 && !f2:
1150		i1 = po.newnode(n1)
1151		po.roots = append(po.roots, i1)
1152		po.upush(undoNewRoot, i1, 0)
1153		po.aliasnewnode(n1, n2)
1154	case f1 && !f2:
1155		po.aliasnewnode(n1, n2)
1156	case !f1 && f2:
1157		po.aliasnewnode(n2, n1)
1158	case f1 && f2:
1159		if i1 == i2 {
1160			// Already aliased, ignore
1161			return true
1162		}
1163
1164		// If we recorded that n1!=n2, this is a contradiction.
1165		if po.isnoneq(i1, i2) {
1166			return false
1167		}
1168
1169		// If we already knew that n1<=n2, we can collapse the path to
1170		// record n1==n2 (and viceversa).
1171		if po.reaches(i1, i2, false) {
1172			return po.collapsepath(n1, n2)
1173		}
1174		if po.reaches(i2, i1, false) {
1175			return po.collapsepath(n2, n1)
1176		}
1177
1178		r1 := po.findroot(i1)
1179		r2 := po.findroot(i2)
1180		if r1 != r2 {
1181			// Merge the two DAGs so we can record relations between the nodes
1182			po.mergeroot(r1, r2)
1183		}
1184
1185		// Set n2 as alias of n1. This will also update all the references
1186		// to n2 to become references to n1
1187		i2s := newBitset(int(po.lastidx) + 1)
1188		i2s.Set(i2)
1189		po.aliasnodes(n1, i2s)
1190	}
1191	return true
1192}
1193
1194// SetNonEqual records that n1!=n2. Returns false if this is a contradiction
1195// (that is, if it is already recorded that n1==n2).
1196// Complexity is O(n).
1197func (po *poset) SetNonEqual(n1, n2 *Value) bool {
1198	if debugPoset {
1199		defer po.CheckIntegrity()
1200	}
1201	if n1.ID == n2.ID {
1202		panic("should not call SetNonEqual with n1==n2")
1203	}
1204
1205	// Check whether the nodes are already in the poset
1206	i1, f1 := po.lookup(n1)
1207	i2, f2 := po.lookup(n2)
1208
1209	// If either node wasn't present, we just record the new relation
1210	// and exit.
1211	if !f1 || !f2 {
1212		po.setnoneq(n1, n2)
1213		return true
1214	}
1215
1216	// See if we already know this, in which case there's nothing to do.
1217	if po.isnoneq(i1, i2) {
1218		return true
1219	}
1220
1221	// Check if we're contradicting an existing equality relation
1222	if po.Equal(n1, n2) {
1223		return false
1224	}
1225
1226	// Record non-equality
1227	po.setnoneq(n1, n2)
1228
1229	// If we know that i1<=i2 but not i1<i2, learn that as we
1230	// now know that they are not equal. Do the same for i2<=i1.
1231	// Do this check only if both nodes were already in the DAG,
1232	// otherwise there cannot be an existing relation.
1233	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
1234		po.addchild(i1, i2, true)
1235	}
1236	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
1237		po.addchild(i2, i1, true)
1238	}
1239
1240	return true
1241}
1242
1243// Checkpoint saves the current state of the DAG so that it's possible
1244// to later undo this state.
1245// Complexity is O(1).
1246func (po *poset) Checkpoint() {
1247	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
1248}
1249
1250// Undo restores the state of the poset to the previous checkpoint.
1251// Complexity depends on the type of operations that were performed
1252// since the last checkpoint; each Set* operation creates an undo
1253// pass which Undo has to revert with a worst-case complexity of O(n).
1254func (po *poset) Undo() {
1255	if len(po.undo) == 0 {
1256		panic("empty undo stack")
1257	}
1258	if debugPoset {
1259		defer po.CheckIntegrity()
1260	}
1261
1262	for len(po.undo) > 0 {
1263		pass := po.undo[len(po.undo)-1]
1264		po.undo = po.undo[:len(po.undo)-1]
1265
1266		switch pass.typ {
1267		case undoCheckpoint:
1268			return
1269
1270		case undoSetChl:
1271			po.setchl(pass.idx, pass.edge)
1272
1273		case undoSetChr:
1274			po.setchr(pass.idx, pass.edge)
1275
1276		case undoNonEqual:
1277			po.noneq[uint32(pass.ID)].Clear(pass.idx)
1278
1279		case undoNewNode:
1280			if pass.idx != po.lastidx {
1281				panic("invalid newnode index")
1282			}
1283			if pass.ID != 0 {
1284				if po.values[pass.ID] != pass.idx {
1285					panic("invalid newnode undo pass")
1286				}
1287				delete(po.values, pass.ID)
1288			}
1289			po.setchl(pass.idx, 0)
1290			po.setchr(pass.idx, 0)
1291			po.nodes = po.nodes[:pass.idx]
1292			po.lastidx--
1293
1294		case undoNewConstant:
1295			// FIXME: remove this O(n) loop
1296			var val int64
1297			var i uint32
1298			for val, i = range po.constants {
1299				if i == pass.idx {
1300					break
1301				}
1302			}
1303			if i != pass.idx {
1304				panic("constant not found in undo pass")
1305			}
1306			if pass.ID == 0 {
1307				delete(po.constants, val)
1308			} else {
1309				// Restore previous index as constant node
1310				// (also restoring the invariant on correct bounds)
1311				oldidx := uint32(pass.ID)
1312				po.constants[val] = oldidx
1313			}
1314
1315		case undoAliasNode:
1316			ID, prev := pass.ID, pass.idx
1317			cur := po.values[ID]
1318			if prev == 0 {
1319				// Born as an alias, die as an alias
1320				delete(po.values, ID)
1321			} else {
1322				if cur == prev {
1323					panic("invalid aliasnode undo pass")
1324				}
1325				// Give it back previous value
1326				po.values[ID] = prev
1327			}
1328
1329		case undoNewRoot:
1330			i := pass.idx
1331			l, r := po.children(i)
1332			if l|r != 0 {
1333				panic("non-empty root in undo newroot")
1334			}
1335			po.removeroot(i)
1336
1337		case undoChangeRoot:
1338			i := pass.idx
1339			l, r := po.children(i)
1340			if l|r != 0 {
1341				panic("non-empty root in undo changeroot")
1342			}
1343			po.changeroot(i, pass.edge.Target())
1344
1345		case undoMergeRoot:
1346			i := pass.idx
1347			l, r := po.children(i)
1348			po.changeroot(i, l.Target())
1349			po.roots = append(po.roots, r.Target())
1350
1351		default:
1352			panic(pass.typ)
1353		}
1354	}
1355
1356	if debugPoset && po.CheckEmpty() != nil {
1357		panic("poset not empty at the end of undo")
1358	}
1359}
1360