1module Agda.Compiler.Treeless.Simplify (simplifyTTerm) where
2
3import Control.Arrow (second, (***))
4import Control.Monad.Reader
5import qualified Data.List as List
6
7import Agda.Syntax.Treeless
8import Agda.Syntax.Literal
9
10import Agda.TypeChecking.Monad
11import Agda.TypeChecking.Primitive
12import Agda.TypeChecking.Substitute
13
14import Agda.Compiler.Treeless.Compare
15
16import Agda.Utils.List
17import Agda.Utils.Maybe
18
19import Agda.Utils.Impossible
20
21data SEnv = SEnv
22  { envSubst   :: Substitution' TTerm
23  , envRewrite :: [(TTerm, TTerm)] }
24
25type S = Reader SEnv
26
27runS :: S a -> a
28runS m = runReader m $ SEnv IdS []
29
30lookupVar :: Int -> S TTerm
31lookupVar i = asks $ (`lookupS` i) . envSubst
32
33onSubst :: (Substitution' TTerm -> Substitution' TTerm) -> S a -> S a
34onSubst f = local $ \ env -> env { envSubst = f (envSubst env) }
35
36onRewrite :: Substitution' TTerm -> S a -> S a
37onRewrite rho = local $ \ env -> env { envRewrite = map (applySubst rho *** applySubst rho) (envRewrite env) }
38
39addRewrite :: TTerm -> TTerm -> S a -> S a
40addRewrite lhs rhs = local $ \ env -> env { envRewrite = (lhs, rhs) : envRewrite env }
41
42underLams :: Int -> S a -> S a
43underLams i = onRewrite (raiseS i) . onSubst (liftS i)
44
45underLam :: S a -> S a
46underLam = underLams 1
47
48underLet :: TTerm -> S a -> S a
49underLet u = onRewrite (raiseS 1) . onSubst (\rho -> wkS 1 $ u :# rho)
50
51bindVar :: Int -> TTerm -> S a -> S a
52bindVar x u = onSubst (inplaceS x u `composeS`)
53
54rewrite :: TTerm -> S TTerm
55rewrite t = do
56  rules <- asks envRewrite
57  case [ rhs | (lhs, rhs) <- rules, equalTerms t lhs ] of
58    rhs : _ -> pure rhs
59    []      -> pure t
60
61data FunctionKit = FunctionKit
62  { modAux, divAux, natMinus, true, false :: Maybe QName }
63
64simplifyTTerm :: TTerm -> TCM TTerm
65simplifyTTerm t = do
66  kit <- FunctionKit <$> getBuiltinName builtinNatModSucAux
67                     <*> getBuiltinName builtinNatDivSucAux
68                     <*> getBuiltinName builtinNatMinus
69                     <*> getBuiltinName builtinTrue
70                     <*> getBuiltinName builtinFalse
71  return $ runS $ simplify kit t
72
73simplify :: FunctionKit -> TTerm -> S TTerm
74simplify FunctionKit{..} = simpl
75  where
76    simpl = rewrite' >=> unchainCase >=> \case
77
78      t@TDef{}  -> pure t
79      t@TPrim{} -> pure t
80      t@TVar{}  -> pure t
81
82      TApp (TDef f) [TLit (LitNat 0), m, n, m']
83        -- div/mod are equivalent to quot/rem on natural numbers.
84        | m == m', Just f == divAux -> simpl $ tOp PQuot n (tPlusK 1 m)
85        | m == m', Just f == modAux -> simpl $ tOp PRem n (tPlusK 1 m)
86
87      -- Word64 primitives --
88
89      --  toWord (a ∙ b) == toWord a ∙64 toWord b
90      TPFn PITo64 (TPOp op a b)
91        | Just op64 <- opTo64 op -> simpl $ tOp op64 (TPFn PITo64 a) (TPFn PITo64 b)
92        where
93          opTo64 op = lookup op [(PAdd, PAdd64), (PSub, PSub64), (PMul, PMul64),
94                                 (PQuot, PQuot64), (PRem, PRem64)]
95
96      t@(TApp (TPrim _) _) -> pure t  -- taken care of by rewrite'
97
98      TCoerce t -> TCoerce <$> simpl t
99
100      TApp f es -> do
101        f  <- simpl f
102        es <- traverse simpl es
103        maybeMinusToPrim f es
104      TLam b    -> TLam <$> underLam (simpl b)
105      t@TLit{}  -> pure t
106      t@TCon{}  -> pure t
107      TLet e b  -> do
108        simpl e >>= \case
109          TPFn P64ToI a -> do
110            -- Inline calls to P64ToI since these trigger optimisations.
111            -- Ideally, the optimisations would trigger anyway, but at the
112            -- moment they only do if inlining the entire let looks like a
113            -- good idea.
114            let rho = inplaceS 0 (TPFn P64ToI (TVar 0))
115            tLet a <$> underLet a (simpl (applySubst rho b))
116          e -> tLet e <$> underLet e (simpl b)
117
118      TCase x t d bs -> do
119        v <- lookupVar x
120        let (lets, u) = tLetView v
121        (d, bs) <- pruneBoolGuards d <$> traverse (simplAlt x) bs
122        case u of                          -- TODO: also for literals
123          _ | Just (c, as)     <- conView u   -> simpl $ matchCon lets c as d bs
124            | Just (k, TVar y) <- plusKView u -> simpl . mkLets lets . TCase y t d =<< mapM (matchPlusK y x k) bs
125          TCase y t1 d1 bs1 -> simpl $ mkLets lets $ TCase y t1 (distrDef case1 d1) $
126                                       map (distrCase case1) bs1
127            where
128              -- Γ x Δ -> Γ _ Δ Θ y, where x maps to y and Θ are the lets
129              n     = length lets
130              rho   = liftS (x + n + 1) (raiseS 1)    `composeS`
131                      singletonS (x + n + 1) (TVar 0) `composeS`
132                      raiseS (n + 1)
133              case1 = applySubst rho (TCase x t d bs)
134
135              distrDef v d | isUnreachable d = tUnreachable
136                           | otherwise       = tLet d v
137
138              distrCase v (TACon c a b) = TACon c a $ TLet b $ raiseFrom 1 a v
139              distrCase v (TALit l b)   = TALit l   $ TLet b v
140              distrCase v (TAGuard g b) = TAGuard g $ TLet b v
141
142          _ -> do
143            d <- simpl d
144            tCase x t d bs
145
146      t@TUnit    -> pure t
147      t@TSort    -> pure t
148      t@TErased  -> pure t
149      t@TError{} -> pure t
150
151    conView (TCon c)    = Just (c, [])
152    conView (TApp f as) = second (++ as) <$> conView f
153    conView e           = Nothing
154
155    -- Collapse chained cases (case x of bs -> vs; _ -> case x of bs' -> vs'  ==>
156    --                         case x of bs -> vs; bs' -> vs')
157    unchainCase :: TTerm -> S TTerm
158    unchainCase e@(TCase x t d bs) = do
159      let (lets, u) = tLetView d
160          k = length lets
161      return $ case u of
162        TCase y _ d' bs' | x + k == y ->
163          mkLets lets $ TCase y t d' $ raise k bs ++ bs'
164        _ -> e
165    unchainCase e = return e
166
167
168    mkLets es b = foldr TLet b es
169
170    matchCon _ _ _ d [] = d
171    matchCon lets c as d (TALit{}   : bs) = matchCon lets c as d bs
172    matchCon lets c as d (TAGuard{} : bs) = matchCon lets c as d bs
173    matchCon lets c as d (TACon c' a b : bs)
174      | c == c'        = flip (foldr TLet) lets $ mkLet 0 as (raiseFrom a (length lets) b)
175      | otherwise      = matchCon lets c as d bs
176      where
177        mkLet _ []       b = b
178        mkLet i (a : as) b = TLet (raise i a) $ mkLet (i + 1) as b
179
180    -- Simplify let y = x + k in case y of j     -> u; _ | g[y]     -> v
181    -- to       let y = x + k in case x of j - k -> u; _ | g[x + k] -> v
182    matchPlusK :: Int -> Int -> Integer -> TAlt -> S TAlt
183    matchPlusK x y k (TALit (LitNat j) b) = return $ TALit (LitNat (j - k)) b
184    matchPlusK x y k (TAGuard g b) = flip TAGuard b <$> simpl (applySubst (inplaceS y (tPlusK k (TVar x))) g)
185    matchPlusK x y k TACon{} = __IMPOSSIBLE__
186    matchPlusK x y k TALit{} = __IMPOSSIBLE__
187
188    simplPrim (TApp f@TPrim{} args) = do
189        args    <- mapM simpl args
190        inlined <- mapM inline args
191        let u = TApp f args
192            v = simplPrim' (TApp f inlined)
193        pure $ if v `betterThan` u then v else u
194      where
195        inline (TVar x)                   = do
196          v <- lookupVar x
197          if v == TVar x then pure v else inline v
198        inline (TApp f@TPrim{} args)      = TApp f <$> mapM inline args
199        inline u@(TLet _ (TCase 0 _ _ _)) = pure u
200        inline (TLet e b)                 = inline (subst 0 e b)
201        inline u                          = pure u
202    simplPrim t = pure t
203
204    simplPrim' :: TTerm -> TTerm
205    simplPrim' (TApp (TPrim PSeq) (u : v : vs))
206      | u == v             = mkTApp v vs
207      | TApp TCon{} _ <- u = mkTApp v vs
208      | TApp TLit{} _ <- u = mkTApp v vs
209    simplPrim' (TApp (TPrim PLt) [u, v])
210      | Just (PAdd, k, u) <- constArithView u,
211        Just (PAdd, j, v) <- constArithView v,
212        k == j = tOp PLt u v
213      | Just (PSub, k, u) <- constArithView u,
214        Just (PSub, j, v) <- constArithView v,
215        k == j = tOp PLt v u
216      | Just (PAdd, k, v) <- constArithView v,
217        TApp (TPrim P64ToI) [u] <- u,
218        k >= 2^64, Just trueCon <- true = TCon trueCon
219      | Just k <- intView u
220      , Just j <- intView v
221      , Just trueCon <- true
222      , Just falseCon <- false = if k < j then TCon trueCon else TCon falseCon
223    simplPrim' (TApp (TPrim PGeq) [u, v])
224      | Just (PAdd, k, u) <- constArithView u,
225        Just (PAdd, j, v) <- constArithView v,
226        k == j = tOp PGeq u v
227      | Just (PSub, k, u) <- constArithView u,
228        Just (PSub, j, v) <- constArithView v,
229        k == j = tOp PGeq v u
230      | Just k <- intView u
231      , Just j <- intView v
232      , Just trueCon <- true
233      , Just falseCon <- false = if k >= j then TCon trueCon else TCon falseCon
234    simplPrim' (TApp (TPrim op) [u, v])
235      | op `elem` [PGeq, PLt, PEqI]
236      , Just (PAdd, k, u) <- constArithView u
237      , Just j <- intView v = TApp (TPrim op) [u, tInt (j - k)]
238    simplPrim' (TApp (TPrim PEqI) [u, v])
239      | Just (op1, k, u) <- constArithView u,
240        Just (op2, j, v) <- constArithView v,
241        op1 == op2, k == j,
242        op1 `elem` [PAdd, PSub] = tOp PEqI u v
243    simplPrim' (TPOp op u v)
244      | zeroL, isMul || isDiv = tInt 0
245      | zeroL, isAdd          = v
246      | zeroR, isMul          = tInt 0
247      | zeroR, isAdd || isSub = u
248      where zeroL = Just 0 == intView u || Just 0 == word64View u
249            zeroR = Just 0 == intView v || Just 0 == word64View v
250            isAdd = op `elem` [PAdd, PAdd64]
251            isSub = op `elem` [PSub, PSub64]
252            isMul = op `elem` [PMul, PMul64]
253            isDiv = op `elem` [PQuot, PQuot64, PRem, PRem64]
254    simplPrim' (TApp (TPrim op) [u, v])
255      | Just u <- negView u,
256        Just v <- negView v,
257        op `elem` [PMul, PQuot] = tOp op u v
258      | Just u <- negView u,
259        op `elem` [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v)
260      | Just v <- negView v,
261        op `elem` [PMul, PQuot] = simplArith $ tOp PSub (tInt 0) (tOp op u v)
262    simplPrim' (TApp (TPrim PRem) [u, v])
263      | Just u <- negView u  = simplArith $ tOp PSub (tInt 0) (tOp PRem u (unNeg v))
264      | Just v <- negView v  = tOp PRem u v
265
266      -- (fromWord a == fromWord b) = (a ==64 b)
267    simplPrim' (TPOp op (TPFn P64ToI a) (TPFn P64ToI b))
268        | Just op64 <- opTo64 op = tOp op64 a b
269        where
270          opTo64 op = lookup op [(PEqI, PEq64), (PLt, PLt64)]
271
272      -- toWord/fromWord k == fromIntegral k
273    simplPrim' (TPFn PITo64 (TLit (LitNat n)))    = TLit (LitWord64 (fromIntegral n))
274    simplPrim' (TPFn P64ToI (TLit (LitWord64 n))) = TLit (LitNat    (fromIntegral n))
275
276      -- toWord (fromWord a) == a
277    simplPrim' (TPFn PITo64 (TPFn P64ToI a)) = a
278
279    simplPrim' (TApp f@(TPrim op) [u, v]) = simplArith $ TApp f [simplPrim' u, simplPrim' v]
280    simplPrim' u = u
281
282    unNeg u | Just v <- negView u = v
283            | otherwise           = u
284
285    negView (TApp (TPrim PSub) [a, b])
286      | Just 0 <- intView a = Just b
287    negView _ = Nothing
288
289    -- Count arithmetic operations
290    betterThan u v = operations u <= operations v
291      where
292        operations (TApp (TPrim _) [a, b]) = 1 + operations a + operations b
293        operations (TApp (TPrim PSeq) (a : _))
294          | notVar a                       = 1000000  -- only seq on variables!
295        operations (TApp (TPrim _) [a])    = 1 + operations a
296        operations TVar{}                  = 0
297        operations TLit{}                  = 0
298        operations TCon{}                  = 0
299        operations TDef{}                  = 0
300        operations _                       = 1000
301
302        notVar TVar{} = False
303        notVar _      = True
304
305    rewrite' t = rewrite =<< simplPrim t
306
307    constArithView :: TTerm -> Maybe (TPrim, Integer, TTerm)
308    constArithView (TApp (TPrim op) [TLit (LitNat k), u])
309      | op `elem` [PAdd, PSub] = Just (op, k, u)
310    constArithView (TApp (TPrim op) [u, TLit (LitNat k)])
311      | op == PAdd = Just (op, k, u)
312      | op == PSub = Just (PAdd, -k, u)
313    constArithView _ = Nothing
314
315    simplAlt x (TACon c a b) = TACon c a <$> underLams a (maybeAddRewrite (x + a) conTerm $ simpl b)
316      where conTerm = mkTApp (TCon c) $ map TVar $ downFrom a
317    simplAlt x (TALit l b)   = TALit l   <$> maybeAddRewrite x (TLit l) (simpl b)
318    simplAlt x (TAGuard g b) = TAGuard   <$> simpl g <*> simpl b
319
320    -- If x is already bound we add a rewrite, otherwise we bind x to rhs.
321    maybeAddRewrite x rhs cont = do
322      v <- lookupVar x
323      case v of
324        TVar y | x == y -> bindVar x rhs $ cont
325        _ -> addRewrite v rhs cont
326
327    isTrue (TCon c) = Just c == true
328    isTrue _        = False
329
330    isFalse (TCon c) = Just c == false
331    isFalse _        = False
332
333    maybeMinusToPrim f@(TDef minus) es@[a, b]
334      | Just minus == natMinus = do
335      leq  <- checkLeq b a
336      if leq then pure $ tOp PSub a b
337             else tApp f es
338
339    maybeMinusToPrim f es = tApp f es
340
341    tLet (TVar x) b = subst 0 (TVar x) b
342    tLet e (TVar 0) = e
343    tLet e b        = TLet e b
344
345    tCase :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm
346    tCase x t d [] = pure d
347    tCase x t d bs
348      | isUnreachable d =
349        case reverse bs' of
350          [] -> pure d
351          TALit _ b   : as  -> tCase x t b (reverse as)
352          TAGuard _ b : as  -> tCase x t b (reverse as)
353          TACon c a b : _   -> tCase' x t d bs'
354      | otherwise = do
355        d' <- lookupIfVar d
356        case d' of
357          TCase y _ d bs'' | x == y ->
358            tCase x t d (bs' ++ filter noOverlap bs'')
359          _ -> tCase' x t d bs'
360      where
361        bs' = filter (not . isUnreachable) bs
362
363        lookupIfVar (TVar i) = lookupVar i
364        lookupIfVar t = pure t
365
366        noOverlap b = not $ any (overlapped b) bs'
367        overlapped (TACon c _ _)  (TACon c' _ _) = c == c'
368        overlapped (TALit l _)    (TALit l' _)   = l == l'
369        overlapped _              _              = False
370
371    -- Drop unreachable cases for Nat and Int cases.
372    pruneLitCases :: Int -> CaseInfo -> TTerm -> [TAlt] -> S TTerm
373    pruneLitCases x t d bs | CTNat == caseType t =
374      case complete bs [] Nothing of
375        Just bs' -> tCase x t tUnreachable bs'
376        Nothing  -> return $ TCase x t d bs
377      where
378        complete bs small (Just upper)
379          | null $ [0..upper - 1] List.\\ small = Just []
380        complete (b@(TALit (LitNat n) _) : bs) small upper =
381          (b :) <$> complete bs (n : small) upper
382        complete (b@(TAGuard (TApp (TPrim PGeq) [TVar y, TLit (LitNat j)]) _) : bs) small upper | x == y =
383          (b :) <$> complete bs small (Just $ maybe j (min j) upper)
384        complete _ _ _ = Nothing
385
386    pruneLitCases x t d bs
387      | CTInt == caseType t = return $ TCase x t d bs -- TODO
388      | otherwise           = return $ TCase x t d bs
389
390    -- Drop 'false' branches and drop everything after 'true' branches (including the default
391    -- branch)
392    pruneBoolGuards d [] = (d, [])
393    pruneBoolGuards d (b@(TAGuard (TCon c) _) : bs)
394      | Just c == true  = (tUnreachable, [b])
395      | Just c == false = pruneBoolGuards d bs
396    pruneBoolGuards d (b : bs) =
397      second (b :) $ pruneBoolGuards d bs
398
399    tCase' x t d [] = return d
400    tCase' x t d bs = pruneLitCases x t d bs
401
402    tApp :: TTerm -> [TTerm] -> S TTerm
403    tApp (TLet e b) es = TLet e <$> underLet e (tApp b (raise 1 es))
404    tApp (TCase x t d bs) es = do
405      d  <- tApp d es
406      bs <- mapM (`tAppAlt` es) bs
407      simpl $ TCase x t d bs    -- will resimplify branches
408    tApp (TVar x) es = do
409      v <- lookupVar x
410      case v of
411        _ | v /= TVar x && isAtomic v -> tApp v es
412        TLam{} -> tApp v es   -- could blow up the code
413        _      -> pure $ mkTApp (TVar x) es
414    tApp f [] = pure f
415    tApp (TLam b) (TVar i : es) = tApp (subst 0 (TVar i) b) es
416    tApp (TLam b) (e : es) = tApp (TLet e b) es
417    tApp f es = pure $ TApp f es
418
419    tAppAlt (TACon c a b) es = TACon c a <$> underLams a (tApp b (raise a es))
420    tAppAlt (TALit l b) es   = TALit l   <$> tApp b es
421    tAppAlt (TAGuard g b) es = TAGuard g <$> tApp b es
422
423    isAtomic = \case
424      TVar{}    -> True
425      TCon{}    -> True
426      TPrim{}   -> True
427      TDef{}    -> True
428      TLit{}    -> True
429      TSort{}   -> True
430      TErased{} -> True
431      TError{}  -> True
432      _         -> False
433
434    checkLeq a b = do
435      rho  <- asks envSubst
436      rwr  <- asks envRewrite
437      let nf = toArith . applySubst rho
438          less = [ (nf a, nf b) | (TPOp PLt a b, rhs) <- rwr, isTrue  rhs ]
439          leq  = [ (nf b, nf a) | (TPOp PLt a b, rhs) <- rwr, isFalse rhs ]
440
441          match (j, as) (k, bs)
442            | as == bs  = Just (j - k)
443            | otherwise = Nothing
444
445          -- Do we have x ≤ y given x' < y' + d ?
446          matchEqn d x y (x', y') = isJust $ do
447            k <- match x x'     -- x = x' + k
448            j <- match y y'     -- y = y' + j
449            guard (k <= j + d)  -- x ≤ y if k ≤ j + d
450
451          matchLess = matchEqn 1
452          matchLeq  = matchEqn 0
453
454          literal (j, []) (k, []) = j <= k
455          literal _ _ = False
456
457          -- k + fromWord x ≤ y  if  k + 2^64 - 1 ≤ y
458          wordUpperBound (k, [Pos (TApp (TPrim P64ToI) _)]) y = go (k + 2^64 - 1, []) y
459          wordUpperBound _ _ = False
460
461          -- x ≤ k + fromWord y  if  x ≤ k
462          wordLowerBound a (k, [Pos (TApp (TPrim P64ToI) _)]) = go a (k, [])
463          wordLowerBound _ _ = False
464
465          go x y = or
466            [ literal x y
467            , wordUpperBound x y
468            , wordLowerBound x y
469            , any (matchLess x y) less
470            , any (matchLeq x y) leq ]
471
472      return $ go (nf a) (nf b)
473
474type Arith = (Integer, [Atom])
475
476data Atom = Pos TTerm | Neg TTerm
477  deriving (Show, Eq, Ord)
478
479aNeg :: Atom -> Atom
480aNeg (Pos a) = Neg a
481aNeg (Neg a) = Pos a
482
483aCancel :: [Atom] -> [Atom]
484aCancel (a : as)
485  | (aNeg a) `elem` as = aCancel (List.delete (aNeg a) as)
486  | otherwise          = a : aCancel as
487aCancel [] = []
488
489sortR :: Ord a => [a] -> [a]
490sortR = List.sortBy (flip compare)
491
492aAdd :: Arith -> Arith -> Arith
493aAdd (a, xs) (b, ys) = (a + b, aCancel $ sortR $ xs ++ ys)
494
495aSub :: Arith -> Arith -> Arith
496aSub (a, xs) (b, ys) = (a - b, aCancel $ sortR $ xs ++ map aNeg ys)
497
498fromArith :: Arith -> TTerm
499fromArith (n, []) = tInt n
500fromArith (0, xs)
501  | (ys, Pos a : zs) <- break isPos xs = foldl addAtom a (ys ++ zs)
502fromArith (n, xs)
503  | n < 0, (ys, Pos a : zs) <- break isPos xs =
504    tOp PSub (foldl addAtom a (ys ++ zs)) (tInt (-n))
505fromArith (n, xs) = foldl addAtom (tInt n) xs
506
507isPos :: Atom -> Bool
508isPos Pos{} = True
509isPos Neg{} = False
510
511addAtom :: TTerm -> Atom -> TTerm
512addAtom t (Pos a) = tOp PAdd t a
513addAtom t (Neg a) = tOp PSub t a
514
515toArith :: TTerm -> Arith
516toArith t | Just n <- intView t = (n, [])
517toArith (TApp (TPrim PAdd) [a, b]) = aAdd (toArith a) (toArith b)
518toArith (TApp (TPrim PSub) [a, b]) = aSub (toArith a) (toArith b)
519toArith t = (0, [Pos t])
520
521simplArith :: TTerm -> TTerm
522simplArith = fromArith . toArith
523