1module Contravariant.Extras.TH (
2    opContrazipDecs,
3    contrazipDecs,
4    contrazipExp,
5  ) where
6
7import Contravariant.Extras.Prelude
8import Data.Functor.Contravariant
9import Data.Functor.Contravariant.Divisible
10import Language.Haskell.TH.Syntax hiding (classP)
11import qualified TemplateHaskell.Compat.V0208 as Compat
12
13
14{-|
15Generates declarations in the spirit of the following:
16
17@
18tuple3 :: Monoid a => Op a b1 -> Op a b2 -> Op a b3 -> Op a ( b1 , b2 , b3 )
19tuple3 ( Op op1 ) ( Op op2 ) ( Op op3 ) =
20  Op $ \( v1 , v2 , v3 ) -> mconcat [ op1 v1 , op2 v2 , op3 v3 ]
21@
22-}
23opContrazipDecs :: String -> Int -> [ Dec ]
24opContrazipDecs baseName arity =
25  [ signature , value ]
26  where
27    name =
28      mkName (showString baseName (show arity))
29    signature =
30      SigD name type_
31      where
32        type_ =
33          ForallT vars cxt type_
34          where
35            vars =
36              map (Compat.specifiedPlainTV . mkName) ("a" : bs)
37              where
38                bs =
39                  map b (enumFromTo 1 arity)
40                  where
41                    b index =
42                      showString "b" (show index)
43            cxt =
44              [ pred ]
45              where
46                pred =
47                  Compat.classP ''Monoid [ a ]
48                  where
49                    a =
50                      VarT (mkName "a")
51            type_ =
52              foldr appArrowT result params
53              where
54                appArrowT a b =
55                  AppT (AppT ArrowT a) b
56                a =
57                  VarT (mkName "a")
58                result =
59                  AppT (AppT (ConT ''Op) a) tuple
60                  where
61                    tuple =
62                      foldl AppT (TupleT arity) params
63                      where
64                        params =
65                          map param (enumFromTo 1 arity)
66                          where
67                            param index =
68                              VarT (mkName (showString "b" (show index)))
69                params =
70                  map param (enumFromTo 1 arity)
71                  where
72                    param index =
73                      AppT (AppT (ConT ''Op) a) b
74                      where
75                        b =
76                          VarT (mkName (showString "b" (show index)))
77    value =
78      FunD name clauses
79      where
80        clauses =
81          [ clause ]
82          where
83            clause =
84              Clause pats body []
85              where
86                pats =
87                  map pat (enumFromTo 1 arity)
88                  where
89                    pat index =
90                      ConP 'Op pats
91                      where
92                        pats =
93                          [ VarP name ]
94                          where
95                            name =
96                              mkName (showString "op" (show index))
97                body =
98                  NormalB (AppE (ConE 'Op) lambda)
99                  where
100                    lambda =
101                      LamE pats exp
102                      where
103                        pats =
104                          [ TupP pats ]
105                          where
106                            pats =
107                              map pat (enumFromTo 1 arity)
108                              where
109                                pat index =
110                                  VarP (mkName (showString "v" (show index)))
111                        exp =
112                          AppE (VarE 'mconcat) (ListE applications)
113                          where
114                            applications =
115                              map application (enumFromTo 1 arity)
116                              where
117                                application index =
118                                  AppE (VarE opName) (VarE varName)
119                                  where
120                                    opName =
121                                      mkName (showString "op" (show index))
122                                    varName =
123                                      mkName (showString "v" (show index))
124
125{-|
126Generates declarations in the spirit of the following:
127
128@
129contrazip4 :: Divisible f => f a1 -> f a2 -> f a3 -> f a4 -> f ( a1 , a2 , a3 , a4 )
130contrazip4 f1 f2 f3 f4 =
131  divide $(TupleTH.splitTupleAt 4 1) f1 $
132  divide $(TupleTH.splitTupleAt 3 1) f2 $
133  divide $(TupleTH.splitTupleAt 2 1) f3 $
134  f4
135@
136-}
137contrazipDecs :: String -> Int -> [Dec]
138contrazipDecs baseName arity = [signature, value] where
139  name = mkName (showString baseName (show arity))
140  signature = SigD name (contrazipType arity)
141  value = FunD name clauses where
142    clauses = [clause] where
143      clause = Clause [] body [] where
144        body = NormalB (contrazipExp arity)
145
146contrazipType :: Int -> Type
147contrazipType arity = ForallT vars cxt type_ where
148  fName = mkName "f"
149  aNames = map aName (enumFromTo 1 arity) where
150    aName index = mkName (showString "a" (show index))
151  vars = map Compat.specifiedPlainTV (fName : aNames)
152  cxt = [pred] where
153    pred = Compat.classP ''Divisible [VarT fName]
154  type_ = foldr appArrowT result params where
155    appArrowT a b = AppT (AppT ArrowT a) b
156    result = AppT (VarT fName) tuple where
157      tuple = foldl AppT (TupleT arity) (map VarT aNames)
158    params = map param aNames where
159      param aName = AppT (VarT fName) (VarT aName)
160
161{-|
162Contrazip lambda expression of specified arity.
163
164Allows to create contrazip expressions of any arity:
165
166>>>:t $(return (contrazipExp 2))
167$(return (contrazipExp 2))
168  :: Data.Functor.Contravariant.Divisible.Divisible f =>
169     f a1 -> f a2 -> f (a1, a2)
170-}
171contrazipExp :: Int -> Exp
172contrazipExp arity = SigE (LamE pats body) (contrazipType arity) where
173  pats = map pat (enumFromTo 1 arity) where
174    pat index = VarP name where
175      name = mkName (showString "f" (show index))
176  body = exp arity where
177    exp index = case index of
178      1 -> VarE (mkName (showString "f" (show arity)))
179      _ -> foldl1 AppE [
180          VarE 'divide
181          ,
182          splitTupleAtExp index 1
183          ,
184          VarE (mkName (showString "f" (show (arity - index + 1))))
185          ,
186          exp (pred index)
187        ]
188
189splitTupleAtExp :: Int -> Int -> Exp
190splitTupleAtExp arity position =
191  let
192    nameByIndex index = Name (OccName ('_' : show index)) NameS
193    names = enumFromTo 0 (pred arity) & map nameByIndex
194    pats = names & map VarP
195    pat = TupP pats
196    exps = names & map VarE
197    body = splitAt position exps & \ (a, b) -> Compat.tupE [Compat.tupE a, Compat.tupE b]
198    in LamE [pat] body
199