1-- | 2-- Module : Crypto.KDF.BCryptPBKDF 3-- License : BSD-style 4-- Stability : experimental 5-- Portability : Good 6-- 7-- Port of the bcrypt_pbkdf key derivation function from OpenBSD 8-- as described at <http://man.openbsd.org/bcrypt_pbkdf.3>. 9module Crypto.KDF.BCryptPBKDF 10 ( Parameters (..) 11 , generate 12 , hashInternal 13 ) 14where 15 16import Basement.Block (MutableBlock) 17import qualified Basement.Block as Block 18import qualified Basement.Block.Mutable as Block 19import Basement.Monad (PrimState) 20import Basement.Types.OffsetSize (CountOf (..), Offset (..)) 21import Control.Exception (finally) 22import Control.Monad (when) 23import qualified Crypto.Cipher.Blowfish.Box as Blowfish 24import qualified Crypto.Cipher.Blowfish.Primitive as Blowfish 25import Crypto.Hash.Algorithms (SHA512 (..)) 26import Crypto.Hash.Types (Context, 27 hashDigestSize, 28 hashInternalContextSize, 29 hashInternalFinalize, 30 hashInternalInit, 31 hashInternalUpdate) 32import Crypto.Internal.Compat (unsafeDoIO) 33import Data.Bits 34import qualified Data.ByteArray as B 35import Data.Foldable (forM_) 36import Data.Memory.PtrMethods (memCopy, memSet, memXor) 37import Data.Word 38import Foreign.Ptr (Ptr, castPtr) 39import Foreign.Storable (peekByteOff, pokeByteOff) 40 41data Parameters = Parameters 42 { iterCounts :: Int -- ^ The number of user-defined iterations for the algorithm 43 -- (must be > 0) 44 , outputLength :: Int -- ^ The number of bytes to generate out of BCryptPBKDF 45 -- (must be in 1..1024) 46 } deriving (Eq, Ord, Show) 47 48-- | Derive a key of specified length using the bcrypt_pbkdf algorithm. 49generate :: (B.ByteArray pass, B.ByteArray salt, B.ByteArray output) 50 => Parameters 51 -> pass 52 -> salt 53 -> output 54generate params pass salt 55 | iterCounts params < 1 = error "BCryptPBKDF: iterCounts must be > 0" 56 | keyLen < 1 || keyLen > 1024 = error "BCryptPBKDF: outputLength must be in 1..1024" 57 | otherwise = B.unsafeCreate keyLen deriveKey 58 where 59 outLen, tmpLen, blkLen, keyLen, passLen, saltLen, ctxLen, hashLen, blocks :: Int 60 outLen = 32 61 tmpLen = 32 62 blkLen = 4 63 passLen = B.length pass 64 saltLen = B.length salt 65 keyLen = outputLength params 66 ctxLen = hashInternalContextSize SHA512 67 hashLen = hashDigestSize SHA512 -- 64 68 blocks = (keyLen + outLen - 1) `div` outLen 69 70 deriveKey :: Ptr Word8 -> IO () 71 deriveKey keyPtr = do 72 -- Allocate all necessary memory. The algorihm shall not allocate 73 -- any more dynamic memory after this point. Blocks need to be pinned 74 -- as pointers to them are passed to the SHA512 implementation. 75 ksClean <- Blowfish.createKeySchedule 76 ksDirty <- Blowfish.createKeySchedule 77 ctxMBlock <- Block.newPinned (CountOf ctxLen :: CountOf Word8) 78 outMBlock <- Block.newPinned (CountOf outLen :: CountOf Word8) 79 tmpMBlock <- Block.newPinned (CountOf tmpLen :: CountOf Word8) 80 blkMBlock <- Block.newPinned (CountOf blkLen :: CountOf Word8) 81 passHashMBlock <- Block.newPinned (CountOf hashLen :: CountOf Word8) 82 saltHashMBlock <- Block.newPinned (CountOf hashLen :: CountOf Word8) 83 -- Finally erase all memory areas that contain information from 84 -- which the derived key could be reconstructed. 85 -- As all MutableBlocks are pinned it shall be guaranteed that 86 -- no temporary trampoline buffers are allocated. 87 finallyErase outMBlock $ finallyErase passHashMBlock $ 88 B.withByteArray pass $ \passPtr-> 89 B.withByteArray salt $ \saltPtr-> 90 Block.withMutablePtr ctxMBlock $ \ctxPtr-> 91 Block.withMutablePtr outMBlock $ \outPtr-> 92 Block.withMutablePtr tmpMBlock $ \tmpPtr-> 93 Block.withMutablePtr blkMBlock $ \blkPtr-> 94 Block.withMutablePtr passHashMBlock $ \passHashPtr-> 95 Block.withMutablePtr saltHashMBlock $ \saltHashPtr-> do 96 -- Hash the password. 97 let shaPtr = castPtr ctxPtr :: Ptr (Context SHA512) 98 hashInternalInit shaPtr 99 hashInternalUpdate shaPtr passPtr (fromIntegral passLen) 100 hashInternalFinalize shaPtr (castPtr passHashPtr) 101 passHashBlock <- Block.unsafeFreeze passHashMBlock 102 forM_ [1..blocks] $ \block-> do 103 -- Poke the increased block counter. 104 Block.unsafeWrite blkMBlock 0 (fromIntegral $ block `shiftR` 24) 105 Block.unsafeWrite blkMBlock 1 (fromIntegral $ block `shiftR` 16) 106 Block.unsafeWrite blkMBlock 2 (fromIntegral $ block `shiftR` 8) 107 Block.unsafeWrite blkMBlock 3 (fromIntegral $ block `shiftR` 0) 108 -- First round (slightly different). 109 hashInternalInit shaPtr 110 hashInternalUpdate shaPtr saltPtr (fromIntegral saltLen) 111 hashInternalUpdate shaPtr blkPtr (fromIntegral blkLen) 112 hashInternalFinalize shaPtr (castPtr saltHashPtr) 113 Block.unsafeFreeze saltHashMBlock >>= \x-> do 114 Blowfish.copyKeySchedule ksDirty ksClean 115 hashInternalMutable ksDirty passHashBlock x tmpMBlock 116 memCopy outPtr tmpPtr outLen 117 -- Remaining rounds. 118 forM_ [2..iterCounts params] $ const $ do 119 hashInternalInit shaPtr 120 hashInternalUpdate shaPtr tmpPtr (fromIntegral tmpLen) 121 hashInternalFinalize shaPtr (castPtr saltHashPtr) 122 Block.unsafeFreeze saltHashMBlock >>= \x-> do 123 Blowfish.copyKeySchedule ksDirty ksClean 124 hashInternalMutable ksDirty passHashBlock x tmpMBlock 125 memXor outPtr outPtr tmpPtr outLen 126 -- Spread the current out buffer evenly over the key buffer. 127 -- After both loops have run every byte of the key buffer 128 -- will have been written to exactly once and every byte 129 -- of the output will have been used. 130 forM_ [0..outLen - 1] $ \outIdx-> do 131 let keyIdx = outIdx * blocks + block - 1 132 when (keyIdx < keyLen) $ do 133 w8 <- peekByteOff outPtr outIdx :: IO Word8 134 pokeByteOff keyPtr keyIdx w8 135 136-- | Internal hash function used by `generate`. 137-- 138-- Normal users should not need this. 139hashInternal :: (B.ByteArrayAccess pass, B.ByteArrayAccess salt, B.ByteArray output) 140 => pass 141 -> salt 142 -> output 143hashInternal passHash saltHash 144 | B.length passHash /= 64 = error "passHash must be 512 bits" 145 | B.length saltHash /= 64 = error "saltHash must be 512 bits" 146 | otherwise = unsafeDoIO $ do 147 ks0 <- Blowfish.createKeySchedule 148 outMBlock <- Block.newPinned 32 149 hashInternalMutable ks0 passHash saltHash outMBlock 150 B.convert `fmap` Block.freeze outMBlock 151 152hashInternalMutable :: (B.ByteArrayAccess pass, B.ByteArrayAccess salt) 153 => Blowfish.KeySchedule 154 -> pass 155 -> salt 156 -> MutableBlock Word8 (PrimState IO) 157 -> IO () 158hashInternalMutable bfks passHash saltHash outMBlock = do 159 Blowfish.expandKeyWithSalt bfks passHash saltHash 160 forM_ [0..63 :: Int] $ const $ do 161 Blowfish.expandKey bfks saltHash 162 Blowfish.expandKey bfks passHash 163 -- "OxychromaticBlowfishSwatDynamite" represented as 4 Word64 in big-endian. 164 store 0 =<< cipher 64 0x4f78796368726f6d 165 store 8 =<< cipher 64 0x61746963426c6f77 166 store 16 =<< cipher 64 0x6669736853776174 167 store 24 =<< cipher 64 0x44796e616d697465 168 where 169 store :: Offset Word8 -> Word64 -> IO () 170 store o w64 = do 171 Block.unsafeWrite outMBlock (o + 0) (fromIntegral $ w64 `shiftR` 32) 172 Block.unsafeWrite outMBlock (o + 1) (fromIntegral $ w64 `shiftR` 40) 173 Block.unsafeWrite outMBlock (o + 2) (fromIntegral $ w64 `shiftR` 48) 174 Block.unsafeWrite outMBlock (o + 3) (fromIntegral $ w64 `shiftR` 56) 175 Block.unsafeWrite outMBlock (o + 4) (fromIntegral $ w64 `shiftR` 0) 176 Block.unsafeWrite outMBlock (o + 5) (fromIntegral $ w64 `shiftR` 8) 177 Block.unsafeWrite outMBlock (o + 6) (fromIntegral $ w64 `shiftR` 16) 178 Block.unsafeWrite outMBlock (o + 7) (fromIntegral $ w64 `shiftR` 24) 179 cipher :: Int -> Word64 -> IO Word64 180 cipher 0 block = return block 181 cipher i block = Blowfish.cipherBlockMutable bfks block >>= cipher (i - 1) 182 183finallyErase :: MutableBlock Word8 (PrimState IO) -> IO () -> IO () 184finallyErase mblock action = 185 action `finally` Block.withMutablePtr mblock (\ptr-> memSet ptr 0 len) 186 where 187 CountOf len = Block.mutableLengthBytes mblock 188