1{-# LANGUAGE CPP #-}
2#if __GLASGOW_HASKELL__ >= 702
3{-# LANGUAGE Safe #-}
4#endif
5#if __GLASGOW_HASKELL__ >= 710
6{-# LANGUAGE AutoDeriveTypeable #-}
7#endif
8-----------------------------------------------------------------------------
9-- |
10-- Module      :  Control.Monad.Trans.State.Strict
11-- Copyright   :  (c) Andy Gill 2001,
12--                (c) Oregon Graduate Institute of Science and Technology, 2001
13-- License     :  BSD-style (see the file LICENSE)
14--
15-- Maintainer  :  R.Paterson@city.ac.uk
16-- Stability   :  experimental
17-- Portability :  portable
18--
19-- Strict state monads, passing an updatable state through a computation.
20-- See below for examples.
21--
22-- Some computations may not require the full power of state transformers:
23--
24-- * For a read-only state, see "Control.Monad.Trans.Reader".
25--
26-- * To accumulate a value without using it on the way, see
27--   "Control.Monad.Trans.Writer".
28--
29-- In this version, sequencing of computations is strict (but computations
30-- are not strict in the state unless you force it with 'seq' or the like).
31-- For a lazy version with the same interface, see
32-- "Control.Monad.Trans.State.Lazy".
33-----------------------------------------------------------------------------
34
35module Control.Monad.Trans.State.Strict (
36    -- * The State monad
37    State,
38    state,
39    runState,
40    evalState,
41    execState,
42    mapState,
43    withState,
44    -- * The StateT monad transformer
45    StateT(..),
46    evalStateT,
47    execStateT,
48    mapStateT,
49    withStateT,
50    -- * State operations
51    get,
52    put,
53    modify,
54    modify',
55    gets,
56    -- * Lifting other operations
57    liftCallCC,
58    liftCallCC',
59    liftCatch,
60    liftListen,
61    liftPass,
62    -- * Examples
63    -- ** State monads
64    -- $examples
65
66    -- ** Counting
67    -- $counting
68
69    -- ** Labelling trees
70    -- $labelling
71  ) where
72
73import Control.Monad.IO.Class
74import Control.Monad.Signatures
75import Control.Monad.Trans.Class
76#if MIN_VERSION_base(4,12,0)
77import Data.Functor.Contravariant
78#endif
79import Data.Functor.Identity
80
81import Control.Applicative
82import Control.Monad
83#if MIN_VERSION_base(4,9,0)
84import qualified Control.Monad.Fail as Fail
85#endif
86import Control.Monad.Fix
87
88-- ---------------------------------------------------------------------------
89-- | A state monad parameterized by the type @s@ of the state to carry.
90--
91-- The 'return' function leaves the state unchanged, while @>>=@ uses
92-- the final state of the first computation as the initial state of
93-- the second.
94type State s = StateT s Identity
95
96-- | Construct a state monad computation from a function.
97-- (The inverse of 'runState'.)
98state :: (Monad m)
99      => (s -> (a, s))  -- ^pure state transformer
100      -> StateT s m a   -- ^equivalent state-passing computation
101state f = StateT (return . f)
102{-# INLINE state #-}
103
104-- | Unwrap a state monad computation as a function.
105-- (The inverse of 'state'.)
106runState :: State s a   -- ^state-passing computation to execute
107         -> s           -- ^initial state
108         -> (a, s)      -- ^return value and final state
109runState m = runIdentity . runStateT m
110{-# INLINE runState #-}
111
112-- | Evaluate a state computation with the given initial state
113-- and return the final value, discarding the final state.
114--
115-- * @'evalState' m s = 'fst' ('runState' m s)@
116evalState :: State s a  -- ^state-passing computation to execute
117          -> s          -- ^initial value
118          -> a          -- ^return value of the state computation
119evalState m s = fst (runState m s)
120{-# INLINE evalState #-}
121
122-- | Evaluate a state computation with the given initial state
123-- and return the final state, discarding the final value.
124--
125-- * @'execState' m s = 'snd' ('runState' m s)@
126execState :: State s a  -- ^state-passing computation to execute
127          -> s          -- ^initial value
128          -> s          -- ^final state
129execState m s = snd (runState m s)
130{-# INLINE execState #-}
131
132-- | Map both the return value and final state of a computation using
133-- the given function.
134--
135-- * @'runState' ('mapState' f m) = f . 'runState' m@
136mapState :: ((a, s) -> (b, s)) -> State s a -> State s b
137mapState f = mapStateT (Identity . f . runIdentity)
138{-# INLINE mapState #-}
139
140-- | @'withState' f m@ executes action @m@ on a state modified by
141-- applying @f@.
142--
143-- * @'withState' f m = 'modify' f >> m@
144withState :: (s -> s) -> State s a -> State s a
145withState = withStateT
146{-# INLINE withState #-}
147
148-- ---------------------------------------------------------------------------
149-- | A state transformer monad parameterized by:
150--
151--   * @s@ - The state.
152--
153--   * @m@ - The inner monad.
154--
155-- The 'return' function leaves the state unchanged, while @>>=@ uses
156-- the final state of the first computation as the initial state of
157-- the second.
158newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
159
160-- | Evaluate a state computation with the given initial state
161-- and return the final value, discarding the final state.
162--
163-- * @'evalStateT' m s = 'liftM' 'fst' ('runStateT' m s)@
164evalStateT :: (Monad m) => StateT s m a -> s -> m a
165evalStateT m s = do
166    (a, _) <- runStateT m s
167    return a
168{-# INLINE evalStateT #-}
169
170-- | Evaluate a state computation with the given initial state
171-- and return the final state, discarding the final value.
172--
173-- * @'execStateT' m s = 'liftM' 'snd' ('runStateT' m s)@
174execStateT :: (Monad m) => StateT s m a -> s -> m s
175execStateT m s = do
176    (_, s') <- runStateT m s
177    return s'
178{-# INLINE execStateT #-}
179
180-- | Map both the return value and final state of a computation using
181-- the given function.
182--
183-- * @'runStateT' ('mapStateT' f m) = f . 'runStateT' m@
184mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
185mapStateT f m = StateT $ f . runStateT m
186{-# INLINE mapStateT #-}
187
188-- | @'withStateT' f m@ executes action @m@ on a state modified by
189-- applying @f@.
190--
191-- * @'withStateT' f m = 'modify' f >> m@
192withStateT :: (s -> s) -> StateT s m a -> StateT s m a
193withStateT f m = StateT $ runStateT m . f
194{-# INLINE withStateT #-}
195
196instance (Functor m) => Functor (StateT s m) where
197    fmap f m = StateT $ \ s ->
198        fmap (\ (a, s') -> (f a, s')) $ runStateT m s
199    {-# INLINE fmap #-}
200
201instance (Functor m, Monad m) => Applicative (StateT s m) where
202    pure a = StateT $ \ s -> return (a, s)
203    {-# INLINE pure #-}
204    StateT mf <*> StateT mx = StateT $ \ s -> do
205        (f, s') <- mf s
206        (x, s'') <- mx s'
207        return (f x, s'')
208    {-# INLINE (<*>) #-}
209    m *> k = m >>= \_ -> k
210    {-# INLINE (*>) #-}
211
212instance (Functor m, MonadPlus m) => Alternative (StateT s m) where
213    empty = StateT $ \ _ -> mzero
214    {-# INLINE empty #-}
215    StateT m <|> StateT n = StateT $ \ s -> m s `mplus` n s
216    {-# INLINE (<|>) #-}
217
218instance (Monad m) => Monad (StateT s m) where
219#if !(MIN_VERSION_base(4,8,0))
220    return a = StateT $ \ s -> return (a, s)
221    {-# INLINE return #-}
222#endif
223    m >>= k  = StateT $ \ s -> do
224        (a, s') <- runStateT m s
225        runStateT (k a) s'
226    {-# INLINE (>>=) #-}
227#if !(MIN_VERSION_base(4,13,0))
228    fail str = StateT $ \ _ -> fail str
229    {-# INLINE fail #-}
230#endif
231
232#if MIN_VERSION_base(4,9,0)
233instance (Fail.MonadFail m) => Fail.MonadFail (StateT s m) where
234    fail str = StateT $ \ _ -> Fail.fail str
235    {-# INLINE fail #-}
236#endif
237
238instance (MonadPlus m) => MonadPlus (StateT s m) where
239    mzero       = StateT $ \ _ -> mzero
240    {-# INLINE mzero #-}
241    StateT m `mplus` StateT n = StateT $ \ s -> m s `mplus` n s
242    {-# INLINE mplus #-}
243
244instance (MonadFix m) => MonadFix (StateT s m) where
245    mfix f = StateT $ \ s -> mfix $ \ ~(a, _) -> runStateT (f a) s
246    {-# INLINE mfix #-}
247
248instance MonadTrans (StateT s) where
249    lift m = StateT $ \ s -> do
250        a <- m
251        return (a, s)
252    {-# INLINE lift #-}
253
254instance (MonadIO m) => MonadIO (StateT s m) where
255    liftIO = lift . liftIO
256    {-# INLINE liftIO #-}
257
258#if MIN_VERSION_base(4,12,0)
259instance Contravariant m => Contravariant (StateT s m) where
260    contramap f m = StateT $ \s ->
261      contramap (\ (a, s') -> (f a, s')) $ runStateT m s
262    {-# INLINE contramap #-}
263#endif
264
265-- | Fetch the current value of the state within the monad.
266get :: (Monad m) => StateT s m s
267get = state $ \ s -> (s, s)
268{-# INLINE get #-}
269
270-- | @'put' s@ sets the state within the monad to @s@.
271put :: (Monad m) => s -> StateT s m ()
272put s = state $ \ _ -> ((), s)
273{-# INLINE put #-}
274
275-- | @'modify' f@ is an action that updates the state to the result of
276-- applying @f@ to the current state.
277--
278-- * @'modify' f = 'get' >>= ('put' . f)@
279modify :: (Monad m) => (s -> s) -> StateT s m ()
280modify f = state $ \ s -> ((), f s)
281{-# INLINE modify #-}
282
283-- | A variant of 'modify' in which the computation is strict in the
284-- new state.
285--
286-- * @'modify'' f = 'get' >>= (('$!') 'put' . f)@
287modify' :: (Monad m) => (s -> s) -> StateT s m ()
288modify' f = do
289    s <- get
290    put $! f s
291{-# INLINE modify' #-}
292
293-- | Get a specific component of the state, using a projection function
294-- supplied.
295--
296-- * @'gets' f = 'liftM' f 'get'@
297gets :: (Monad m) => (s -> a) -> StateT s m a
298gets f = state $ \ s -> (f s, s)
299{-# INLINE gets #-}
300
301-- | Uniform lifting of a @callCC@ operation to the new monad.
302-- This version rolls back to the original state on entering the
303-- continuation.
304liftCallCC :: CallCC m (a,s) (b,s) -> CallCC (StateT s m) a b
305liftCallCC callCC f = StateT $ \ s ->
306    callCC $ \ c ->
307    runStateT (f (\ a -> StateT $ \ _ -> c (a, s))) s
308{-# INLINE liftCallCC #-}
309
310-- | In-situ lifting of a @callCC@ operation to the new monad.
311-- This version uses the current state on entering the continuation.
312-- It does not satisfy the uniformity property (see "Control.Monad.Signatures").
313liftCallCC' :: CallCC m (a,s) (b,s) -> CallCC (StateT s m) a b
314liftCallCC' callCC f = StateT $ \ s ->
315    callCC $ \ c ->
316    runStateT (f (\ a -> StateT $ \ s' -> c (a, s'))) s
317{-# INLINE liftCallCC' #-}
318
319-- | Lift a @catchE@ operation to the new monad.
320liftCatch :: Catch e m (a,s) -> Catch e (StateT s m) a
321liftCatch catchE m h =
322    StateT $ \ s -> runStateT m s `catchE` \ e -> runStateT (h e) s
323{-# INLINE liftCatch #-}
324
325-- | Lift a @listen@ operation to the new monad.
326liftListen :: (Monad m) => Listen w m (a,s) -> Listen w (StateT s m) a
327liftListen listen m = StateT $ \ s -> do
328    ((a, s'), w) <- listen (runStateT m s)
329    return ((a, w), s')
330{-# INLINE liftListen #-}
331
332-- | Lift a @pass@ operation to the new monad.
333liftPass :: (Monad m) => Pass w m (a,s) -> Pass w (StateT s m) a
334liftPass pass m = StateT $ \ s -> pass $ do
335    ((a, f), s') <- runStateT m s
336    return ((a, s'), f)
337{-# INLINE liftPass #-}
338
339{- $examples
340
341Parser from ParseLib with Hugs:
342
343> type Parser a = StateT String [] a
344>    ==> StateT (String -> [(a,String)])
345
346For example, item can be written as:
347
348> item = do (x:xs) <- get
349>        put xs
350>        return x
351>
352> type BoringState s a = StateT s Identity a
353>      ==> StateT (s -> Identity (a,s))
354>
355> type StateWithIO s a = StateT s IO a
356>      ==> StateT (s -> IO (a,s))
357>
358> type StateWithErr s a = StateT s Maybe a
359>      ==> StateT (s -> Maybe (a,s))
360
361-}
362
363{- $counting
364
365A function to increment a counter.
366Taken from the paper \"Generalising Monads to Arrows\",
367John Hughes (<http://www.cse.chalmers.se/~rjmh/>), November 1998:
368
369> tick :: State Int Int
370> tick = do n <- get
371>           put (n+1)
372>           return n
373
374Add one to the given number using the state monad:
375
376> plusOne :: Int -> Int
377> plusOne n = execState tick n
378
379A contrived addition example. Works only with positive numbers:
380
381> plus :: Int -> Int -> Int
382> plus n x = execState (sequence $ replicate n tick) x
383
384-}
385
386{- $labelling
387
388An example from /The Craft of Functional Programming/, Simon
389Thompson (<http://www.cs.kent.ac.uk/people/staff/sjt/>),
390Addison-Wesley 1999: \"Given an arbitrary tree, transform it to a
391tree of integers in which the original elements are replaced by
392natural numbers, starting from 0.  The same element has to be
393replaced by the same number at every occurrence, and when we meet
394an as-yet-unvisited element we have to find a \'new\' number to match
395it with:\"
396
397> data Tree a = Nil | Node a (Tree a) (Tree a) deriving (Show, Eq)
398> type Table a = [a]
399
400> numberTree :: Eq a => Tree a -> State (Table a) (Tree Int)
401> numberTree Nil = return Nil
402> numberTree (Node x t1 t2) = do
403>     num <- numberNode x
404>     nt1 <- numberTree t1
405>     nt2 <- numberTree t2
406>     return (Node num nt1 nt2)
407>   where
408>     numberNode :: Eq a => a -> State (Table a) Int
409>     numberNode x = do
410>         table <- get
411>         case elemIndex x table of
412>             Nothing -> do
413>                 put (table ++ [x])
414>                 return (length table)
415>             Just i -> return i
416
417numTree applies numberTree with an initial state:
418
419> numTree :: (Eq a) => Tree a -> Tree Int
420> numTree t = evalState (numberTree t) []
421
422> testTree = Node "Zero" (Node "One" (Node "Two" Nil Nil) (Node "One" (Node "Zero" Nil Nil) Nil)) Nil
423> numTree testTree => Node 0 (Node 1 (Node 2 Nil Nil) (Node 1 (Node 0 Nil Nil) Nil)) Nil
424
425-}
426