1{-# LANGUAGE CPP #-}
2module Control.Monad.Free.Zip (zipFree, zipFree_) where
3
4import Control.Monad.Free
5import Control.Monad.Trans.Class
6import Control.Monad.Trans.State
7import Data.Foldable
8import Data.Traversable as T
9import Prelude hiding (fail)
10
11zipFree
12  :: (Traversable f, Eq (f ()), MonadFail m)
13  => (Free f a -> Free f b -> m (Free f c))
14  -> Free f a
15  -> Free f b
16  -> m (Free f c)
17zipFree f (Impure a) (Impure b)
18  | fmap (const ()) a == fmap (const ()) b = Impure `liftM` unsafeZipWithG f a b
19zipFree _ _ _ = fail "zipFree: structure mistmatch"
20
21zipFree_
22  :: (Traversable f, Eq (f ()), MonadFail m)
23  => (Free f a -> Free f b -> m ()) -> Free f a -> Free f b -> m ()
24zipFree_ f (Impure a) (Impure b)
25  | fmap (const ()) a == fmap (const ()) b = zipWithM_ f (toList a) (toList b)
26zipFree_ _ _ _ = fail "zipFree_: structure mismatch"
27
28
29unsafeZipWithG
30  :: (Traversable t1, Traversable t2, Monad m, MonadFail m)
31  => (a -> b -> m c) -> t1 a -> t2 b -> m (t2 c)
32unsafeZipWithG f t1 t2  = evalStateT (T.mapM zipG' t2) (toList t1)
33       where zipG' y = do (x:xx) <- get
34                          put xx
35                          lift (f x y)
36