1{-# LANGUAGE NondecreasingIndentation #-}
2-- | A simple mutable union-find data structure.
3--
4-- It is used in a unification algorithm for backpack mix-in linking.
5--
6-- This implementation is based off of the one in \"The Essence of ML Type
7-- Inference\". (N.B. the union-find package is also based off of this.)
8--
9module Distribution.Utils.UnionFind (
10    Point,
11    fresh,
12    find,
13    union,
14    equivalent,
15) where
16
17import Data.STRef
18import Control.Monad
19import Control.Monad.ST
20
21-- | A variable which can be unified; alternately, this can be thought
22-- of as an equivalence class with a distinguished representative.
23newtype Point s a = Point (STRef s (Link s a))
24    deriving (Eq)
25
26-- | Mutable write to a 'Point'
27writePoint :: Point s a -> Link s a -> ST s ()
28writePoint (Point v) = writeSTRef v
29
30-- | Read the current value of 'Point'.
31readPoint :: Point s a -> ST s (Link s a)
32readPoint (Point v) = readSTRef v
33
34-- | The internal data structure for a 'Point', which either records
35-- the representative element of an equivalence class, or a link to
36-- the 'Point' that actually stores the representative type.
37data Link s a
38    -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
39    = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
40    | Link {-# UNPACK #-} !(Point s a)
41
42-- | Create a fresh equivalence class with one element.
43fresh :: a -> ST s (Point s a)
44fresh desc = do
45    weight <- newSTRef 1
46    descriptor <- newSTRef desc
47    Point `fmap` newSTRef (Info weight descriptor)
48
49-- | Flatten any chains of links, returning a 'Point'
50-- which points directly to the canonical representation.
51repr :: Point s a -> ST s (Point s a)
52repr point = readPoint point >>= \r ->
53  case r of
54    Link point' -> do
55        point'' <- repr point'
56        when (point'' /= point') $ do
57            writePoint point =<< readPoint point'
58        return point''
59    Info _ _ -> return point
60
61-- | Return the canonical element of an equivalence
62-- class 'Point'.
63find :: Point s a -> ST s a
64find point =
65    -- Optimize length 0 and 1 case at expense of
66    -- general case
67    readPoint point >>= \r ->
68      case r of
69        Info _ d_ref -> readSTRef d_ref
70        Link point' -> readPoint point' >>= \r' ->
71          case r' of
72            Info _ d_ref -> readSTRef d_ref
73            Link _ -> repr point >>= find
74
75-- | Unify two equivalence classes, so that they share
76-- a canonical element. Keeps the descriptor of point2.
77union :: Point s a -> Point s a -> ST s ()
78union refpoint1 refpoint2 = do
79    point1 <- repr refpoint1
80    point2 <- repr refpoint2
81    when (point1 /= point2) $ do
82    l1 <- readPoint point1
83    l2 <- readPoint point2
84    case (l1, l2) of
85        (Info wref1 dref1, Info wref2 dref2) -> do
86            weight1 <- readSTRef wref1
87            weight2 <- readSTRef wref2
88            -- Should be able to optimize the == case separately
89            if weight1 >= weight2
90                then do
91                    writePoint point2 (Link point1)
92                    -- The weight calculation here seems a bit dodgy
93                    writeSTRef wref1 (weight1 + weight2)
94                    writeSTRef dref1 =<< readSTRef dref2
95                else do
96                    writePoint point1 (Link point2)
97                    writeSTRef wref2 (weight1 + weight2)
98        _ -> error "UnionFind.union: repr invariant broken"
99
100-- | Test if two points are in the same equivalence class.
101equivalent :: Point s a -> Point s a -> ST s Bool
102equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2)
103