1{-# LANGUAGE TupleSections #-}
2module Foundation.Monad.State
3    ( -- * MonadState
4      MonadState(..)
5    , get
6    , put
7
8    , -- * StateT
9      StateT
10    , runStateT
11    ) where
12
13import Basement.Compat.Bifunctor (first)
14import Basement.Compat.Base (($), (.), const)
15import Foundation.Monad.Base
16import Control.Monad ((>=>))
17
18class Monad m => MonadState m where
19    type State m
20    withState :: (State m -> (a, State m)) -> m a
21
22get :: MonadState m => m (State m)
23get = withState $ \s -> (s, s)
24
25put :: MonadState m => State m -> m ()
26put s = withState $ const ((), s)
27
28-- | State Transformer
29newtype StateT s m a = StateT { runStateT :: s -> m (a, s) }
30
31instance Functor m => Functor (StateT s m) where
32    fmap f m = StateT $ \s1 -> (first f) `fmap` runStateT m s1
33    {-# INLINE fmap #-}
34
35instance (Applicative m, Monad m) => Applicative (StateT s m) where
36    pure a     = StateT $ \s -> (,s) `fmap` pure a
37    {-# INLINE pure #-}
38    fab <*> fa = StateT $ \s1 -> do
39        (ab,s2) <- runStateT fab s1
40        (a, s3) <- runStateT fa s2
41        return (ab a, s3)
42    {-# INLINE (<*>) #-}
43
44instance (Functor m, Monad m) => Monad (StateT s m) where
45    return a = StateT $ \s -> (,s) `fmap` return a
46    {-# INLINE return #-}
47    ma >>= mab = StateT $ runStateT ma >=> (\(a, s2) -> runStateT (mab a) s2)
48    {-# INLINE (>>=) #-}
49
50instance (Functor m, MonadFix m) => MonadFix (StateT s m) where
51    mfix f = StateT $ \s -> mfix $ \ ~(a, _) -> runStateT (f a) s
52    {-# INLINE mfix #-}
53
54instance MonadTrans (StateT s) where
55    lift f = StateT $ \s -> f >>= return . (,s)
56    {-# INLINE lift #-}
57
58instance (Functor m, MonadIO m) => MonadIO (StateT s m) where
59    liftIO f = lift (liftIO f)
60    {-# INLINE liftIO #-}
61
62instance (Functor m, MonadFailure m) => MonadFailure (StateT s m) where
63    type Failure (StateT s m) = Failure m
64    mFail e = StateT $ \s -> ((,s) `fmap` mFail e)
65
66instance (Functor m, MonadThrow m) => MonadThrow (StateT s m) where
67    throw e = StateT $ \_ -> throw e
68
69instance (Functor m, MonadCatch m) => MonadCatch (StateT s m) where
70    catch (StateT m) c = StateT $ \s1 -> m s1 `catch` (\e -> runStateT (c e) s1)
71
72instance (Functor m, Monad m) => MonadState (StateT s m) where
73    type State (StateT s m) = s
74    withState f = StateT $ return . f
75