1{-# LANGUAGE TupleSections #-} 2module Foundation.Monad.State 3 ( -- * MonadState 4 MonadState(..) 5 , get 6 , put 7 8 , -- * StateT 9 StateT 10 , runStateT 11 ) where 12 13import Basement.Compat.Bifunctor (first) 14import Basement.Compat.Base (($), (.), const) 15import Foundation.Monad.Base 16import Control.Monad ((>=>)) 17 18class Monad m => MonadState m where 19 type State m 20 withState :: (State m -> (a, State m)) -> m a 21 22get :: MonadState m => m (State m) 23get = withState $ \s -> (s, s) 24 25put :: MonadState m => State m -> m () 26put s = withState $ const ((), s) 27 28-- | State Transformer 29newtype StateT s m a = StateT { runStateT :: s -> m (a, s) } 30 31instance Functor m => Functor (StateT s m) where 32 fmap f m = StateT $ \s1 -> (first f) `fmap` runStateT m s1 33 {-# INLINE fmap #-} 34 35instance (Applicative m, Monad m) => Applicative (StateT s m) where 36 pure a = StateT $ \s -> (,s) `fmap` pure a 37 {-# INLINE pure #-} 38 fab <*> fa = StateT $ \s1 -> do 39 (ab,s2) <- runStateT fab s1 40 (a, s3) <- runStateT fa s2 41 return (ab a, s3) 42 {-# INLINE (<*>) #-} 43 44instance (Functor m, Monad m) => Monad (StateT s m) where 45 return a = StateT $ \s -> (,s) `fmap` return a 46 {-# INLINE return #-} 47 ma >>= mab = StateT $ runStateT ma >=> (\(a, s2) -> runStateT (mab a) s2) 48 {-# INLINE (>>=) #-} 49 50instance (Functor m, MonadFix m) => MonadFix (StateT s m) where 51 mfix f = StateT $ \s -> mfix $ \ ~(a, _) -> runStateT (f a) s 52 {-# INLINE mfix #-} 53 54instance MonadTrans (StateT s) where 55 lift f = StateT $ \s -> f >>= return . (,s) 56 {-# INLINE lift #-} 57 58instance (Functor m, MonadIO m) => MonadIO (StateT s m) where 59 liftIO f = lift (liftIO f) 60 {-# INLINE liftIO #-} 61 62instance (Functor m, MonadFailure m) => MonadFailure (StateT s m) where 63 type Failure (StateT s m) = Failure m 64 mFail e = StateT $ \s -> ((,s) `fmap` mFail e) 65 66instance (Functor m, MonadThrow m) => MonadThrow (StateT s m) where 67 throw e = StateT $ \_ -> throw e 68 69instance (Functor m, MonadCatch m) => MonadCatch (StateT s m) where 70 catch (StateT m) c = StateT $ \s1 -> m s1 `catch` (\e -> runStateT (c e) s1) 71 72instance (Functor m, Monad m) => MonadState (StateT s m) where 73 type State (StateT s m) = s 74 withState f = StateT $ return . f 75