1(* Rbset -- functional sets using Okasaki-style red-black trees *)
2(* Ken Friis Larsen <ken@friislarsen.net>                                *)
3(* Various extensions, and test: sestoft@dina.kvl.dk * 2001-10-21 *)
4
5structure Rbset :>  Rbset =
6struct
7
8  datatype 'item tree = LEAF
9                      | RED   of 'item * 'item tree * 'item tree
10                      | BLACK of 'item * 'item tree * 'item tree
11
12  type 'item set  = ('item * 'item -> order) * 'item tree * int
13
14  datatype 'item intv =
15      All
16    | From of 'item
17    | To   of 'item
18    | FromTo of 'item * 'item
19
20  exception NotFound
21
22  fun empty compare = (compare, LEAF, 0)
23
24  fun getOrder (compare, _, _) = compare
25
26  fun numItems (_, _, n) = n
27
28  fun singleton compare x = (compare, BLACK(x, LEAF, LEAF), 1)
29
30  fun isEmpty (_, LEAF, _) = true
31    | isEmpty _            = false
32
33  fun member ((compare, tree, n), elm) =
34      let fun memShared x left right =
35              case compare(elm,x) of
36                  EQUAL   => true
37                | LESS    => mem left
38                | GREATER => mem right
39          and mem LEAF                    = false
40            | mem (RED(x, left, right))   = memShared x left right
41            | mem (BLACK(x, left, right)) = memShared x left right
42      in  mem tree end
43
44  fun retrieve (set, x) = if member(set, x) then x else raise NotFound
45
46  fun peek (set, x) = if member(set, x) then SOME x else NONE
47
48  fun lbalance z (RED(y,RED(x,a,b),c)) d =
49      RED(y,BLACK(x,a,b),BLACK(z,c,d))
50    | lbalance z (RED(x,a,RED(y,b,c))) d =
51      RED(y,BLACK(x,a,b),BLACK(z,c,d))
52    | lbalance x left right = BLACK(x, left, right)
53
54  fun rbalance x a (RED(y,b,RED(z,c,d))) =
55      RED(y,BLACK(x,a,b),BLACK(z,c,d))
56    | rbalance x a (RED(z,RED(y,b,c),d)) =
57      RED(y,BLACK(x,a,b),BLACK(z,c,d))
58    | rbalance x left right = BLACK(x, left, right)
59
60  exception GETOUT
61
62  local
63      fun insert compare elm =
64          let fun ins LEAF = RED(elm,LEAF,LEAF)
65	        | ins (BLACK(x,left,right)) =
66                  (case compare(elm, x) of
67                       LESS    => lbalance x (ins left) right
68                     | GREATER => rbalance x left (ins right)
69                     | EQUAL   => raise GETOUT)
70	        | ins (RED(x,left,right)) =
71                  (case compare(elm, x) of
72                       LESS    => RED(x, (ins left), right)
73                     | GREATER => RED(x, left, (ins right))
74                     | EQUAL   => raise GETOUT)
75          in  ins end
76  in
77
78  fun add (set as (compare, tree, n), elm) =
79      ( compare
80      , case insert compare elm tree of
81            RED(e, l, r) => BLACK(e, l, r)
82          | tree         => tree
83      , n+1)
84      handle GETOUT => set
85
86  fun add' (elm, set) = add(set, elm)
87
88  fun addList (set, xs) = List.foldl add' set xs
89  end
90
91  fun push LEAF stack = stack
92    | push tree stack = tree :: stack
93
94  fun pushNode x right stack = BLACK(x, LEAF, LEAF) :: push right stack
95
96  fun getMin []             some none = none
97    | getMin (tree :: rest) some none =
98      let fun descend tree stack =
99	  case tree of
100	      LEAF                  => getMin stack some none
101	    | RED  (x, LEAF, right) => some x (push right stack)
102	    | BLACK(x, LEAF, right) => some x (push right stack)
103	    | RED  (x, left, right) => descend left (pushNode x right stack)
104	    | BLACK(x, left, right) => descend left (pushNode x right stack)
105      in descend tree rest end
106
107(*   fun getMin []             some none = none
108    | getMin (tree :: rest) some none =
109      case tree of
110          LEAF                  => getMin rest some none
111        | RED  (x, LEAF, right) => some x (push right rest)
112        | BLACK(x, LEAF, right) => some x (push right rest)
113        | RED  (x, left, right) => getMin(pushNode left x right rest) some none
114        | BLACK(x, left, right) => getMin(pushNode left x right rest) some none
115 *)
116
117  fun getMax []             some none = none
118    | getMax (tree :: rest) some none =
119      let fun descend tree stack =
120	  case tree of
121	      LEAF                  => getMax stack some none
122	    | RED  (x, left, LEAF)  => some x (push left stack)
123	    | BLACK(x, left, LEAF)  => some x (push left stack)
124	    | RED  (x, left, right) => descend right (pushNode x left stack)
125	    | BLACK(x, left, right) => descend right (pushNode x left stack)
126      in descend tree rest end
127
128(*   fun getMax []             some none = none
129    | getMax (tree :: rest) some none =
130      case tree of
131          LEAF                  => getMax rest some none
132        | RED  (x, left, LEAF)  => some x (push left rest)
133        | BLACK(x, left, LEAF)  => some x (push left rest)
134        | RED  (x, left, right) => getMax(pushNode right x left rest) some none
135        | BLACK(x, left, right) => getMax(pushNode right x left rest) some none
136 *)
137  fun fold get f e (compare, tree, n) =
138      let fun loop stack acc =
139              get stack (fn x => fn stack => loop stack (f(x, acc))) acc
140      in  loop (push tree []) e end
141
142  fun foldl f = fold getMin f
143
144  fun foldr f = fold getMax f
145
146  fun listItems set = foldr op:: [] set
147
148  fun appAll get f (compare, tree, n) =
149      let fun loop stack = get stack (fn x => (f x; loop)) ()
150      in  loop [tree] end
151
152  fun app f = appAll getMin f
153
154  fun revapp f = appAll getMax f
155
156  fun find p (compare, tree, n) =
157      let fun loop stack =
158              getMin stack (fn x => fn stack =>
159                                       if p x then SOME x else loop stack) NONE
160      in  loop (push tree []) end
161
162  fun map (f, compare) s =
163      foldl (fn (k, res) => add(res, f k)) (empty compare) s
164
165  (*  Ralf Hinze's convert a sorted list to RB tree *)
166  local
167      datatype 'item digits =
168               ZERO
169             | ONE of 'item * 'item tree * 'item digits
170             | TWO of 'item * 'item tree * 'item * 'item tree * 'item digits
171
172      fun incr x a ZERO                  = ONE(x, a, ZERO)
173        | incr x a (ONE(y, b, ds))       = TWO(x, a, y, b, ds)
174        | incr z c (TWO(y, b, x, a, ds)) =
175          ONE(z, c, incr y (BLACK(x, a, b)) ds)
176
177      fun insertMax(a, digits) = incr a LEAF digits
178
179      fun build ZERO                  a = a
180        | build (ONE(x, a, ds))       b = build ds (BLACK(x, a, b))
181        | build (TWO(y, b, x, a, ds)) c = build ds (BLACK(x, a, RED(y, b, c)))
182
183      fun buildAll digits = build digits LEAF
184
185      fun toInt digits =
186          let fun loop ZERO power acc            = acc
187                | loop (ONE(_,_,rest)) power acc =
188                  loop rest (2*power) (power + acc)
189                | loop (TWO(_,_,_,_,rest)) power acc =
190                  loop rest (2*power) (2*power + acc)
191          in  loop digits 1 0 end
192
193      fun get stack = getMin stack (fn x => fn stack => SOME(x,stack)) NONE
194
195      fun insRest stack acc =
196          getMin stack (fn x => fn stack => insRest stack (insertMax(x,acc)))
197                 acc
198
199  in
200  fun fromSortedList (compare, ls) =
201      let val digits = List.foldl insertMax ZERO ls
202      in  (compare, buildAll digits, toInt digits) end
203
204
205  (* FIXME: it *must* be possible to write union, equal, isSubset,
206            intersection, and difference more elegantly.
207  *)
208  fun union (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
209      let fun loop x y stack1 stack2 res =
210              case compare(x, y) of
211                  EQUAL =>
212                  let val res = insertMax(x, res)
213                  in  case (get stack1, get stack2) of
214                          (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
215                        | (NONE, NONE)               => res
216                        | (SOME _, _)                => insRest stack1 res
217                        | (_, SOME _)                => insRest stack2 res
218                  end
219                | LESS =>
220                  let val res = insertMax(x, res)
221                  in  case get stack1 of
222                          NONE => insRest stack2 (insertMax(y, res))
223                        | SOME(x, stack1) => loop x y stack1 stack2 res
224                  end
225                | GREATER =>
226                  let val res = insertMax(y, res)
227                  in  case get stack2 of
228                          NONE => insRest stack1 (insertMax(x, res))
229                        | SOME(y, stack2) => loop x y stack1 stack2 res
230                  end
231      in  (* FIXME: here is lots of room for optimizations *)
232          case (get [t1], get [t2]) of
233              (SOME(x, stack1), SOME(y, stack2)) =>
234              let val digits = loop x y stack1 stack2 ZERO
235              in  (compare, buildAll digits, toInt digits) end
236            | (_, SOME _) => s2
237            | _           => s1 end
238
239
240  fun intersection (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
241      let fun loop x y stack1 stack2 res =
242              case compare(x, y) of
243                  EQUAL =>
244                  let val res = insertMax(x, res)
245                  in  case (get stack1, get stack2) of
246                          (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
247                        | _                          => res
248                  end
249                | LESS =>
250                  (case get stack1 of
251                       NONE            => res
252                     | SOME(x, stack1) => loop x y stack1 stack2 res)
253                | GREATER =>
254                  (case get stack2 of
255                       NONE            => res
256                     | SOME(y, stack2) => loop x y stack1 stack2 res)
257      in  (* FIXME: here is lots of room for optimizations *)
258          case (get [t1], get [t2]) of
259              (SOME(x, stack1), SOME(y, stack2)) =>
260              let val digits = loop x y stack1 stack2 ZERO
261              in  (compare, buildAll digits, toInt digits) end
262            | _           => empty compare end
263
264
265  fun difference (s1 as (compare, t1, n1), s2 as (_, t2, n2)) =
266      let fun loop x y stack1 stack2 res =
267              case compare(x, y) of
268                  EQUAL =>
269                  (case (get stack1, get stack2) of
270                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2 res
271                     | (SOME _, _)                => insRest stack1 res
272                     | _                          => res)
273                | LESS =>
274                  let val res = insertMax(x, res)
275                  in  case get stack1 of
276                          NONE            => res
277                        | SOME(x, stack1) => loop x y stack1 stack2 res
278                  end
279                | GREATER =>
280                  (case get stack2 of
281                       NONE => insRest stack1 (insertMax(x, res))
282                     | SOME(y, stack2) => loop x y stack1 stack2 res)
283      in  (* FIXME: here is lots of room for optimizations *)
284          case (get [t1], get [t2]) of
285              (SOME(x, stack1), SOME(y, stack2)) =>
286              let val digits = loop x y stack1 stack2 ZERO
287              in  (compare, buildAll digits, toInt digits) end
288            | (_, SOME _) => empty compare
289            | _           => s1 end
290
291  fun equal ((compare, t1, _), (_, t2, _)) =
292      let fun loop x y stack1 stack2 =
293              case compare(x, y) of
294                  EQUAL =>
295                  (case (get stack1, get stack2) of
296                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2
297                     | (NONE, NONE)               => true
298                     | _                          => false)
299                | _ => false
300      in  (* FIXME: here is lots of room for optimizations *)
301          case (get [t1], get [t2]) of
302              (SOME(x, stack1), SOME(y, stack2)) => loop x y stack1 stack2
303            | (NONE, NONE)                       => true
304            | _                                  => false end
305
306  fun compare ((cmp, t1, _), (_, t2, _)) =
307      let fun loop x y stack1 stack2 =
308	  case cmp(x, y) of
309	      EQUAL =>
310                  (case (get stack1, get stack2) of
311                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2
312                     | (NONE, NONE)               => EQUAL
313                     | (NONE, _)                  => LESS
314                     | (_, NONE)                  => GREATER)
315	    | order => order
316      in
317          case (get [t1], get [t2]) of
318              (SOME(x, stack1), SOME(y, stack2)) => loop x y stack1 stack2
319            | (NONE, NONE)                       => EQUAL
320	    | (NONE, _)                          => LESS
321	    | (_, NONE)                          => GREATER
322      end
323
324  fun isSubset ((compare, t1, _), (_, t2, _)) =
325      let fun loop x y stack1 stack2 =
326              case compare(x, y) of
327                  EQUAL =>
328                  (case (get stack1, get stack2) of
329                       (SOME(x, s1), SOME(y, s2)) => loop x y s1 s2
330                     | (NONE, _)                  => true
331                     | _                          => false)
332                | LESS => false
333                | GREATER =>
334                  (case get stack2 of
335                       SOME(y, stack2) => loop x y stack1 stack2
336                     | NONE            => false)
337      in  (* FIXME: here is lots of room for optimizations *)
338          case (get [t1], get [t2]) of
339              (SOME(x, stack1), SOME(y, stack2)) => loop x y stack1 stack2
340            | (NONE, _)                          => true
341            | _                                  => false end
342
343  end
344
345  (* Function f must be strictly monotonically increasing on the
346     elements of s; we check this requirement: *)
347
348  exception NonMonotonic
349
350  fun mapMono (f, compare) s =
351      let val fxs = foldl (fn (x, res) => f x :: res) [] s
352	  fun sorted []         = true
353	    | sorted (y1 :: yr) =
354	      let fun h x0 []       = true
355		    | h x0 (x1::xr) = compare(x0, x1) = LESS andalso h x1 xr
356	      in h y1 yr end
357      in
358	  if sorted fxs then
359	      fromSortedList (compare, fxs)
360	  else
361	      raise NonMonotonic
362      end
363
364  (* Peter Sestoft's convert a sorted list to RB tree *)
365  (* Did I write this?  I'm impressed, but let's ignore it for now.
366
367  fun fromSortedList' (compare, ls) =
368      let val len = List.length ls
369	  fun log2 n =
370	      let fun loop k p = if p >= n then k else loop (k+1) (2*p)
371	      in loop 0 1 end
372	  fun h 0 _ xs = (LEAF, xs)
373	    | h n d xs =
374	      let val m = n div 2
375	          val (t1, y :: yr) = h m       (d-1) xs
376	          val (t2, zs)      = h (n-m-1) (d-1) yr
377	      in (if d=0 then RED(y, t1, t2) else BLACK(y, t1, t2), zs) end
378      in  (compare,
379	   case #1 (h len (log2 (len + 1) - 1) ls) of
380                RED(x, left, right) => BLACK(x, left, right)
381              | tree                => tree
382          , len)
383      end
384  *)
385
386  (* delete a la Stefan M. Kahrs *)
387
388  fun sub1 (BLACK arg) = RED arg
389    | sub1 _ = raise Fail "Rbset.sub1: impossible"
390
391  fun balleft y (RED(x,a,b)) c             = RED(y, BLACK(x, a, b), c)
392    | balleft x bl (BLACK(y, a, b))        = rbalance x bl (RED(y, a, b))
393    | balleft x bl (RED(z,BLACK(y,a,b),c)) =
394      RED(y, BLACK(x, bl, a), rbalance z b (sub1 c))
395    | balleft _ _ _ = raise Fail "Rbset.balleft: impossible"
396
397  fun balright x a             (RED(y,b,c)) = RED(x, a, BLACK(y, b, c))
398    | balright y (BLACK(x,a,b))          br = lbalance y (RED(x,a,b)) br
399    | balright z (RED(x,a,BLACK(y,b,c))) br =
400      RED(y, lbalance x (sub1 a) b, BLACK(z, c, br))
401    | balright _ _ _ = raise Fail "Rbset.balright: impossible"
402
403  (* [append left right] constructs a new tree t.
404  PRECONDITIONS: RB left /\ RB right
405              /\ !e in left => !x in right e < x
406  POSTCONDITION: not (RB t)
407  *)
408  fun append LEAF right                    = right
409    | append left LEAF                     = left
410    | append (RED(x,a,b)) (RED(y,c,d))     =
411      (case append b c of
412	   RED(z, b, c) => RED(z, RED(x, a, b), RED(y, c, d))
413         | bc           => RED(x, a, RED(y, bc, d)))
414    | append a (RED(x,b,c))                = RED(x, append a b, c)
415    | append (RED(x,a,b)) c                = RED(x, a, append b c)
416    | append (BLACK(x,a,b)) (BLACK(y,c,d)) =
417      (case append b c of
418	   RED(z, b, c) => RED(z, BLACK(x, a, b), BLACK(y, c, d))
419         | bc           => balleft x a (BLACK(y, bc, d)))
420
421  fun delete (set as (compare, tree, n), x) =
422      let fun delShared y a b =
423              case compare(x,y) of
424                  EQUAL   => append a b
425                | LESS    => (case a of
426                                  BLACK _ => balleft y (del a) b
427                                | _       => RED(y, del a, b))
428                | GREATER => (case b of
429                                  BLACK _ => balright y a (del b)
430                                | _       => RED(y, a, del b))
431          and del LEAF             = raise NotFound
432            | del (RED(y, a, b))   = delShared y a b
433            | del (BLACK(y, a, b)) = delShared y a b
434      in  ( compare
435          , case del tree of
436                RED arg => BLACK arg
437              | tree    => tree
438          , n-1) end
439
440  fun min (_, t, _) =
441      let fun h LEAF = NONE
442	    | h (RED  (k, LEAF, t2)) = SOME k
443	    | h (RED  (k, t1,   t2)) = h t1
444	    | h (BLACK(k, LEAF, t2)) = SOME k
445	    | h (BLACK(k, t1,   t2)) = h t1
446      in h t end
447
448  fun max (_, t, _) =
449      let fun h LEAF = NONE
450	    | h (RED  (k, t1, LEAF)) = SOME k
451	    | h (RED  (k, t1, t2  )) = h t2
452	    | h (BLACK(k, t1, LEAF)) = SOME k
453	    | h (BLACK(k, t1, t2  )) = h t2
454      in h t end
455
456  fun hash (h : 'item -> word) (s : 'item set) =
457      foldl (fn (k, res) => h k + res) 0w0 s
458
459  (* Extract sublist containing the elements that are in the given interval *)
460
461  fun sublist((cmp, t, _), intv) =
462      let fun collectall LEAF res = res
463	    | collectall (RED(k, t1, t2)) res =
464	      collectall t1 (k :: collectall t2 res)
465	    | collectall (BLACK(k, t1, t2)) res =
466	      collectall t1 (k :: collectall t2 res)
467	  (* Collect from `from' till end *)
468	  fun collectfrom LEAF res = res
469	    | collectfrom (tree as RED  (k, t1, t2)) res =
470	      collnode tree k t1 t2 res
471	    | collectfrom (tree as BLACK(k, t1, t2)) res =
472	      collnode tree k t1 t2 res
473	  and collnode tree k t1 t2 res =
474	      case intv of
475		  From from =>
476		      if cmp(from, k) = GREATER then (* ignore left *)
477			  collectfrom t2 res
478		      else (* from <= k *)
479			  collectfrom t1 (k :: collectall t2 res)
480		| FromTo (from, _) =>
481		      if cmp(from, k) = GREATER then (* ignore left *)
482			  collectfrom t2 res
483		      else (* from <= k *)
484			  collectfrom t1 (k :: collectfrom t2 res)
485		| _ => collectall tree res
486	  (* Collect from beginning to `to', exclusive *)
487	  fun collectto LEAF res = res
488	    | collectto (tree as RED  (k, t1, t2)) res =
489	      collnode tree k t1 t2 res
490	    | collectto (tree as BLACK(k, t1, t2)) res =
491	      collnode tree k t1 t2 res
492	  and collnode tree k t1 t2 res =
493	      case intv of
494		  To to =>
495		      if cmp(k, to) = LESS then
496			  collectall t1 (k :: collectto t2 res)
497		      else (* ignore right, k >= to *)
498			  collectto t1 res
499		| FromTo (_, to) =>
500		      if cmp(k, to) = LESS then
501			  collectall t1 (k :: collectto t2 res)
502		      else (* ignore right, k >= to *)
503			  collectto t1 res
504		| _ => collectall tree res
505	  (* Collect from `from' to `to' *)
506	  fun collectfromto LEAF res = res
507	    | collectfromto (tree as RED  (k, t1, t2)) res =
508	      collnode tree k t1 t2 res
509	    | collectfromto (tree as BLACK(k, t1, t2)) res =
510	      collnode tree k t1 t2 res
511	  and collnode tree k t1 t2 res =
512	      case intv of
513		  From from => collectfrom tree res
514		| To to     => collectto tree res
515		| FromTo (from, to) =>
516		      if cmp(from, k) = GREATER then (* ignore left *)
517			  collectfromto t2 res
518		      else if cmp(k, to) = LESS then  (* from <= k < to *)
519			  collectfrom t1 (k :: collectto t2 res)
520		      else (* ignore right *)
521			  collectfromto t1 res
522		| All => collectall tree res
523      in collectfromto t [] end
524
525    (* Note: builds an intermediate list of elements *)
526    fun subset (s as (cmp, t, _), intv) =
527	fromSortedList(cmp, sublist(s, intv))
528
529    (* For debugging only *)
530
531    fun depth LEAF = 0
532      | depth (RED  (_, t1, t2)) = 1 + Int.max(depth t1, depth t2)
533      | depth (BLACK(_, t1, t2)) = 1 + Int.max(depth t1, depth t2)
534
535    val depth = fn (_, t, _) => depth t
536end
537