1{-# LANGUAGE CPP #-}
2{-# LANGUAGE StrictData #-}
3{-# LANGUAGE OverloadedStrings #-}
4{-# LANGUAGE DeriveTraversable #-}
5{-# LANGUAGE DeriveLift #-}
6module Text.Collate.Trie
7  ( Trie
8  , empty
9  , insert
10  , alter
11  , unfoldTrie
12  , matchLongestPrefix
13  , lookupNonEmptyChild
14  )
15  where
16
17import Control.Monad (foldM)
18import qualified Data.IntMap as M
19import Data.Bifunctor (first)
20import Data.Binary (Binary(..))
21import Language.Haskell.TH.Syntax (Lift(..))
22import Instances.TH.Lift ()
23import Data.Maybe (fromMaybe)
24#if MIN_VERSION_base(4,11,0)
25#else
26import Data.Semigroup (Semigroup(..))
27#endif
28
29data Trie a = Trie (Maybe a) (Maybe (M.IntMap (Trie a)))
30  deriving (Show, Eq, Ord, Lift, Functor, Foldable, Traversable)
31
32instance Semigroup (Trie a) where
33   trie1 <> trie2 = foldr (uncurry insert) trie1 (unfoldTrie trie2)
34
35instance Monoid (Trie a) where
36   mempty = Trie Nothing Nothing
37   mappend = (<>)
38
39instance Binary a => Binary (Trie a) where
40   put (Trie mbv mbm) = put (mbv, mbm)
41   get = do
42     (mbv,mbm) <- get
43     return $ Trie mbv mbm
44
45empty :: Trie a
46empty = Trie Nothing Nothing
47
48unfoldTrie :: Trie a -> [([Int], a)]
49unfoldTrie  = map (first reverse) . go []
50 where
51  go xs (Trie (Just v) (Just m)) =
52    (xs, v) : concatMap (gopair xs) (M.toList m)
53  go xs (Trie (Just v) Nothing) = [(xs, v)]
54  go xs (Trie Nothing (Just m)) =
55    concatMap (gopair xs) (M.toList m)
56  go _ (Trie Nothing Nothing) = []
57  gopair xs (i, trie) = go (i:xs) trie
58
59insert :: [Int] -> a -> Trie a -> Trie a
60insert [] x (Trie _ mbm) = Trie (Just x) mbm
61insert (c:cs) x (Trie mbv (Just m)) =
62  case M.lookup c m of
63    Nothing   -> Trie mbv (Just (M.insert c (insert cs x empty) m))
64    Just trie -> Trie mbv (Just (M.insert c (insert cs x trie) m))
65insert (c:cs) x (Trie mbv Nothing) =
66  Trie mbv (Just (M.insert c (insert cs x empty) mempty))
67
68alter :: (Maybe a -> Maybe a) -> [Int] -> Trie a -> Trie a
69alter f [] (Trie mbv mbm) = Trie (f mbv) mbm
70alter f (c:cs) (Trie mbv (Just m)) =
71  Trie mbv (Just (M.insert c (alter f cs $ fromMaybe empty $ M.lookup c m) m))
72alter f (c:cs) (Trie mbv Nothing) =
73  Trie mbv (Just (M.insert c (alter f cs empty) mempty))
74
75type MatchState a = (Maybe (a, Int, Trie a), Int, Trie a)
76  -- best match so far, number of code points consumed, current subtrie
77
78{-# SPECIALIZE matchLongestPrefix :: Trie a -> [Int] -> Maybe (a, Int, Trie a) #-}
79-- returns Nothing for no match, or:
80-- Just (value, number of code points consumed, subtrie)
81matchLongestPrefix :: Foldable t => Trie a -> t Int -> Maybe (a, Int, Trie a)
82matchLongestPrefix trie = either id getBest . foldM go (Nothing, 0, trie)
83 where
84   getBest (x,_,_) = x
85   -- Left means we've failed, Right means we're still pursuing a match
86   go :: MatchState a -> Int -> Either (Maybe (a, Int, Trie a)) (MatchState a)
87   go (best, consumed, Trie _ mbm) c =
88     case mbm >>= M.lookup c of
89       -- char not matched: stop processing, return best so far:
90       Nothing -> Left best
91       -- char matched, with value: replace best, keep going:
92       Just subtrie@(Trie (Just x) _)
93               -> Right (Just (x, consumed + 1, subtrie), consumed + 1, subtrie)
94       -- char matched, but not value: keep best, keep going:
95       Just subtrie@(Trie Nothing _)
96               -> Right (best, consumed + 1, subtrie)
97
98-- | Return the sub-trie at the given branch if it exists and has a
99-- non-empty node
100lookupNonEmptyChild :: Trie a -> Int -> Maybe (a, Trie a)
101lookupNonEmptyChild (Trie _ Nothing) _ = Nothing
102lookupNonEmptyChild (Trie _ (Just m)) idx = do
103  Trie mnode m' <- M.lookup idx m
104  node <- mnode
105  return (node, Trie Nothing m')
106