1{-# LANGUAGE TemplateHaskell, GADTs #-}
2{-# OPTIONS_GHC -fno-warn-type-defaults -fno-warn-missing-signatures #-}
3module Data.Random.Source.Internal.TH (monadRandom, randomSource) where
4
5import Data.Bits
6import Data.Generics
7import Data.List
8import Data.Maybe
9import Data.Monoid
10import Data.Random.Internal.Source (Prim(..), MonadRandom(..), RandomSource(..))
11import Data.Random.Internal.Words
12import Language.Haskell.TH
13import Language.Haskell.TH.Extras
14import qualified Language.Haskell.TH.FlexibleDefaults as FD
15
16import Control.Monad.Reader
17
18data Method
19    = GetPrim
20    | GetWord8
21    | GetWord16
22    | GetWord32
23    | GetWord64
24    | GetDouble
25    | GetNByteInteger
26    deriving (Eq, Ord, Enum, Bounded, Read, Show)
27
28allMethods :: [Method]
29allMethods = [minBound .. maxBound]
30
31data Context
32    = Generic
33    | RandomSource
34    | MonadRandom
35    deriving (Eq, Ord, Enum, Bounded, Read, Show)
36
37methodNameBase :: Context -> Method -> String
38methodNameBase c n = nameBase (methodName c n)
39
40methodName :: Context -> Method -> Name
41methodName Generic      GetPrim          = mkName "getPrim"
42methodName Generic      GetWord8         = mkName "getWord8"
43methodName Generic      GetWord16        = mkName "getWord16"
44methodName Generic      GetWord32        = mkName "getWord32"
45methodName Generic      GetWord64        = mkName "getWord64"
46methodName Generic      GetDouble        = mkName "getDouble"
47methodName Generic      GetNByteInteger  = mkName "getNByteInteger"
48methodName RandomSource GetPrim          = 'getRandomPrimFrom
49methodName RandomSource GetWord8         = 'getRandomWord8From
50methodName RandomSource GetWord16        = 'getRandomWord16From
51methodName RandomSource GetWord32        = 'getRandomWord32From
52methodName RandomSource GetWord64        = 'getRandomWord64From
53methodName RandomSource GetDouble        = 'getRandomDoubleFrom
54methodName RandomSource GetNByteInteger  = 'getRandomNByteIntegerFrom
55methodName MonadRandom  GetPrim          = 'getRandomPrim
56methodName MonadRandom  GetWord8         = 'getRandomWord8
57methodName MonadRandom  GetWord16        = 'getRandomWord16
58methodName MonadRandom  GetWord32        = 'getRandomWord32
59methodName MonadRandom  GetWord64        = 'getRandomWord64
60methodName MonadRandom  GetDouble        = 'getRandomDouble
61methodName MonadRandom  GetNByteInteger  = 'getRandomNByteInteger
62
63isMethodName :: Context -> Name -> Bool
64isMethodName c n = isJust (nameToMethod c n)
65
66nameToMethod :: Context -> Name -> Maybe Method
67nameToMethod c name
68    = lookup name
69        [ (n, m)
70        | m <- allMethods
71        , let n = methodName c m
72        ]
73
74
75-- 'Context'-sensitive version of the FlexibleDefaults DSL
76scoreBy :: (a -> b) -> ReaderT Context (FD.Defaults a) t -> ReaderT Context (FD.Defaults b) t
77scoreBy f = mapReaderT (FD.scoreBy f)
78
79method :: Method -> ReaderT Context (FD.Function s) t -> ReaderT Context (FD.Defaults s) t
80method m f = do
81    c <- ask
82    mapReaderT (FD.function (methodNameBase c m)) f
83
84implementation :: ReaderT Context (FD.Implementation s) (Q [Dec]) -> ReaderT Context (FD.Function s) ()
85implementation = mapReaderT FD.implementation
86
87cost :: Num s => s -> ReaderT Context (FD.Implementation s) ()
88cost = lift . FD.cost
89
90dependsOn :: Method -> ReaderT Context (FD.Implementation s) ()
91dependsOn m = do
92    c <- ask
93    lift (FD.dependsOn (methodNameBase c m))
94
95changeContext :: Context -> Context -> Name -> Name
96changeContext c1 c2 = replace (fmap (methodName c2) . nameToMethod c1)
97
98-- map all occurrences of generic method names to the proper local ones
99-- and introduce a 'src' parameter where needed if the Context is RandomSource
100specialize :: Monad m => Q [Dec] -> ReaderT Context m (Q [Dec])
101specialize futzedDecsQ = do
102    let decQ = fmap genericalizeDecs futzedDecsQ
103    c <- ask
104    let specializeDec = everywhere (mkT (changeContext Generic c))
105    if c == RandomSource
106        then return $ do
107                src <- newName "_src"
108                decs <- decQ
109                return (map (addSrcParam src) . specializeDec $ decs)
110        else return (fmap specializeDec decQ)
111
112stripTypeSigs :: Q [Dec] -> Q [Dec]
113stripTypeSigs = fmap (filter (not . isSig))
114    where isSig SigD{} = True; isSig _ = False
115
116addSrcParam :: Name -> Dec -> Dec
117addSrcParam src
118    = everywhere (mkT expandDecs)
119    . everywhere (mkT expandExps)
120    where
121        srcP = VarP src
122        srcE = VarE src
123
124        expandDecs (ValD (VarP n) body decs)
125            | isMethodName RandomSource n
126            = FunD n [Clause [srcP] body decs]
127        expandDecs (FunD n clauses)
128            | isMethodName RandomSource n
129            = FunD n [Clause (srcP : ps) body decs | Clause ps body decs <- clauses]
130
131        expandDecs other = other
132
133        expandExps e@(VarE n)
134            | isMethodName RandomSource n   = AppE e srcE
135        expandExps other = other
136
137-- dummy expressions which will be remapped by 'specialize'
138dummy :: Method -> ExpQ
139dummy = return . VarE . methodName Generic
140
141getPrim, getWord8, getWord16,
142    getWord32, getWord64, getDouble,
143    getNByteInteger :: ExpQ
144getPrim             = dummy GetPrim
145getWord8            = dummy GetWord8
146getWord16           = dummy GetWord16
147getWord32           = dummy GetWord32
148getWord64           = dummy GetWord64
149getDouble           = dummy GetDouble
150getNByteInteger     = dummy GetNByteInteger
151
152-- The defaulting rules for RandomSource and MonadRandom.  Costs are rates of
153-- entropy waste (bits discarded per bit requested) plus the occasional ad-hoc
154-- penalty where it seems appropriate.
155
156-- TODO: figure out a clean way to break these up for individual testing.
157-- Also analyze to see which of these can never be selected (I suspect that set is non-empty)
158defaults :: Context -> FD.Defaults (Sum Double) ()
159defaults = runReaderT $
160    scoreBy Sum $ do
161        method GetPrim $ do
162            implementation $ do
163                mapM_ dependsOn (allMethods \\ [GetPrim])
164
165                -- GHC 6 requires type signatures for GADT matches, even
166                -- inside [d||].  This code is evaluated at more than one type, though,
167                -- and at its eventual splice site the signature actually isn't even allowed.
168                -- So, there's a dummy signature here which is immediately stripped out.
169                specialize . stripTypeSigs $
170                    [d| getPrim :: Prim a -> m a
171                        getPrim PrimWord8               = $getWord8
172                        getPrim PrimWord16              = $getWord16
173                        getPrim PrimWord32              = $getWord32
174                        getPrim PrimWord64              = $getWord64
175                        getPrim PrimDouble              = $getDouble
176                        getPrim (PrimNByteInteger n)    = $getNByteInteger n
177                     |]
178
179        scoreBy (/8) $
180            method GetWord8 $ do
181                implementation $ do
182                    dependsOn GetPrim
183                    specialize [d| getWord8 = $getPrim PrimWord8 |]
184
185                implementation $ do
186                    cost 1
187                    dependsOn GetNByteInteger
188                    specialize [d| getWord8 = liftM fromInteger ($getNByteInteger 1) |]
189
190                implementation $ do
191                    cost 8
192                    dependsOn GetWord16
193                    specialize [d| getWord8 = liftM fromIntegral $getWord16 |]
194
195                implementation $ do
196                    cost 24
197                    dependsOn GetWord32
198                    specialize [d| getWord8 = liftM fromIntegral $getWord32 |]
199
200                implementation $ do
201                    cost 56
202                    dependsOn GetWord64
203                    specialize [d| getWord8 = liftM fromIntegral $getWord64 |]
204
205                implementation $ do
206                    cost 64
207                    dependsOn GetDouble
208                    specialize [d| getWord8 = liftM (truncate . (256*)) $getDouble |]
209
210        scoreBy (/16) $
211            method GetWord16 $ do
212                implementation $ do
213                    dependsOn GetPrim
214                    specialize [d| getWord16 = $getPrim PrimWord16 |]
215
216                implementation $ do
217                    cost 1
218                    dependsOn GetNByteInteger
219                    specialize [d| getWord16 = liftM fromInteger ($getNByteInteger 2) |]
220
221                implementation $ do
222                    dependsOn GetWord8
223                    specialize
224                        [d|
225                            getWord16 = do
226                                a <- $getWord8
227                                b <- $getWord8
228                                return (buildWord16 a b)
229                         |]
230
231                implementation $ do
232                    cost 16
233                    dependsOn GetWord32
234                    specialize [d| getWord16 = liftM fromIntegral $getWord32 |]
235
236                implementation $ do
237                    cost 48
238                    dependsOn GetWord64
239                    specialize [d| getWord16 = liftM fromIntegral $getWord64 |]
240
241                implementation $ do
242                    cost 64
243                    dependsOn GetDouble
244                    specialize [d| getWord16 = liftM (truncate . (65536*)) $getDouble |]
245
246        scoreBy (/32) $
247            method GetWord32 $ do
248                implementation $ do
249                    dependsOn GetPrim
250                    specialize [d| getWord32 = $getPrim PrimWord32 |]
251
252                implementation $ do
253                    cost 1
254                    dependsOn GetNByteInteger
255                    specialize [d| getWord32 = liftM fromInteger ($getNByteInteger 4) |]
256
257                implementation $ do
258                    cost 0.1
259                    dependsOn GetWord8
260                    specialize
261                        [d|
262                            getWord32 = do
263                                a <- $getWord8
264                                b <- $getWord8
265                                c <- $getWord8
266                                d <- $getWord8
267                                return (buildWord32 a b c d)
268                         |]
269
270                implementation $ do
271                    dependsOn GetWord16
272                    specialize
273                        [d|
274                            getWord32 = do
275                                a <- $getWord16
276                                b <- $getWord16
277                                return (buildWord32' a b)
278                         |]
279
280                implementation $ do
281                    cost 32
282                    dependsOn GetWord64
283                    specialize [d| getWord32 = liftM fromIntegral $getWord64 |]
284
285                implementation $ do
286                    cost 64
287                    dependsOn GetDouble
288                    specialize [d| getWord32 = liftM (truncate . (4294967296*)) $getDouble |]
289
290        scoreBy (/64) $
291            method GetWord64 $ do
292                implementation $ do
293                    dependsOn GetPrim
294                    specialize [d| getWord64 = $getPrim PrimWord64 |]
295
296                implementation $ do
297                    cost 1
298                    dependsOn GetNByteInteger
299                    specialize [d| getWord64 = liftM fromInteger ($getNByteInteger 8) |]
300
301                implementation $ do
302                    cost 0.2
303                    dependsOn GetWord8
304                    specialize
305                        [d|
306                            getWord64 = do
307                                a <- $getWord8
308                                b <- $getWord8
309                                c <- $getWord8
310                                d <- $getWord8
311                                e <- $getWord8
312                                f <- $getWord8
313                                g <- $getWord8
314                                h <- $getWord8
315                                return (buildWord64 a b c d e f g h)
316                         |]
317
318                implementation $ do
319                    cost 0.1
320                    dependsOn GetWord16
321                    specialize
322                        [d|
323                            getWord64 = do
324                                a <- $getWord16
325                                b <- $getWord16
326                                c <- $getWord16
327                                d <- $getWord16
328                                return (buildWord64' a b c d)
329                         |]
330
331                implementation $ do
332                    dependsOn GetWord32
333                    specialize
334                        [d|
335                            getWord64 = do
336                                a <- $getWord32
337                                b <- $getWord32
338                                return (buildWord64'' a b)
339                         |]
340
341        scoreBy (/52) $
342            method GetDouble $ do
343                implementation $ do
344                    dependsOn GetPrim
345                    specialize [d| getDouble = $getPrim PrimDouble |]
346
347                implementation $ do
348                    cost 12
349                    dependsOn GetWord64
350                    specialize
351                        [d|
352                            getDouble = do
353                                w <- $getWord64
354                                return (wordToDouble w)
355                         |]
356
357        method GetNByteInteger $ do
358            implementation $ do
359                dependsOn GetPrim
360                specialize [d| getNByteInteger n = $getPrim (PrimNByteInteger n) |]
361
362            implementation $ do
363                when intIs64 (cost 1e-2)
364                dependsOn GetWord8
365                dependsOn GetWord16
366                dependsOn GetWord32
367                specialize
368                    [d|
369                        getNByteInteger 1 = do
370                            x <- $getWord8
371                            return $! toInteger x
372                        getNByteInteger 2 = do
373                            x <- $getWord16
374                            return $! toInteger x
375                        getNByteInteger 4 = do
376                            x <- $getWord32
377                            return $! toInteger x
378                        getNByteInteger np4
379                            | np4 > 4 = do
380                                let n = np4 - 4
381                                x <- $getWord32
382                                y <- $(dummy GetNByteInteger) n
383                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
384                        getNByteInteger np2
385                            | np2 > 2 = do
386                                let n = np2 - 2
387                                x <- $getWord16
388                                y <- $(dummy GetNByteInteger) n
389                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
390                        getNByteInteger _ = return 0
391                      |]
392
393            implementation $ do
394                when (not intIs64) (cost 1e-2)
395                dependsOn GetWord8
396                dependsOn GetWord16
397                dependsOn GetWord32
398                dependsOn GetWord64
399                specialize
400                    [d|
401                        getNByteInteger 1 = do
402                            x <- $getWord8
403                            return $! toInteger x
404                        getNByteInteger 2 = do
405                            x <- $getWord16
406                            return $! toInteger x
407                        getNByteInteger 4 = do
408                            x <- $getWord32
409                            return $! toInteger x
410                        getNByteInteger 8 = do
411                            x <- $getWord64
412                            return $! toInteger x
413                        getNByteInteger np8
414                            | np8 > 8 = do
415                                let n = np8 - 8
416                                x <- $getWord64
417                                y <- $(dummy GetNByteInteger) n
418                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
419                        getNByteInteger np4
420                            | np4 > 4 = do
421                                let n = np4 - 4
422                                x <- $getWord32
423                                y <- $(dummy GetNByteInteger) n
424                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
425                        getNByteInteger np2
426                            | np2 > 2 = do
427                                let n = np2 - 2
428                                x <- $getWord16
429                                y <- $(dummy GetNByteInteger) n
430                                return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y
431                        getNByteInteger _ = return 0
432                      |]
433
434
435-- |Complete a possibly-incomplete 'RandomSource' implementation.  It is
436-- recommended that this macro be used even if the implementation is currently
437-- complete, as the 'RandomSource' class may be extended at any time.
438--
439-- To use 'randomSource', just wrap your instance declaration as follows (and
440-- enable the TemplateHaskell, MultiParamTypeClasses and GADTs language
441-- extensions, as well as any others required by your instances, such as
442-- FlexibleInstances):
443--
444-- > $(randomSource [d|
445-- >         instance RandomSource FooM Bar where
446-- >             {- at least one RandomSource function... -}
447-- >     |])
448randomSource :: Q [Dec] -> Q [Dec]
449randomSource = FD.withDefaults (defaults RandomSource)
450
451-- |Complete a possibly-incomplete 'MonadRandom' implementation.  It is
452-- recommended that this macro be used even if the implementation is currently
453-- complete, as the 'MonadRandom' class may be extended at any time.
454--
455-- To use 'monadRandom', just wrap your instance declaration as follows (and
456-- enable the TemplateHaskell and GADTs language extensions):
457--
458-- > $(monadRandom [d|
459-- >         instance MonadRandom FooM where
460-- >             getRandomDouble = return pi
461-- >             getRandomWord16 = return 4
462-- >             {- etc... -}
463-- >     |])
464monadRandom :: Q [Dec] -> Q [Dec]
465monadRandom = FD.withDefaults (defaults MonadRandom)
466
467-- -- This is nice in theory, but under GHC 7 it never typechecks; without generalizing the let-bound
468-- -- functions, it gets absurd errors like "cannot match 'm Int' with 'IO t'".  Probably need
469-- -- to mechanically specialize the supplied signature to create a signature for every other
470-- -- let-bound function.
471-- primFunction :: Q Type -> Q [Dec] -> ExpQ
472-- primFunction getPrimType decsQ = do
473--     getPrimSig <- sigD (mkName (methodName Generic GetPrim)) getPrimType
474--     decs <- decsQ >>= FD.implementDefaults (defaults Generic)
475--     f <- getPrim
476--     return (LetE (getPrimSig : decs) f)
477