1{-# LANGUAGE TemplateHaskell, GADTs #-} 2{-# OPTIONS_GHC -fno-warn-type-defaults -fno-warn-missing-signatures #-} 3module Data.Random.Source.Internal.TH (monadRandom, randomSource) where 4 5import Data.Bits 6import Data.Generics 7import Data.List 8import Data.Maybe 9import Data.Monoid 10import Data.Random.Internal.Source (Prim(..), MonadRandom(..), RandomSource(..)) 11import Data.Random.Internal.Words 12import Language.Haskell.TH 13import Language.Haskell.TH.Extras 14import qualified Language.Haskell.TH.FlexibleDefaults as FD 15 16import Control.Monad.Reader 17 18data Method 19 = GetPrim 20 | GetWord8 21 | GetWord16 22 | GetWord32 23 | GetWord64 24 | GetDouble 25 | GetNByteInteger 26 deriving (Eq, Ord, Enum, Bounded, Read, Show) 27 28allMethods :: [Method] 29allMethods = [minBound .. maxBound] 30 31data Context 32 = Generic 33 | RandomSource 34 | MonadRandom 35 deriving (Eq, Ord, Enum, Bounded, Read, Show) 36 37methodNameBase :: Context -> Method -> String 38methodNameBase c n = nameBase (methodName c n) 39 40methodName :: Context -> Method -> Name 41methodName Generic GetPrim = mkName "getPrim" 42methodName Generic GetWord8 = mkName "getWord8" 43methodName Generic GetWord16 = mkName "getWord16" 44methodName Generic GetWord32 = mkName "getWord32" 45methodName Generic GetWord64 = mkName "getWord64" 46methodName Generic GetDouble = mkName "getDouble" 47methodName Generic GetNByteInteger = mkName "getNByteInteger" 48methodName RandomSource GetPrim = 'getRandomPrimFrom 49methodName RandomSource GetWord8 = 'getRandomWord8From 50methodName RandomSource GetWord16 = 'getRandomWord16From 51methodName RandomSource GetWord32 = 'getRandomWord32From 52methodName RandomSource GetWord64 = 'getRandomWord64From 53methodName RandomSource GetDouble = 'getRandomDoubleFrom 54methodName RandomSource GetNByteInteger = 'getRandomNByteIntegerFrom 55methodName MonadRandom GetPrim = 'getRandomPrim 56methodName MonadRandom GetWord8 = 'getRandomWord8 57methodName MonadRandom GetWord16 = 'getRandomWord16 58methodName MonadRandom GetWord32 = 'getRandomWord32 59methodName MonadRandom GetWord64 = 'getRandomWord64 60methodName MonadRandom GetDouble = 'getRandomDouble 61methodName MonadRandom GetNByteInteger = 'getRandomNByteInteger 62 63isMethodName :: Context -> Name -> Bool 64isMethodName c n = isJust (nameToMethod c n) 65 66nameToMethod :: Context -> Name -> Maybe Method 67nameToMethod c name 68 = lookup name 69 [ (n, m) 70 | m <- allMethods 71 , let n = methodName c m 72 ] 73 74 75-- 'Context'-sensitive version of the FlexibleDefaults DSL 76scoreBy :: (a -> b) -> ReaderT Context (FD.Defaults a) t -> ReaderT Context (FD.Defaults b) t 77scoreBy f = mapReaderT (FD.scoreBy f) 78 79method :: Method -> ReaderT Context (FD.Function s) t -> ReaderT Context (FD.Defaults s) t 80method m f = do 81 c <- ask 82 mapReaderT (FD.function (methodNameBase c m)) f 83 84implementation :: ReaderT Context (FD.Implementation s) (Q [Dec]) -> ReaderT Context (FD.Function s) () 85implementation = mapReaderT FD.implementation 86 87cost :: Num s => s -> ReaderT Context (FD.Implementation s) () 88cost = lift . FD.cost 89 90dependsOn :: Method -> ReaderT Context (FD.Implementation s) () 91dependsOn m = do 92 c <- ask 93 lift (FD.dependsOn (methodNameBase c m)) 94 95changeContext :: Context -> Context -> Name -> Name 96changeContext c1 c2 = replace (fmap (methodName c2) . nameToMethod c1) 97 98-- map all occurrences of generic method names to the proper local ones 99-- and introduce a 'src' parameter where needed if the Context is RandomSource 100specialize :: Monad m => Q [Dec] -> ReaderT Context m (Q [Dec]) 101specialize futzedDecsQ = do 102 let decQ = fmap genericalizeDecs futzedDecsQ 103 c <- ask 104 let specializeDec = everywhere (mkT (changeContext Generic c)) 105 if c == RandomSource 106 then return $ do 107 src <- newName "_src" 108 decs <- decQ 109 return (map (addSrcParam src) . specializeDec $ decs) 110 else return (fmap specializeDec decQ) 111 112stripTypeSigs :: Q [Dec] -> Q [Dec] 113stripTypeSigs = fmap (filter (not . isSig)) 114 where isSig SigD{} = True; isSig _ = False 115 116addSrcParam :: Name -> Dec -> Dec 117addSrcParam src 118 = everywhere (mkT expandDecs) 119 . everywhere (mkT expandExps) 120 where 121 srcP = VarP src 122 srcE = VarE src 123 124 expandDecs (ValD (VarP n) body decs) 125 | isMethodName RandomSource n 126 = FunD n [Clause [srcP] body decs] 127 expandDecs (FunD n clauses) 128 | isMethodName RandomSource n 129 = FunD n [Clause (srcP : ps) body decs | Clause ps body decs <- clauses] 130 131 expandDecs other = other 132 133 expandExps e@(VarE n) 134 | isMethodName RandomSource n = AppE e srcE 135 expandExps other = other 136 137-- dummy expressions which will be remapped by 'specialize' 138dummy :: Method -> ExpQ 139dummy = return . VarE . methodName Generic 140 141getPrim, getWord8, getWord16, 142 getWord32, getWord64, getDouble, 143 getNByteInteger :: ExpQ 144getPrim = dummy GetPrim 145getWord8 = dummy GetWord8 146getWord16 = dummy GetWord16 147getWord32 = dummy GetWord32 148getWord64 = dummy GetWord64 149getDouble = dummy GetDouble 150getNByteInteger = dummy GetNByteInteger 151 152-- The defaulting rules for RandomSource and MonadRandom. Costs are rates of 153-- entropy waste (bits discarded per bit requested) plus the occasional ad-hoc 154-- penalty where it seems appropriate. 155 156-- TODO: figure out a clean way to break these up for individual testing. 157-- Also analyze to see which of these can never be selected (I suspect that set is non-empty) 158defaults :: Context -> FD.Defaults (Sum Double) () 159defaults = runReaderT $ 160 scoreBy Sum $ do 161 method GetPrim $ do 162 implementation $ do 163 mapM_ dependsOn (allMethods \\ [GetPrim]) 164 165 -- GHC 6 requires type signatures for GADT matches, even 166 -- inside [d||]. This code is evaluated at more than one type, though, 167 -- and at its eventual splice site the signature actually isn't even allowed. 168 -- So, there's a dummy signature here which is immediately stripped out. 169 specialize . stripTypeSigs $ 170 [d| getPrim :: Prim a -> m a 171 getPrim PrimWord8 = $getWord8 172 getPrim PrimWord16 = $getWord16 173 getPrim PrimWord32 = $getWord32 174 getPrim PrimWord64 = $getWord64 175 getPrim PrimDouble = $getDouble 176 getPrim (PrimNByteInteger n) = $getNByteInteger n 177 |] 178 179 scoreBy (/8) $ 180 method GetWord8 $ do 181 implementation $ do 182 dependsOn GetPrim 183 specialize [d| getWord8 = $getPrim PrimWord8 |] 184 185 implementation $ do 186 cost 1 187 dependsOn GetNByteInteger 188 specialize [d| getWord8 = liftM fromInteger ($getNByteInteger 1) |] 189 190 implementation $ do 191 cost 8 192 dependsOn GetWord16 193 specialize [d| getWord8 = liftM fromIntegral $getWord16 |] 194 195 implementation $ do 196 cost 24 197 dependsOn GetWord32 198 specialize [d| getWord8 = liftM fromIntegral $getWord32 |] 199 200 implementation $ do 201 cost 56 202 dependsOn GetWord64 203 specialize [d| getWord8 = liftM fromIntegral $getWord64 |] 204 205 implementation $ do 206 cost 64 207 dependsOn GetDouble 208 specialize [d| getWord8 = liftM (truncate . (256*)) $getDouble |] 209 210 scoreBy (/16) $ 211 method GetWord16 $ do 212 implementation $ do 213 dependsOn GetPrim 214 specialize [d| getWord16 = $getPrim PrimWord16 |] 215 216 implementation $ do 217 cost 1 218 dependsOn GetNByteInteger 219 specialize [d| getWord16 = liftM fromInteger ($getNByteInteger 2) |] 220 221 implementation $ do 222 dependsOn GetWord8 223 specialize 224 [d| 225 getWord16 = do 226 a <- $getWord8 227 b <- $getWord8 228 return (buildWord16 a b) 229 |] 230 231 implementation $ do 232 cost 16 233 dependsOn GetWord32 234 specialize [d| getWord16 = liftM fromIntegral $getWord32 |] 235 236 implementation $ do 237 cost 48 238 dependsOn GetWord64 239 specialize [d| getWord16 = liftM fromIntegral $getWord64 |] 240 241 implementation $ do 242 cost 64 243 dependsOn GetDouble 244 specialize [d| getWord16 = liftM (truncate . (65536*)) $getDouble |] 245 246 scoreBy (/32) $ 247 method GetWord32 $ do 248 implementation $ do 249 dependsOn GetPrim 250 specialize [d| getWord32 = $getPrim PrimWord32 |] 251 252 implementation $ do 253 cost 1 254 dependsOn GetNByteInteger 255 specialize [d| getWord32 = liftM fromInteger ($getNByteInteger 4) |] 256 257 implementation $ do 258 cost 0.1 259 dependsOn GetWord8 260 specialize 261 [d| 262 getWord32 = do 263 a <- $getWord8 264 b <- $getWord8 265 c <- $getWord8 266 d <- $getWord8 267 return (buildWord32 a b c d) 268 |] 269 270 implementation $ do 271 dependsOn GetWord16 272 specialize 273 [d| 274 getWord32 = do 275 a <- $getWord16 276 b <- $getWord16 277 return (buildWord32' a b) 278 |] 279 280 implementation $ do 281 cost 32 282 dependsOn GetWord64 283 specialize [d| getWord32 = liftM fromIntegral $getWord64 |] 284 285 implementation $ do 286 cost 64 287 dependsOn GetDouble 288 specialize [d| getWord32 = liftM (truncate . (4294967296*)) $getDouble |] 289 290 scoreBy (/64) $ 291 method GetWord64 $ do 292 implementation $ do 293 dependsOn GetPrim 294 specialize [d| getWord64 = $getPrim PrimWord64 |] 295 296 implementation $ do 297 cost 1 298 dependsOn GetNByteInteger 299 specialize [d| getWord64 = liftM fromInteger ($getNByteInteger 8) |] 300 301 implementation $ do 302 cost 0.2 303 dependsOn GetWord8 304 specialize 305 [d| 306 getWord64 = do 307 a <- $getWord8 308 b <- $getWord8 309 c <- $getWord8 310 d <- $getWord8 311 e <- $getWord8 312 f <- $getWord8 313 g <- $getWord8 314 h <- $getWord8 315 return (buildWord64 a b c d e f g h) 316 |] 317 318 implementation $ do 319 cost 0.1 320 dependsOn GetWord16 321 specialize 322 [d| 323 getWord64 = do 324 a <- $getWord16 325 b <- $getWord16 326 c <- $getWord16 327 d <- $getWord16 328 return (buildWord64' a b c d) 329 |] 330 331 implementation $ do 332 dependsOn GetWord32 333 specialize 334 [d| 335 getWord64 = do 336 a <- $getWord32 337 b <- $getWord32 338 return (buildWord64'' a b) 339 |] 340 341 scoreBy (/52) $ 342 method GetDouble $ do 343 implementation $ do 344 dependsOn GetPrim 345 specialize [d| getDouble = $getPrim PrimDouble |] 346 347 implementation $ do 348 cost 12 349 dependsOn GetWord64 350 specialize 351 [d| 352 getDouble = do 353 w <- $getWord64 354 return (wordToDouble w) 355 |] 356 357 method GetNByteInteger $ do 358 implementation $ do 359 dependsOn GetPrim 360 specialize [d| getNByteInteger n = $getPrim (PrimNByteInteger n) |] 361 362 implementation $ do 363 when intIs64 (cost 1e-2) 364 dependsOn GetWord8 365 dependsOn GetWord16 366 dependsOn GetWord32 367 specialize 368 [d| 369 getNByteInteger 1 = do 370 x <- $getWord8 371 return $! toInteger x 372 getNByteInteger 2 = do 373 x <- $getWord16 374 return $! toInteger x 375 getNByteInteger 4 = do 376 x <- $getWord32 377 return $! toInteger x 378 getNByteInteger np4 379 | np4 > 4 = do 380 let n = np4 - 4 381 x <- $getWord32 382 y <- $(dummy GetNByteInteger) n 383 return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y 384 getNByteInteger np2 385 | np2 > 2 = do 386 let n = np2 - 2 387 x <- $getWord16 388 y <- $(dummy GetNByteInteger) n 389 return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y 390 getNByteInteger _ = return 0 391 |] 392 393 implementation $ do 394 when (not intIs64) (cost 1e-2) 395 dependsOn GetWord8 396 dependsOn GetWord16 397 dependsOn GetWord32 398 dependsOn GetWord64 399 specialize 400 [d| 401 getNByteInteger 1 = do 402 x <- $getWord8 403 return $! toInteger x 404 getNByteInteger 2 = do 405 x <- $getWord16 406 return $! toInteger x 407 getNByteInteger 4 = do 408 x <- $getWord32 409 return $! toInteger x 410 getNByteInteger 8 = do 411 x <- $getWord64 412 return $! toInteger x 413 getNByteInteger np8 414 | np8 > 8 = do 415 let n = np8 - 8 416 x <- $getWord64 417 y <- $(dummy GetNByteInteger) n 418 return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y 419 getNByteInteger np4 420 | np4 > 4 = do 421 let n = np4 - 4 422 x <- $getWord32 423 y <- $(dummy GetNByteInteger) n 424 return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y 425 getNByteInteger np2 426 | np2 > 2 = do 427 let n = np2 - 2 428 x <- $getWord16 429 y <- $(dummy GetNByteInteger) n 430 return $! (toInteger x `shiftL` (n `shiftL` 3)) .|. y 431 getNByteInteger _ = return 0 432 |] 433 434 435-- |Complete a possibly-incomplete 'RandomSource' implementation. It is 436-- recommended that this macro be used even if the implementation is currently 437-- complete, as the 'RandomSource' class may be extended at any time. 438-- 439-- To use 'randomSource', just wrap your instance declaration as follows (and 440-- enable the TemplateHaskell, MultiParamTypeClasses and GADTs language 441-- extensions, as well as any others required by your instances, such as 442-- FlexibleInstances): 443-- 444-- > $(randomSource [d| 445-- > instance RandomSource FooM Bar where 446-- > {- at least one RandomSource function... -} 447-- > |]) 448randomSource :: Q [Dec] -> Q [Dec] 449randomSource = FD.withDefaults (defaults RandomSource) 450 451-- |Complete a possibly-incomplete 'MonadRandom' implementation. It is 452-- recommended that this macro be used even if the implementation is currently 453-- complete, as the 'MonadRandom' class may be extended at any time. 454-- 455-- To use 'monadRandom', just wrap your instance declaration as follows (and 456-- enable the TemplateHaskell and GADTs language extensions): 457-- 458-- > $(monadRandom [d| 459-- > instance MonadRandom FooM where 460-- > getRandomDouble = return pi 461-- > getRandomWord16 = return 4 462-- > {- etc... -} 463-- > |]) 464monadRandom :: Q [Dec] -> Q [Dec] 465monadRandom = FD.withDefaults (defaults MonadRandom) 466 467-- -- This is nice in theory, but under GHC 7 it never typechecks; without generalizing the let-bound 468-- -- functions, it gets absurd errors like "cannot match 'm Int' with 'IO t'". Probably need 469-- -- to mechanically specialize the supplied signature to create a signature for every other 470-- -- let-bound function. 471-- primFunction :: Q Type -> Q [Dec] -> ExpQ 472-- primFunction getPrimType decsQ = do 473-- getPrimSig <- sigD (mkName (methodName Generic GetPrim)) getPrimType 474-- decs <- decsQ >>= FD.implementDefaults (defaults Generic) 475-- f <- getPrim 476-- return (LetE (getPrimSig : decs) f) 477