1------------------------------------------------------------------------
2-- |
3-- Module           : Data.Parameterized.TH.GADT
4-- Copyright        : (c) Galois, Inc 2013-2019
5-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
6-- Description : Template Haskell primitives for working with large GADTs
7--
8-- This module declares template Haskell primitives so that it is easier
9-- to work with GADTs that have many constructors.
10------------------------------------------------------------------------
11{-# LANGUAGE DoAndIfThenElse #-}
12{-# LANGUAGE GADTs #-}
13{-# LANGUAGE MultiParamTypeClasses #-}
14{-# LANGUAGE TemplateHaskell #-}
15{-# LANGUAGE TypeOperators #-}
16{-# LANGUAGE EmptyCase #-}
17module Data.Parameterized.TH.GADT
18  ( -- * Instance generators
19    -- $typePatterns
20  structuralEquality
21  , structuralTypeEquality
22  , structuralTypeOrd
23  , structuralTraversal
24  , structuralShowsPrec
25  , structuralHash
26  , structuralHashWithSalt
27  , PolyEq(..)
28    -- * Template haskell utilities that may be useful in other contexts.
29  , DataD
30  , lookupDataType'
31  , asTypeCon
32  , conPat
33  , TypePat(..)
34  , dataParamTypes
35  , assocTypePats
36  ) where
37
38import Control.Monad
39import Data.Maybe
40import Data.Set (Set)
41import qualified Data.Set as Set
42import Language.Haskell.TH
43import Language.Haskell.TH.Datatype
44
45
46import Data.Parameterized.Classes
47
48------------------------------------------------------------------------
49-- Template Haskell utilities
50
51type DataD = DatatypeInfo
52
53lookupDataType' :: Name -> Q DatatypeInfo
54lookupDataType' = reifyDatatype
55
56-- | Given a constructor and string, this generates a pattern for matching
57-- the expression, and the names of variables bound by pattern in order
58-- they appear in constructor.
59conPat ::
60  ConstructorInfo {- ^ constructor information -} ->
61  String          {- ^ generated name prefix   -} ->
62  Q (Pat, [Name]) {- ^ pattern and bound names -}
63conPat con pre = do
64  nms <- newNames pre (length (constructorFields con))
65  return (ConP (constructorName con) (VarP <$> nms), nms)
66
67
68-- | Return an expression corresponding to the constructor.
69-- Note that this will have the type of a function expecting
70-- the argumetns given.
71conExpr :: ConstructorInfo -> Exp
72conExpr = ConE . constructorName
73
74------------------------------------------------------------------------
75-- TypePat
76
77-- | A type used to describe (and match) types appearing in generated pattern
78-- matches inside of the TH generators in this module ('structuralEquality',
79-- 'structuralTypeEquality', 'structuralTypeOrd', and 'structuralTraversal')
80data TypePat
81   = TypeApp TypePat TypePat -- ^ The application of a type.
82   | AnyType       -- ^ Match any type.
83   | DataArg Int   -- ^ Match the i'th argument of the data type we are traversing.
84   | ConType TypeQ -- ^ Match a ground type.
85
86matchTypePat :: [Type] -> TypePat -> Type -> Q Bool
87matchTypePat d (TypeApp p q) (AppT x y) = do
88  r <- matchTypePat d p x
89  case r of
90    True -> matchTypePat d q y
91    False -> return False
92matchTypePat _ AnyType _ = return True
93matchTypePat tps (DataArg i) tp
94  | i < 0 || i >= length tps = error ("Type pattern index " ++ show i ++ " out of bounds")
95  | otherwise = return (stripSigT (tps !! i) == tp)
96  where
97    -- th-abstraction can annotate type parameters with their kinds,
98    -- we ignore these for matching
99    stripSigT (SigT t _) = t
100    stripSigT t          = t
101matchTypePat _ (ConType tpq) tp = do
102  tp' <- tpq
103  return (tp' == tp)
104matchTypePat _ _ _ = return False
105
106-- | The dataParamTypes function returns the list of Type arguments
107-- for the constructor.  For example, if passed the DatatypeInfo for a
108-- @newtype Id a = MkId a@ then this would return @['SigT' ('VarT' a)
109-- 'StarT']@.  Note that there may be type *variables* not referenced
110-- in the returned array; this simply returns the type *arguments*.
111dataParamTypes :: DatatypeInfo -> [Type]
112dataParamTypes = datatypeInstTypes
113 -- see th-abstraction 'dataTypeVars' for the type variables if needed
114
115-- | Find value associated with first pattern that matches given pat if any.
116assocTypePats :: [Type] -> [(TypePat, v)] -> Type -> Q (Maybe v)
117assocTypePats _ [] _ = return Nothing
118assocTypePats dTypes ((p,v):pats) tp = do
119  r <- matchTypePat dTypes p tp
120  case r of
121    True -> return (Just v)
122    False -> assocTypePats dTypes pats tp
123
124------------------------------------------------------------------------
125-- Contructor cases
126
127typeVars :: TypeSubstitution a => a -> Set Name
128typeVars = Set.fromList . freeVariables
129
130
131-- | @structuralEquality@ declares a structural equality predicate.
132structuralEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
133structuralEquality tpq pats =
134  [| \x y -> isJust ($(structuralTypeEquality tpq pats) x y) |]
135
136joinEqMaybe :: Name -> Name -> ExpQ -> ExpQ
137joinEqMaybe x y r = do
138  [| if $(varE x) == $(varE y) then $(r) else Nothing |]
139
140joinTestEquality :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
141joinTestEquality f x y r =
142  [| case $(f) $(varE x) $(varE y) of
143      Nothing -> Nothing
144      Just Refl -> $(r)
145   |]
146
147matchEqArguments :: [Type]
148                    -- ^ Types bound by data arguments.
149                 -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
150                 -> Name
151                     -- ^ Name of constructor.
152                 -> Set Name
153                 -> [Type]
154                 -> [Name]
155                 -> [Name]
156                 -> ExpQ
157matchEqArguments dTypes pats cnm bnd (tp:tpl) (x:xl) (y:yl) = do
158  doesMatch <- assocTypePats dTypes pats tp
159  case doesMatch of
160    Just q -> do
161      let bnd' =
162            case tp of
163              AppT _ (VarT nm) -> Set.insert nm bnd
164              _ -> bnd
165      joinTestEquality q x y (matchEqArguments dTypes pats cnm bnd' tpl xl yl)
166    Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
167      joinEqMaybe x y        (matchEqArguments dTypes pats cnm bnd  tpl xl yl)
168    Nothing -> do
169      fail $ "Unsupported argument type " ++ show tp
170          ++ " in " ++ show (ppr cnm) ++ "."
171matchEqArguments _ _ _ _ [] [] [] = [| Just Refl |]
172matchEqArguments _ _ _ _ [] _  _  = error "Unexpected end of types."
173matchEqArguments _ _ _ _ _  [] _  = error "Unexpected end of names."
174matchEqArguments _ _ _ _ _  _  [] = error "Unexpected end of names."
175
176mkSimpleEqF :: [Type] -- ^ Data declaration types
177            -> Set Name
178             -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
179             -> ConstructorInfo
180             -> [Name]
181             -> ExpQ
182             -> Bool -- ^ wildcard case required
183             -> ExpQ
184mkSimpleEqF dTypes bnd pats con xv yQ multipleCases = do
185  -- Get argument types for constructor.
186  let nm = constructorName con
187  (yp,yv) <- conPat con "y"
188  let rv = matchEqArguments dTypes pats nm bnd (constructorFields con) xv yv
189  caseE yQ $ match (pure yp) (normalB rv) []
190           : [ match wildP (normalB [| Nothing |]) [] | multipleCases ]
191
192-- | Match equational form.
193mkEqF :: DatatypeInfo -- ^ Data declaration.
194      -> [(TypePat,ExpQ)]
195      -> ConstructorInfo
196      -> [Name]
197      -> ExpQ
198      -> Bool -- ^ wildcard case required
199      -> ExpQ
200mkEqF d pats con =
201  let dVars = dataParamTypes d  -- the type arguments for the constructor
202      -- bnd is the list of type arguments for this datatype.  Since
203      -- this is Functor equality, ignore the final type since this is
204      -- a higher-kinded equality.
205      bnd | null dVars = Set.empty
206          | otherwise  = typeVars (init dVars)
207  in mkSimpleEqF dVars bnd pats con
208
209-- | @structuralTypeEquality f@ returns a function with the type:
210--   @
211--     forall x y . f x -> f y -> Maybe (x :~: y)
212--   @
213structuralTypeEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
214structuralTypeEquality tpq pats = do
215  d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
216
217  let multipleCons = not (null (drop 1 (datatypeCons d)))
218      trueEqs yQ = [ do (xp,xv) <- conPat con "x"
219                        match (pure xp) (normalB (mkEqF d pats con xv yQ multipleCons)) []
220                   | con <- datatypeCons d
221                   ]
222
223  if null (datatypeCons d)
224    then [| \x -> case x of {} |]
225    else [| \x y -> $(caseE [| x |] (trueEqs [| y |])) |]
226
227-- | @structuralTypeOrd f@ returns a function with the type:
228--   @
229--     forall x y . f x -> f y -> OrderingF x y
230--   @
231--
232-- This implementation avoids matching on both the first and second
233-- parameters in a simple case expression in order to avoid stressing
234-- GHC's coverage checker. In the case that the first and second parameters
235-- have unique constructors, a simple numeric comparison is done to
236-- compute the result.
237structuralTypeOrd ::
238  TypeQ ->
239  [(TypePat,ExpQ)] {- ^ List of type patterns to match. -} ->
240  ExpQ
241structuralTypeOrd tpq l = do
242  d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
243
244  let withNumber :: ExpQ -> (Maybe ExpQ -> ExpQ) -> ExpQ
245      withNumber yQ k
246        | null (drop 1 (datatypeCons d)) = k Nothing
247        | otherwise =  [| let yn :: Int
248                              yn = $(caseE yQ (constructorNumberMatches (datatypeCons d)))
249                          in $(k (Just [| yn |])) |]
250
251  if null (datatypeCons d)
252    then [| \x -> case x of {} |]
253    else [| \x y -> $(withNumber [|y|] $ \mbYn -> caseE [| x |] (outerOrdMatches d [|y|] mbYn)) |]
254  where
255    constructorNumberMatches :: [ConstructorInfo] -> [MatchQ]
256    constructorNumberMatches cons =
257      [ match (recP (constructorName con) [])
258              (normalB (litE (integerL i)))
259              []
260      | (i,con) <- zip [0..] cons ]
261
262    outerOrdMatches :: DatatypeInfo -> ExpQ -> Maybe ExpQ -> [MatchQ]
263    outerOrdMatches d yExp mbYn =
264      [ do (pat,xv) <- conPat con "x"
265           match (pure pat)
266                 (normalB (do xs <- mkOrdF d l con i mbYn xv
267                              caseE yExp xs))
268                 []
269      | (i,con) <- zip [0..] (datatypeCons d) ]
270
271-- | Generate a list of fresh names using the base name
272-- and numbered 1 to @n@ to make them useful in conjunction with
273-- @-dsuppress-uniques@.
274newNames ::
275  String   {- ^ base name                     -} ->
276  Int      {- ^ quantity                      -} ->
277  Q [Name] {- ^ list of names: @base1@, @base2@, ... -}
278newNames base n = traverse (\i -> newName (base ++ show i)) [1..n]
279
280
281joinCompareF :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
282joinCompareF f x y r = do
283  [| case $(f) $(varE x) $(varE y) of
284      LTF -> LTF
285      GTF -> GTF
286      EQF -> $(r)
287   |]
288
289-- | Compare two variables, returning the third argument if they are equal.
290--
291-- This returns an 'OrdF' instance.
292joinCompareToOrdF :: Name -> Name -> ExpQ -> ExpQ
293joinCompareToOrdF x y r =
294  [| case compare $(varE x) $(varE y) of
295      LT -> LTF
296      GT -> GTF
297      EQ -> $(r)
298   |]
299
300-- | Match expression with given type to variables
301matchOrdArguments :: [Type]
302                     -- ^ Types bound by data arguments
303                  -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
304                  -> Name
305                     -- ^ Name of constructor.
306                  -> Set Name
307                    -- ^ Names bound in data declaration
308                  -> [Type]
309                     -- ^ Types for constructors
310                  -> [Name]
311                     -- ^ Variables bound in first pattern
312                  -> [Name]
313                     -- ^ Variables bound in second pattern
314                  -> ExpQ
315matchOrdArguments dTypes pats cnm bnd (tp : tpl) (x:xl) (y:yl) = do
316  doesMatch <- assocTypePats dTypes pats tp
317  case doesMatch of
318    Just f -> do
319      let bnd' = case tp of
320                   AppT _ (VarT nm) -> Set.insert nm bnd
321                   _ -> bnd
322      joinCompareF f x y (matchOrdArguments dTypes pats cnm bnd' tpl xl yl)
323    Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
324      joinCompareToOrdF x y (matchOrdArguments dTypes pats cnm bnd tpl xl yl)
325    Nothing ->
326      fail $ "Unsupported argument type " ++ show (ppr tp)
327             ++ " in " ++ show (ppr cnm) ++ "."
328matchOrdArguments _ _ _ _ [] [] [] = [| EQF |]
329matchOrdArguments _ _ _ _ [] _  _  = error "Unexpected end of types."
330matchOrdArguments _ _ _ _ _  [] _  = error "Unexpected end of names."
331matchOrdArguments _ _ _ _ _  _  [] = error "Unexpected end of names."
332
333mkSimpleOrdF :: [Type] -- ^ Data declaration types
334             -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
335             -> ConstructorInfo -- ^ Information about the second constructor
336             -> Integer -- ^ First constructor's index
337             -> Maybe ExpQ -- ^ Optional second constructor's index
338             -> [Name]  -- ^ Name from first pattern
339             -> Q [MatchQ]
340mkSimpleOrdF dTypes pats con xnum mbYn xv = do
341  (yp,yv) <- conPat con "y"
342  let rv = matchOrdArguments dTypes pats (constructorName con) Set.empty (constructorFields con) xv yv
343  -- Return match expression
344  return $ match (pure yp) (normalB rv) []
345         : case mbYn of
346             Nothing -> []
347             Just yn -> [match wildP (normalB [| if xnum < $yn then LTF else GTF |]) []]
348
349-- | Match equational form.
350mkOrdF :: DatatypeInfo -- ^ Data declaration.
351       -> [(TypePat,ExpQ)] -- ^ Patterns for matching arguments
352       -> ConstructorInfo
353       -> Integer
354       -> Maybe ExpQ -- ^ optional right constructr index
355       -> [Name]
356       -> Q [MatchQ]
357mkOrdF d pats = mkSimpleOrdF (datatypeInstTypes d) pats
358
359-- | @genTraverseOfType f var tp@ applies @f@ to @var@ where @var@ has type @tp@.
360genTraverseOfType :: [Type]
361                    -- ^ Argument types for the data declaration.
362                 -> [(TypePat, ExpQ)]
363                    -- ^ Patterrns the user provided for overriding type lookup.
364                  -> ExpQ -- ^ Function to apply
365                  -> ExpQ -- ^ Expression denoting value of this constructor field.
366                  -> Type -- ^ Type bound for this constructor field.
367                  -> Q (Maybe Exp)
368genTraverseOfType dataArgs pats f v tp = do
369  mr <- assocTypePats dataArgs pats tp
370  case mr of
371    Just g ->  Just <$> [| $(g) $(f) $(v) |]
372    Nothing ->
373      case tp of
374        AppT (ConT _) (AppT (VarT _) _) -> Just <$> [| traverse $(f) $(v) |]
375        AppT (VarT _) _ -> Just <$> [| $(f) $(v) |]
376        _ -> return Nothing
377
378-- | @traverseAppMatch patMatch cexp @ builds a case statement that matches a term with
379-- the constructor @c@ and applies @f@ to each argument.
380traverseAppMatch :: [Type]
381                    -- ^ Argument types for the data declaration.
382                 -> [(TypePat, ExpQ)]
383                    -- ^ Patterrns the user provided for overriding type lookup.
384                 -> ExpQ -- ^ Function @f@ given to `traverse`
385                 -> ConstructorInfo -- ^ Constructor to match.
386                 -> MatchQ
387traverseAppMatch dataArgs pats fv c0 = do
388  (pat,patArgs) <- conPat c0 "p"
389  exprs <- zipWithM (genTraverseOfType dataArgs pats fv) (varE <$> patArgs) (constructorFields c0)
390  let mkRes :: ExpQ -> [(Name, Maybe Exp)] -> ExpQ
391      mkRes e [] = e
392      mkRes e ((v,Nothing):r) =
393        mkRes (appE e (varE v)) r
394      mkRes e ((_,Just{}):r) = do
395        v <- newName "r"
396        lamE [varP v] (mkRes (appE e (varE v)) r)
397
398  -- Apply the remaining argument to the expression in list.
399  let applyRest :: ExpQ -> [Exp] -> ExpQ
400      applyRest e [] = e
401      applyRest e (a:r) = applyRest [| $(e) <*> $(pure a) |] r
402
403  -- Apply the first argument to the list
404  let applyFirst :: ExpQ -> [Exp] -> ExpQ
405      applyFirst e [] = [| pure $(e) |]
406      applyFirst e (a:r) = applyRest [| $(e) <$> $(pure a) |] r
407
408  let pargs = patArgs `zip` exprs
409  let rhs = applyFirst (mkRes (pure (conExpr c0)) pargs) (catMaybes exprs)
410  match (pure pat) (normalB rhs) []
411
412-- | @structuralTraversal tp@ generates a function that applies
413-- a traversal @f@ to the subterms with free variables in @tp@.
414structuralTraversal :: TypeQ -> [(TypePat, ExpQ)] -> ExpQ
415structuralTraversal tpq pats0 = do
416  d <- reifyDatatype =<< asTypeCon "structuralTraversal" =<< tpq
417  f <- newName "f"
418  a <- newName "a"
419  lamE [varP f, varP a] $
420      caseE (varE a)
421      (traverseAppMatch (datatypeInstTypes d) pats0 (varE f) <$> datatypeCons d)
422
423asTypeCon :: String -> Type -> Q Name
424asTypeCon _ (ConT nm) = return nm
425asTypeCon fn _ = fail (fn ++ " expected type constructor.")
426
427-- | @structuralHash tp@ generates a function with the type
428-- @Int -> tp -> Int@ that hashes type.
429--
430-- All arguments use `hashable`, and `structuralHashWithSalt` can be
431-- used instead as it allows user-definable patterns to be used at
432-- specific types.
433structuralHash :: TypeQ -> ExpQ
434structuralHash tpq = structuralHashWithSalt tpq []
435{-# DEPRECATED structuralHash "Use structuralHashWithSalt" #-}
436
437-- | @structuralHashWithSalt tp@ generates a function with the type
438-- @Int -> tp -> Int@ that hashes type.
439--
440-- The second arguments is for generating user-defined patterns to replace
441-- `hashWithSalt` for specific types.
442structuralHashWithSalt :: TypeQ -> [(TypePat, ExpQ)] -> ExpQ
443structuralHashWithSalt tpq pats = do
444  d <- reifyDatatype =<< asTypeCon "structuralHash" =<< tpq
445  s <- newName "s"
446  a <- newName "a"
447  lamE [varP s, varP a] $
448    caseE (varE a) (zipWith (matchHashCtor d pats (varE s)) [0..] (datatypeCons d))
449
450-- | This matches one of the constructors in a datatype when generating
451-- a `hashWithSalt` function.
452matchHashCtor :: DatatypeInfo
453                 -- ^ Data declaration of type we are hashing.
454              -> [(TypePat, ExpQ)]
455                 -- ^ User provide type patterns
456              -> ExpQ -- ^ Initial salt expression
457              -> Integer -- ^ Index of constructor
458              -> ConstructorInfo -- ^ Constructor information
459              -> MatchQ
460matchHashCtor d pats s0 i c = do
461  (pat,vars) <- conPat c "x"
462  let go s (e, tp) = do
463        mr <- assocTypePats (datatypeInstTypes d) pats tp
464        case mr of
465          Just f -> do
466            [| $(f) $(s) $(e) |]
467          Nothing ->
468            [| hashWithSalt $(s) $(e) |]
469  let s1 = [| hashWithSalt $(s0) ($(litE (IntegerL i)) :: Int) |]
470  let rhs = foldl go s1 (zip (varE <$> vars) (constructorFields c))
471  match (pure pat) (normalB rhs) []
472
473-- | @structuralShow tp@ generates a function with the type
474-- @tp -> ShowS@ that shows the constructor.
475structuralShowsPrec :: TypeQ -> ExpQ
476structuralShowsPrec tpq = do
477  d <- reifyDatatype =<< asTypeCon "structuralShowPrec" =<< tpq
478  p <- newName "_p"
479  a <- newName "a"
480  lamE [varP p, varP a] $
481    caseE (varE a) (matchShowCtor (varE p) <$> datatypeCons d)
482
483showCon :: ExpQ -> Name -> Int -> MatchQ
484showCon p nm n = do
485  vars <- newNames "x" n
486  let pat = ConP nm (VarP <$> vars)
487  let go s e = [| $(s) . showChar ' ' . showsPrec 11 $(varE e) |]
488  let ctor = [| showString $(return (LitE (StringL (nameBase nm)))) |]
489  let rhs | null vars = ctor
490          | otherwise = [| showParen ($(p) >= 11) $(foldl go ctor vars) |]
491  match (pure pat) (normalB rhs) []
492
493matchShowCtor :: ExpQ -> ConstructorInfo -> MatchQ
494matchShowCtor p con = showCon p (constructorName con) (length (constructorFields con))
495
496-- $typePatterns
497--
498-- The Template Haskell instance generators 'structuralEquality',
499-- 'structuralTypeEquality', 'structuralTypeOrd', and 'structuralTraversal'
500-- employ heuristics to generate valid instances in the majority of cases.  Most
501-- failures in the heuristics occur on sub-terms that are type indexed.  To
502-- handle cases where these functions fail to produce a valid instance, they
503-- take a list of exceptions in the form of their second parameter, which has
504-- type @[('TypePat', 'ExpQ')]@.  Each 'TypePat' is a /matcher/ that tells the
505-- TH generator to use the 'ExpQ' to process the matched sub-term.  Consider the
506-- following example:
507--
508-- > data T a b where
509-- >   C1 :: NatRepr n -> T () n
510-- >
511-- > instance TestEquality (T a) where
512-- >   testEquality = $(structuralTypeEquality [t|T|]
513-- >                    [ (ConType [t|NatRepr|] `TypeApp` AnyType, [|testEquality|])
514-- >                    ])
515--
516-- The exception list says that 'structuralTypeEquality' should use
517-- 'testEquality' to compare any sub-terms of type @'NatRepr' n@ in a value of
518-- type @T@.
519--
520-- * 'AnyType' means that the type parameter in that position can be instantiated as any type
521--
522-- * @'DataArg' n@ means that the type parameter in that position is the @n@-th
523--   type parameter of the GADT being traversed (@T@ in the example)
524--
525-- * 'TypeApp' is type application
526--
527-- * 'ConType' specifies a base type
528--
529-- The exception list could have equivalently (and more precisely) have been specified as:
530--
531-- > [(ConType [t|NatRepr|] `TypeApp` DataArg 1, [|testEquality|])]
532--
533-- The use of 'DataArg' says that the type parameter of the 'NatRepr' must
534-- be the same as the second type parameter of @T@.
535