1{-# LANGUAGE CPP                 #-}
2{-# LANGUAGE TupleSections       #-}
3{-# LANGUAGE DeriveDataTypeable  #-}
4{-# LANGUAGE DeriveFunctor       #-}
5{-# LANGUAGE DeriveGeneric       #-}
6{-# LANGUAGE GADTs               #-}
7{-# LANGUAGE LambdaCase          #-}
8{-# LANGUAGE RankNTypes          #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE StandaloneDeriving  #-}
11module UnliftIO.Internals.Async where
12
13import           Control.Applicative
14import           Control.Concurrent       (threadDelay, getNumCapabilities)
15import qualified Control.Concurrent       as C
16import           Control.Concurrent.Async (Async)
17import qualified Control.Concurrent.Async as A
18import           Control.Concurrent.STM
19import           Control.Exception        (Exception, SomeException)
20import           Control.Monad            (forever, liftM, unless, void, (>=>))
21import           Control.Monad.IO.Unlift
22import           Data.Foldable            (for_, traverse_)
23import           Data.Typeable            (Typeable)
24import           Data.IORef (IORef, readIORef, atomicWriteIORef, newIORef, atomicModifyIORef')
25import qualified UnliftIO.Exception       as UE
26
27-- For the implementation of Conc below, we do not want any of the
28-- smart async exception handling logic from UnliftIO.Exception, since
29-- (eg) we're low-level enough to need to explicit be throwing async
30-- exceptions synchronously.
31import qualified Control.Exception        as E
32import           GHC.Generics             (Generic)
33
34#if MIN_VERSION_base(4,9,0)
35import           Data.Semigroup
36#else
37import           Data.Monoid              hiding (Alt)
38#endif
39import           Data.Foldable            (Foldable, toList)
40import           Data.Traversable         (Traversable, for, traverse)
41
42-- | Unlifted 'A.async'.
43--
44-- @since 0.1.0.0
45async :: MonadUnliftIO m => m a -> m (Async a)
46async m = withRunInIO $ \run -> A.async $ run m
47
48-- | Unlifted 'A.asyncBound'.
49--
50-- @since 0.1.0.0
51asyncBound :: MonadUnliftIO m => m a -> m (Async a)
52asyncBound m = withRunInIO $ \run -> A.asyncBound $ run m
53
54-- | Unlifted 'A.asyncOn'.
55--
56-- @since 0.1.0.0
57asyncOn :: MonadUnliftIO m => Int -> m a -> m (Async a)
58asyncOn i m = withRunInIO $ \run -> A.asyncOn i $ run m
59
60-- | Unlifted 'A.asyncWithUnmask'.
61--
62-- @since 0.1.0.0
63asyncWithUnmask :: MonadUnliftIO m => ((forall b. m b -> m b) -> m a) -> m (Async a)
64asyncWithUnmask m =
65  withRunInIO $ \run -> A.asyncWithUnmask $ \unmask -> run $ m $ liftIO . unmask . run
66
67-- | Unlifted 'A.asyncOnWithUnmask'.
68--
69-- @since 0.1.0.0
70asyncOnWithUnmask :: MonadUnliftIO m => Int -> ((forall b. m b -> m b) -> m a) -> m (Async a)
71asyncOnWithUnmask i m =
72  withRunInIO $ \run -> A.asyncOnWithUnmask i $ \unmask -> run $ m $ liftIO . unmask . run
73
74-- | Unlifted 'A.withAsync'.
75--
76-- @since 0.1.0.0
77withAsync :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b
78withAsync a b = withRunInIO $ \run -> A.withAsync (run a) (run . b)
79
80-- | Unlifted 'A.withAsyncBound'.
81--
82-- @since 0.1.0.0
83withAsyncBound :: MonadUnliftIO m => m a -> (Async a -> m b) -> m b
84withAsyncBound a b = withRunInIO $ \run -> A.withAsyncBound (run a) (run . b)
85
86-- | Unlifted 'A.withAsyncOn'.
87--
88-- @since 0.1.0.0
89withAsyncOn :: MonadUnliftIO m => Int -> m a -> (Async a -> m b) -> m b
90withAsyncOn i a b = withRunInIO $ \run -> A.withAsyncOn i (run a) (run . b)
91
92-- | Unlifted 'A.withAsyncWithUnmask'.
93--
94-- @since 0.1.0.0
95withAsyncWithUnmask
96  :: MonadUnliftIO m
97  => ((forall c. m c -> m c) -> m a)
98  -> (Async a -> m b)
99  -> m b
100withAsyncWithUnmask a b =
101  withRunInIO $ \run -> A.withAsyncWithUnmask
102    (\unmask -> run $ a $ liftIO . unmask . run)
103    (run . b)
104
105-- | Unlifted 'A.withAsyncOnWithMask'.
106--
107-- @since 0.1.0.0
108withAsyncOnWithUnmask
109  :: MonadUnliftIO m
110  => Int
111  -> ((forall c. m c -> m c) -> m a)
112  -> (Async a -> m b)
113  -> m b
114withAsyncOnWithUnmask i a b =
115  withRunInIO $ \run -> A.withAsyncOnWithUnmask i
116    (\unmask -> run $ a $ liftIO . unmask . run)
117    (run . b)
118
119-- | Lifted 'A.wait'.
120--
121-- @since 0.1.0.0
122wait :: MonadIO m => Async a -> m a
123wait = liftIO . A.wait
124
125-- | Lifted 'A.poll'.
126--
127-- @since 0.1.0.0
128poll :: MonadIO m => Async a -> m (Maybe (Either SomeException a))
129poll = liftIO . A.poll
130
131-- | Lifted 'A.waitCatch'.
132--
133-- @since 0.1.0.0
134waitCatch :: MonadIO m => Async a -> m (Either SomeException a)
135waitCatch = liftIO . A.waitCatch
136
137-- | Lifted 'A.cancel'.
138--
139-- @since 0.1.0.0
140cancel :: MonadIO m => Async a -> m ()
141cancel = liftIO . A.cancel
142
143-- | Lifted 'A.uninterruptibleCancel'.
144--
145-- @since 0.1.0.0
146uninterruptibleCancel :: MonadIO m => Async a -> m ()
147uninterruptibleCancel = liftIO . A.uninterruptibleCancel
148
149-- | Lifted 'A.cancelWith'. Additionally uses 'UE.toAsyncException' to
150-- ensure async exception safety.
151--
152-- @since 0.1.0.0
153cancelWith :: (Exception e, MonadIO m) => Async a -> e -> m ()
154cancelWith a e = liftIO (A.cancelWith a (UE.toAsyncException e))
155
156-- | Lifted 'A.waitAny'.
157--
158-- @since 0.1.0.0
159waitAny :: MonadIO m => [Async a] -> m (Async a, a)
160waitAny = liftIO . A.waitAny
161
162-- | Lifted 'A.waitAnyCatch'.
163--
164-- @since 0.1.0.0
165waitAnyCatch :: MonadIO m => [Async a] -> m (Async a, Either SomeException a)
166waitAnyCatch = liftIO . A.waitAnyCatch
167
168-- | Lifted 'A.waitAnyCancel'.
169--
170-- @since 0.1.0.0
171waitAnyCancel :: MonadIO m => [Async a] -> m (Async a, a)
172waitAnyCancel = liftIO . A.waitAnyCancel
173
174-- | Lifted 'A.waitAnyCatchCancel'.
175--
176-- @since 0.1.0.0
177waitAnyCatchCancel :: MonadIO m => [Async a] -> m (Async a, Either SomeException a)
178waitAnyCatchCancel = liftIO . A.waitAnyCatchCancel
179
180-- | Lifted 'A.waitEither'.
181--
182-- @since 0.1.0.0
183waitEither :: MonadIO m => Async a -> Async b -> m (Either a b)
184waitEither a b = liftIO (A.waitEither a b)
185
186-- | Lifted 'A.waitEitherCatch'.
187--
188-- @since 0.1.0.0
189waitEitherCatch :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b))
190waitEitherCatch a b = liftIO (A.waitEitherCatch a b)
191
192-- | Lifted 'A.waitEitherCancel'.
193--
194-- @since 0.1.0.0
195waitEitherCancel :: MonadIO m => Async a -> Async b -> m (Either a b)
196waitEitherCancel a b = liftIO (A.waitEitherCancel a b)
197
198-- | Lifted 'A.waitEitherCatchCancel'.
199--
200-- @since 0.1.0.0
201waitEitherCatchCancel :: MonadIO m => Async a -> Async b -> m (Either (Either SomeException a) (Either SomeException b))
202waitEitherCatchCancel a b = liftIO (A.waitEitherCatchCancel a b)
203
204-- | Lifted 'A.waitEither_'.
205--
206-- @since 0.1.0.0
207waitEither_ :: MonadIO m => Async a -> Async b -> m ()
208waitEither_ a b = liftIO (A.waitEither_ a b)
209
210-- | Lifted 'A.waitBoth'.
211--
212-- @since 0.1.0.0
213waitBoth :: MonadIO m => Async a -> Async b -> m (a, b)
214waitBoth a b = liftIO (A.waitBoth a b)
215
216-- | Lifted 'A.link'.
217--
218-- @since 0.1.0.0
219link :: MonadIO m => Async a -> m ()
220link = liftIO . A.link
221
222-- | Lifted 'A.link2'.
223--
224-- @since 0.1.0.0
225link2 :: MonadIO m => Async a -> Async b -> m ()
226link2 a b = liftIO (A.link2 a b)
227
228-- | Unlifted 'A.race'.
229--
230-- @since 0.1.0.0
231race :: MonadUnliftIO m => m a -> m b -> m (Either a b)
232race a b = withRunInIO $ \run -> A.race (run a) (run b)
233
234-- | Unlifted 'A.race_'.
235--
236-- @since 0.1.0.0
237race_ :: MonadUnliftIO m => m a -> m b -> m ()
238race_ a b = withRunInIO $ \run -> A.race_ (run a) (run b)
239
240-- | Unlifted 'A.concurrently'.
241--
242-- @since 0.1.0.0
243concurrently :: MonadUnliftIO m => m a -> m b -> m (a, b)
244concurrently a b = withRunInIO $ \run -> A.concurrently (run a) (run b)
245
246-- | Unlifted 'A.concurrently_'.
247--
248-- @since 0.1.0.0
249concurrently_ :: MonadUnliftIO m => m a -> m b -> m ()
250concurrently_ a b = withRunInIO $ \run -> A.concurrently_ (run a) (run b)
251
252-- | Unlifted 'A.Concurrently'.
253--
254-- @since 0.1.0.0
255newtype Concurrently m a = Concurrently
256  { runConcurrently :: m a
257  }
258
259-- | @since 0.1.0.0
260instance Monad m => Functor (Concurrently m) where
261  fmap f (Concurrently a) = Concurrently $ liftM f a
262
263-- | @since 0.1.0.0
264instance MonadUnliftIO m => Applicative (Concurrently m) where
265  pure = Concurrently . return
266  Concurrently fs <*> Concurrently as =
267    Concurrently $ liftM (\(f, a) -> f a) (concurrently fs as)
268
269-- | Composing two unlifted 'Concurrently' values using 'Alternative' is the
270-- equivalent to using a 'race' combinator, the asynchrounous sub-routine that
271-- returns a value first is the one that gets it's value returned, the slowest
272-- sub-routine gets cancelled and it's thread is killed.
273--
274-- @since 0.1.0.0
275instance MonadUnliftIO m => Alternative (Concurrently m) where
276  -- | Care should be taken when using the 'empty' value of the 'Alternative'
277  -- interface, as it will create a thread that delays for a long period of
278  -- time. The reason behind this implementation is that any other computation
279  -- will finish first than the 'empty' value. This implementation is less than
280  -- ideal, and in a perfect world, we would have a typeclass family that allows
281  -- '(<|>)' but not 'empty'.
282  --
283  -- @since 0.1.0.0
284  empty = Concurrently $ liftIO (forever (threadDelay maxBound))
285  Concurrently as <|> Concurrently bs =
286    Concurrently $ liftM (either id id) (race as bs)
287
288--------------------------------------------------------------------------------
289#if MIN_VERSION_base(4,9,0)
290--------------------------------------------------------------------------------
291-- | Only defined by @async@ for @base >= 4.9@.
292--
293-- @since 0.1.0.0
294instance (MonadUnliftIO m, Semigroup a) => Semigroup (Concurrently m a) where
295  (<>) = liftA2 (<>)
296
297-- | @since 0.1.0.0
298instance (Semigroup a, Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where
299  mempty = pure mempty
300  mappend = (<>)
301--------------------------------------------------------------------------------
302#else
303--------------------------------------------------------------------------------
304-- | @since 0.1.0.0
305instance (Monoid a, MonadUnliftIO m) => Monoid (Concurrently m a) where
306  mempty = pure mempty
307  mappend = liftA2 mappend
308--------------------------------------------------------------------------------
309#endif
310--------------------------------------------------------------------------------
311
312-- | Similar to 'mapConcurrently' but with arguments flipped
313--
314-- @since 0.1.0.0
315forConcurrently :: MonadUnliftIO m => Traversable t => t a -> (a -> m b) -> m (t b)
316forConcurrently = flip mapConcurrently
317{-# INLINE forConcurrently #-}
318
319-- | Similar to 'mapConcurrently_' but with arguments flipped
320--
321-- @since 0.1.0.0
322forConcurrently_ :: MonadUnliftIO m => Foldable f => f a -> (a -> m b) -> m ()
323forConcurrently_ = flip mapConcurrently_
324{-# INLINE forConcurrently_ #-}
325
326-- | Unlifted 'A.replicateConcurrently'.
327--
328-- @since 0.1.0.0
329#if MIN_VERSION_base(4,7,0)
330#else
331replicateConcurrently :: (Functor m, MonadUnliftIO m) => Int -> m a -> m [a]
332#endif
333replicateConcurrently cnt m =
334  case compare cnt 1 of
335    LT -> pure []
336    EQ -> (:[]) <$> m
337    GT -> mapConcurrently id (replicate cnt m)
338{-# INLINE replicateConcurrently #-}
339
340-- | Unlifted 'A.replicateConcurrently_'.
341--
342-- @since 0.1.0.0
343#if MIN_VERSION_base(4,7,0)
344replicateConcurrently_ :: (Applicative m, MonadUnliftIO m) => Int -> m a -> m ()
345#else
346replicateConcurrently_ :: (MonadUnliftIO m) => Int -> m a -> m ()
347#endif
348replicateConcurrently_ cnt m =
349  case compare cnt 1 of
350    LT -> pure ()
351    EQ -> void m
352    GT -> mapConcurrently_ id (replicate cnt m)
353{-# INLINE replicateConcurrently_ #-}
354
355-- Conc uses GHC features that are not supported in versions <= to ghc-7.10
356-- so we are going to export/use it when we have a higher version only.
357--------------------------------------------------------------------------------
358#if MIN_VERSION_base(4,8,0)
359--------------------------------------------------------------------------------
360
361-- | Executes a 'Traversable' container of items concurrently, it uses the 'Flat'
362-- type internally.
363--
364-- @since 0.1.0.0
365mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b)
366mapConcurrently f t = withRunInIO $ \run -> runFlat $ traverse
367  (FlatApp . FlatAction . run . f)
368  t
369{-# INLINE mapConcurrently #-}
370
371-- | Executes a 'Traversable' container of items concurrently, it uses the 'Flat'
372-- type internally. This function ignores the results.
373--
374-- @since 0.1.0.0
375mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m ()
376mapConcurrently_ f t = withRunInIO $ \run -> runFlat $ traverse_
377  (FlatApp . FlatAction . run . f)
378  t
379{-# INLINE mapConcurrently_ #-}
380
381
382-- More efficient Conc implementation
383
384-- | A more efficient alternative to 'Concurrently', which reduces the
385-- number of threads that need to be forked. For more information, see
386-- @FIXME link to blog post@. This is provided as a separate type to
387-- @Concurrently@ as it has a slightly different API.
388--
389-- Use the 'conc' function to construct values of type 'Conc', and
390-- 'runConc' to execute the composed actions. You can use the
391-- @Applicative@ instance to run different actions and wait for all of
392-- them to complete, or the @Alternative@ instance to wait for the
393-- first thread to complete.
394--
395-- In the event of a runtime exception thrown by any of the children
396-- threads, or an asynchronous exception received in the parent
397-- thread, all threads will be killed with an 'A.AsyncCancelled'
398-- exception and the original exception rethrown. If multiple
399-- exceptions are generated by different threads, there are no
400-- guarantees on which exception will end up getting rethrown.
401--
402-- For many common use cases, you may prefer using helper functions in
403-- this module like 'mapConcurrently'.
404--
405-- There are some intentional differences in behavior to
406-- @Concurrently@:
407--
408-- * Children threads are always launched in an unmasked state, not
409--   the inherited state of the parent thread.
410--
411-- Note that it is a programmer error to use the @Alternative@
412-- instance in such a way that there are no alternatives to an empty,
413-- e.g. @runConc (empty <|> empty)@. In such a case, a 'ConcException'
414-- will be thrown. If there was an @Alternative@ in the standard
415-- libraries without @empty@, this library would use it instead.
416--
417-- @since 0.2.9.0
418data Conc m a where
419  Action :: m a -> Conc m a
420  Apply   :: Conc m (v -> a) -> Conc m v -> Conc m a
421  LiftA2 :: (x -> y -> a) -> Conc m x -> Conc m y -> Conc m a
422
423  -- Just an optimization to avoid spawning extra threads
424  Pure :: a -> Conc m a
425
426  -- I thought there would be an optimization available from having a
427  -- data constructor that explicit doesn't care about the first
428  -- result. Turns out it doesn't help much: we still need to keep a
429  -- TMVar below to know when the thread completes.
430  --
431  -- Then :: Conc m a -> Conc m b -> Conc m b
432
433  Alt :: Conc m a -> Conc m a -> Conc m a
434  Empty :: Conc m a
435
436deriving instance Functor m => Functor (Conc m)
437-- fmap f (Action routine) = Action (fmap f routine)
438-- fmap f (LiftA2 g x y)   = LiftA2 (fmap f g) x y
439-- fmap f (Pure val)       = Pure (f val)
440-- fmap f (Alt a b)        = Alt (fmap f a) (fmap f b)
441-- fmap f Empty            = Empty
442
443-- | Construct a value of type 'Conc' from an action. Compose these
444-- values using the typeclass instances (most commonly 'Applicative'
445-- and 'Alternative') and then run with 'runConc'.
446--
447-- @since 0.2.9.0
448conc :: m a -> Conc m a
449conc = Action
450{-# INLINE conc #-}
451
452-- | Run a 'Conc' value on multiple threads.
453--
454-- @since 0.2.9.0
455runConc :: MonadUnliftIO m => Conc m a -> m a
456runConc = flatten >=> (liftIO . runFlat)
457{-# INLINE runConc #-}
458
459-- | @since 0.2.9.0
460instance MonadUnliftIO m => Applicative (Conc m) where
461  pure = Pure
462  {-# INLINE pure #-}
463  -- | Following is an example of how an 'Applicative' expands to a Tree
464  --
465  -- @@@
466  -- downloadA :: IO String
467  -- downloadB :: IO String
468  --
469  -- (f <$> conc downloadA <*> conc downloadB <*> pure 123)
470  --
471  --   (((f <$> a) <*> b) <*> c))
472  --        (1)    (2)    (3)
473  --
474  -- (1)
475  --   Action (fmap f downloadA)
476  -- (2)
477  --   Apply (Action (fmap f downloadA)) (Action downloadB)
478  -- (3)
479  --   Apply (Apply (Action (fmap f downloadA)) (Action downloadB))
480  --        (Pure 123)
481  -- @@@
482  --
483  (<*>) = Apply
484  {-# INLINE (<*>) #-}
485  -- See comment above on Then
486  -- (*>) = Then
487#if MIN_VERSION_base(4,11,0)
488  liftA2 = LiftA2
489  {-# INLINE liftA2 #-}
490#endif
491
492  a *> b = LiftA2 (\_ x -> x) a b
493  {-# INLINE (*>) #-}
494
495-- | @since 0.2.9.0
496instance MonadUnliftIO m => Alternative (Conc m) where
497  empty = Empty -- this is so ugly, we don't actually want to provide it!
498  {-# INLINE empty #-}
499  (<|>) = Alt
500  {-# INLINE (<|>) #-}
501
502#if MIN_VERSION_base(4, 11, 0)
503-- | @since 0.2.9.0
504instance (MonadUnliftIO m, Semigroup a) => Semigroup (Conc m a) where
505  (<>) = liftA2 (<>)
506  {-# INLINE (<>) #-}
507#endif
508
509-- | @since 0.2.9.0
510instance (Monoid a, MonadUnliftIO m) => Monoid (Conc m a) where
511  mempty = pure mempty
512  {-# INLINE mempty #-}
513  mappend = liftA2 mappend
514  {-# INLINE mappend #-}
515
516-------------------------
517-- Conc implementation --
518-------------------------
519
520-- Data types for flattening out the original @Conc@ into a simplified
521-- view. Goals:
522--
523-- * We want to get rid of the Empty data constructor. We don't want
524--   it anyway, it's only there because of the Alternative typeclass.
525--
526-- * We want to ensure that there is no nesting of Alt data
527--   constructors. There is a bookkeeping overhead to each time we
528--   need to track raced threads, and we want to minimize that
529--   bookkeeping.
530--
531-- * We want to ensure that, when racing, we're always racing at least
532--   two threads.
533--
534-- * We want to simplify down to IO.
535
536-- | Flattened structure, either Applicative or Alternative
537data Flat a
538  = FlatApp !(FlatApp a)
539  -- | Flattened Alternative. Has at least 2 entries, which must be
540  -- FlatApp (no nesting of FlatAlts).
541  | FlatAlt !(FlatApp a) !(FlatApp a) ![FlatApp a]
542
543deriving instance Functor Flat
544-- fmap f (FlatApp a) =
545--  FlatApp (fmap f a)
546-- fmap f (FlatAlt (FlatApp a) (FlatApp b) xs) =
547--   FlatAlt (FlatApp (fmap f a)) (FlatApp (fmap f b)) (map (fmap f) xs)
548instance Applicative Flat where
549  pure = FlatApp . pure
550  (<*>) f a = FlatApp (FlatLiftA2 id f a)
551#if MIN_VERSION_base(4,11,0)
552  liftA2 f a b = FlatApp (FlatLiftA2 f a b)
553#endif
554
555-- | Flattened Applicative. No Alternative stuff directly in here, but may be in
556-- the children. Notice this type doesn't have a type parameter for monadic
557-- contexts, it hardwires the base monad to IO given concurrency relies
558-- eventually on that.
559--
560-- @since 0.2.9.0
561data FlatApp a where
562  FlatPure   :: a -> FlatApp a
563  FlatAction :: IO a -> FlatApp a
564  FlatApply   :: Flat (v -> a) -> Flat v -> FlatApp a
565  FlatLiftA2 :: (x -> y -> a) -> Flat x -> Flat y -> FlatApp a
566
567deriving instance Functor FlatApp
568instance Applicative FlatApp where
569  pure = FlatPure
570  (<*>) mf ma = FlatApply (FlatApp mf) (FlatApp ma)
571#if MIN_VERSION_base(4,11,0)
572  liftA2 f a b = FlatLiftA2 f (FlatApp a) (FlatApp b)
573#endif
574
575-- | Things that can go wrong in the structure of a 'Conc'. These are
576-- /programmer errors/.
577--
578-- @since 0.2.9.0
579data ConcException
580  = EmptyWithNoAlternative
581  deriving (Generic, Show, Typeable, Eq, Ord)
582instance E.Exception ConcException
583
584-- | Simple difference list, for nicer types below
585type DList a = [a] -> [a]
586
587dlistConcat :: DList a -> DList a -> DList a
588dlistConcat = (.)
589{-# INLINE dlistConcat #-}
590
591dlistCons :: a -> DList a -> DList a
592dlistCons a as = dlistSingleton a `dlistConcat` as
593{-# INLINE dlistCons #-}
594
595dlistConcatAll :: [DList a] -> DList a
596dlistConcatAll = foldr (.) id
597{-# INLINE dlistConcatAll #-}
598
599dlistToList :: DList a -> [a]
600dlistToList = ($ [])
601{-# INLINE dlistToList #-}
602
603dlistSingleton :: a -> DList a
604dlistSingleton a = (a:)
605{-# INLINE dlistSingleton #-}
606
607dlistEmpty :: DList a
608dlistEmpty = id
609{-# INLINE dlistEmpty #-}
610
611-- | Turn a 'Conc' into a 'Flat'. Note that thanks to the ugliness of
612-- 'empty', this may fail, e.g. @flatten Empty@.
613--
614-- @since 0.2.9.0
615flatten :: forall m a. MonadUnliftIO m => Conc m a -> m (Flat a)
616flatten c0 = withRunInIO $ \run -> do
617
618  -- why not app?
619  let both :: forall k. Conc m k -> IO (Flat k)
620      both Empty = E.throwIO EmptyWithNoAlternative
621      both (Action m) = pure $ FlatApp $ FlatAction $ run m
622      both (Apply cf ca) = do
623        f <- both cf
624        a <- both ca
625        pure $ FlatApp $ FlatApply f a
626      both (LiftA2 f ca cb) = do
627        a <- both ca
628        b <- both cb
629        pure $ FlatApp $ FlatLiftA2 f a b
630      both (Alt ca cb) = do
631        a <- alt ca
632        b <- alt cb
633        case dlistToList (a `dlistConcat` b) of
634          []    -> E.throwIO EmptyWithNoAlternative
635          [x]   -> pure $ FlatApp x
636          x:y:z -> pure $ FlatAlt x y z
637      both (Pure a) = pure $ FlatApp $ FlatPure a
638
639      -- Returns a difference list for cheaper concatenation
640      alt :: forall k. Conc m k -> IO (DList (FlatApp k))
641      alt Empty = pure dlistEmpty
642      alt (Apply cf ca) = do
643        f <- both cf
644        a <- both ca
645        pure (dlistSingleton $ FlatApply f a)
646      alt (Alt ca cb) = do
647        a <- alt ca
648        b <- alt cb
649        pure $ a `dlistConcat` b
650      alt (Action m) = pure (dlistSingleton $ FlatAction (run m))
651      alt (LiftA2 f ca cb) = do
652        a <- both ca
653        b <- both cb
654        pure (dlistSingleton $ FlatLiftA2 f a b)
655      alt (Pure a) = pure (dlistSingleton $ FlatPure a)
656
657  both c0
658
659-- | Run a @Flat a@ on multiple threads.
660runFlat :: Flat a -> IO a
661
662-- Silly, simple optimizations
663runFlat (FlatApp (FlatAction io)) = io
664runFlat (FlatApp (FlatPure x)) = pure x
665
666-- Start off with all exceptions masked so we can install proper cleanup.
667runFlat f0 = E.uninterruptibleMask $ \restore -> do
668  -- How many threads have been spawned and finished their task? We need to
669  -- ensure we kill all child threads and wait for them to die.
670  resultCountVar <- newTVarIO 0
671
672  -- Forks off as many threads as necessary to run the given Flat a,
673  -- and returns:
674  --
675  -- + An STM action that will block until completion and return the
676  --   result.
677  --
678  -- + The IDs of all forked threads. These need to be tracked so they
679  --   can be killed (either when an exception is thrown, or when one
680  --   of the alternatives completes first).
681  --
682  -- It would be nice to have the returned STM action return an Either
683  -- and keep the SomeException values somewhat explicit, but in all
684  -- my testing this absolutely kills performance. Instead, we're
685  -- going to use a hack of providing a TMVar to fill up with a
686  -- SomeException when things fail.
687  --
688  -- TODO: Investigate why performance degradation on Either
689  let go :: forall a.
690            TMVar E.SomeException
691         -> Flat a
692         -> IO (STM a, DList C.ThreadId)
693      go _excVar (FlatApp (FlatPure x)) = pure (pure x, dlistEmpty)
694      go excVar (FlatApp (FlatAction io)) = do
695        resVar <- newEmptyTMVarIO
696        tid <- C.forkIOWithUnmask $ \restore1 -> do
697          res <- E.try $ restore1 io
698          atomically $ do
699            modifyTVar' resultCountVar (+ 1)
700            case res of
701              Left e  -> void $ tryPutTMVar excVar e
702              Right x -> putTMVar resVar x
703        pure (readTMVar resVar, dlistSingleton tid)
704      go excVar (FlatApp (FlatApply cf ca)) = do
705        (f, tidsf) <- go excVar cf
706        (a, tidsa) <- go excVar ca
707        pure (f <*> a, tidsf `dlistConcat` tidsa)
708      go excVar (FlatApp (FlatLiftA2 f a b)) = do
709        (a', tidsa) <- go excVar a
710        (b', tidsb) <- go excVar b
711        pure (liftA2 f a' b', tidsa `dlistConcat` tidsb)
712
713      go excVar0 (FlatAlt x y z) = do
714        -- As soon as one of the children finishes, we need to kill the siblings,
715        -- we're going to create our own excVar here to pass to the children, so
716        -- we can prevent the ThreadKilled exceptions we throw to the children
717        -- here from propagating and taking down the whole system.
718        excVar <- newEmptyTMVarIO
719        resVar <- newEmptyTMVarIO
720        pairs <- traverse (go excVar . FlatApp) (x:y:z)
721        let (blockers, workerTids) = unzip pairs
722
723        -- Fork a helper thread to wait for the first child to
724        -- complete, or for one of them to die with an exception so we
725        -- can propagate it to excVar0.
726        helperTid <- C.forkIOWithUnmask $ \restore1 -> do
727          eres <- E.try $ restore1 $ atomically $ foldr
728            (\blocker rest -> (Right <$> blocker) <|> rest)
729            (Left <$> readTMVar excVar)
730            blockers
731          atomically $ do
732            modifyTVar' resultCountVar (+ 1)
733            case eres of
734              -- NOTE: The child threads are spawned from @traverse go@ call above, they
735              -- are _not_ children of this helper thread, and helper thread doesn't throw
736              -- synchronous exceptions, so, any exception that the try above would catch
737              -- must be an async exception.
738              -- We were killed by an async exception, do nothing.
739              Left (_ :: E.SomeException) -> pure ()
740              -- Child thread died, propagate it
741              Right (Left e)              -> void $ tryPutTMVar excVar0 e
742              -- Successful result from one of the children
743              Right (Right res)           -> putTMVar resVar res
744
745          -- And kill all of the threads
746          for_ workerTids $ \tids' ->
747            -- NOTE: Replacing A.AsyncCancelled with KillThread as the
748            -- 'A.AsyncCancelled' constructor is not exported in older versions
749            -- of the async package
750            -- for_ (tids' []) $ \workerTid -> E.throwTo workerTid A.AsyncCancelled
751            for_ (dlistToList tids') $ \workerTid -> C.killThread workerTid
752
753        pure ( readTMVar resVar
754             , helperTid `dlistCons` dlistConcatAll workerTids
755             )
756
757  excVar <- newEmptyTMVarIO
758  (getRes, tids0) <- go excVar f0
759  let tids = dlistToList tids0
760      tidCount = length tids
761      allDone count =
762        if count > tidCount
763          then error ("allDone: count ("
764                      <> show count
765                      <> ") should never be greater than tidCount ("
766                      <> show tidCount
767                      <> ")")
768          else count == tidCount
769
770  -- Automatically retry if we get killed by a
771  -- BlockedIndefinitelyOnSTM. For more information, see:
772  --
773  -- + https:\/\/github.com\/simonmar\/async\/issues\/14
774  -- + https:\/\/github.com\/simonmar\/async\/pull\/15
775  --
776  let autoRetry action =
777        action `E.catch`
778        \E.BlockedIndefinitelyOnSTM -> autoRetry action
779
780  -- Restore the original masking state while blocking and catch
781  -- exceptions to allow the parent thread to be killed early.
782  res <- E.try $ restore $ autoRetry $ atomically $
783         (Left <$> readTMVar excVar) <|>
784         (Right <$> getRes)
785
786  count0 <- atomically $ readTVar resultCountVar
787  unless (allDone count0) $ do
788    -- Kill all of the threads
789    -- NOTE: Replacing A.AsyncCancelled with KillThread as the
790    -- 'A.AsyncCancelled' constructor is not exported in older versions
791    -- of the async package
792    -- for_ tids $ \tid -> E.throwTo tid A.AsyncCancelled
793    for_ tids $ \tid -> C.killThread tid
794
795    -- Wait for all of the threads to die. We're going to restore the original
796    -- masking state here, just in case there's a bug in the cleanup code of a
797    -- child thread, so that we can be killed by an async exception. We decided
798    -- this is a better behavior than hanging indefinitely and wait for a SIGKILL.
799    restore $ atomically $ do
800      count <- readTVar resultCountVar
801      -- retries until resultCountVar has increased to the threadId count returned by go
802      check $ allDone count
803
804  -- Return the result or throw an exception. Yes, we could use
805  -- either or join, but explicit pattern matching is nicer here.
806  case res of
807    -- Parent thread was killed with an async exception
808    Left e          -> E.throwIO (e :: E.SomeException)
809    -- Some child thread died
810    Right (Left e)  -> E.throwIO e
811    -- Everything worked!
812    Right (Right x) -> pure x
813{-# INLINEABLE runFlat #-}
814
815--------------------------------------------------------------------------------
816#else
817--------------------------------------------------------------------------------
818
819-- | Unlifted 'A.mapConcurrently'.
820--
821-- @since 0.1.0.0
822mapConcurrently :: MonadUnliftIO m => Traversable t => (a -> m b) -> t a -> m (t b)
823mapConcurrently f t = withRunInIO $ \run -> A.mapConcurrently (run . f) t
824{-# INLINE mapConcurrently #-}
825
826-- | Unlifted 'A.mapConcurrently_'.
827--
828-- @since 0.1.0.0
829mapConcurrently_ :: MonadUnliftIO m => Foldable f => (a -> m b) -> f a -> m ()
830mapConcurrently_ f t = withRunInIO $ \run -> A.mapConcurrently_ (run . f) t
831{-# INLINE mapConcurrently_ #-}
832
833--------------------------------------------------------------------------------
834#endif
835--------------------------------------------------------------------------------
836
837-- | Like 'mapConcurrently' from async, but instead of one thread per
838-- element, it does pooling from a set of threads. This is useful in
839-- scenarios where resource consumption is bounded and for use cases
840-- where too many concurrent tasks aren't allowed.
841--
842-- === __Example usage__
843--
844-- @
845-- import Say
846--
847-- action :: Int -> IO Int
848-- action n = do
849--   tid <- myThreadId
850--   sayString $ show tid
851--   threadDelay (2 * 10^6) -- 2 seconds
852--   return n
853--
854-- main :: IO ()
855-- main = do
856--   yx \<- pooledMapConcurrentlyN 5 (\\x -\> action x) [1..5]
857--   print yx
858-- @
859--
860-- On executing you can see that five threads have been spawned:
861--
862-- @
863-- \$ ./pool
864-- ThreadId 36
865-- ThreadId 38
866-- ThreadId 40
867-- ThreadId 42
868-- ThreadId 44
869-- [1,2,3,4,5]
870-- @
871--
872--
873-- Let's modify the above program such that there are less threads
874-- than the number of items in the list:
875--
876-- @
877-- import Say
878--
879-- action :: Int -> IO Int
880-- action n = do
881--   tid <- myThreadId
882--   sayString $ show tid
883--   threadDelay (2 * 10^6) -- 2 seconds
884--   return n
885--
886-- main :: IO ()
887-- main = do
888--   yx \<- pooledMapConcurrentlyN 3 (\\x -\> action x) [1..5]
889--   print yx
890-- @
891-- On executing you can see that only three threads are active totally:
892--
893-- @
894-- \$ ./pool
895-- ThreadId 35
896-- ThreadId 37
897-- ThreadId 39
898-- ThreadId 35
899-- ThreadId 39
900-- [1,2,3,4,5]
901-- @
902--
903-- @since 0.2.10
904pooledMapConcurrentlyN :: (MonadUnliftIO m, Traversable t)
905                      => Int -- ^ Max. number of threads. Should not be less than 1.
906                      -> (a -> m b) -> t a -> m (t b)
907pooledMapConcurrentlyN numProcs f xs =
908    withRunInIO $ \run -> pooledMapConcurrentlyIO numProcs (run . f) xs
909
910-- | Similar to 'pooledMapConcurrentlyN' but with number of threads
911-- set from 'getNumCapabilities'. Usually this is useful for CPU bound
912-- tasks.
913--
914-- @since 0.2.10
915pooledMapConcurrently :: (MonadUnliftIO m, Traversable t) => (a -> m b) -> t a -> m (t b)
916pooledMapConcurrently f xs = do
917  withRunInIO $ \run -> do
918    numProcs <- getNumCapabilities
919    pooledMapConcurrentlyIO numProcs (run . f) xs
920
921-- | Similar to 'pooledMapConcurrentlyN' but with flipped arguments.
922--
923-- @since 0.2.10
924pooledForConcurrentlyN :: (MonadUnliftIO m, Traversable t)
925                      => Int -- ^ Max. number of threads. Should not be less than 1.
926                      -> t a -> (a -> m b) -> m (t b)
927pooledForConcurrentlyN numProcs = flip (pooledMapConcurrentlyN numProcs)
928
929-- | Similar to 'pooledForConcurrentlyN' but with number of threads
930-- set from 'getNumCapabilities'. Usually this is useful for CPU bound
931-- tasks.
932--
933-- @since 0.2.10
934pooledForConcurrently :: (MonadUnliftIO m, Traversable t) => t a -> (a -> m b) -> m (t b)
935pooledForConcurrently = flip pooledMapConcurrently
936
937pooledMapConcurrentlyIO :: Traversable t => Int -> (a -> IO b) -> t a -> IO (t b)
938pooledMapConcurrentlyIO numProcs f xs =
939    if (numProcs < 1)
940    then error "pooledMapconcurrentlyIO: number of threads < 1"
941    else pooledMapConcurrentlyIO' numProcs f xs
942
943-- | Performs the actual pooling for the tasks. This function will
944-- continue execution until the task queue becomes empty. When one of
945-- the pooled thread finishes it's task, it will pickup the next task
946-- from the queue if an job is available.
947pooledConcurrently
948  :: Int -- ^ Max. number of threads. Should not be less than 1.
949  -> IORef [a] -- ^ Task queue. These are required as inputs for the jobs.
950  -> (a -> IO ()) -- ^ The task which will be run concurrently (but
951                 -- will be pooled properly).
952  -> IO ()
953pooledConcurrently numProcs jobsVar f = do
954  replicateConcurrently_ numProcs $ do
955    let loop = do
956          mbJob :: Maybe a <- atomicModifyIORef' jobsVar $ \x -> case x of
957            [] -> ([], Nothing)
958            var : vars -> (vars, Just var)
959          case mbJob of
960            Nothing -> return ()
961            Just x -> do
962              f x
963              loop
964     in loop
965
966pooledMapConcurrentlyIO' ::
967    Traversable t => Int  -- ^ Max. number of threads. Should not be less than 1.
968                  -> (a -> IO b)
969                  -> t a
970                  -> IO (t b)
971pooledMapConcurrentlyIO' numProcs f xs = do
972  -- prepare one IORef per result...
973  jobs :: t (a, IORef b) <-
974    for xs (\x -> (x, ) <$> newIORef (error "pooledMapConcurrentlyIO': empty IORef"))
975  -- ...put all the inputs in a queue..
976  jobsVar :: IORef [(a, IORef b)] <- newIORef (toList jobs)
977  -- ...run `numProcs` threads in parallel, each
978  -- of them consuming the queue and filling in
979  -- the respective IORefs.
980  pooledConcurrently numProcs jobsVar $ \ (x, outRef) -> f x >>= atomicWriteIORef outRef      -- Read all the IORefs
981  for jobs (\(_, outputRef) -> readIORef outputRef)
982
983pooledMapConcurrentlyIO_' ::
984  Foldable t => Int -> (a -> IO ()) -> t a -> IO ()
985pooledMapConcurrentlyIO_' numProcs f jobs = do
986  jobsVar :: IORef [a] <- newIORef (toList jobs)
987  pooledConcurrently numProcs jobsVar f
988
989pooledMapConcurrentlyIO_ :: Foldable t => Int -> (a -> IO b) -> t a -> IO ()
990pooledMapConcurrentlyIO_ numProcs f xs =
991    if (numProcs < 1)
992    then error "pooledMapconcurrentlyIO_: number of threads < 1"
993    else pooledMapConcurrentlyIO_' numProcs (\x -> f x >> return ()) xs
994
995-- | Like 'pooledMapConcurrentlyN' but with the return value
996-- discarded.
997--
998-- @since 0.2.10
999pooledMapConcurrentlyN_ :: (MonadUnliftIO m, Foldable f)
1000                        => Int -- ^ Max. number of threads. Should not be less than 1.
1001                        -> (a -> m b) -> f a -> m ()
1002pooledMapConcurrentlyN_ numProcs f t =
1003  withRunInIO $ \run -> pooledMapConcurrentlyIO_ numProcs (run . f) t
1004
1005-- | Like 'pooledMapConcurrently' but with the return value discarded.
1006--
1007-- @since 0.2.10
1008pooledMapConcurrently_ :: (MonadUnliftIO m, Foldable f) => (a -> m b) -> f a -> m ()
1009pooledMapConcurrently_ f t =
1010  withRunInIO $ \run -> do
1011    numProcs <- getNumCapabilities
1012    pooledMapConcurrentlyIO_ numProcs (run . f) t
1013
1014-- | Like 'pooledMapConcurrently_' but with flipped arguments.
1015--
1016-- @since 0.2.10
1017pooledForConcurrently_ :: (MonadUnliftIO m, Foldable f) => f a -> (a -> m b) -> m ()
1018pooledForConcurrently_ = flip pooledMapConcurrently_
1019
1020-- | Like 'pooledMapConcurrentlyN_' but with flipped arguments.
1021--
1022-- @since 0.2.10
1023pooledForConcurrentlyN_ :: (MonadUnliftIO m, Foldable t)
1024                        => Int -- ^ Max. number of threads. Should not be less than 1.
1025                        -> t a -> (a -> m b) -> m ()
1026pooledForConcurrentlyN_ numProcs = flip (pooledMapConcurrentlyN_ numProcs)
1027
1028
1029-- | Pooled version of 'replicateConcurrently'. Performs the action in
1030-- the pooled threads.
1031--
1032-- @since 0.2.10
1033pooledReplicateConcurrentlyN :: (MonadUnliftIO m)
1034                             => Int -- ^ Max. number of threads. Should not be less than 1.
1035                             -> Int -- ^ Number of times to perform the action.
1036                             -> m a -> m [a]
1037pooledReplicateConcurrentlyN numProcs cnt task =
1038    if cnt < 1
1039    then return []
1040    else pooledMapConcurrentlyN numProcs (\_ -> task) [1..cnt]
1041
1042-- | Similar to 'pooledReplicateConcurrentlyN' but with number of
1043-- threads set from 'getNumCapabilities'. Usually this is useful for
1044-- CPU bound tasks.
1045--
1046-- @since 0.2.10
1047pooledReplicateConcurrently :: (MonadUnliftIO m)
1048                            => Int -- ^ Number of times to perform the action.
1049                            -> m a -> m [a]
1050pooledReplicateConcurrently cnt task =
1051    if cnt < 1
1052    then return []
1053    else pooledMapConcurrently (\_ -> task) [1..cnt]
1054
1055-- | Pooled version of 'replicateConcurrently_'. Performs the action in
1056-- the pooled threads.
1057--
1058-- @since 0.2.10
1059pooledReplicateConcurrentlyN_ :: (MonadUnliftIO m)
1060                              => Int -- ^ Max. number of threads. Should not be less than 1.
1061                              -> Int -- ^ Number of times to perform the action.
1062                              -> m a -> m ()
1063pooledReplicateConcurrentlyN_ numProcs cnt task =
1064  if cnt < 1
1065  then return ()
1066  else pooledMapConcurrentlyN_ numProcs (\_ -> task) [1..cnt]
1067
1068-- | Similar to 'pooledReplicateConcurrently_' but with number of
1069-- threads set from 'getNumCapabilities'. Usually this is useful for
1070-- CPU bound tasks.
1071--
1072-- @since 0.2.10
1073pooledReplicateConcurrently_ :: (MonadUnliftIO m)
1074                             => Int -- ^ Number of times to perform the action.
1075                             -> m a -> m ()
1076pooledReplicateConcurrently_ cnt task =
1077  if cnt < 1
1078  then return ()
1079  else pooledMapConcurrently_ (\_ -> task) [1..cnt]
1080