1{-# LANGUAGE TupleSections #-}
2
3module Floskell.Imports ( sortImports, groupImports, splitImports ) where
4
5import           Control.Monad.Trans.State    ( State, execState, gets, modify )
6
7import           Data.Function                ( on )
8import           Data.List
9                 ( groupBy, inits, intercalate, sortOn, sortOn, unfoldr )
10import qualified Data.Map                     as M
11import           Data.Monoid                  ( First(..) )
12
13import           Floskell.Config
14                 ( ImportsGroup(..), ImportsGroupOrder(..) )
15
16import           Language.Haskell.Exts.Syntax ( ImportDecl(..), ModuleName(..) )
17
18moduleName :: ImportDecl a -> String
19moduleName i = case importModule i of
20    (ModuleName _ s) -> s
21
22splitOn :: Char -> String -> [String]
23splitOn c = unfoldr go
24  where
25    go [] = Nothing
26    go x = Just $ drop 1 <$> break (== c) x
27
28modulePrefixes :: String -> [String]
29modulePrefixes = map (intercalate ".") . reverse . inits . splitOn '.'
30
31data St a = St { stIndex  :: M.Map String Int
32               , stGroups :: M.Map Int (ImportsGroup, [ImportDecl a])
33               , stRest   :: [ImportDecl a]
34               }
35
36commonPrefixLength :: Eq a => [[a]] -> Int
37commonPrefixLength = go 0
38  where
39    go l [] = l
40    go l ([] : _) = l
41    go l ((x : xs) : ys) =
42        if all ((== [ x ]) . take 1) ys then go (l + 1) (xs : ys) else l
43
44sortImports :: [ImportDecl a] -> [ImportDecl a]
45sortImports = sortOn moduleName
46
47groupImports :: Int -> [ImportDecl a] -> [[ImportDecl a]]
48groupImports n = groupBy ((==) `on` prefix n)
49  where
50    prefix l = take 1 . drop l . splitOn '.' . moduleName
51
52lookupFirst :: Ord a => [a] -> M.Map a b -> Maybe b
53lookupFirst ks m = getFirst . mconcat $ map (First . (`M.lookup` m)) ks
54
55placeImport :: ImportDecl a -> State (St a) ()
56placeImport i = do
57    idx <- gets (lookupFirst (modulePrefixes $ moduleName i) . stIndex)
58    case idx of
59        Just idx' -> modify $ \s -> s { stGroups = placeAt idx' (stGroups s) }
60        Nothing -> modify $ \s -> s { stRest = stRest s ++ [ i ] }
61  where
62    placeAt = M.adjust (fmap (++ [ i ]))
63
64splitImports :: [ImportsGroup] -> [ImportDecl a] -> [[ImportDecl a]]
65splitImports groups imports = extract $
66    execState (mapM_ placeImport imports) initial
67  where
68    initial = St { stIndex  = M.fromList . concat $
69                       zipWith (\n g -> map (, n) (importsPrefixes g))
70                               [ 0 .. ]
71                               groups
72                 , stGroups = M.fromList $
73                       zipWith (\n g -> (n, (g, []))) [ 0 .. ] groups
74                 , stRest   = []
75                 }
76
77    extract s = filter (not . null) $
78        concatMap maybeSortAndGroup (M.elems $ stGroups s) ++ [ stRest s ]
79
80    maybeSortAndGroup (g, is) = case importsOrder g of
81        ImportsGroupKeep -> [ is ]
82        ImportsGroupSorted -> [ sortImports is ]
83        ImportsGroupGrouped -> groupImports (commonPrefixLength $
84                                             importsPrefixes g) $ sortImports is
85