1{-# LANGUAGE CPP #-} 2{-# LANGUAGE ExistentialQuantification #-} 3{-# LANGUAGE GeneralizedNewtypeDeriving #-} 4{-# LANGUAGE RankNTypes #-} 5{-# LANGUAGE FlexibleInstances #-} 6{-# LANGUAGE MultiParamTypeClasses #-} 7{-# LANGUAGE UndecidableInstances #-} 8 9#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702 10{-# LANGUAGE Trustworthy #-} 11#endif 12 13#ifndef MIN_VERSION_transformers 14#define MIN_VERSION_transformers(x,y,z) 1 15#endif 16 17#ifndef MIN_VERSION_mtl 18#define MIN_VERSION_mtl(x,y,z) 1 19#endif 20 21-------------------------------------------------------------------- 22-- | 23-- Copyright : (C) Edward Kmett 2013-2015, (c) Google Inc. 2012 24-- License : BSD-style (see the file LICENSE) 25-- Maintainer : Edward Kmett <ekmett@gmail.com> 26-- Stability : experimental 27-- Portability : non-portable 28-- 29-- This module supplies a \'pure\' monad transformer that can be used for 30-- mock-testing code that throws exceptions, so long as those exceptions 31-- are always thrown with 'throwM'. 32-- 33-- Do not mix 'CatchT' with 'IO'. Choose one or the other for the 34-- bottom of your transformer stack! 35-------------------------------------------------------------------- 36 37module Control.Monad.Catch.Pure ( 38 -- * Transformer 39 -- $transformer 40 CatchT(..), Catch 41 , runCatch 42 , mapCatchT 43 44 -- * Typeclass 45 -- $mtl 46 , module Control.Monad.Catch 47 ) where 48 49#if defined(__GLASGOW_HASKELL__) && (__GLASGOW_HASKELL__ >= 706) 50import Prelude hiding (foldr) 51#else 52import Prelude hiding (catch, foldr) 53#endif 54 55import Control.Applicative 56import Control.Monad.Catch 57import qualified Control.Monad.Fail as Fail 58import Control.Monad.Reader as Reader 59import Control.Monad.RWS 60#if __GLASGOW_HASKELL__ < 710 61import Data.Foldable 62#endif 63import Data.Functor.Identity 64import Data.Traversable as Traversable 65 66------------------------------------------------------------------------------ 67-- $mtl 68-- The mtl style typeclass 69------------------------------------------------------------------------------ 70 71------------------------------------------------------------------------------ 72-- $transformer 73-- The @transformers@-style monad transfomer 74------------------------------------------------------------------------------ 75 76-- | Add 'Exception' handling abilities to a 'Monad'. 77-- 78-- This should /never/ be used in combination with 'IO'. Think of 'CatchT' 79-- as an alternative base monad for use with mocking code that solely throws 80-- exceptions via 'throwM'. 81-- 82-- Note: that 'IO' monad has these abilities already, so stacking 'CatchT' on top 83-- of it does not add any value and can possibly be confusing: 84-- 85-- >>> (error "Hello!" :: IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e) 86-- Hello! 87-- 88-- >>> runCatchT $ (error "Hello!" :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e) 89-- *** Exception: Hello! 90-- 91-- >>> runCatchT $ (throwM (ErrorCall "Hello!") :: CatchT IO ()) `catch` (\(e :: ErrorCall) -> liftIO $ print e) 92-- Hello! 93 94newtype CatchT m a = CatchT { runCatchT :: m (Either SomeException a) } 95 96type Catch = CatchT Identity 97 98runCatch :: Catch a -> Either SomeException a 99runCatch = runIdentity . runCatchT 100 101instance Monad m => Functor (CatchT m) where 102 fmap f (CatchT m) = CatchT (liftM (fmap f) m) 103 104instance Monad m => Applicative (CatchT m) where 105 pure a = CatchT (return (Right a)) 106 (<*>) = ap 107 108instance Monad m => Monad (CatchT m) where 109 return = pure 110 CatchT m >>= k = CatchT $ m >>= \ea -> case ea of 111 Left e -> return (Left e) 112 Right a -> runCatchT (k a) 113#if !(MIN_VERSION_base(4,13,0)) 114 fail = Fail.fail 115#endif 116 117instance Monad m => Fail.MonadFail (CatchT m) where 118 fail = CatchT . return . Left . toException . userError 119 120instance MonadFix m => MonadFix (CatchT m) where 121 mfix f = CatchT $ mfix $ \a -> runCatchT $ f $ case a of 122 Right r -> r 123 _ -> error "empty mfix argument" 124 125instance Foldable m => Foldable (CatchT m) where 126 foldMap f (CatchT m) = foldMap (foldMapEither f) m where 127 foldMapEither g (Right a) = g a 128 foldMapEither _ (Left _) = mempty 129 130instance (Monad m, Traversable m) => Traversable (CatchT m) where 131 traverse f (CatchT m) = CatchT <$> Traversable.traverse (traverseEither f) m where 132 traverseEither g (Right a) = Right <$> g a 133 traverseEither _ (Left e) = pure (Left e) 134 135instance Monad m => Alternative (CatchT m) where 136 empty = mzero 137 (<|>) = mplus 138 139instance Monad m => MonadPlus (CatchT m) where 140 mzero = CatchT $ return $ Left $ toException $ userError "" 141 mplus (CatchT m) (CatchT n) = CatchT $ m >>= \ea -> case ea of 142 Left _ -> n 143 Right a -> return (Right a) 144 145instance MonadTrans CatchT where 146 lift m = CatchT $ do 147 a <- m 148 return $ Right a 149 150instance MonadIO m => MonadIO (CatchT m) where 151 liftIO m = CatchT $ do 152 a <- liftIO m 153 return $ Right a 154 155instance Monad m => MonadThrow (CatchT m) where 156 throwM = CatchT . return . Left . toException 157instance Monad m => MonadCatch (CatchT m) where 158 catch (CatchT m) c = CatchT $ m >>= \ea -> case ea of 159 Left e -> case fromException e of 160 Just e' -> runCatchT (c e') 161 Nothing -> return (Left e) 162 Right a -> return (Right a) 163-- | Note: This instance is only valid if the underlying monad has a single 164-- exit point! 165-- 166-- For example, @IO@ or @Either@ would be invalid base monads, but 167-- @Reader@ or @State@ would be acceptable. 168instance Monad m => MonadMask (CatchT m) where 169 mask a = a id 170 uninterruptibleMask a = a id 171 generalBracket acquire release use = CatchT $ do 172 eresource <- runCatchT acquire 173 case eresource of 174 Left e -> return $ Left e 175 Right resource -> do 176 eb <- runCatchT (use resource) 177 case eb of 178 Left e -> runCatchT $ do 179 _ <- release resource (ExitCaseException e) 180 throwM e 181 Right b -> runCatchT $ do 182 c <- release resource (ExitCaseSuccess b) 183 return (b, c) 184 185instance MonadState s m => MonadState s (CatchT m) where 186 get = lift get 187 put = lift . put 188#if MIN_VERSION_mtl(2,1,0) 189 state = lift . state 190#endif 191 192instance MonadReader e m => MonadReader e (CatchT m) where 193 ask = lift ask 194 local f (CatchT m) = CatchT (local f m) 195 196instance MonadWriter w m => MonadWriter w (CatchT m) where 197 tell = lift . tell 198 listen = mapCatchT $ \ m -> do 199 (a, w) <- listen m 200 return $! fmap (\ r -> (r, w)) a 201 pass = mapCatchT $ \ m -> pass $ do 202 a <- m 203 return $! case a of 204 Left l -> (Left l, id) 205 Right (r, f) -> (Right r, f) 206#if MIN_VERSION_mtl(2,1,0) 207 writer aw = CatchT (Right `liftM` writer aw) 208#endif 209 210instance MonadRWS r w s m => MonadRWS r w s (CatchT m) 211 212-- | Map the unwrapped computation using the given function. 213-- 214-- @'runCatchT' ('mapCatchT' f m) = f ('runCatchT' m)@ 215mapCatchT :: (m (Either SomeException a) -> n (Either SomeException b)) 216 -> CatchT m a 217 -> CatchT n b 218mapCatchT f m = CatchT $ f (runCatchT m) 219