1{-# LANGUAGE CPP #-}
2{-# LANGUAGE ExistentialQuantification #-}
3{-# LANGUAGE GeneralizedNewtypeDeriving #-}
4{-# LANGUAGE RankNTypes #-}
5{-# LANGUAGE FlexibleInstances #-}
6{-# LANGUAGE MultiParamTypeClasses #-}
7{-# LANGUAGE UndecidableInstances #-}
8
9#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
10{-# LANGUAGE Trustworthy #-}
11#endif
12
13#ifndef MIN_VERSION_transformers
14#define MIN_VERSION_transformers(x,y,z) 1
15#endif
16
17#ifndef MIN_VERSION_mtl
18#define MIN_VERSION_mtl(x,y,z) 1
19#endif
20
21--------------------------------------------------------------------
22-- |
23-- Copyright   :  (C) Edward Kmett 2013-2015, (c) Google Inc. 2012
24-- License     :  BSD-style (see the file LICENSE)
25-- Maintainer  :  Edward Kmett <ekmett@gmail.com>
26-- Stability   :  experimental
27-- Portability :  non-portable
28--
29-- This module supplies a \'pure\' monad transformer that can be used for
30-- mock-testing code that throws exceptions, so long as those exceptions
31-- are always thrown with 'throwM'.
32--
33-- Do not mix 'CatchT' with 'IO'. Choose one or the other for the
34-- bottom of your transformer stack!
35--------------------------------------------------------------------
36
37module Control.Monad.Catch.Pure (
38    -- * Transformer
39    -- $transformer
40    CatchT(..), Catch
41  , runCatch
42  , mapCatchT
43
44  -- * Typeclass
45  -- $mtl
46  , module Control.Monad.Catch
47  ) where
48
49#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706)
50import Prelude hiding (foldr)
51#else
52import Prelude hiding (catch, foldr)
53#endif
54
55import Control.Applicative
56import Control.Monad.Catch
57import qualified Control.Monad.Fail as Fail
58import Control.Monad.Reader as Reader
59import Control.Monad.RWS
60#if __GLASGOW_HASKELL__ < 710
61import Data.Foldable
62#endif
63import Data.Functor.Identity
64import Data.Traversable as Traversable
65
66------------------------------------------------------------------------------
67-- $mtl
68-- The mtl style typeclass
69------------------------------------------------------------------------------
70
71------------------------------------------------------------------------------
72-- $transformer
73-- The @transformers@-style monad transfomer
74------------------------------------------------------------------------------
75
76-- | Add 'Exception' handling abilities to a 'Monad'.
77--
78-- This should /never/ be used in combination with 'IO'. Think of 'CatchT'
79-- as an alternative base monad for use with mocking code that solely throws
80-- exceptions via 'throwM'.
81--
82-- Note: that 'IO' monad has these abilities already, so stacking 'CatchT' on top
83-- of it does not add any value and can possibly be confusing:
84--
85-- >>> (error "Hello!" :: IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
86-- Hello!
87--
88-- >>> runCatchT $ (error "Hello!" :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
89-- *** Exception: Hello!
90--
91-- >>> runCatchT $ (throwM (ErrorCall "Hello!") :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e)
92-- Hello!
93
94newtype CatchT m a = CatchT { runCatchT :: m (Either SomeException a) }
95
96type Catch = CatchT Identity
97
98runCatch :: Catch a -> Either SomeException a
99runCatch = runIdentity . runCatchT
100
101instance Monad m => Functor (CatchT m) where
102  fmap f (CatchT m) = CatchT (liftM (fmap f) m)
103
104instance Monad m => Applicative (CatchT m) where
105  pure a = CatchT (return (Right a))
106  (<*>) = ap
107
108instance Monad m => Monad (CatchT m) where
109  return = pure
110  CatchT m >>= k = CatchT $ m >>= \ea -> case ea of
111    Left e -> return (Left e)
112    Right a -> runCatchT (k a)
113#if !(MIN_VERSION_base(4,13,0))
114  fail = Fail.fail
115#endif
116
117instance Monad m => Fail.MonadFail (CatchT m) where
118  fail = CatchT . return . Left . toException . userError
119
120instance MonadFix m => MonadFix (CatchT m) where
121  mfix f = CatchT $ mfix $ \a -> runCatchT $ f $ case a of
122    Right r -> r
123    _       -> error "empty mfix argument"
124
125instance Foldable m => Foldable (CatchT m) where
126  foldMap f (CatchT m) = foldMap (foldMapEither f) m where
127    foldMapEither g (Right a) = g a
128    foldMapEither _ (Left _) = mempty
129
130instance (Monad m, Traversable m) => Traversable (CatchT m) where
131  traverse f (CatchT m) = CatchT <$> Traversable.traverse (traverseEither f) m where
132    traverseEither g (Right a) = Right <$> g a
133    traverseEither _ (Left e) = pure (Left e)
134
135instance Monad m => Alternative (CatchT m) where
136  empty = mzero
137  (<|>) = mplus
138
139instance Monad m => MonadPlus (CatchT m) where
140  mzero = CatchT $ return $ Left $ toException $ userError ""
141  mplus (CatchT m) (CatchT n) = CatchT $ m >>= \ea -> case ea of
142    Left _ -> n
143    Right a -> return (Right a)
144
145instance MonadTrans CatchT where
146  lift m = CatchT $ do
147    a <- m
148    return $ Right a
149
150instance MonadIO m => MonadIO (CatchT m) where
151  liftIO m = CatchT $ do
152    a <- liftIO m
153    return $ Right a
154
155instance Monad m => MonadThrow (CatchT m) where
156  throwM = CatchT . return . Left . toException
157instance Monad m => MonadCatch (CatchT m) where
158  catch (CatchT m) c = CatchT $ m >>= \ea -> case ea of
159    Left e -> case fromException e of
160      Just e' -> runCatchT (c e')
161      Nothing -> return (Left e)
162    Right a -> return (Right a)
163-- | Note: This instance is only valid if the underlying monad has a single
164-- exit point!
165--
166-- For example, @IO@ or @Either@ would be invalid base monads, but
167-- @Reader@ or @State@ would be acceptable.
168instance Monad m => MonadMask (CatchT m) where
169  mask a = a id
170  uninterruptibleMask a = a id
171  generalBracket acquire release use = CatchT $ do
172    eresource <- runCatchT acquire
173    case eresource of
174      Left e -> return $ Left e
175      Right resource -> do
176        eb <- runCatchT (use resource)
177        case eb of
178          Left e -> runCatchT $ do
179            _ <- release resource (ExitCaseException e)
180            throwM e
181          Right b -> runCatchT $ do
182            c <- release resource (ExitCaseSuccess b)
183            return (b, c)
184
185instance MonadState s m => MonadState s (CatchT m) where
186  get = lift get
187  put = lift . put
188#if MIN_VERSION_mtl(2,1,0)
189  state = lift . state
190#endif
191
192instance MonadReader e m => MonadReader e (CatchT m) where
193  ask = lift ask
194  local f (CatchT m) = CatchT (local f m)
195
196instance MonadWriter w m => MonadWriter w (CatchT m) where
197  tell = lift . tell
198  listen = mapCatchT $ \ m -> do
199    (a, w) <- listen m
200    return $! fmap (\ r -> (r, w)) a
201  pass = mapCatchT $ \ m -> pass $ do
202    a <- m
203    return $! case a of
204        Left  l      -> (Left  l, id)
205        Right (r, f) -> (Right r, f)
206#if MIN_VERSION_mtl(2,1,0)
207  writer aw = CatchT (Right `liftM` writer aw)
208#endif
209
210instance MonadRWS r w s m => MonadRWS r w s (CatchT m)
211
212-- | Map the unwrapped computation using the given function.
213--
214-- @'runCatchT' ('mapCatchT' f m) = f ('runCatchT' m)@
215mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b))
216          -> CatchT m a
217          -> CatchT n b
218mapCatchT f m = CatchT $ f (runCatchT m)
219