1{-# LANGUAGE PatternGuards, ViewPatterns #-}
2{-# LANGUAGE RecordWildCards #-}
3
4{-
5map f [] = []
6map f (x:xs) = f x : map f xs
7
8foldr f z [] = z
9foldr f z (x:xs) = f x (foldr f z xs)
10
11foldl f z [] = z
12foldl f z (x:xs) = foldl f (f z x) xs
13-}
14
15{-
16<TEST>
17f (x:xs) = negate x + f xs ; f [] = 0 -- f xs = foldr ((+) . negate) 0 xs
18f (x:xs) = x + 1 : f xs ; f [] = [] -- f xs = map (+ 1) xs
19f z (x:xs) = f (z*x) xs ; f z [] = z -- f z xs = foldl (*) z xs
20f a (x:xs) b = x + a + b : f a xs b ; f a [] b = [] -- f a xs b = map (\ x -> x + a + b) xs
21f [] a = return a ; f (x:xs) a = a + x >>= \fax -> f xs fax -- f xs a = foldM (+) a xs
22f (x:xs) a = a + x >>= \fax -> f xs fax ; f [] a = pure a -- f xs a = foldM (+) a xs
23foos [] x = x; foos (y:ys) x = foo y $ foos ys x -- foos ys x = foldr foo x ys
24f [] y = y; f (x:xs) y = f xs $ g x y -- f xs y = foldl (flip g) y xs
25f [] y = y; f (x : xs) y = let z = g x y in f xs z -- f xs y = foldl (flip g) y xs
26f [] y = y; f (x:xs) y = f xs (f xs z)
27fun [] = []; fun (x:xs) = f x xs ++ fun xs
28</TEST>
29-}
30
31
32module Hint.ListRec(listRecHint) where
33
34import Hint.Type (DeclHint, Severity(Suggestion, Warning), idea, toSS)
35
36import Data.Generics.Uniplate.DataOnly
37import Data.List.Extra
38import Data.Maybe
39import Data.Either.Extra
40import Control.Monad
41import Refact.Types hiding (RType(Match))
42
43import SrcLoc
44import GHC.Hs.Extension
45import GHC.Hs.Pat
46import GHC.Hs.Types
47import TysWiredIn
48import RdrName
49import GHC.Hs.Binds
50import GHC.Hs.Expr
51import GHC.Hs.Decls
52import BasicTypes
53
54import GHC.Util
55import Language.Haskell.GhclibParserEx.GHC.Hs.Pat
56import Language.Haskell.GhclibParserEx.GHC.Hs.Expr
57import Language.Haskell.GhclibParserEx.GHC.Hs.ExtendInstances
58import Language.Haskell.GhclibParserEx.GHC.Utils.Outputable
59import Language.Haskell.GhclibParserEx.GHC.Types.Name.Reader
60
61listRecHint :: DeclHint
62listRecHint _ _ = concatMap f . universe
63    where
64        f o = maybeToList $ do
65            let x = o
66            (x, addCase) <- findCase x
67            (use,severity,x) <- matchListRec x
68            let y = addCase x
69            guard $ recursiveStr `notElem` varss y
70            -- Maybe we can do better here maintaining source
71            -- formatting?
72            pure $ idea severity ("Use " ++ use) o y [Replace Decl (toSS o) [] (unsafePrettyPrint y)]
73
74recursiveStr :: String
75recursiveStr = "_recursive_"
76recursive = strToVar recursiveStr
77
78data ListCase =
79  ListCase
80    [String] -- recursion parameters
81    (LHsExpr GhcPs)  -- nil case
82    (String, String, LHsExpr GhcPs) -- cons case
83-- For cons-case delete any recursive calls with 'xs' in them. Any
84-- recursive calls are marked "_recursive_".
85
86data BList = BNil | BCons String String
87             deriving (Eq, Ord, Show)
88
89data Branch =
90  Branch
91    String  -- function name
92    [String]  -- parameters
93    Int -- list position
94    BList (LHsExpr GhcPs) -- list type/body
95
96
97---------------------------------------------------------------------
98-- MATCH THE RECURSION
99
100
101matchListRec :: ListCase -> Maybe (String, Severity, LHsExpr GhcPs)
102matchListRec o@(ListCase vs nil (x, xs, cons))
103    -- Suggest 'map'?
104    | [] <- vs, varToStr nil == "[]", (L _ (OpApp _ lhs c rhs)) <- cons, varToStr c == ":"
105    , astEq (fromParen rhs) recursive, xs `notElem` vars lhs
106    = Just $ (,,) "map" Hint.Type.Warning $
107      appsBracket [ strToVar "map", niceLambda [x] lhs, strToVar xs]
108    -- Suggest 'foldr'?
109    | [] <- vs, App2 op lhs rhs <- view cons
110    , xs `notElem` (vars op ++ vars lhs) -- the meaning of xs changes, see #793
111    , astEq (fromParen rhs) recursive
112    = Just $ (,,) "foldr" Suggestion $
113      appsBracket [ strToVar "foldr", niceLambda [x] $ appsBracket [op,lhs], nil, strToVar xs]
114    -- Suggest 'foldl'?
115    | [v] <- vs, view nil == Var_ v, (L _ (HsApp _ r lhs)) <- cons
116    , astEq (fromParen r) recursive
117    , xs `notElem` vars lhs
118    = Just $ (,,) "foldl" Suggestion $
119      appsBracket [ strToVar "foldl", niceLambda [v,x] lhs, strToVar v, strToVar xs]
120    -- Suggest 'foldM'?
121    | [v] <- vs, (L _ (HsApp _ ret res)) <- nil, isReturn ret, varToStr res == "()" || view res == Var_ v
122    , [L _ (BindStmt _ (view -> PVar_ b1) e _ _), L _ (BodyStmt _ (fromParen -> (L _ (HsApp _ r (view -> Var_ b2)))) _ _)] <- asDo cons
123    , b1 == b2, astEq r recursive, xs `notElem` vars e
124    , name <- "foldM" ++ ['_' | varToStr res == "()"]
125    = Just $ (,,) name Suggestion $
126      appsBracket [strToVar name, niceLambda [v,x] e, strToVar v, strToVar xs]
127    -- Nope, I got nothing ¯\_(ツ)_/¯.
128    | otherwise = Nothing
129
130-- Very limited attempt to convert >>= to do, only useful for
131-- 'foldM' / 'foldM_'.
132asDo :: LHsExpr GhcPs -> [LStmt GhcPs (LHsExpr GhcPs)]
133asDo (view ->
134       App2 bind lhs
135         (L _ (HsLam _ MG {
136              mg_origin=FromSource
137            , mg_alts=L _ [
138                 L _ Match {  m_ctxt=LambdaExpr
139                            , m_pats=[v@(L _ VarPat{})]
140                            , m_grhss=GRHSs _
141                                        [L _ (GRHS _ [] rhs)]
142                                        (L _ (EmptyLocalBinds _))}]}))
143      ) =
144  [ noLoc $ BindStmt noExtField v lhs noSyntaxExpr noSyntaxExpr
145  , noLoc $ BodyStmt noExtField rhs noSyntaxExpr noSyntaxExpr ]
146asDo (L _ (HsDo _ DoExpr (L _ stmts))) = stmts
147asDo x = [noLoc $ BodyStmt noExtField x noSyntaxExpr noSyntaxExpr]
148
149
150---------------------------------------------------------------------
151-- FIND THE CASE ANALYSIS
152
153
154findCase :: LHsDecl GhcPs -> Maybe (ListCase, LHsExpr GhcPs -> LHsDecl GhcPs)
155findCase x = do
156  -- Match a function binding with two alternatives.
157  (L _ (ValD _ FunBind {fun_matches=
158              MG{mg_origin=FromSource, mg_alts=
159                     (L _
160                            [ x1@(L _ Match{..}) -- Match fields.
161                            , x2]), ..} -- Match group fields.
162          , ..} -- Fun. bind fields.
163      )) <- pure x
164
165  Branch name1 ps1 p1 c1 b1 <- findBranch x1
166  Branch name2 ps2 p2 c2 b2 <- findBranch x2
167  guard (name1 == name2 && ps1 == ps2 && p1 == p2)
168  [(BNil, b1), (BCons x xs, b2)] <- pure $ sortOn fst [(c1, b1), (c2, b2)]
169  b2 <- transformAppsM (delCons name1 p1 xs) b2
170  (ps, b2) <- pure $ eliminateArgs ps1 b2
171
172  let ps12 = let (a, b) = splitAt p1 ps1 in map strToPat (a ++ xs : b) -- Function arguments.
173      emptyLocalBinds = noLoc $ EmptyLocalBinds noExtField -- Empty where clause.
174      gRHS e = noLoc $ GRHS noExtField [] e :: LGRHS GhcPs (LHsExpr GhcPs) -- Guarded rhs.
175      gRHSSs e = GRHSs noExtField [gRHS e] emptyLocalBinds -- Guarded rhs set.
176      match e = Match{m_ext=noExtField,m_pats=ps12, m_grhss=gRHSSs e, ..} -- Match.
177      matchGroup e = MG{mg_alts=noLoc [noLoc $ match e], mg_origin=Generated, ..} -- Match group.
178      funBind e = FunBind {fun_matches=matchGroup e, ..} :: HsBindLR GhcPs GhcPs -- Fun bind.
179
180  pure (ListCase ps b1 (x, xs, b2), noLoc . ValD noExtField . funBind)
181
182delCons :: String -> Int -> String -> LHsExpr GhcPs -> Maybe (LHsExpr GhcPs)
183delCons func pos var (fromApps -> (view -> Var_ x) : xs) | func == x = do
184    (pre, (view -> Var_ v) : post) <- pure $ splitAt pos xs
185    guard $ v == var
186    pure $ apps $ recursive : pre ++ post
187delCons _ _ _ x = pure x
188
189eliminateArgs :: [String] -> LHsExpr GhcPs -> ([String], LHsExpr GhcPs)
190eliminateArgs ps cons = (remove ps, transform f cons)
191  where
192    args = [zs | z : zs <- map fromApps $ universeApps cons, astEq z recursive]
193    elim = [all (\xs -> length xs > i && view (xs !! i) == Var_ p) args | (i, p) <- zipFrom 0 ps] ++ repeat False
194    remove = concat . zipWith (\b x -> [x | not b]) elim
195
196    f (fromApps -> x : xs) | astEq x recursive = apps $ x : remove xs
197    f x = x
198
199
200---------------------------------------------------------------------
201-- FIND A BRANCH
202
203
204findBranch :: LMatch GhcPs (LHsExpr GhcPs) -> Maybe Branch
205findBranch (L _ x) = do
206  Match { m_ctxt = FunRhs {mc_fun=(L _ name)}
207            , m_pats = ps
208            , m_grhss =
209              GRHSs {grhssGRHSs=[L l (GRHS _ [] body)]
210                        , grhssLocalBinds=L _ (EmptyLocalBinds _)
211                        }
212            } <- pure x
213  (a, b, c) <- findPat ps
214  pure $ Branch (occNameStr name) a b c $ simplifyExp body
215
216findPat :: [LPat GhcPs] -> Maybe ([String], Int, BList)
217findPat ps = do
218  ps <- mapM readPat ps
219  [i] <- pure $ findIndices isRight ps
220  let (left, [right]) = partitionEithers ps
221
222  pure (left, i, right)
223
224readPat :: LPat GhcPs -> Maybe (Either String BList)
225readPat (view -> PVar_ x) = Just $ Left x
226readPat (L _ (ParPat _ (L _ (ConPatIn (L _ n) (InfixCon (view -> PVar_ x) (view -> PVar_ xs))))))
227 | n == consDataCon_RDR = Just $ Right $ BCons x xs
228readPat (L _ (ConPatIn (L _ n) (PrefixCon [])))
229  | n == nameRdrName nilDataConName = Just $ Right BNil
230readPat _ = Nothing
231