1{-# LANGUAGE NondecreasingIndentation #-} 2-- | A simple mutable union-find data structure. 3-- 4-- It is used in a unification algorithm for backpack mix-in linking. 5-- 6-- This implementation is based off of the one in \"The Essence of ML Type 7-- Inference\". (N.B. the union-find package is also based off of this.) 8-- 9module Distribution.Utils.UnionFind ( 10 Point, 11 fresh, 12 find, 13 union, 14 equivalent, 15) where 16 17import Data.STRef 18import Control.Monad 19import Control.Monad.ST 20 21-- | A variable which can be unified; alternately, this can be thought 22-- of as an equivalence class with a distinguished representative. 23newtype Point s a = Point (STRef s (Link s a)) 24 deriving (Eq) 25 26-- | Mutable write to a 'Point' 27writePoint :: Point s a -> Link s a -> ST s () 28writePoint (Point v) = writeSTRef v 29 30-- | Read the current value of 'Point'. 31readPoint :: Point s a -> ST s (Link s a) 32readPoint (Point v) = readSTRef v 33 34-- | The internal data structure for a 'Point', which either records 35-- the representative element of an equivalence class, or a link to 36-- the 'Point' that actually stores the representative type. 37data Link s a 38 -- NB: it is too bad we can't say STRef Int#; the weights remain boxed 39 = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a) 40 | Link {-# UNPACK #-} !(Point s a) 41 42-- | Create a fresh equivalence class with one element. 43fresh :: a -> ST s (Point s a) 44fresh desc = do 45 weight <- newSTRef 1 46 descriptor <- newSTRef desc 47 Point `fmap` newSTRef (Info weight descriptor) 48 49-- | Flatten any chains of links, returning a 'Point' 50-- which points directly to the canonical representation. 51repr :: Point s a -> ST s (Point s a) 52repr point = readPoint point >>= \r -> 53 case r of 54 Link point' -> do 55 point'' <- repr point' 56 when (point'' /= point') $ do 57 writePoint point =<< readPoint point' 58 return point'' 59 Info _ _ -> return point 60 61-- | Return the canonical element of an equivalence 62-- class 'Point'. 63find :: Point s a -> ST s a 64find point = 65 -- Optimize length 0 and 1 case at expense of 66 -- general case 67 readPoint point >>= \r -> 68 case r of 69 Info _ d_ref -> readSTRef d_ref 70 Link point' -> readPoint point' >>= \r' -> 71 case r' of 72 Info _ d_ref -> readSTRef d_ref 73 Link _ -> repr point >>= find 74 75-- | Unify two equivalence classes, so that they share 76-- a canonical element. Keeps the descriptor of point2. 77union :: Point s a -> Point s a -> ST s () 78union refpoint1 refpoint2 = do 79 point1 <- repr refpoint1 80 point2 <- repr refpoint2 81 when (point1 /= point2) $ do 82 l1 <- readPoint point1 83 l2 <- readPoint point2 84 case (l1, l2) of 85 (Info wref1 dref1, Info wref2 dref2) -> do 86 weight1 <- readSTRef wref1 87 weight2 <- readSTRef wref2 88 -- Should be able to optimize the == case separately 89 if weight1 >= weight2 90 then do 91 writePoint point2 (Link point1) 92 -- The weight calculation here seems a bit dodgy 93 writeSTRef wref1 (weight1 + weight2) 94 writeSTRef dref1 =<< readSTRef dref2 95 else do 96 writePoint point1 (Link point2) 97 writeSTRef wref2 (weight1 + weight2) 98 _ -> error "UnionFind.union: repr invariant broken" 99 100-- | Test if two points are in the same equivalence class. 101equivalent :: Point s a -> Point s a -> ST s Bool 102equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2) 103