1// Copyright (c) 2015, Emir Pasic. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package redblacktree implements a red-black tree.
6//
7// Used by TreeSet and TreeMap.
8//
9// Structure is not thread safe.
10//
11// References: http://en.wikipedia.org/wiki/Red%E2%80%93black_tree
12package redblacktree
13
14import (
15	"fmt"
16	"github.com/emirpasic/gods/trees"
17	"github.com/emirpasic/gods/utils"
18)
19
20func assertTreeImplementation() {
21	var _ trees.Tree = (*Tree)(nil)
22}
23
24type color bool
25
26const (
27	black, red color = true, false
28)
29
30// Tree holds elements of the red-black tree
31type Tree struct {
32	Root       *Node
33	size       int
34	Comparator utils.Comparator
35}
36
37// Node is a single element within the tree
38type Node struct {
39	Key    interface{}
40	Value  interface{}
41	color  color
42	Left   *Node
43	Right  *Node
44	Parent *Node
45}
46
47// NewWith instantiates a red-black tree with the custom comparator.
48func NewWith(comparator utils.Comparator) *Tree {
49	return &Tree{Comparator: comparator}
50}
51
52// NewWithIntComparator instantiates a red-black tree with the IntComparator, i.e. keys are of type int.
53func NewWithIntComparator() *Tree {
54	return &Tree{Comparator: utils.IntComparator}
55}
56
57// NewWithStringComparator instantiates a red-black tree with the StringComparator, i.e. keys are of type string.
58func NewWithStringComparator() *Tree {
59	return &Tree{Comparator: utils.StringComparator}
60}
61
62// Put inserts node into the tree.
63// Key should adhere to the comparator's type assertion, otherwise method panics.
64func (tree *Tree) Put(key interface{}, value interface{}) {
65	var insertedNode *Node
66	if tree.Root == nil {
67		// Assert key is of comparator's type for initial tree
68		tree.Comparator(key, key)
69		tree.Root = &Node{Key: key, Value: value, color: red}
70		insertedNode = tree.Root
71	} else {
72		node := tree.Root
73		loop := true
74		for loop {
75			compare := tree.Comparator(key, node.Key)
76			switch {
77			case compare == 0:
78				node.Key = key
79				node.Value = value
80				return
81			case compare < 0:
82				if node.Left == nil {
83					node.Left = &Node{Key: key, Value: value, color: red}
84					insertedNode = node.Left
85					loop = false
86				} else {
87					node = node.Left
88				}
89			case compare > 0:
90				if node.Right == nil {
91					node.Right = &Node{Key: key, Value: value, color: red}
92					insertedNode = node.Right
93					loop = false
94				} else {
95					node = node.Right
96				}
97			}
98		}
99		insertedNode.Parent = node
100	}
101	tree.insertCase1(insertedNode)
102	tree.size++
103}
104
105// Get searches the node in the tree by key and returns its value or nil if key is not found in tree.
106// Second return parameter is true if key was found, otherwise false.
107// Key should adhere to the comparator's type assertion, otherwise method panics.
108func (tree *Tree) Get(key interface{}) (value interface{}, found bool) {
109	node := tree.lookup(key)
110	if node != nil {
111		return node.Value, true
112	}
113	return nil, false
114}
115
116// Remove remove the node from the tree by key.
117// Key should adhere to the comparator's type assertion, otherwise method panics.
118func (tree *Tree) Remove(key interface{}) {
119	var child *Node
120	node := tree.lookup(key)
121	if node == nil {
122		return
123	}
124	if node.Left != nil && node.Right != nil {
125		pred := node.Left.maximumNode()
126		node.Key = pred.Key
127		node.Value = pred.Value
128		node = pred
129	}
130	if node.Left == nil || node.Right == nil {
131		if node.Right == nil {
132			child = node.Left
133		} else {
134			child = node.Right
135		}
136		if node.color == black {
137			node.color = nodeColor(child)
138			tree.deleteCase1(node)
139		}
140		tree.replaceNode(node, child)
141		if node.Parent == nil && child != nil {
142			child.color = black
143		}
144	}
145	tree.size--
146}
147
148// Empty returns true if tree does not contain any nodes
149func (tree *Tree) Empty() bool {
150	return tree.size == 0
151}
152
153// Size returns number of nodes in the tree.
154func (tree *Tree) Size() int {
155	return tree.size
156}
157
158// Keys returns all keys in-order
159func (tree *Tree) Keys() []interface{} {
160	keys := make([]interface{}, tree.size)
161	it := tree.Iterator()
162	for i := 0; it.Next(); i++ {
163		keys[i] = it.Key()
164	}
165	return keys
166}
167
168// Values returns all values in-order based on the key.
169func (tree *Tree) Values() []interface{} {
170	values := make([]interface{}, tree.size)
171	it := tree.Iterator()
172	for i := 0; it.Next(); i++ {
173		values[i] = it.Value()
174	}
175	return values
176}
177
178// Left returns the left-most (min) node or nil if tree is empty.
179func (tree *Tree) Left() *Node {
180	var parent *Node
181	current := tree.Root
182	for current != nil {
183		parent = current
184		current = current.Left
185	}
186	return parent
187}
188
189// Right returns the right-most (max) node or nil if tree is empty.
190func (tree *Tree) Right() *Node {
191	var parent *Node
192	current := tree.Root
193	for current != nil {
194		parent = current
195		current = current.Right
196	}
197	return parent
198}
199
200// Floor Finds floor node of the input key, return the floor node or nil if no floor is found.
201// Second return parameter is true if floor was found, otherwise false.
202//
203// Floor node is defined as the largest node that is smaller than or equal to the given node.
204// A floor node may not be found, either because the tree is empty, or because
205// all nodes in the tree are larger than the given node.
206//
207// Key should adhere to the comparator's type assertion, otherwise method panics.
208func (tree *Tree) Floor(key interface{}) (floor *Node, found bool) {
209	found = false
210	node := tree.Root
211	for node != nil {
212		compare := tree.Comparator(key, node.Key)
213		switch {
214		case compare == 0:
215			return node, true
216		case compare < 0:
217			node = node.Left
218		case compare > 0:
219			floor, found = node, true
220			node = node.Right
221		}
222	}
223	if found {
224		return floor, true
225	}
226	return nil, false
227}
228
229// Ceiling finds ceiling node of the input key, return the ceiling node or nil if no ceiling is found.
230// Second return parameter is true if ceiling was found, otherwise false.
231//
232// Ceiling node is defined as the smallest node that is larger than or equal to the given node.
233// A ceiling node may not be found, either because the tree is empty, or because
234// all nodes in the tree are smaller than the given node.
235//
236// Key should adhere to the comparator's type assertion, otherwise method panics.
237func (tree *Tree) Ceiling(key interface{}) (ceiling *Node, found bool) {
238	found = false
239	node := tree.Root
240	for node != nil {
241		compare := tree.Comparator(key, node.Key)
242		switch {
243		case compare == 0:
244			return node, true
245		case compare < 0:
246			ceiling, found = node, true
247			node = node.Left
248		case compare > 0:
249			node = node.Right
250		}
251	}
252	if found {
253		return ceiling, true
254	}
255	return nil, false
256}
257
258// Clear removes all nodes from the tree.
259func (tree *Tree) Clear() {
260	tree.Root = nil
261	tree.size = 0
262}
263
264// String returns a string representation of container
265func (tree *Tree) String() string {
266	str := "RedBlackTree\n"
267	if !tree.Empty() {
268		output(tree.Root, "", true, &str)
269	}
270	return str
271}
272
273func (node *Node) String() string {
274	return fmt.Sprintf("%v", node.Key)
275}
276
277func output(node *Node, prefix string, isTail bool, str *string) {
278	if node.Right != nil {
279		newPrefix := prefix
280		if isTail {
281			newPrefix += "│   "
282		} else {
283			newPrefix += "    "
284		}
285		output(node.Right, newPrefix, false, str)
286	}
287	*str += prefix
288	if isTail {
289		*str += "└── "
290	} else {
291		*str += "┌── "
292	}
293	*str += node.String() + "\n"
294	if node.Left != nil {
295		newPrefix := prefix
296		if isTail {
297			newPrefix += "    "
298		} else {
299			newPrefix += "│   "
300		}
301		output(node.Left, newPrefix, true, str)
302	}
303}
304
305func (tree *Tree) lookup(key interface{}) *Node {
306	node := tree.Root
307	for node != nil {
308		compare := tree.Comparator(key, node.Key)
309		switch {
310		case compare == 0:
311			return node
312		case compare < 0:
313			node = node.Left
314		case compare > 0:
315			node = node.Right
316		}
317	}
318	return nil
319}
320
321func (node *Node) grandparent() *Node {
322	if node != nil && node.Parent != nil {
323		return node.Parent.Parent
324	}
325	return nil
326}
327
328func (node *Node) uncle() *Node {
329	if node == nil || node.Parent == nil || node.Parent.Parent == nil {
330		return nil
331	}
332	return node.Parent.sibling()
333}
334
335func (node *Node) sibling() *Node {
336	if node == nil || node.Parent == nil {
337		return nil
338	}
339	if node == node.Parent.Left {
340		return node.Parent.Right
341	}
342	return node.Parent.Left
343}
344
345func (tree *Tree) rotateLeft(node *Node) {
346	right := node.Right
347	tree.replaceNode(node, right)
348	node.Right = right.Left
349	if right.Left != nil {
350		right.Left.Parent = node
351	}
352	right.Left = node
353	node.Parent = right
354}
355
356func (tree *Tree) rotateRight(node *Node) {
357	left := node.Left
358	tree.replaceNode(node, left)
359	node.Left = left.Right
360	if left.Right != nil {
361		left.Right.Parent = node
362	}
363	left.Right = node
364	node.Parent = left
365}
366
367func (tree *Tree) replaceNode(old *Node, new *Node) {
368	if old.Parent == nil {
369		tree.Root = new
370	} else {
371		if old == old.Parent.Left {
372			old.Parent.Left = new
373		} else {
374			old.Parent.Right = new
375		}
376	}
377	if new != nil {
378		new.Parent = old.Parent
379	}
380}
381
382func (tree *Tree) insertCase1(node *Node) {
383	if node.Parent == nil {
384		node.color = black
385	} else {
386		tree.insertCase2(node)
387	}
388}
389
390func (tree *Tree) insertCase2(node *Node) {
391	if nodeColor(node.Parent) == black {
392		return
393	}
394	tree.insertCase3(node)
395}
396
397func (tree *Tree) insertCase3(node *Node) {
398	uncle := node.uncle()
399	if nodeColor(uncle) == red {
400		node.Parent.color = black
401		uncle.color = black
402		node.grandparent().color = red
403		tree.insertCase1(node.grandparent())
404	} else {
405		tree.insertCase4(node)
406	}
407}
408
409func (tree *Tree) insertCase4(node *Node) {
410	grandparent := node.grandparent()
411	if node == node.Parent.Right && node.Parent == grandparent.Left {
412		tree.rotateLeft(node.Parent)
413		node = node.Left
414	} else if node == node.Parent.Left && node.Parent == grandparent.Right {
415		tree.rotateRight(node.Parent)
416		node = node.Right
417	}
418	tree.insertCase5(node)
419}
420
421func (tree *Tree) insertCase5(node *Node) {
422	node.Parent.color = black
423	grandparent := node.grandparent()
424	grandparent.color = red
425	if node == node.Parent.Left && node.Parent == grandparent.Left {
426		tree.rotateRight(grandparent)
427	} else if node == node.Parent.Right && node.Parent == grandparent.Right {
428		tree.rotateLeft(grandparent)
429	}
430}
431
432func (node *Node) maximumNode() *Node {
433	if node == nil {
434		return nil
435	}
436	for node.Right != nil {
437		node = node.Right
438	}
439	return node
440}
441
442func (tree *Tree) deleteCase1(node *Node) {
443	if node.Parent == nil {
444		return
445	}
446	tree.deleteCase2(node)
447}
448
449func (tree *Tree) deleteCase2(node *Node) {
450	sibling := node.sibling()
451	if nodeColor(sibling) == red {
452		node.Parent.color = red
453		sibling.color = black
454		if node == node.Parent.Left {
455			tree.rotateLeft(node.Parent)
456		} else {
457			tree.rotateRight(node.Parent)
458		}
459	}
460	tree.deleteCase3(node)
461}
462
463func (tree *Tree) deleteCase3(node *Node) {
464	sibling := node.sibling()
465	if nodeColor(node.Parent) == black &&
466		nodeColor(sibling) == black &&
467		nodeColor(sibling.Left) == black &&
468		nodeColor(sibling.Right) == black {
469		sibling.color = red
470		tree.deleteCase1(node.Parent)
471	} else {
472		tree.deleteCase4(node)
473	}
474}
475
476func (tree *Tree) deleteCase4(node *Node) {
477	sibling := node.sibling()
478	if nodeColor(node.Parent) == red &&
479		nodeColor(sibling) == black &&
480		nodeColor(sibling.Left) == black &&
481		nodeColor(sibling.Right) == black {
482		sibling.color = red
483		node.Parent.color = black
484	} else {
485		tree.deleteCase5(node)
486	}
487}
488
489func (tree *Tree) deleteCase5(node *Node) {
490	sibling := node.sibling()
491	if node == node.Parent.Left &&
492		nodeColor(sibling) == black &&
493		nodeColor(sibling.Left) == red &&
494		nodeColor(sibling.Right) == black {
495		sibling.color = red
496		sibling.Left.color = black
497		tree.rotateRight(sibling)
498	} else if node == node.Parent.Right &&
499		nodeColor(sibling) == black &&
500		nodeColor(sibling.Right) == red &&
501		nodeColor(sibling.Left) == black {
502		sibling.color = red
503		sibling.Right.color = black
504		tree.rotateLeft(sibling)
505	}
506	tree.deleteCase6(node)
507}
508
509func (tree *Tree) deleteCase6(node *Node) {
510	sibling := node.sibling()
511	sibling.color = nodeColor(node.Parent)
512	node.Parent.color = black
513	if node == node.Parent.Left && nodeColor(sibling.Right) == red {
514		sibling.Right.color = black
515		tree.rotateLeft(node.Parent)
516	} else if nodeColor(sibling.Left) == red {
517		sibling.Left.color = black
518		tree.rotateRight(node.Parent)
519	}
520}
521
522func nodeColor(node *Node) color {
523	if node == nil {
524		return black
525	}
526	return node.color
527}
528