1module Agda.Compiler.Treeless.Simplify (simplifyTTerm) where 2 3import Control.Arrow (second, (***)) 4import Control.Monad.Reader 5import qualified Data.List as List 6 7import Agda.Syntax.Treeless 8import Agda.Syntax.Literal 9 10import Agda.TypeChecking.Monad 11import Agda.TypeChecking.Primitive 12import Agda.TypeChecking.Substitute 13 14import Agda.Compiler.Treeless.Compare 15 16import Agda.Utils.List 17import Agda.Utils.Maybe 18 19import Agda.Utils.Impossible 20 21data SEnv = SEnv 22 { envSubst :: Substitution' TTerm 23 , envRewrite :: [(TTerm, TTerm)] } 24 25type S = Reader SEnv 26 27runS :: S a -> a 28runS m = runReader m $ SEnv IdS [] 29 30lookupVar :: Int -> S TTerm 31lookupVar i = asks $ (`lookupS` i) . envSubst 32 33onSubst :: (Substitution' TTerm -> Substitution' TTerm) -> S a -> S a 34onSubst f = local $ \ env -> env { envSubst = f (envSubst env) } 35 36onRewrite :: Substitution' TTerm -> S a -> S a 37onRewrite rho = local $ \ env -> env { envRewrite = map (applySubst rho *** applySubst rho) (envRewrite env) } 38 39addRewrite :: TTerm -> TTerm -> S a -> S a 40addRewrite lhs rhs = local $ \ env -> env { envRewrite = (lhs, rhs) : envRewrite env } 41 42underLams :: Int -> S a -> S a 43underLams i = onRewrite (raiseS i) . onSubst (liftS i) 44 45underLam :: S a -> S a 46underLam = underLams 1 47 48underLet :: TTerm -> S a -> S a 49underLet u = onRewrite (raiseS 1) . onSubst (\rho -> wkS 1 $ u :# rho) 50 51bindVar :: Int -> TTerm -> S a -> S a 52bindVar x u = onSubst (inplaceS x u `composeS`) 53 54rewrite :: TTerm -> S TTerm 55rewrite t = do 56 rules <- asks envRewrite 57 case [ rhs | (lhs, rhs) <- rules, equalTerms t lhs ] of 58 rhs : _ -> pure rhs 59 [] -> pure t 60 61data FunctionKit = FunctionKit 62 { modAux, divAux, natMinus, true, false :: Maybe QName } 63 64simplifyTTerm :: TTerm -> TCM TTerm 65simplifyTTerm t = do 66 kit <- FunctionKit <$> getBuiltinName builtinNatModSucAux 67 <*> getBuiltinName builtinNatDivSucAux 68 <*> getBuiltinName builtinNatMinus 69 <*> getBuiltinName builtinTrue 70 <*> getBuiltinName builtinFalse 71 return $ runS $ simplify kit t 72 73simplify :: FunctionKit -> TTerm -> S TTerm 74simplify FunctionKit{..} = simpl 75 where 76 simpl = rewrite' >=> unchainCase >=> \case 77 78 t@TDef{} -> pure t 79 t@TPrim{} -> pure t 80 t@TVar{} -> pure t 81 82 TApp (TDef f) [TLit (LitNat 0), m, n, m'] 83 -- div/mod are equivalent to quot/rem on natural numbers. 84 | m == m', Just f == divAux -> simpl $ tOp PQuot n (tPlusK 1 m) 85 | m == m', Just f == modAux -> simpl $ tOp PRem n (tPlusK 1 m) 86 87 -- Word64 primitives -- 88 89 -- toWord (a ∙ b) == toWord a ∙64 toWord b 90 TPFn PITo64 (TPOp op a b) 91 | Just op64 <- opTo64 op -> simpl $ tOp op64 (TPFn PITo64 a) (TPFn PITo64 b) 92 where 93 opTo64 op = lookup op [(PAdd, PAdd64), (PSub, PSub64), (PMul, PMul64), 94 (PQuot, PQuot64), (PRem, PRem64)] 95 96 t@(TApp (TPrim _) _) -> pure t -- taken care of by rewrite' 97 98 TCoerce t -> TCoerce <$> simpl t 99 100 TApp f es -> do 101 f <- simpl f 102 es <- traverse simpl es 103 maybeMinusToPrim f es 104 TLam b -> TLam <$> underLam (simpl b) 105 t@TLit{} -> pure t 106 t@TCon{} -> pure t 107 TLet e b -> do 108 simpl e >>= \case 109 TPFn P64ToI a -> do 110 -- Inline calls to P64ToI since these trigger optimisations. 111 -- Ideally, the optimisations would trigger anyway, but at the 112 -- moment they only do if inlining the entire let looks like a 113 -- good idea. 114 let rho = inplaceS 0 (TPFn P64ToI (TVar 0)) 115 tLet a <$> underLet a (simpl (applySubst rho b)) 116 e -> tLet e <$> underLet e (simpl b) 117 118 TCase x t d bs -> do 119 v <- lookupVar x 120 let (lets, u) = tLetView v 121 (d, bs) <- pruneBoolGuards d <$> traverse (simplAlt x) bs 122 case u of -- TODO: also for literals 123 _ | Just (c, as) <- conView u -> simpl $ matchCon lets c as d bs 124 | Just (k, TVar y) <- plusKView u -> simpl . mkLets lets . TCase y t d =<< mapM (matchPlusK y x k) bs 125 TCase y t1 d1 bs1 -> simpl $ mkLets lets $ TCase y t1 (distrDef case1 d1) $ 126 map (distrCase case1) bs1 127 where 128 -- Γ x Δ -> Γ _ Δ Θ y, where x maps to y and Θ are the lets 129 n = length lets 130 rho = liftS (x + n + 1) (raiseS 1) `composeS` 131 singletonS (x + n + 1) (TVar 0) `composeS` 132 raiseS (n + 1) 133 case1 = applySubst rho (TCase x t d bs) 134 135 distrDef v d | isUnreachable d = tUnreachable 136 | otherwise = tLet d v 137 138 distrCase v (TACon c a b) = TACon c a $ TLet b $ raiseFrom 1 a v 139 distrCase v (TALit l b) = TALit l $ TLet b v 140 distrCase v (TAGuard g b) = TAGuard g $ TLet b v 141 142 _ -> do 143 d <- simpl d 144 tCase x t d bs 145 146 t@TUnit -> pure t 147 t@TSort -> pure t 148 t@TErased -> pure t 149 t@TError{} -> pure t 150 151 conView (TCon c) = Just (c, []) 152 conView (TApp f as) = second (++ as) <$> conView f 153 conView e = Nothing 154 155 -- Collapse chained cases (case x of bs -> vs; _ -> case x of bs' -> vs' ==> 156 -- case x of bs -> vs; bs' -> vs') 157 unchainCase :: TTerm -> S TTerm 158 unchainCase e@(TCase x t d bs) = do 159 let (lets, u) = tLetView d 160 k = length lets 161 return $ case u of 162 TCase y _ d' bs' | x + k == y -> 163 mkLets lets $ TCase y t d' $ raise k bs ++ bs' 164 _ -> e 165 unchainCase e = return e 166 167 168 mkLets es b = foldr TLet b es 169 170 matchCon _ _ _ d [] = d 171 matchCon lets c as d (TALit{} : bs) = matchCon lets c as d bs 172 matchCon lets c as d (TAGuard{} : bs) = matchCon lets c as d bs 173 matchCon lets c as d (TACon c' a b : bs) 174 | c == c' = flip (foldr TLet) lets $ mkLet 0 as (raiseFrom a (length lets) b) 175 | otherwise = matchCon lets c as d bs 176 where 177 mkLet _ [] b = b 178 mkLet i (a : as) b = TLet (raise i a) $ mkLet (i + 1) as b 179 180 -- Simplify let y = x + k in case y of j -> u; _ | g[y] -> v 181 -- to let y = x + k in case x of j - k -> u; _ | g[x + k] -> v 182 matchPlusK :: Int -> Int -> Integer -> TAlt -> S TAlt 183 matchPlusK x y k (TALit (LitNat j) b) = return $ TALit (LitNat (j - k)) b 184 matchPlusK x y k (TAGuard g b) = flip TAGuard b <$> simpl (applySubst (inplaceS y (tPlusK k (TVar x))) g) 185 matchPlusK x y k TACon{} = __IMPOSSIBLE__ 186 matchPlusK x y k TALit{} = __IMPOSSIBLE__ 187 188 simplPrim (TApp f@TPrim{} args) = do 189 args <- mapM simpl args 190 inlined <- mapM inline args 191 let u = TApp f args 192 v = simplPrim' (TApp f inlined) 193 pure $ if v `betterThan` u then v else u 194 where 195 inline (TVar x) = do 196 v <- lookupVar x 197 if v == TVar x then pure v else inline v 198 inline (TApp f@TPrim{} args) = TApp f <$> mapM inline args 199 inline u@(TLet _ (TCase 0 _ _ _)) = pure u 200 inline (TLet e b) = inline (subst 0 e b) 201 inline u = pure u 202 simplPrim t = pure t 203 204 simplPrim' :: TTerm -> TTerm 205 simplPrim' (TApp (TPrim PSeq) (u : v : vs)) 206 | u == v = mkTApp v vs 207 | TApp TCon{} _ <- u = mkTApp v vs 208 | TApp TLit{} _ <- u = mkTApp v vs 209 simplPrim' (TApp (TPrim PLt) [u, v]) 210 | Just (PAdd, k, u) <- constArithView u, 211 Just (PAdd, j, v) <- constArithView v, 212 k == j = tOp PLt u v 213 | Just (PSub, k, u) <- constArithView u, 214 Just (PSub, j, v) <- constArithView v, 215 k == j = tOp PLt v u 216 | Just (PAdd, k, v) <- constArithView v, 217 TApp (TPrim P64ToI) [u] <- u, 218 k >= 2^64, Just trueCon <- true = TCon trueCon 219 | Just k <- intView u 220 , Just j <- intView v 221 , Just trueCon <- true 222 , Just falseCon <- false = if k < j then TCon trueCon else TCon falseCon 223 simplPrim' (TApp (TPrim PGeq) [u, v]) 224 | Just (PAdd, k, u) <- constArithView u, 225 Just (PAdd, j, v) <- constArithView v, 226 k == j = tOp PGeq u v 227 | Just (PSub, k, u) <- constArithView u, 228 Just (PSub, j, v) <- constArithView v, 229 k == j = tOp PGeq v u 230 | Just k <- intView u 231 , Just j <- intView v 232 , Just trueCon <- true 233 , Just falseCon <- false = if k >= j then TCon trueCon else TCon falseCon 234 simplPrim' (TApp (TPrim op) [u, v]) 235 | op `elem` [PGeq, PLt, PEqI] 236 , Just (PAdd, k, u) <- constArithView u 237 , Just j <- intView v = TApp (TPrim op) [u, tInt (j - k)] 238 simplPrim' (TApp (TPrim PEqI) [u, v]) 239 | Just (op1, k, u) <- constArithView u, 240 Just (op2, j, v) <- constArithView v, 241 op1 == op2, k == j, 242 op1 `elem` [PAdd, PSub] = tOp PEqI u v 243 simplPrim' (TPOp op u v) 244 | zeroL, isMul || isDiv = tInt 0 245 | zeroL, isAdd = v 246 | zeroR, isMul = tInt 0 247 | zeroR, isAdd || isSub = u 248 where zeroL = Just 0 == intView u || Just 0 == word64View u 249 zeroR = Just 0 == intView v || Just 0 == word64View v 250 isAdd = op `elem` [PAdd, PAdd64] 251 isSub = op `elem` [PSub, PSub64] 252 isMul = op `elem` [PMul, PMul64] 253 isDiv = op `elem` [PQuot, PQuot64, PRem, PRem64] 254 simplPrim' (TApp (TPrim op) [u, v]) 255 | Just u <- negView u, 256 Just v <- negView v, 257 op `elem` [PMul, PQuot] = tOp op u v 258 | Just u <- negView u, 259 op `elem` [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) 260 | Just v <- negView v, 261 op `elem` [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v) 262 simplPrim' (TApp (TPrim PRem) [u, v]) 263 | Just u <- negView u = simplArith $ tOp PSub (tInt 0) (tOp PRem u (unNeg v)) 264 | Just v <- negView v = tOp PRem u v 265 266 -- (fromWord a == fromWord b) = (a ==64 b) 267 simplPrim' (TPOp op (TPFn P64ToI a) (TPFn P64ToI b)) 268 | Just op64 <- opTo64 op = tOp op64 a b 269 where 270 opTo64 op = lookup op [(PEqI, PEq64), (PLt, PLt64)] 271 272 -- toWord/fromWord k == fromIntegral k 273 simplPrim' (TPFn PITo64 (TLit (LitNat n))) = TLit (LitWord64 (fromIntegral n)) 274 simplPrim' (TPFn P64ToI (TLit (LitWord64 n))) = TLit (LitNat (fromIntegral n)) 275 276 -- toWord (fromWord a) == a 277 simplPrim' (TPFn PITo64 (TPFn P64ToI a)) = a 278 279 simplPrim' (TApp f@(TPrim op) [u, v]) = simplArith $ TApp f [simplPrim' u, simplPrim' v] 280 simplPrim' u = u 281 282 unNeg u | Just v <- negView u = v 283 | otherwise = u 284 285 negView (TApp (TPrim PSub) [a, b]) 286 | Just 0 <- intView a = Just b 287 negView _ = Nothing 288 289 -- Count arithmetic operations 290 betterThan u v = operations u <= operations v 291 where 292 operations (TApp (TPrim _) [a, b]) = 1 + operations a + operations b 293 operations (TApp (TPrim PSeq) (a : _)) 294 | notVar a = 1000000 -- only seq on variables! 295 operations (TApp (TPrim _) [a]) = 1 + operations a 296 operations TVar{} = 0 297 operations TLit{} = 0 298 operations TCon{} = 0 299 operations TDef{} = 0 300 operations _ = 1000 301 302 notVar TVar{} = False 303 notVar _ = True 304 305 rewrite' t = rewrite =<< simplPrim t 306 307 constArithView :: TTerm -> Maybe (TPrim, Integer, TTerm) 308 constArithView (TApp (TPrim op) [TLit (LitNat k), u]) 309 | op `elem` [PAdd, PSub] = Just (op, k, u) 310 constArithView (TApp (TPrim op) [u, TLit (LitNat k)]) 311 | op == PAdd = Just (op, k, u) 312 | op == PSub = Just (PAdd, -k, u) 313 constArithView _ = Nothing 314 315 simplAlt x (TACon c a b) = TACon c a <$> underLams a (maybeAddRewrite (x + a) conTerm $ simpl b) 316 where conTerm = mkTApp (TCon c) $ map TVar $ downFrom a 317 simplAlt x (TALit l b) = TALit l <$> maybeAddRewrite x (TLit l) (simpl b) 318 simplAlt x (TAGuard g b) = TAGuard <$> simpl g <*> simpl b 319 320 -- If x is already bound we add a rewrite, otherwise we bind x to rhs. 321 maybeAddRewrite x rhs cont = do 322 v <- lookupVar x 323 case v of 324 TVar y | x == y -> bindVar x rhs $ cont 325 _ -> addRewrite v rhs cont 326 327 isTrue (TCon c) = Just c == true 328 isTrue _ = False 329 330 isFalse (TCon c) = Just c == false 331 isFalse _ = False 332 333 maybeMinusToPrim f@(TDef minus) es@[a, b] 334 | Just minus == natMinus = do 335 leq <- checkLeq b a 336 if leq then pure $ tOp PSub a b 337 else tApp f es 338 339 maybeMinusToPrim f es = tApp f es 340 341 tLet (TVar x) b = subst 0 (TVar x) b 342 tLet e (TVar 0) = e 343 tLet e b = TLet e b 344 345 tCase :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm 346 tCase x t d [] = pure d 347 tCase x t d bs 348 | isUnreachable d = 349 case reverse bs' of 350 [] -> pure d 351 TALit _ b : as -> tCase x t b (reverse as) 352 TAGuard _ b : as -> tCase x t b (reverse as) 353 TACon c a b : _ -> tCase' x t d bs' 354 | otherwise = do 355 d' <- lookupIfVar d 356 case d' of 357 TCase y _ d bs'' | x == y -> 358 tCase x t d (bs' ++ filter noOverlap bs'') 359 _ -> tCase' x t d bs' 360 where 361 bs' = filter (not . isUnreachable) bs 362 363 lookupIfVar (TVar i) = lookupVar i 364 lookupIfVar t = pure t 365 366 noOverlap b = not $ any (overlapped b) bs' 367 overlapped (TACon c _ _) (TACon c' _ _) = c == c' 368 overlapped (TALit l _) (TALit l' _) = l == l' 369 overlapped _ _ = False 370 371 -- Drop unreachable cases for Nat and Int cases. 372 pruneLitCases :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm 373 pruneLitCases x t d bs | CTNat == caseType t = 374 case complete bs [] Nothing of 375 Just bs' -> tCase x t tUnreachable bs' 376 Nothing -> return $ TCase x t d bs 377 where 378 complete bs small (Just upper) 379 | null $ [0..upper - 1] List.\\ small = Just [] 380 complete (b@(TALit (LitNat n) _) : bs) small upper = 381 (b :) <$> complete bs (n : small) upper 382 complete (b@(TAGuard (TApp (TPrim PGeq) [TVar y, TLit (LitNat j)]) _) : bs) small upper | x == y = 383 (b :) <$> complete bs small (Just $ maybe j (min j) upper) 384 complete _ _ _ = Nothing 385 386 pruneLitCases x t d bs 387 | CTInt == caseType t = return $ TCase x t d bs -- TODO 388 | otherwise = return $ TCase x t d bs 389 390 -- Drop 'false' branches and drop everything after 'true' branches (including the default 391 -- branch) 392 pruneBoolGuards d [] = (d, []) 393 pruneBoolGuards d (b@(TAGuard (TCon c) _) : bs) 394 | Just c == true = (tUnreachable, [b]) 395 | Just c == false = pruneBoolGuards d bs 396 pruneBoolGuards d (b : bs) = 397 second (b :) $ pruneBoolGuards d bs 398 399 tCase' x t d [] = return d 400 tCase' x t d bs = pruneLitCases x t d bs 401 402 tApp :: TTerm -> [TTerm] -> S TTerm 403 tApp (TLet e b) es = TLet e <$> underLet e (tApp b (raise 1 es)) 404 tApp (TCase x t d bs) es = do 405 d <- tApp d es 406 bs <- mapM (`tAppAlt` es) bs 407 simpl $ TCase x t d bs -- will resimplify branches 408 tApp (TVar x) es = do 409 v <- lookupVar x 410 case v of 411 _ | v /= TVar x && isAtomic v -> tApp v es 412 TLam{} -> tApp v es -- could blow up the code 413 _ -> pure $ mkTApp (TVar x) es 414 tApp f [] = pure f 415 tApp (TLam b) (TVar i : es) = tApp (subst 0 (TVar i) b) es 416 tApp (TLam b) (e : es) = tApp (TLet e b) es 417 tApp f es = pure $ TApp f es 418 419 tAppAlt (TACon c a b) es = TACon c a <$> underLams a (tApp b (raise a es)) 420 tAppAlt (TALit l b) es = TALit l <$> tApp b es 421 tAppAlt (TAGuard g b) es = TAGuard g <$> tApp b es 422 423 isAtomic = \case 424 TVar{} -> True 425 TCon{} -> True 426 TPrim{} -> True 427 TDef{} -> True 428 TLit{} -> True 429 TSort{} -> True 430 TErased{} -> True 431 TError{} -> True 432 _ -> False 433 434 checkLeq a b = do 435 rho <- asks envSubst 436 rwr <- asks envRewrite 437 let nf = toArith . applySubst rho 438 less = [ (nf a, nf b) | (TPOp PLt a b, rhs) <- rwr, isTrue rhs ] 439 leq = [ (nf b, nf a) | (TPOp PLt a b, rhs) <- rwr, isFalse rhs ] 440 441 match (j, as) (k, bs) 442 | as == bs = Just (j - k) 443 | otherwise = Nothing 444 445 -- Do we have x ≤ y given x' < y' + d ? 446 matchEqn d x y (x', y') = isJust $ do 447 k <- match x x' -- x = x' + k 448 j <- match y y' -- y = y' + j 449 guard (k <= j + d) -- x ≤ y if k ≤ j + d 450 451 matchLess = matchEqn 1 452 matchLeq = matchEqn 0 453 454 literal (j, []) (k, []) = j <= k 455 literal _ _ = False 456 457 -- k + fromWord x ≤ y if k + 2^64 - 1 ≤ y 458 wordUpperBound (k, [Pos (TApp (TPrim P64ToI) _)]) y = go (k + 2^64 - 1, []) y 459 wordUpperBound _ _ = False 460 461 -- x ≤ k + fromWord y if x ≤ k 462 wordLowerBound a (k, [Pos (TApp (TPrim P64ToI) _)]) = go a (k, []) 463 wordLowerBound _ _ = False 464 465 go x y = or 466 [ literal x y 467 , wordUpperBound x y 468 , wordLowerBound x y 469 , any (matchLess x y) less 470 , any (matchLeq x y) leq ] 471 472 return $ go (nf a) (nf b) 473 474type Arith = (Integer, [Atom]) 475 476data Atom = Pos TTerm | Neg TTerm 477 deriving (Show, Eq, Ord) 478 479aNeg :: Atom -> Atom 480aNeg (Pos a) = Neg a 481aNeg (Neg a) = Pos a 482 483aCancel :: [Atom] -> [Atom] 484aCancel (a : as) 485 | (aNeg a) `elem` as = aCancel (List.delete (aNeg a) as) 486 | otherwise = a : aCancel as 487aCancel [] = [] 488 489sortR :: Ord a => [a] -> [a] 490sortR = List.sortBy (flip compare) 491 492aAdd :: Arith -> Arith -> Arith 493aAdd (a, xs) (b, ys) = (a + b, aCancel $ sortR $ xs ++ ys) 494 495aSub :: Arith -> Arith -> Arith 496aSub (a, xs) (b, ys) = (a - b, aCancel $ sortR $ xs ++ map aNeg ys) 497 498fromArith :: Arith -> TTerm 499fromArith (n, []) = tInt n 500fromArith (0, xs) 501 | (ys, Pos a : zs) <- break isPos xs = foldl addAtom a (ys ++ zs) 502fromArith (n, xs) 503 | n < 0, (ys, Pos a : zs) <- break isPos xs = 504 tOp PSub (foldl addAtom a (ys ++ zs)) (tInt (-n)) 505fromArith (n, xs) = foldl addAtom (tInt n) xs 506 507isPos :: Atom -> Bool 508isPos Pos{} = True 509isPos Neg{} = False 510 511addAtom :: TTerm -> Atom -> TTerm 512addAtom t (Pos a) = tOp PAdd t a 513addAtom t (Neg a) = tOp PSub t a 514 515toArith :: TTerm -> Arith 516toArith t | Just n <- intView t = (n, []) 517toArith (TApp (TPrim PAdd) [a, b]) = aAdd (toArith a) (toArith b) 518toArith (TApp (TPrim PSub) [a, b]) = aSub (toArith a) (toArith b) 519toArith t = (0, [Pos t]) 520 521simplArith :: TTerm -> TTerm 522simplArith = fromArith . toArith 523