1package gtreap
2
3type Treap struct {
4	compare Compare
5	root    *node
6}
7
8// Compare returns an integer comparing the two items
9// lexicographically. The result will be 0 if a==b, -1 if a < b, and
10// +1 if a > b.
11type Compare func(a, b interface{}) int
12
13// Item can be anything.
14type Item interface{}
15
16type node struct {
17	item     Item
18	priority int
19	left     *node
20	right    *node
21}
22
23func NewTreap(c Compare) *Treap {
24	return &Treap{compare: c, root: nil}
25}
26
27func (t *Treap) Min() Item {
28	n := t.root
29	if n == nil {
30		return nil
31	}
32	for n.left != nil {
33		n = n.left
34	}
35	return n.item
36}
37
38func (t *Treap) Max() Item {
39	n := t.root
40	if n == nil {
41		return nil
42	}
43	for n.right != nil {
44		n = n.right
45	}
46	return n.item
47}
48
49func (t *Treap) Get(target Item) Item {
50	n := t.root
51	for n != nil {
52		c := t.compare(target, n.item)
53		if c < 0 {
54			n = n.left
55		} else if c > 0 {
56			n = n.right
57		} else {
58			return n.item
59		}
60	}
61	return nil
62}
63
64// Note: only the priority of the first insert of an item is used.
65// Priorities from future updates on already existing items are
66// ignored.  To change the priority for an item, you need to do a
67// Delete then an Upsert.
68func (t *Treap) Upsert(item Item, itemPriority int) *Treap {
69	r := t.union(t.root, &node{item: item, priority: itemPriority})
70	return &Treap{compare: t.compare, root: r}
71}
72
73func (t *Treap) union(this *node, that *node) *node {
74	if this == nil {
75		return that
76	}
77	if that == nil {
78		return this
79	}
80	if this.priority > that.priority {
81		left, middle, right := t.split(that, this.item)
82		if middle == nil {
83			return &node{
84				item:     this.item,
85				priority: this.priority,
86				left:     t.union(this.left, left),
87				right:    t.union(this.right, right),
88			}
89		}
90		return &node{
91			item:     middle.item,
92			priority: this.priority,
93			left:     t.union(this.left, left),
94			right:    t.union(this.right, right),
95		}
96	}
97	// We don't use middle because the "that" has precendence.
98	left, _, right := t.split(this, that.item)
99	return &node{
100		item:     that.item,
101		priority: that.priority,
102		left:     t.union(left, that.left),
103		right:    t.union(right, that.right),
104	}
105}
106
107// Splits a treap into two treaps based on a split item "s".
108// The result tuple-3 means (left, X, right), where X is either...
109// nil - meaning the item s was not in the original treap.
110// non-nil - returning the node that had item s.
111// The tuple-3's left result treap has items < s,
112// and the tuple-3's right result treap has items > s.
113func (t *Treap) split(n *node, s Item) (*node, *node, *node) {
114	if n == nil {
115		return nil, nil, nil
116	}
117	c := t.compare(s, n.item)
118	if c == 0 {
119		return n.left, n, n.right
120	}
121	if c < 0 {
122		left, middle, right := t.split(n.left, s)
123		return left, middle, &node{
124			item:     n.item,
125			priority: n.priority,
126			left:     right,
127			right:    n.right,
128		}
129	}
130	left, middle, right := t.split(n.right, s)
131	return &node{
132		item:     n.item,
133		priority: n.priority,
134		left:     n.left,
135		right:    left,
136	}, middle, right
137}
138
139func (t *Treap) Delete(target Item) *Treap {
140	left, _, right := t.split(t.root, target)
141	return &Treap{compare: t.compare, root: t.join(left, right)}
142}
143
144// All the items from this are < items from that.
145func (t *Treap) join(this *node, that *node) *node {
146	if this == nil {
147		return that
148	}
149	if that == nil {
150		return this
151	}
152	if this.priority > that.priority {
153		return &node{
154			item:     this.item,
155			priority: this.priority,
156			left:     this.left,
157			right:    t.join(this.right, that),
158		}
159	}
160	return &node{
161		item:     that.item,
162		priority: that.priority,
163		left:     t.join(this, that.left),
164		right:    that.right,
165	}
166}
167
168type ItemVisitor func(i Item) bool
169
170// Visit items greater-than-or-equal to the pivot.
171func (t *Treap) VisitAscend(pivot Item, visitor ItemVisitor) {
172	t.visitAscend(t.root, pivot, visitor)
173}
174
175func (t *Treap) visitAscend(n *node, pivot Item, visitor ItemVisitor) bool {
176	if n == nil {
177		return true
178	}
179	if t.compare(pivot, n.item) <= 0 {
180		if !t.visitAscend(n.left, pivot, visitor) {
181			return false
182		}
183		if !visitor(n.item) {
184			return false
185		}
186	}
187	return t.visitAscend(n.right, pivot, visitor)
188}
189