1{-# LANGUAGE CPP        #-}
2{-# LANGUAGE RankNTypes #-}
3{-# LANGUAGE Safe       #-}
4#if __GLASGOW_HASKELL__ >= 706
5{-# LANGUAGE PolyKinds  #-}
6#endif
7
8{-| A monad morphism is a natural transformation:
9
10> morph :: forall a . m a -> n a
11
12    ... that obeys the following two laws:
13
14> morph $ do x <- m  =  do x <- morph m
15>            f x           morph (f x)
16>
17> morph (return x) = return x
18
19    ... which are equivalent to the following two functor laws:
20
21> morph . (f >=> g) = morph . f >=> morph . g
22>
23> morph . return = return
24
25    Examples of monad morphisms include:
26
27    * 'lift' (from 'MonadTrans')
28
29    * 'squash' (See below)
30
31    * @'hoist' f@ (See below), if @f@ is a monad morphism
32
33    * @(f . g)@, if @f@ and @g@ are both monad morphisms
34
35    * 'id'
36
37    Monad morphisms commonly arise when manipulating existing monad transformer
38    code for compatibility purposes.  The 'MFunctor', 'MonadTrans', and
39    'MMonad' classes define standard ways to change monad transformer stacks:
40
41    * 'lift' introduces a new monad transformer layer of any type.
42
43    * 'squash' flattens two identical monad transformer layers into a single
44      layer of the same type.
45
46    * 'hoist' maps monad morphisms to modify deeper layers of the monad
47       transformer stack.
48
49-}
50
51module Control.Monad.Morph (
52    -- * Functors over Monads
53    MFunctor(..),
54    generalize,
55    -- * Monads over Monads
56    MMonad(..),
57    MonadTrans(lift),
58    squash,
59    (>|>),
60    (<|<),
61    (=<|),
62    (|>=)
63
64    -- * Tutorial
65    -- $tutorial
66
67    -- ** Generalizing base monads
68    -- $generalize
69
70    -- ** Monad morphisms
71    -- $mmorph
72
73    -- ** Mixing diverse transformers
74    -- $interleave
75
76    -- ** Embedding transformers
77    -- $embed
78    ) where
79
80import Control.Monad.Trans.Class (MonadTrans(lift))
81import qualified Control.Monad.Trans.Error         as E
82import qualified Control.Monad.Trans.Except        as Ex
83import qualified Control.Monad.Trans.Identity      as I
84import qualified Control.Monad.Trans.List          as L
85import qualified Control.Monad.Trans.Maybe         as M
86import qualified Control.Monad.Trans.Reader        as R
87import qualified Control.Monad.Trans.RWS.Lazy      as RWS
88import qualified Control.Monad.Trans.RWS.Strict    as RWS'
89import qualified Control.Monad.Trans.State.Lazy    as S
90import qualified Control.Monad.Trans.State.Strict  as S'
91import qualified Control.Monad.Trans.Writer.Lazy   as W'
92import qualified Control.Monad.Trans.Writer.Strict as W
93import Data.Monoid (Monoid, mappend)
94import Data.Functor.Compose (Compose (Compose))
95import Data.Functor.Identity (runIdentity)
96import Data.Functor.Product (Product (Pair))
97import Control.Applicative.Backwards (Backwards (Backwards))
98import Control.Applicative.Lift (Lift (Pure, Other))
99
100-- For documentation
101import Control.Exception (try, IOException)
102import Control.Monad ((=<<), (>=>), (<=<), join)
103import Data.Functor.Identity (Identity)
104
105{-| A functor in the category of monads, using 'hoist' as the analog of 'fmap':
106
107> hoist (f . g) = hoist f . hoist g
108>
109> hoist id = id
110-}
111class MFunctor t where
112    {-| Lift a monad morphism from @m@ to @n@ into a monad morphism from
113        @(t m)@ to @(t n)@
114
115        The first argument to `hoist` must be a monad morphism, even though the
116        type system does not enforce this
117    -}
118    hoist :: (Monad m) => (forall a . m a -> n a) -> t m b -> t n b
119
120instance MFunctor (E.ErrorT e) where
121    hoist nat m = E.ErrorT (nat (E.runErrorT m))
122
123instance MFunctor (Ex.ExceptT e) where
124    hoist nat m = Ex.ExceptT (nat (Ex.runExceptT m))
125
126instance MFunctor I.IdentityT where
127    hoist nat m = I.IdentityT (nat (I.runIdentityT m))
128
129instance MFunctor L.ListT where
130    hoist nat m = L.ListT (nat (L.runListT m))
131
132instance MFunctor M.MaybeT where
133    hoist nat m = M.MaybeT (nat (M.runMaybeT m))
134
135instance MFunctor (R.ReaderT r) where
136    hoist nat m = R.ReaderT (\i -> nat (R.runReaderT m i))
137
138instance MFunctor (RWS.RWST r w s) where
139    hoist nat m = RWS.RWST (\r s -> nat (RWS.runRWST m r s))
140
141instance MFunctor (RWS'.RWST r w s) where
142    hoist nat m = RWS'.RWST (\r s -> nat (RWS'.runRWST m r s))
143
144instance MFunctor (S.StateT s) where
145    hoist nat m = S.StateT (\s -> nat (S.runStateT m s))
146
147instance MFunctor (S'.StateT s) where
148    hoist nat m = S'.StateT (\s -> nat (S'.runStateT m s))
149
150instance MFunctor (W.WriterT w) where
151    hoist nat m = W.WriterT (nat (W.runWriterT m))
152
153instance MFunctor (W'.WriterT w) where
154    hoist nat m = W'.WriterT (nat (W'.runWriterT m))
155
156instance Functor f => MFunctor (Compose f) where
157    hoist nat (Compose f) = Compose (fmap nat f)
158
159instance MFunctor (Product f) where
160    hoist nat (Pair f g) = Pair f (nat g)
161
162instance MFunctor Backwards where
163    hoist nat (Backwards f) = Backwards (nat f)
164
165instance MFunctor Lift where
166    hoist _   (Pure a)  = Pure a
167    hoist nat (Other f) = Other (nat f)
168
169-- | A function that @generalize@s the 'Identity' base monad to be any monad.
170generalize :: Monad m => Identity a -> m a
171generalize = return . runIdentity
172{-# INLINABLE generalize #-}
173
174{-| A monad in the category of monads, using 'lift' from 'MonadTrans' as the
175    analog of 'return' and 'embed' as the analog of ('=<<'):
176
177> embed lift = id
178>
179> embed f (lift m) = f m
180>
181> embed g (embed f t) = embed (\m -> embed g (f m)) t
182-}
183class (MFunctor t, MonadTrans t) => MMonad t where
184    {-| Embed a newly created 'MMonad' layer within an existing layer
185
186        'embed' is analogous to ('=<<')
187    -}
188    embed :: (Monad n) => (forall a . m a -> t n a) -> t m b -> t n b
189
190{-| Squash two 'MMonad' layers into a single layer
191
192    'squash' is analogous to 'join'
193-}
194squash :: (Monad m, MMonad t) => t (t m) a -> t m a
195squash = embed id
196{-# INLINABLE squash #-}
197
198infixr 2 >|>, =<|
199infixl 2 <|<, |>=
200
201{-| Compose two 'MMonad' layer-building functions
202
203    ('>|>') is analogous to ('>=>')
204-}
205(>|>)
206    :: (Monad m3, MMonad t)
207    => (forall a . m1 a -> t m2 a)
208    -> (forall b . m2 b -> t m3 b)
209    ->             m1 c -> t m3 c
210(f >|> g) m = embed g (f m)
211{-# INLINABLE (>|>) #-}
212
213{-| Equivalent to ('>|>') with the arguments flipped
214
215    ('<|<') is analogous to ('<=<')
216-}
217(<|<)
218    :: (Monad m3, MMonad t)
219    => (forall b . m2 b -> t m3 b)
220    -> (forall a . m1 a -> t m2 a)
221    ->             m1 c -> t m3 c
222(g <|< f) m = embed g (f m)
223{-# INLINABLE (<|<) #-}
224
225{-| An infix operator equivalent to 'embed'
226
227    ('=<|') is analogous to ('=<<')
228-}
229(=<|) :: (Monad n, MMonad t) => (forall a . m a -> t n a) -> t m b -> t n b
230(=<|) = embed
231{-# INLINABLE (=<|) #-}
232
233{-| Equivalent to ('=<|') with the arguments flipped
234
235    ('|>=') is analogous to ('>>=')
236-}
237(|>=) :: (Monad n, MMonad t) => t m b -> (forall a . m a -> t n a) -> t n b
238t |>= f = embed f t
239{-# INLINABLE (|>=) #-}
240
241instance (E.Error e) => MMonad (E.ErrorT e) where
242    embed f m = E.ErrorT (do
243        x <- E.runErrorT (f (E.runErrorT m))
244        return (case x of
245            Left         e  -> Left e
246            Right (Left  e) -> Left e
247            Right (Right a) -> Right a ) )
248
249instance MMonad (Ex.ExceptT e) where
250    embed f m = Ex.ExceptT (do
251        x <- Ex.runExceptT (f (Ex.runExceptT m))
252        return (case x of
253            Left         e  -> Left e
254            Right (Left  e) -> Left e
255            Right (Right a) -> Right a ) )
256
257instance MMonad I.IdentityT where
258    embed f m = f (I.runIdentityT m)
259
260instance MMonad L.ListT where
261    embed f m = L.ListT (do
262        x <- L.runListT (f (L.runListT m))
263        return (concat x))
264
265instance MMonad M.MaybeT where
266    embed f m = M.MaybeT (do
267        x <- M.runMaybeT (f (M.runMaybeT m))
268        return (case x of
269            Nothing       -> Nothing
270            Just Nothing  -> Nothing
271            Just (Just a) -> Just a ) )
272
273instance MMonad (R.ReaderT r) where
274    embed f m = R.ReaderT (\i -> R.runReaderT (f (R.runReaderT m i)) i)
275
276instance (Monoid w) => MMonad (W.WriterT w) where
277    embed f m = W.WriterT (do
278        ~((a, w1), w2) <- W.runWriterT (f (W.runWriterT m))
279        return (a, mappend w1 w2) )
280
281instance (Monoid w) => MMonad (W'.WriterT w) where
282    embed f m = W'.WriterT (do
283        ((a, w1), w2) <- W'.runWriterT (f (W'.runWriterT m))
284        return (a, mappend w1 w2) )
285
286{- $tutorial
287    Monad morphisms solve the common problem of fixing monadic code after the
288    fact without modifying the original source code or type signatures.  The
289    following sections illustrate various examples of transparently modifying
290    existing functions.
291-}
292
293{- $generalize
294    Imagine that some library provided the following 'S.State' code:
295
296> import Control.Monad.Trans.State
297>
298> tick :: State Int ()
299> tick = modify (+1)
300
301    ... but we would prefer to reuse @tick@ within a larger
302    @('S.StateT' Int 'IO')@ block in order to mix in 'IO' actions.
303
304    We could patch the original library to generalize @tick@'s type signature:
305
306> tick :: (Monad m) => StateT Int m ()
307
308    ... but we would prefer not to fork upstream code if possible.  How could
309    we generalize @tick@'s type without modifying the original code?
310
311    We can solve this if we realize that 'S.State' is a type synonym for
312    'S.StateT' with an 'Identity' base monad:
313
314> type State s = StateT s Identity
315
316    ... which means that @tick@'s true type is actually:
317
318> tick :: StateT Int Identity ()
319
320    Now all we need is a function that @generalize@s the 'Identity' base monad
321    to be any monad:
322
323> import Data.Functor.Identity
324>
325> generalize :: (Monad m) => Identity a -> m a
326> generalize m = return (runIdentity m)
327
328    ... which we can 'hoist' to change @tick@'s base monad:
329
330> hoist :: (Monad m, MFunctor t) => (forall a . m a -> n a) -> t m b -> t n b
331>
332> hoist generalize :: (Monad m, MFunctor t) => t Identity b -> t m b
333>
334> hoist generalize tick :: (Monad m) => StateT Int m ()
335
336    This lets us mix @tick@ alongside 'IO' using 'lift':
337
338> import Control.Monad.Morph
339> import Control.Monad.Trans.Class
340>
341> tock                        ::                   StateT Int IO ()
342> tock = do
343>     hoist generalize tick   :: (Monad      m) => StateT Int m  ()
344>     lift $ putStrLn "Tock!" :: (MonadTrans t) => t          IO ()
345
346>>> runStateT tock 0
347Tock!
348((), 1)
349
350-}
351
352{- $mmorph
353    Notice that @generalize@ is a monad morphism, and the following two proofs
354    show how @generalize@ satisfies the monad morphism laws.  You can refer to
355    these proofs as an example for how to prove a function obeys the monad
356    morphism laws:
357
358> generalize (return x)
359>
360> -- Definition of 'return' for the Identity monad
361> = generalize (Identity x)
362>
363> -- Definition of 'generalize'
364> = return (runIdentity (Identity x))
365>
366> -- runIdentity (Identity x) = x
367> = return x
368
369> generalize $ do x <- m
370>                 f x
371>
372> -- Definition of (>>=) for the Identity monad
373> = generalize (f (runIdentity m))
374>
375> -- Definition of 'generalize'
376> = return (runIdentity (f (runIdentity m)))
377>
378> -- Monad law: Left identity
379> = do x <- return (runIdentity m)
380>      return (runIdentity (f x))
381>
382> -- Definition of 'generalize' in reverse
383> = do x <- generalize m
384>      generalize (f x)
385-}
386
387{- $interleave
388    You can combine 'hoist' and 'lift' to insert arbitrary layers anywhere
389    within a monad transformer stack.  This comes in handy when interleaving two
390    diverse stacks.
391
392    For example, we might want to combine the following @save@ function:
393
394> import Control.Monad.Trans.Writer
395>
396> -- i.e. :: StateT Int (WriterT [Int] Identity) ()
397> save    :: StateT Int (Writer  [Int]) ()
398> save = do
399>     n <- get
400>     lift $ tell [n]
401
402    ... with our previous @tock@ function:
403
404> tock :: StateT Int IO ()
405
406    However, @save@ and @tock@ differ in two ways:
407
408    * @tock@ lacks a 'W.WriterT' layer
409
410    * @save@ has an 'Identity' base monad
411
412    We can mix the two by inserting a 'W.WriterT' layer for @tock@ and
413    generalizing @save@'s base monad:
414
415> import Control.Monad
416>
417> program ::                   StateT Int (WriterT [Int] IO) ()
418> program = replicateM_ 4 $ do
419>     hoist lift tock
420>         :: (MonadTrans t) => StateT Int (t             IO) ()
421>     hoist (hoist generalize) save
422>         :: (Monad      m) => StateT Int (WriterT [Int] m ) ()
423
424>>> execWriterT (runStateT program 0)
425Tock!
426Tock!
427Tock!
428Tock!
429[1,2,3,4]
430
431-}
432
433{- $embed
434    Suppose we decided to @check@ all 'IOException's using a combination of
435    'try' and 'ErrorT':
436
437> import Control.Exception
438> import Control.Monad.Trans.Class
439> import Control.Monad.Trans.Error
440>
441> check :: IO a -> ErrorT IOException IO a
442> check io = ErrorT (try io)
443
444    ... but then we forget to use @check@ in one spot, mistakenly using 'lift'
445    instead:
446
447> program :: ErrorT IOException IO ()
448> program = do
449>     str <- lift $ readFile "test.txt"
450>     check $ putStr str
451
452>>> runErrorT program
453*** Exception: test.txt: openFile: does not exist (No such file or directory)
454
455    How could we go back and fix 'program' without modifying its source code?
456
457    Well, @check@ is a monad morphism, but we can't 'hoist' it to modify the
458    base monad because then we get two 'E.ErrorT' layers instead of one:
459
460> hoist check :: (MFunctor t) => t IO a -> t (ErrorT IOException IO) a
461>
462> hoist check program :: ErrorT IOException (ErrorT IOException IO) ()
463
464    We'd prefer to 'embed' all newly generated exceptions in the existing
465    'E.ErrorT' layer:
466
467> embed check :: ErrorT IOException IO a -> ErrorT IOException IO a
468>
469> embed check program :: ErrorT IOException IO ()
470
471    This correctly checks the exceptions that slipped through the cracks:
472
473>>> import Control.Monad.Morph
474>>> runErrorT (embed check program)
475Left test.txt: openFile: does not exist (No such file or directory)
476
477-}
478