1module Exitify ( exitifyProgram ) where
2
3{-
4Note [Exitification]
5~~~~~~~~~~~~~~~~~~~~
6
7This module implements Exitification. The goal is to pull as much code out of
8recursive functions as possible, as the simplifier is better at inlining into
9call-sites that are not in recursive functions.
10
11Example:
12
13  let t = foo bar
14  joinrec go 0     x y = t (x*x)
15          go (n-1) x y = jump go (n-1) (x+y)
16  in …
17
18We’d like to inline `t`, but that does not happen: Because t is a thunk and is
19used in a recursive function, doing so might lose sharing in general. In
20this case, however, `t` is on the _exit path_ of `go`, so called at most once.
21How do we make this clearly visible to the simplifier?
22
23A code path (i.e., an expression in a tail-recursive position) in a recursive
24function is an exit path if it does not contain a recursive call. We can bind
25this expression outside the recursive function, as a join-point.
26
27Example result:
28
29  let t = foo bar
30  join exit x = t (x*x)
31  joinrec go 0     x y = jump exit x
32          go (n-1) x y = jump go (n-1) (x+y)
33  in …
34
35Now `t` is no longer in a recursive function, and good things happen!
36-}
37
38import GhcPrelude
39import Var
40import Id
41import IdInfo
42import CoreSyn
43import CoreUtils
44import State
45import Unique
46import VarSet
47import VarEnv
48import CoreFVs
49import FastString
50import Type
51import Util( mapSnd )
52
53import Data.Bifunctor
54import Control.Monad
55
56-- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them.
57-- The really interesting function is exitifyRec
58exitifyProgram :: CoreProgram -> CoreProgram
59exitifyProgram binds = map goTopLvl binds
60  where
61    goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e)
62    goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs)
63      -- Top-level bindings are never join points
64
65    in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds
66
67    go :: InScopeSet -> CoreExpr -> CoreExpr
68    go _    e@(Var{})       = e
69    go _    e@(Lit {})      = e
70    go _    e@(Type {})     = e
71    go _    e@(Coercion {}) = e
72    go in_scope (Cast e' c) = Cast (go in_scope e') c
73    go in_scope (Tick t e') = Tick t (go in_scope e')
74    go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2)
75
76    go in_scope (Lam v e')
77      = Lam v (go in_scope' e')
78      where in_scope' = in_scope `extendInScopeSet` v
79
80    go in_scope (Case scrut bndr ty alts)
81      = Case (go in_scope scrut) bndr ty (map go_alt alts)
82      where
83        in_scope1 = in_scope `extendInScopeSet` bndr
84        go_alt (dc, pats, rhs) = (dc, pats, go in_scope' rhs)
85           where in_scope' = in_scope1 `extendInScopeSetList` pats
86
87    go in_scope (Let (NonRec bndr rhs) body)
88      = Let (NonRec bndr (go in_scope rhs)) (go in_scope' body)
89      where
90        in_scope' = in_scope `extendInScopeSet` bndr
91
92    go in_scope (Let (Rec pairs) body)
93      | is_join_rec = mkLets (exitifyRec in_scope' pairs') body'
94      | otherwise   = Let (Rec pairs') body'
95      where
96        is_join_rec = any (isJoinId . fst) pairs
97        in_scope'   = in_scope `extendInScopeSetList` bindersOf (Rec pairs)
98        pairs'      = mapSnd (go in_scope') pairs
99        body'       = go in_scope' body
100
101
102-- | State Monad used inside `exitify`
103type ExitifyM =  State [(JoinId, CoreExpr)]
104
105-- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as
106--   join-points outside the joinrec.
107exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
108exitifyRec in_scope pairs
109  = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs']
110  where
111    -- We need the set of free variables of many subexpressions here, so
112    -- annotate the AST with them
113    -- see Note [Calculating free variables]
114    ann_pairs = map (second freeVars) pairs
115
116    -- Which are the recursive calls?
117    recursive_calls = mkVarSet $ map fst pairs
118
119    (pairs',exits) = (`runState` []) $ do
120        forM ann_pairs $ \(x,rhs) -> do
121            -- go past the lambdas of the join point
122            let (args, body) = collectNAnnBndrs (idJoinArity x) rhs
123            body' <- go args body
124            let rhs' = mkLams args body'
125            return (x, rhs')
126
127    ---------------------
128    -- 'go' is the main working function.
129    -- It goes through the RHS (tail-call positions only),
130    -- checks if there are no more recursive calls, if so, abstracts over
131    -- variables bound on the way and lifts it out as a join point.
132    --
133    -- ExitifyM is a state monad to keep track of floated binds
134    go :: [Var]           -- ^ Variables that are in-scope here, but
135                          -- not in scope at the joinrec; that is,
136                          -- we must potentially abstract over them.
137                          -- Invariant: they are kept in dependency order
138       -> CoreExprWithFVs -- ^ Current expression in tail position
139       -> ExitifyM CoreExpr
140
141    -- We first look at the expression (no matter what it shape is)
142    -- and determine if we can turn it into a exit join point
143    go captured ann_e
144        | -- An exit expression has no recursive calls
145          let fvs = dVarSetToVarSet (freeVarsOf ann_e)
146        , disjointVarSet fvs recursive_calls
147        = go_exit captured (deAnnotate ann_e) fvs
148
149    -- We could not turn it into a exit joint point. So now recurse
150    -- into all expression where eligible exit join points might sit,
151    -- i.e. into all tail-call positions:
152
153    -- Case right hand sides are in tail-call position
154    go captured (_, AnnCase scrut bndr ty alts) = do
155        alts' <- forM alts $ \(dc, pats, rhs) -> do
156            rhs' <- go (captured ++ [bndr] ++ pats) rhs
157            return (dc, pats, rhs')
158        return $ Case (deAnnotate scrut) bndr ty alts'
159
160    go captured (_, AnnLet ann_bind body)
161        -- join point, RHS and body are in tail-call position
162        | AnnNonRec j rhs <- ann_bind
163        , Just join_arity <- isJoinId_maybe j
164        = do let (params, join_body) = collectNAnnBndrs join_arity rhs
165             join_body' <- go (captured ++ params) join_body
166             let rhs' = mkLams params join_body'
167             body' <- go (captured ++ [j]) body
168             return $ Let (NonRec j rhs') body'
169
170        -- rec join point, RHSs and body are in tail-call position
171        | AnnRec pairs <- ann_bind
172        , isJoinId (fst (head pairs))
173        = do let js = map fst pairs
174             pairs' <- forM pairs $ \(j,rhs) -> do
175                 let join_arity = idJoinArity j
176                     (params, join_body) = collectNAnnBndrs join_arity rhs
177                 join_body' <- go (captured ++ js ++ params) join_body
178                 let rhs' = mkLams params join_body'
179                 return (j, rhs')
180             body' <- go (captured ++ js) body
181             return $ Let (Rec pairs') body'
182
183        -- normal Let, only the body is in tail-call position
184        | otherwise
185        = do body' <- go (captured ++ bindersOf bind ) body
186             return $ Let bind body'
187      where bind = deAnnBind ann_bind
188
189    -- Cannot be turned into an exit join point, but also has no
190    -- tail-call subexpression. Nothing to do here.
191    go _ ann_e = return (deAnnotate ann_e)
192
193    ---------------------
194    go_exit :: [Var]      -- Variables captured locally
195            -> CoreExpr   -- An exit expression
196            -> VarSet     -- Free vars of the expression
197            -> ExitifyM CoreExpr
198    -- go_exit deals with a tail expression that is floatable
199    -- out as an exit point; that is, it mentions no recursive calls
200    go_exit captured e fvs
201      -- Do not touch an expression that is already a join jump where all arguments
202      -- are captured variables. See Note [Idempotency]
203      -- But _do_ float join jumps with interesting arguments.
204      -- See Note [Jumps can be interesting]
205      | (Var f, args) <- collectArgs e
206      , isJoinId f
207      , all isCapturedVarArg args
208      = return e
209
210      -- Do not touch a boring expression (see Note [Interesting expression])
211      | not is_interesting
212      = return e
213
214      -- Cannot float out if local join points are used, as
215      -- we cannot abstract over them
216      | captures_join_points
217      = return e
218
219      -- We have something to float out!
220      | otherwise
221      = do { -- Assemble the RHS of the exit join point
222             let rhs   = mkLams abs_vars e
223                 avoid = in_scope `extendInScopeSetList` captured
224             -- Remember this binding under a suitable name
225           ; v <- addExit avoid (length abs_vars) rhs
226             -- And jump to it from here
227           ; return $ mkVarApps (Var v) abs_vars }
228
229      where
230        -- Used to detect exit expressoins that are already proper exit jumps
231        isCapturedVarArg (Var v) = v `elem` captured
232        isCapturedVarArg _ = False
233
234        -- An interesting exit expression has free, non-imported
235        -- variables from outside the recursive group
236        -- See Note [Interesting expression]
237        is_interesting = anyVarSet isLocalId $
238                         fvs `minusVarSet` mkVarSet captured
239
240        -- The arguments of this exit join point
241        -- See Note [Picking arguments to abstract over]
242        abs_vars = snd $ foldr pick (fvs, []) captured
243          where
244            pick v (fvs', acc) | v `elemVarSet` fvs' = (fvs' `delVarSet` v, zap v : acc)
245                               | otherwise           = (fvs',               acc)
246
247        -- We are going to abstract over these variables, so we must
248        -- zap any IdInfo they have; see #15005
249        -- cf. SetLevels.abstractVars
250        zap v | isId v = setIdInfo v vanillaIdInfo
251              | otherwise = v
252
253        -- We cannot abstract over join points
254        captures_join_points = any isJoinId abs_vars
255
256
257-- Picks a new unique, which is disjoint from
258--  * the free variables of the whole joinrec
259--  * any bound variables (captured)
260--  * any exit join points created so far.
261mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
262mkExitJoinId in_scope ty join_arity = do
263    fs <- get
264    let avoid = in_scope `extendInScopeSetList` (map fst fs)
265                         `extendInScopeSet` exit_id_tmpl -- just cosmetics
266    return (uniqAway avoid exit_id_tmpl)
267  where
268    exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty
269                    `asJoinId` join_arity
270
271addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
272addExit in_scope join_arity rhs = do
273    -- Pick a suitable name
274    let ty = exprType rhs
275    v <- mkExitJoinId in_scope ty join_arity
276    fs <- get
277    put ((v,rhs):fs)
278    return v
279
280{-
281Note [Interesting expression]
282~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
283We do not want this to happen:
284
285  joinrec go 0     x y = x
286          go (n-1) x y = jump go (n-1) (x+y)
287  in …
288==>
289  join exit x = x
290  joinrec go 0     x y = jump exit x
291          go (n-1) x y = jump go (n-1) (x+y)
292  in …
293
294because the floated exit path (`x`) is simply a parameter of `go`; there are
295not useful interactions exposed this way.
296
297Neither do we want this to happen
298
299  joinrec go 0     x y = x+x
300          go (n-1) x y = jump go (n-1) (x+y)
301  in …
302==>
303  join exit x = x+x
304  joinrec go 0     x y = jump exit x
305          go (n-1) x y = jump go (n-1) (x+y)
306  in …
307
308where the floated expression `x+x` is a bit more complicated, but still not
309intersting.
310
311Expressions are interesting when they move an occurrence of a variable outside
312the recursive `go` that can benefit from being obviously called once, for example:
313 * a local thunk that can then be inlined (see example in note [Exitification])
314 * the parameter of a function, where the demand analyzer then can then
315   see that it is called at most once, and hence improve the function’s
316   strictness signature
317
318So we only hoist an exit expression out if it mentiones at least one free,
319non-imported variable.
320
321Note [Jumps can be interesting]
322~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
323A jump to a join point can be interesting, if its arguments contain free
324non-exported variables (z in the following example):
325
326  joinrec go 0     x y = jump j (x+z)
327          go (n-1) x y = jump go (n-1) (x+y)
328  in …
329==>
330  join exit x y = jump j (x+z)
331  joinrec go 0     x y = jump exit x
332          go (n-1) x y = jump go (n-1) (x+y)
333
334
335The join point itself can be interesting, even if none if its
336arguments have free variables free in the joinrec.  For example
337
338  join j p = case p of (x,y) -> x+y
339  joinrec go 0     x y = jump j (x,y)
340          go (n-1) x y = jump go (n-1) (x+y) y
341  in …
342
343Here, `j` would not be inlined because we do not inline something that looks
344like an exit join point (see Note [Do not inline exit join points]). But
345if we exitify the 'jump j (x,y)' we get
346
347  join j p = case p of (x,y) -> x+y
348  join exit x y = jump j (x,y)
349  joinrec go 0     x y = jump exit x y
350          go (n-1) x y = jump go (n-1) (x+y) y
351  in …
352
353and now 'j' can inline, and we get rid of the pair. Here's another
354example (assume `g` to be an imported function that, on its own,
355does not make this interesting):
356
357  join j y = map f y
358  joinrec go 0     x y = jump j (map g x)
359          go (n-1) x y = jump go (n-1) (x+y)
360  in …
361
362Again, `j` would not be inlined because we do not inline something that looks
363like an exit join point (see Note [Do not inline exit join points]).
364
365But after exitification we have
366
367  join j y = map f y
368  join exit x = jump j (map g x)
369  joinrec go 0     x y = jump j (map g x)
370              go (n-1) x y = jump go (n-1) (x+y)
371  in …
372
373and now we can inline `j` and this will allow `map/map` to fire.
374
375
376Note [Idempotency]
377~~~~~~~~~~~~~~~~~~
378
379We do not want this to happen, where we replace the floated expression with
380essentially the same expression:
381
382  join exit x = t (x*x)
383  joinrec go 0     x y = jump exit x
384          go (n-1) x y = jump go (n-1) (x+y)
385  in …
386==>
387  join exit x = t (x*x)
388  join exit' x = jump exit x
389  joinrec go 0     x y = jump exit' x
390          go (n-1) x y = jump go (n-1) (x+y)
391  in …
392
393So when the RHS is a join jump, and all of its arguments are captured variables,
394then we leave it in place.
395
396Note that `jump exit x` in this example looks interesting, as `exit` is a free
397variable. Therefore, idempotency does not simply follow from floating only
398interesting expressions.
399
400Note [Calculating free variables]
401~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
402We have two options where to annotate the tree with free variables:
403
404 A) The whole tree.
405 B) Each individual joinrec as we come across it.
406
407Downside of A: We pay the price on the whole module, even outside any joinrecs.
408Downside of B: We pay the price per joinrec, possibly multiple times when
409joinrecs are nested.
410
411Further downside of A: If the exitify function returns annotated expressions,
412it would have to ensure that the annotations are correct.
413
414We therefore choose B, and calculate the free variables in `exitify`.
415
416
417Note [Do not inline exit join points]
418~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
419When we have
420
421  let t = foo bar
422  join exit x = t (x*x)
423  joinrec go 0     x y = jump exit x
424          go (n-1) x y = jump go (n-1) (x+y)
425  in …
426
427we do not want the simplifier to simply inline `exit` back in (which it happily
428would).
429
430To prevent this, we need to recognize exit join points, and then disable
431inlining.
432
433Exit join points, recognizeable using `isExitJoinId` are join points with an
434occurence in a recursive group, and can be recognized (after the occurence
435analyzer ran!) using `isExitJoinId`.
436This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`,
437because the lambdas of a non-recursive join point are not considered for
438`occ_in_lam`.  For example, in the following code, `j1` is /not/ marked
439occ_in_lam, because `j2` is called only once.
440
441  join j1 x = x+1
442  join j2 y = join j1 (y+2)
443
444To prevent inlining, we check for isExitJoinId
445* In `preInlineUnconditionally` directly.
446* In `simplLetUnfolding` we simply give exit join points no unfolding, which
447  prevents inlining in `postInlineUnconditionally` and call sites.
448
449Note [Placement of the exitification pass]
450~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
451I (Joachim) experimented with multiple positions for the Exitification pass in
452the Core2Core pipeline:
453
454 A) Before the `simpl_phases`
455 B) Between the `simpl_phases` and the "main" simplifier pass
456 C) After demand_analyser
457 D) Before the final simplification phase
458
459Here is the table (this is without inlining join exit points in the final
460simplifier run):
461
462        Program |                       Allocs                      |                      Instrs
463                | ABCD.log     A.log     B.log     C.log     D.log  | ABCD.log     A.log     B.log     C.log     D.log
464----------------|---------------------------------------------------|-------------------------------------------------
465 fannkuch-redux |   -99.9%     +0.0%    -99.9%    -99.9%    -99.9%  |    -3.9%     +0.5%     -3.0%     -3.9%     -3.9%
466          fasta |    -0.0%     +0.0%     +0.0%     -0.0%     -0.0%  |    -8.5%     +0.0%     +0.0%     -0.0%     -8.5%
467            fem |     0.0%      0.0%      0.0%      0.0%     +0.0%  |    -2.2%     -0.1%     -0.1%     -2.1%     -2.1%
468           fish |     0.0%      0.0%      0.0%      0.0%     +0.0%  |    -3.1%     +0.0%     -1.1%     -1.1%     -0.0%
469   k-nucleotide |   -91.3%    -91.0%    -91.0%    -91.3%    -91.3%  |    -6.3%    +11.4%    +11.4%     -6.3%     -6.2%
470            scs |    -0.0%     -0.0%     -0.0%     -0.0%     -0.0%  |    -3.4%     -3.0%     -3.1%     -3.3%     -3.3%
471         simple |    -6.0%      0.0%     -6.0%     -6.0%     +0.0%  |    -3.4%     +0.0%     -5.2%     -3.4%     -0.1%
472  spectral-norm |    -0.0%      0.0%      0.0%     -0.0%     +0.0%  |    -2.7%     +0.0%     -2.7%     -5.4%     -5.4%
473----------------|---------------------------------------------------|-------------------------------------------------
474            Min |   -95.0%    -91.0%    -95.0%    -95.0%    -95.0%  |    -8.5%     -3.0%     -5.2%     -6.3%     -8.5%
475            Max |    +0.2%     +0.2%     +0.2%     +0.2%     +1.5%  |    +0.4%    +11.4%    +11.4%     +0.4%     +1.5%
476 Geometric Mean |    -4.7%     -2.1%     -4.7%     -4.7%     -4.6%  |    -0.4%     +0.1%     -0.1%     -0.3%     -0.2%
477
478Position A is disqualified, as it does not get rid of the allocations in
479fannkuch-redux.
480Position A and B are disqualified because it increases instructions in k-nucleotide.
481Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta.
482
483Assuming we have a budget of _one_ run of Exitification, then C wins (but we
484could get more from running it multiple times, as seen in fish).
485
486Note [Picking arguments to abstract over]
487~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488
489When we create an exit join point, so we need to abstract over those of its
490free variables that are be out-of-scope at the destination of the exit join
491point. So we go through the list `captured` and pick those that are actually
492free variables of the join point.
493
494We do not just `filter (`elemVarSet` fvs) captured`, as there might be
495shadowing, and `captured` may contain multiple variables with the same Unique. I
496these cases we want to abstract only over the last occurence, hence the `foldr`
497(with emphasis on the `r`). This is #15110.
498
499-}
500