1{-# LANGUAGE CPP #-}
2module Test.Hspec.Core.Shuffle (
3  shuffleForest
4#ifdef TEST
5, shuffle
6, mkArray
7#endif
8) where
9
10import           Prelude ()
11import           Test.Hspec.Core.Compat
12import           Test.Hspec.Core.Tree
13
14import           System.Random
15import           Control.Monad.ST
16import           Data.STRef
17import           Data.Array.ST
18
19shuffleForest :: STRef s StdGen -> [Tree c a] -> ST s [Tree c a]
20shuffleForest ref xs = (shuffle ref xs >>= mapM (shuffleTree ref))
21
22shuffleTree :: STRef s StdGen -> Tree c a -> ST s (Tree c a)
23shuffleTree ref t = case t of
24  Node d xs -> Node d <$> shuffleForest ref xs
25  NodeWithCleanup c xs -> NodeWithCleanup c <$> shuffleForest ref xs
26  Leaf {} -> return t
27
28shuffle :: STRef s StdGen -> [a] -> ST s [a]
29shuffle ref xs = do
30  arr <- mkArray xs
31  bounds@(_, n) <- getBounds arr
32  forM (range bounds) $ \ i -> do
33    j <- randomIndex (i, n)
34    vi <- readArray arr i
35    vj <- readArray arr j
36    writeArray arr j vi
37    return vj
38  where
39    randomIndex bounds = do
40      (a, gen) <- randomR bounds <$> readSTRef ref
41      writeSTRef ref gen
42      return a
43
44mkArray :: [a] -> ST s (STArray s Int a)
45mkArray xs = newListArray (1, length xs) xs
46