1{-# Language CPP, TemplateHaskell #-}
2
3{-|
4Module      : Harness
5Description : Comparison functions for data type info used in tests
6Copyright   : Eric Mertens 2017
7License     : ISC
8Maintainer  : emertens@gmail.com
9
10This module provides comparison functions that are able to check
11that the computed 'DatatypeInfo' values match the expected ones
12up to alpha renaming.
13
14-}
15module Harness
16  ( validateDI
17  , validateCI
18  , equateCxt
19
20    -- * Utilities
21  , varKCompat
22  ) where
23
24import           Control.Monad
25import qualified Data.Map as Map
26import           Data.Map (Map)
27import           Data.Maybe
28import           Language.Haskell.TH
29import           Language.Haskell.TH.Datatype
30import           Language.Haskell.TH.Datatype.TyVarBndr
31import           Language.Haskell.TH.Lib (starK)
32
33validateDI :: DatatypeInfo -> DatatypeInfo -> ExpQ
34validateDI = validate equateDI
35
36validateCI :: ConstructorInfo -> ConstructorInfo -> ExpQ
37validateCI = validate equateCI
38
39validate :: (a -> a -> Either String ()) -> a -> a -> ExpQ
40validate equate x y = either fail (\_ -> [| return () |]) (equate x y)
41
42-- | If the arguments are equal up to renaming return @'Right' ()@,
43-- otherwise return a string exlaining the mismatch.
44equateDI :: DatatypeInfo -> DatatypeInfo -> Either String ()
45equateDI dat1 dat2 =
46  do check "datatypeName"          (nameBase . datatypeName)    dat1 dat2
47     check "datatypeVars len"      (length . datatypeVars)      dat1 dat2
48     check "datatypeInstTypes len" (length . datatypeInstTypes) dat1 dat2
49     check "datatypeVariant"       datatypeVariant              dat1 dat2
50     check "datatypeCons len"      (length . datatypeCons)      dat1 dat2
51
52     let sub = Map.fromList (zip (freeVariables (bndrParams (datatypeVars dat2)))
53                                 (map VarT (freeVariables (bndrParams (datatypeVars dat1)))))
54
55     check "datatypeVars" id
56       (datatypeVars dat1)
57       (substIntoTyVarBndrs sub (datatypeVars dat2))
58
59     check "datatypeInstTypes" id
60       (datatypeInstTypes dat1)
61       (applySubstitution sub (datatypeInstTypes dat2))
62
63     zipWithM_ (equateCxt "datatypeContext")
64       (datatypeContext dat1)
65       (applySubstitution sub (datatypeContext dat2))
66
67     zipWithM_ equateCI
68       (datatypeCons dat1)
69       (datatypeCons dat2) -- Don't bother applying the substitution here, as
70                           -- equateCI takes care of this for us
71
72equateCxt :: String -> Pred -> Pred -> Either String ()
73equateCxt lbl pred1 pred2 =
74  do check (lbl ++ " class")    asClassPred pred1 pred2
75     check (lbl ++ " equality") asEqualPred pred1 pred2
76
77-- | If the arguments are equal up to renaming return @'Right' ()@,
78-- otherwise return a string exlaining the mismatch.
79equateCI :: ConstructorInfo -> ConstructorInfo -> Either String ()
80equateCI con1 con2 =
81  do check "constructorName"       (nameBase . constructorName) con1 con2
82     check "constructorVariant"    constructorVariantBase       con1 con2
83
84     let sub1 = Map.fromList (zip (freeVariables (bndrParams (constructorVars con2)))
85                                  (map VarT (freeVariables (bndrParams (constructorVars con1)))))
86         sub2 = Map.fromList (zip (freeVariables con2)
87                                  (map VarT (freeVariables con1)))
88         sub  = Map.unions [sub1, sub2]
89
90     zipWithM_ (equateCxt "constructorContext")
91        (constructorContext con1)
92        (applySubstitution sub (constructorContext con2))
93
94     check "constructorVars" id
95        (constructorVars con1)
96        (substIntoTyVarBndrs sub (constructorVars con2))
97
98     check "constructorFields" id
99        (constructorFields con1)
100        (applySubstitution sub (constructorFields con2))
101
102     zipWithM_ equateStrictness
103        (constructorStrictness con1)
104        (constructorStrictness con2)
105  where
106    constructorVariantBase :: ConstructorInfo -> ConstructorVariant
107    constructorVariantBase con =
108      case constructorVariant con of
109        NormalConstructor        -> NormalConstructor
110        i@InfixConstructor{}     -> i
111        RecordConstructor fields -> RecordConstructor $ map (mkName . nameBase) fields
112
113-- Substitutes both type variable names and kinds.
114substIntoTyVarBndrs :: Map Name Type -> [TyVarBndr_ flag] -> [TyVarBndr_ flag]
115substIntoTyVarBndrs subst = map go
116  where
117    go = mapTV (substName subst) id (applySubstitution subst)
118
119    substName :: Map Name Type -> Name -> Name
120    substName subst n = fromMaybe n $ do
121      nty <- Map.lookup n subst
122      case nty of
123        VarT n' -> Just n'
124        _       -> Nothing
125
126bndrParams :: [TyVarBndr_ flag] -> [Type]
127bndrParams = map $ elimTV VarT (\n k -> SigT (VarT n) k)
128
129equateStrictness :: FieldStrictness -> FieldStrictness -> Either String ()
130equateStrictness fs1 fs2 =
131  check "constructorStrictness" oldGhcHack fs1 fs2
132  where
133#if MIN_VERSION_template_haskell(2,7,0)
134    oldGhcHack = id
135#else
136    -- GHC 7.0 and 7.2 didn't have an Unpacked TH constructor, so as a
137    -- simple workaround, we will treat unpackedAnnot as isStrictAnnot
138    -- (the closest equivalent).
139    oldGhcHack fs
140      | fs == unpackedAnnot = isStrictAnnot
141      | otherwise           = fs
142#endif
143
144check :: (Show b, Eq b) => String -> (a -> b) -> a -> a -> Either String ()
145check lbl f x y
146  | f x == f y = Right ()
147  | otherwise  = Left (lbl ++ ":\n\n" ++ show (f x) ++ "\n\n" ++ show (f y))
148
149-- If on a recent-enough version of Template Haskell, construct a kind variable.
150-- Otherwise, default to starK.
151varKCompat :: Name -> Kind
152#if MIN_VERSION_template_haskell(2,8,0)
153varKCompat = VarT
154#else
155varKCompat _ = starK
156#endif
157