1{-# LANGUAGE ParallelListComp #-}
2module Main where
3
4import Text.PrettyPrint
5
6import System.Environment ( getArgs )
7
8main = do
9         [s] <- getArgs
10         let n = read s
11         mapM_ (putStrLn . render . generate) [2..n]
12
13generate :: Int -> Doc
14generate n =
15  vcat [ text "#ifdef DEFINE_INSTANCES"
16       , data_instance "MVector s" "MV"
17       , data_instance "Vector" "V"
18       , class_instance "Unbox"
19       , class_instance "M.MVector MVector" <+> text "where"
20       , nest 2 $ vcat $ map method methods_MVector
21       , class_instance "G.Vector Vector" <+> text "where"
22       , nest 2 $ vcat $ map method methods_Vector
23       , text "#endif"
24       , text "#ifdef DEFINE_MUTABLE"
25       , define_zip "MVector s" "MV"
26       , define_unzip "MVector s" "MV"
27       , text "#endif"
28       , text "#ifdef DEFINE_IMMUTABLE"
29       , define_zip "Vector" "V"
30       , define_zip_rule
31       , define_unzip "Vector" "V"
32       , text "#endif"
33       ]
34
35  where
36    vars  = map (\c -> text ['_',c]) $ take n ['a'..]
37    varss = map (<> char 's') vars
38    tuple xs = parens $ hsep $ punctuate comma xs
39    vtuple xs = parens $ sep $ punctuate comma xs
40    con s = text s <> char '_' <> int n
41    var c = text ('_' : c : "_")
42
43    data_instance ty c
44      = hang (hsep [text "data instance", text ty, tuple vars])
45             4
46             (hsep [char '=', con c, text "{-# UNPACK #-} !Int"
47                   , vcat $ map (\v -> char '!' <> parens (text ty <+> v)) vars])
48
49    class_instance cls
50      = text "instance" <+> vtuple [text "Unbox" <+> v | v <- vars]
51                        <+> text "=>" <+> text cls <+> tuple vars
52
53
54    define_zip ty c
55      = sep [text "-- | /O(1)/ Zip" <+> int n <+> text "vectors"
56            ,name <+> text "::"
57                  <+> vtuple [text "Unbox" <+> v | v <- vars]
58                  <+> text "=>"
59                  <+> sep (punctuate (text " ->") [text ty <+> v | v <- vars])
60                  <+> text "->"
61                  <+> text ty <+> tuple vars
62             ,text "{-# INLINE_FUSED"  <+> name <+> text "#-}"
63             ,name <+> sep varss
64                   <+> text "="
65                   <+> con c
66                   <+> text "len"
67                   <+> sep [parens $ text "unsafeSlice"
68                                     <+> char '0'
69                                     <+> text "len"
70                                     <+> vs | vs <- varss]
71             ,nest 2 $ hang (text "where")
72                            2
73                     $ text "len ="
74                       <+> sep (punctuate (text " `delayed_min`")
75                                          [text "length" <+> vs | vs <- varss])
76             ]
77      where
78        name | n == 2    = text "zip"
79             | otherwise = text "zip" <> int n
80
81    define_zip_rule
82      = hang (text "{-# RULES" <+> text "\"stream/" <> name "zip"
83              <> text " [Vector.Unboxed]\" forall" <+> sep varss <+> char '.')
84             2 $
85             text "G.stream" <+> parens (name "zip" <+> sep varss)
86             <+> char '='
87             <+> text "Bundle." <> name "zipWith" <+> tuple (replicate n empty)
88             <+> sep [parens $ text "G.stream" <+> vs | vs <- varss]
89             $$ text "#-}"
90     where
91       name s | n == 2    = text s
92              | otherwise = text s <> int n
93
94
95    define_unzip ty c
96      = sep [text "-- | /O(1)/ Unzip" <+> int n <+> text "vectors"
97            ,name <+> text "::"
98                  <+> vtuple [text "Unbox" <+> v | v <- vars]
99                  <+> text "=>"
100                  <+> text ty <+> tuple vars
101                  <+> text "->" <+> vtuple [text ty <+> v | v <- vars]
102            ,text "{-# INLINE" <+> name <+> text "#-}"
103            ,name <+> pat c <+> text "="
104                  <+> vtuple varss
105            ]
106      where
107        name | n == 2    = text "unzip"
108             | otherwise = text "unzip" <> int n
109
110    pat c = parens $ con c <+> var 'n' <+> sep varss
111    patn c n = parens $ con c <+> (var 'n' <> int n)
112                              <+> sep [v <> int n | v <- varss]
113
114    qM s = text "M." <> text s
115    qG s = text "G." <> text s
116
117    gen_length c _ = (pat c, var 'n')
118
119    gen_unsafeSlice mod c rec
120      = (var 'i' <+> var 'm' <+> pat c,
121         con c <+> var 'm'
122               <+> vcat [parens
123                         $ text mod <> char '.' <> text rec
124                                    <+> var 'i' <+> var 'm' <+> vs
125                                        | vs <- varss])
126
127
128    gen_overlaps rec = (patn "MV" 1 <+> patn "MV" 2,
129                        vcat $ r : [text "||" <+> r | r <- rs])
130      where
131        r : rs = [qM rec <+> v <> char '1' <+> v <> char '2' | v <- varss]
132
133    gen_unsafeNew rec
134      = (var 'n',
135         mk_do [v <+> text "<-" <+> qM rec <+> var 'n' | v <- varss]
136               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
137
138    gen_unsafeReplicate rec
139      = (var 'n' <+> tuple vars,
140         mk_do [vs <+> text "<-" <+> qM rec <+> var 'n' <+> v
141                        | v  <- vars | vs <- varss]
142               $ text "return $" <+> con "MV" <+> var 'n' <+> sep varss)
143
144    gen_unsafeRead rec
145      = (pat "MV" <+> var 'i',
146         mk_do [v <+> text "<-" <+> qM rec <+> vs <+> var 'i' | v  <- vars
147                                                              | vs <- varss]
148               $ text "return" <+> tuple vars)
149
150    gen_unsafeWrite rec
151      = (pat "MV" <+> var 'i' <+> tuple vars,
152         mk_do [qM rec <+> vs <+> var 'i' <+> v | v  <- vars | vs <- varss]
153               empty)
154
155    gen_clear rec
156      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
157
158    gen_set rec
159      = (pat "MV" <+> tuple vars,
160         mk_do [qM rec <+> vs <+> v | vs <- varss | v <- vars] empty)
161
162    gen_unsafeCopy c q rec
163      = (patn "MV" 1 <+> patn c 2,
164         mk_do [q rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
165               empty)
166
167    gen_unsafeMove rec
168      = (patn "MV" 1 <+> patn "MV" 2,
169         mk_do [qM rec <+> vs <> char '1' <+> vs <> char '2' | vs <- varss]
170               empty)
171
172    gen_unsafeGrow rec
173      = (pat "MV" <+> var 'm',
174         mk_do [vs <> char '\'' <+> text "<-"
175                                <+> qM rec <+> vs <+> var 'm' | vs <- varss]
176               $ text "return $" <+> con "MV"
177                                 <+> parens (var 'm' <> char '+' <> var 'n')
178                                 <+> sep (map (<> char '\'') varss))
179
180    gen_initialize rec
181      = (pat "MV", mk_do [qM rec <+> vs | vs <- varss] empty)
182
183    gen_unsafeFreeze rec
184      = (pat "MV",
185         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
186               $ text "return $" <+> con "V" <+> var 'n'
187                                 <+> sep [vs <> char '\'' | vs <- varss])
188
189    gen_unsafeThaw rec
190      = (pat "V",
191         mk_do [vs <> char '\'' <+> text "<-" <+> qG rec <+> vs | vs <- varss]
192               $ text "return $" <+> con "MV" <+> var 'n'
193                                 <+> sep [vs <> char '\'' | vs <- varss])
194
195    gen_basicUnsafeIndexM rec
196      = (pat "V" <+> var 'i',
197         mk_do [v <+> text "<-" <+> qG rec <+> vs <+> var 'i'
198                        | vs <- varss | v <- vars]
199               $ text "return" <+> tuple vars)
200
201    gen_elemseq rec
202      = (char '_' <+> tuple vars,
203         vcat $ r : [char '.' <+> r | r <- rs])
204      where
205        r : rs = [qG rec <+> parens (text "undefined :: Vector" <+> v)
206                         <+> v | v <- vars]
207
208    mk_do cmds ret = hang (text "do")
209                          2
210                          $ vcat $ cmds ++ [ret]
211
212    method (s, f) = case f s of
213                      (p,e) ->  text "{-# INLINE" <+> text s <+> text " #-}"
214                                $$ hang (text s <+> p)
215                                   4
216                                   (char '=' <+> e)
217
218
219    methods_MVector = [("basicLength",            gen_length "MV")
220                      ,("basicUnsafeSlice",       gen_unsafeSlice "M" "MV")
221                      ,("basicOverlaps",          gen_overlaps)
222                      ,("basicUnsafeNew",         gen_unsafeNew)
223                      ,("basicUnsafeReplicate",   gen_unsafeReplicate)
224                      ,("basicUnsafeRead",        gen_unsafeRead)
225                      ,("basicUnsafeWrite",       gen_unsafeWrite)
226                      ,("basicClear",             gen_clear)
227                      ,("basicSet",               gen_set)
228                      ,("basicUnsafeCopy",        gen_unsafeCopy "MV" qM)
229                      ,("basicUnsafeMove",        gen_unsafeMove)
230                      ,("basicUnsafeGrow",        gen_unsafeGrow)
231                      ,("basicInitialize",        gen_initialize)]
232
233    methods_Vector  = [("basicUnsafeFreeze",      gen_unsafeFreeze)
234                      ,("basicUnsafeThaw",        gen_unsafeThaw)
235                      ,("basicLength",            gen_length "V")
236                      ,("basicUnsafeSlice",       gen_unsafeSlice "G" "V")
237                      ,("basicUnsafeIndexM",      gen_basicUnsafeIndexM)
238                      ,("basicUnsafeCopy",        gen_unsafeCopy "V" qG)
239                      ,("elemseq",                gen_elemseq)]
240