1module Exitify ( exitifyProgram ) where 2 3{- 4Note [Exitification] 5~~~~~~~~~~~~~~~~~~~~ 6 7This module implements Exitification. The goal is to pull as much code out of 8recursive functions as possible, as the simplifier is better at inlining into 9call-sites that are not in recursive functions. 10 11Example: 12 13 let t = foo bar 14 joinrec go 0 x y = t (x*x) 15 go (n-1) x y = jump go (n-1) (x+y) 16 in … 17 18We’d like to inline `t`, but that does not happen: Because t is a thunk and is 19used in a recursive function, doing so might lose sharing in general. In 20this case, however, `t` is on the _exit path_ of `go`, so called at most once. 21How do we make this clearly visible to the simplifier? 22 23A code path (i.e., an expression in a tail-recursive position) in a recursive 24function is an exit path if it does not contain a recursive call. We can bind 25this expression outside the recursive function, as a join-point. 26 27Example result: 28 29 let t = foo bar 30 join exit x = t (x*x) 31 joinrec go 0 x y = jump exit x 32 go (n-1) x y = jump go (n-1) (x+y) 33 in … 34 35Now `t` is no longer in a recursive function, and good things happen! 36-} 37 38import GhcPrelude 39import Var 40import Id 41import IdInfo 42import CoreSyn 43import CoreUtils 44import State 45import Unique 46import VarSet 47import VarEnv 48import CoreFVs 49import FastString 50import Type 51import Util( mapSnd ) 52 53import Data.Bifunctor 54import Control.Monad 55 56-- | Traverses the AST, simply to find all joinrecs and call 'exitify' on them. 57-- The really interesting function is exitifyRec 58exitifyProgram :: CoreProgram -> CoreProgram 59exitifyProgram binds = map goTopLvl binds 60 where 61 goTopLvl (NonRec v e) = NonRec v (go in_scope_toplvl e) 62 goTopLvl (Rec pairs) = Rec (map (second (go in_scope_toplvl)) pairs) 63 -- Top-level bindings are never join points 64 65 in_scope_toplvl = emptyInScopeSet `extendInScopeSetList` bindersOfBinds binds 66 67 go :: InScopeSet -> CoreExpr -> CoreExpr 68 go _ e@(Var{}) = e 69 go _ e@(Lit {}) = e 70 go _ e@(Type {}) = e 71 go _ e@(Coercion {}) = e 72 go in_scope (Cast e' c) = Cast (go in_scope e') c 73 go in_scope (Tick t e') = Tick t (go in_scope e') 74 go in_scope (App e1 e2) = App (go in_scope e1) (go in_scope e2) 75 76 go in_scope (Lam v e') 77 = Lam v (go in_scope' e') 78 where in_scope' = in_scope `extendInScopeSet` v 79 80 go in_scope (Case scrut bndr ty alts) 81 = Case (go in_scope scrut) bndr ty (map go_alt alts) 82 where 83 in_scope1 = in_scope `extendInScopeSet` bndr 84 go_alt (dc, pats, rhs) = (dc, pats, go in_scope' rhs) 85 where in_scope' = in_scope1 `extendInScopeSetList` pats 86 87 go in_scope (Let (NonRec bndr rhs) body) 88 = Let (NonRec bndr (go in_scope rhs)) (go in_scope' body) 89 where 90 in_scope' = in_scope `extendInScopeSet` bndr 91 92 go in_scope (Let (Rec pairs) body) 93 | is_join_rec = mkLets (exitifyRec in_scope' pairs') body' 94 | otherwise = Let (Rec pairs') body' 95 where 96 is_join_rec = any (isJoinId . fst) pairs 97 in_scope' = in_scope `extendInScopeSetList` bindersOf (Rec pairs) 98 pairs' = mapSnd (go in_scope') pairs 99 body' = go in_scope' body 100 101 102-- | State Monad used inside `exitify` 103type ExitifyM = State [(JoinId, CoreExpr)] 104 105-- | Given a recursive group of a joinrec, identifies “exit paths” and binds them as 106-- join-points outside the joinrec. 107exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind] 108exitifyRec in_scope pairs 109 = [ NonRec xid rhs | (xid,rhs) <- exits ] ++ [Rec pairs'] 110 where 111 -- We need the set of free variables of many subexpressions here, so 112 -- annotate the AST with them 113 -- see Note [Calculating free variables] 114 ann_pairs = map (second freeVars) pairs 115 116 -- Which are the recursive calls? 117 recursive_calls = mkVarSet $ map fst pairs 118 119 (pairs',exits) = (`runState` []) $ do 120 forM ann_pairs $ \(x,rhs) -> do 121 -- go past the lambdas of the join point 122 let (args, body) = collectNAnnBndrs (idJoinArity x) rhs 123 body' <- go args body 124 let rhs' = mkLams args body' 125 return (x, rhs') 126 127 --------------------- 128 -- 'go' is the main working function. 129 -- It goes through the RHS (tail-call positions only), 130 -- checks if there are no more recursive calls, if so, abstracts over 131 -- variables bound on the way and lifts it out as a join point. 132 -- 133 -- ExitifyM is a state monad to keep track of floated binds 134 go :: [Var] -- ^ Variables that are in-scope here, but 135 -- not in scope at the joinrec; that is, 136 -- we must potentially abstract over them. 137 -- Invariant: they are kept in dependency order 138 -> CoreExprWithFVs -- ^ Current expression in tail position 139 -> ExitifyM CoreExpr 140 141 -- We first look at the expression (no matter what it shape is) 142 -- and determine if we can turn it into a exit join point 143 go captured ann_e 144 | -- An exit expression has no recursive calls 145 let fvs = dVarSetToVarSet (freeVarsOf ann_e) 146 , disjointVarSet fvs recursive_calls 147 = go_exit captured (deAnnotate ann_e) fvs 148 149 -- We could not turn it into a exit joint point. So now recurse 150 -- into all expression where eligible exit join points might sit, 151 -- i.e. into all tail-call positions: 152 153 -- Case right hand sides are in tail-call position 154 go captured (_, AnnCase scrut bndr ty alts) = do 155 alts' <- forM alts $ \(dc, pats, rhs) -> do 156 rhs' <- go (captured ++ [bndr] ++ pats) rhs 157 return (dc, pats, rhs') 158 return $ Case (deAnnotate scrut) bndr ty alts' 159 160 go captured (_, AnnLet ann_bind body) 161 -- join point, RHS and body are in tail-call position 162 | AnnNonRec j rhs <- ann_bind 163 , Just join_arity <- isJoinId_maybe j 164 = do let (params, join_body) = collectNAnnBndrs join_arity rhs 165 join_body' <- go (captured ++ params) join_body 166 let rhs' = mkLams params join_body' 167 body' <- go (captured ++ [j]) body 168 return $ Let (NonRec j rhs') body' 169 170 -- rec join point, RHSs and body are in tail-call position 171 | AnnRec pairs <- ann_bind 172 , isJoinId (fst (head pairs)) 173 = do let js = map fst pairs 174 pairs' <- forM pairs $ \(j,rhs) -> do 175 let join_arity = idJoinArity j 176 (params, join_body) = collectNAnnBndrs join_arity rhs 177 join_body' <- go (captured ++ js ++ params) join_body 178 let rhs' = mkLams params join_body' 179 return (j, rhs') 180 body' <- go (captured ++ js) body 181 return $ Let (Rec pairs') body' 182 183 -- normal Let, only the body is in tail-call position 184 | otherwise 185 = do body' <- go (captured ++ bindersOf bind ) body 186 return $ Let bind body' 187 where bind = deAnnBind ann_bind 188 189 -- Cannot be turned into an exit join point, but also has no 190 -- tail-call subexpression. Nothing to do here. 191 go _ ann_e = return (deAnnotate ann_e) 192 193 --------------------- 194 go_exit :: [Var] -- Variables captured locally 195 -> CoreExpr -- An exit expression 196 -> VarSet -- Free vars of the expression 197 -> ExitifyM CoreExpr 198 -- go_exit deals with a tail expression that is floatable 199 -- out as an exit point; that is, it mentions no recursive calls 200 go_exit captured e fvs 201 -- Do not touch an expression that is already a join jump where all arguments 202 -- are captured variables. See Note [Idempotency] 203 -- But _do_ float join jumps with interesting arguments. 204 -- See Note [Jumps can be interesting] 205 | (Var f, args) <- collectArgs e 206 , isJoinId f 207 , all isCapturedVarArg args 208 = return e 209 210 -- Do not touch a boring expression (see Note [Interesting expression]) 211 | not is_interesting 212 = return e 213 214 -- Cannot float out if local join points are used, as 215 -- we cannot abstract over them 216 | captures_join_points 217 = return e 218 219 -- We have something to float out! 220 | otherwise 221 = do { -- Assemble the RHS of the exit join point 222 let rhs = mkLams abs_vars e 223 avoid = in_scope `extendInScopeSetList` captured 224 -- Remember this binding under a suitable name 225 ; v <- addExit avoid (length abs_vars) rhs 226 -- And jump to it from here 227 ; return $ mkVarApps (Var v) abs_vars } 228 229 where 230 -- Used to detect exit expressoins that are already proper exit jumps 231 isCapturedVarArg (Var v) = v `elem` captured 232 isCapturedVarArg _ = False 233 234 -- An interesting exit expression has free, non-imported 235 -- variables from outside the recursive group 236 -- See Note [Interesting expression] 237 is_interesting = anyVarSet isLocalId $ 238 fvs `minusVarSet` mkVarSet captured 239 240 -- The arguments of this exit join point 241 -- See Note [Picking arguments to abstract over] 242 abs_vars = snd $ foldr pick (fvs, []) captured 243 where 244 pick v (fvs', acc) | v `elemVarSet` fvs' = (fvs' `delVarSet` v, zap v : acc) 245 | otherwise = (fvs', acc) 246 247 -- We are going to abstract over these variables, so we must 248 -- zap any IdInfo they have; see #15005 249 -- cf. SetLevels.abstractVars 250 zap v | isId v = setIdInfo v vanillaIdInfo 251 | otherwise = v 252 253 -- We cannot abstract over join points 254 captures_join_points = any isJoinId abs_vars 255 256 257-- Picks a new unique, which is disjoint from 258-- * the free variables of the whole joinrec 259-- * any bound variables (captured) 260-- * any exit join points created so far. 261mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId 262mkExitJoinId in_scope ty join_arity = do 263 fs <- get 264 let avoid = in_scope `extendInScopeSetList` (map fst fs) 265 `extendInScopeSet` exit_id_tmpl -- just cosmetics 266 return (uniqAway avoid exit_id_tmpl) 267 where 268 exit_id_tmpl = mkSysLocal (fsLit "exit") initExitJoinUnique ty 269 `asJoinId` join_arity 270 271addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId 272addExit in_scope join_arity rhs = do 273 -- Pick a suitable name 274 let ty = exprType rhs 275 v <- mkExitJoinId in_scope ty join_arity 276 fs <- get 277 put ((v,rhs):fs) 278 return v 279 280{- 281Note [Interesting expression] 282~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 283We do not want this to happen: 284 285 joinrec go 0 x y = x 286 go (n-1) x y = jump go (n-1) (x+y) 287 in … 288==> 289 join exit x = x 290 joinrec go 0 x y = jump exit x 291 go (n-1) x y = jump go (n-1) (x+y) 292 in … 293 294because the floated exit path (`x`) is simply a parameter of `go`; there are 295not useful interactions exposed this way. 296 297Neither do we want this to happen 298 299 joinrec go 0 x y = x+x 300 go (n-1) x y = jump go (n-1) (x+y) 301 in … 302==> 303 join exit x = x+x 304 joinrec go 0 x y = jump exit x 305 go (n-1) x y = jump go (n-1) (x+y) 306 in … 307 308where the floated expression `x+x` is a bit more complicated, but still not 309intersting. 310 311Expressions are interesting when they move an occurrence of a variable outside 312the recursive `go` that can benefit from being obviously called once, for example: 313 * a local thunk that can then be inlined (see example in note [Exitification]) 314 * the parameter of a function, where the demand analyzer then can then 315 see that it is called at most once, and hence improve the function’s 316 strictness signature 317 318So we only hoist an exit expression out if it mentiones at least one free, 319non-imported variable. 320 321Note [Jumps can be interesting] 322~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 323A jump to a join point can be interesting, if its arguments contain free 324non-exported variables (z in the following example): 325 326 joinrec go 0 x y = jump j (x+z) 327 go (n-1) x y = jump go (n-1) (x+y) 328 in … 329==> 330 join exit x y = jump j (x+z) 331 joinrec go 0 x y = jump exit x 332 go (n-1) x y = jump go (n-1) (x+y) 333 334 335The join point itself can be interesting, even if none if its 336arguments have free variables free in the joinrec. For example 337 338 join j p = case p of (x,y) -> x+y 339 joinrec go 0 x y = jump j (x,y) 340 go (n-1) x y = jump go (n-1) (x+y) y 341 in … 342 343Here, `j` would not be inlined because we do not inline something that looks 344like an exit join point (see Note [Do not inline exit join points]). But 345if we exitify the 'jump j (x,y)' we get 346 347 join j p = case p of (x,y) -> x+y 348 join exit x y = jump j (x,y) 349 joinrec go 0 x y = jump exit x y 350 go (n-1) x y = jump go (n-1) (x+y) y 351 in … 352 353and now 'j' can inline, and we get rid of the pair. Here's another 354example (assume `g` to be an imported function that, on its own, 355does not make this interesting): 356 357 join j y = map f y 358 joinrec go 0 x y = jump j (map g x) 359 go (n-1) x y = jump go (n-1) (x+y) 360 in … 361 362Again, `j` would not be inlined because we do not inline something that looks 363like an exit join point (see Note [Do not inline exit join points]). 364 365But after exitification we have 366 367 join j y = map f y 368 join exit x = jump j (map g x) 369 joinrec go 0 x y = jump j (map g x) 370 go (n-1) x y = jump go (n-1) (x+y) 371 in … 372 373and now we can inline `j` and this will allow `map/map` to fire. 374 375 376Note [Idempotency] 377~~~~~~~~~~~~~~~~~~ 378 379We do not want this to happen, where we replace the floated expression with 380essentially the same expression: 381 382 join exit x = t (x*x) 383 joinrec go 0 x y = jump exit x 384 go (n-1) x y = jump go (n-1) (x+y) 385 in … 386==> 387 join exit x = t (x*x) 388 join exit' x = jump exit x 389 joinrec go 0 x y = jump exit' x 390 go (n-1) x y = jump go (n-1) (x+y) 391 in … 392 393So when the RHS is a join jump, and all of its arguments are captured variables, 394then we leave it in place. 395 396Note that `jump exit x` in this example looks interesting, as `exit` is a free 397variable. Therefore, idempotency does not simply follow from floating only 398interesting expressions. 399 400Note [Calculating free variables] 401~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 402We have two options where to annotate the tree with free variables: 403 404 A) The whole tree. 405 B) Each individual joinrec as we come across it. 406 407Downside of A: We pay the price on the whole module, even outside any joinrecs. 408Downside of B: We pay the price per joinrec, possibly multiple times when 409joinrecs are nested. 410 411Further downside of A: If the exitify function returns annotated expressions, 412it would have to ensure that the annotations are correct. 413 414We therefore choose B, and calculate the free variables in `exitify`. 415 416 417Note [Do not inline exit join points] 418~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 419When we have 420 421 let t = foo bar 422 join exit x = t (x*x) 423 joinrec go 0 x y = jump exit x 424 go (n-1) x y = jump go (n-1) (x+y) 425 in … 426 427we do not want the simplifier to simply inline `exit` back in (which it happily 428would). 429 430To prevent this, we need to recognize exit join points, and then disable 431inlining. 432 433Exit join points, recognizeable using `isExitJoinId` are join points with an 434occurence in a recursive group, and can be recognized (after the occurence 435analyzer ran!) using `isExitJoinId`. 436This function detects joinpoints with `occ_in_lam (idOccinfo id) == True`, 437because the lambdas of a non-recursive join point are not considered for 438`occ_in_lam`. For example, in the following code, `j1` is /not/ marked 439occ_in_lam, because `j2` is called only once. 440 441 join j1 x = x+1 442 join j2 y = join j1 (y+2) 443 444To prevent inlining, we check for isExitJoinId 445* In `preInlineUnconditionally` directly. 446* In `simplLetUnfolding` we simply give exit join points no unfolding, which 447 prevents inlining in `postInlineUnconditionally` and call sites. 448 449Note [Placement of the exitification pass] 450~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 451I (Joachim) experimented with multiple positions for the Exitification pass in 452the Core2Core pipeline: 453 454 A) Before the `simpl_phases` 455 B) Between the `simpl_phases` and the "main" simplifier pass 456 C) After demand_analyser 457 D) Before the final simplification phase 458 459Here is the table (this is without inlining join exit points in the final 460simplifier run): 461 462 Program | Allocs | Instrs 463 | ABCD.log A.log B.log C.log D.log | ABCD.log A.log B.log C.log D.log 464----------------|---------------------------------------------------|------------------------------------------------- 465 fannkuch-redux | -99.9% +0.0% -99.9% -99.9% -99.9% | -3.9% +0.5% -3.0% -3.9% -3.9% 466 fasta | -0.0% +0.0% +0.0% -0.0% -0.0% | -8.5% +0.0% +0.0% -0.0% -8.5% 467 fem | 0.0% 0.0% 0.0% 0.0% +0.0% | -2.2% -0.1% -0.1% -2.1% -2.1% 468 fish | 0.0% 0.0% 0.0% 0.0% +0.0% | -3.1% +0.0% -1.1% -1.1% -0.0% 469 k-nucleotide | -91.3% -91.0% -91.0% -91.3% -91.3% | -6.3% +11.4% +11.4% -6.3% -6.2% 470 scs | -0.0% -0.0% -0.0% -0.0% -0.0% | -3.4% -3.0% -3.1% -3.3% -3.3% 471 simple | -6.0% 0.0% -6.0% -6.0% +0.0% | -3.4% +0.0% -5.2% -3.4% -0.1% 472 spectral-norm | -0.0% 0.0% 0.0% -0.0% +0.0% | -2.7% +0.0% -2.7% -5.4% -5.4% 473----------------|---------------------------------------------------|------------------------------------------------- 474 Min | -95.0% -91.0% -95.0% -95.0% -95.0% | -8.5% -3.0% -5.2% -6.3% -8.5% 475 Max | +0.2% +0.2% +0.2% +0.2% +1.5% | +0.4% +11.4% +11.4% +0.4% +1.5% 476 Geometric Mean | -4.7% -2.1% -4.7% -4.7% -4.6% | -0.4% +0.1% -0.1% -0.3% -0.2% 477 478Position A is disqualified, as it does not get rid of the allocations in 479fannkuch-redux. 480Position A and B are disqualified because it increases instructions in k-nucleotide. 481Positions C and D have their advantages: C decreases allocations in simpl, but D instructions in fasta. 482 483Assuming we have a budget of _one_ run of Exitification, then C wins (but we 484could get more from running it multiple times, as seen in fish). 485 486Note [Picking arguments to abstract over] 487~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 488 489When we create an exit join point, so we need to abstract over those of its 490free variables that are be out-of-scope at the destination of the exit join 491point. So we go through the list `captured` and pick those that are actually 492free variables of the join point. 493 494We do not just `filter (`elemVarSet` fvs) captured`, as there might be 495shadowing, and `captured` may contain multiple variables with the same Unique. I 496these cases we want to abstract only over the last occurence, hence the `foldr` 497(with emphasis on the `r`). This is #15110. 498 499-} 500