1{-# LANGUAGE MagicHash #-} 2{-# LANGUAGE BangPatterns #-} 3module Crypto.Cipher.Twofish.Primitive 4 ( Twofish 5 , initTwofish 6 , encrypt 7 , decrypt 8 ) where 9 10import Crypto.Error 11import Crypto.Internal.ByteArray (ByteArray) 12import qualified Crypto.Internal.ByteArray as B 13import Crypto.Internal.WordArray 14import Data.Word 15import Data.Bits 16import Data.List 17 18-- Based on the Golang referance implementation 19-- https://github.com/golang/crypto/blob/master/twofish/twofish.go 20 21 22-- BlockSize is the constant block size of Twofish. 23blockSize :: Int 24blockSize = 16 25 26mdsPolynomial, rsPolynomial :: Word32 27mdsPolynomial = 0x169 -- x^8 + x^6 + x^5 + x^3 + 1, see [TWOFISH] 4.2 28rsPolynomial = 0x14d -- x^8 + x^6 + x^3 + x^2 + 1, see [TWOFISH] 4.3 29 30data Twofish = Twofish { s :: (Array32, Array32, Array32, Array32) 31 , k :: Array32 } 32 33data ByteSize = Bytes16 | Bytes24 | Bytes32 deriving (Eq) 34 35data KeyPackage ba = KeyPackage { rawKeyBytes :: ba 36 , byteSize :: ByteSize } 37 38buildPackage :: ByteArray ba => ba -> Maybe (KeyPackage ba) 39buildPackage key 40 | B.length key == 16 = return $ KeyPackage key Bytes16 41 | B.length key == 24 = return $ KeyPackage key Bytes24 42 | B.length key == 32 = return $ KeyPackage key Bytes32 43 | otherwise = Nothing 44 45-- | Initialize a 128-bit, 192-bit, or 256-bit key 46-- 47-- Return the initialized key or a error message if the given 48-- keyseed was not 16-bytes in length. 49initTwofish :: ByteArray key 50 => key -- ^ The key to create the twofish context 51 -> CryptoFailable Twofish 52initTwofish key = 53 case buildPackage key of Nothing -> CryptoFailed CryptoError_KeySizeInvalid 54 Just keyPackage -> CryptoPassed Twofish { k = generatedK, s = generatedS } 55 where generatedK = array32 40 $ genK keyPackage 56 generatedS = genSboxes keyPackage $ sWords key 57 58mapBlocks :: ByteArray ba => (ba -> ba) -> ba -> ba 59mapBlocks operation input 60 | B.null rest = blockOutput 61 | otherwise = blockOutput `B.append` mapBlocks operation rest 62 where (block, rest) = B.splitAt blockSize input 63 blockOutput = operation block 64 65-- | Encrypts the given ByteString using the given Key 66encrypt :: ByteArray ba 67 => Twofish -- ^ The key to use 68 -> ba -- ^ The data to encrypt 69 -> ba 70encrypt cipher = mapBlocks (encryptBlock cipher) 71 72encryptBlock :: ByteArray ba => Twofish -> ba -> ba 73encryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ts 74 where (a, b, c, d) = load32ls message 75 a' = a `xor` arrayRead32 ks 0 76 b' = b `xor` arrayRead32 ks 1 77 c' = c `xor` arrayRead32 ks 2 78 d' = d `xor` arrayRead32 ks 3 79 (!a'', !b'', !c'', !d'') = foldl' shuffle (a', b', c', d') [0..7] 80 ts = (c'' `xor` arrayRead32 ks 4, d'' `xor` arrayRead32 ks 5, a'' `xor` arrayRead32 ks 6, b'' `xor` arrayRead32 ks 7) 81 82 shuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) 83 shuffle (!retA, !retB, !retC, !retD) ind = (retA', retB', retC', retD') 84 where [k0, k1, k2, k3] = fmap (\offset -> arrayRead32 ks $ (8 + 4 * ind) + offset) [0..3] 85 t2 = byteIndex s2 retB `xor` byteIndex s3 (shiftR retB 8) `xor` byteIndex s4 (shiftR retB 16) `xor` byteIndex s1 (shiftR retB 24) 86 t1 = (byteIndex s1 retA `xor` byteIndex s2 (shiftR retA 8) `xor` byteIndex s3 (shiftR retA 16) `xor` byteIndex s4 (shiftR retA 24)) + t2 87 retC' = rotateR (retC `xor` (t1 + k0)) 1 88 retD' = rotateL retD 1 `xor` (t1 + t2 + k1) 89 t2' = byteIndex s2 retD' `xor` byteIndex s3 (shiftR retD' 8) `xor` byteIndex s4 (shiftR retD' 16) `xor` byteIndex s1 (shiftR retD' 24) 90 t1' = (byteIndex s1 retC' `xor` byteIndex s2 (shiftR retC' 8) `xor` byteIndex s3 (shiftR retC' 16) `xor` byteIndex s4 (shiftR retC' 24)) + t2' 91 retA' = rotateR (retA `xor` (t1' + k2)) 1 92 retB' = rotateL retB 1 `xor` (t1' + t2' + k3) 93 94-- Unsafe, no bounds checking 95byteIndex :: Array32 -> Word32 -> Word32 96byteIndex xs ind = arrayRead32 xs $ fromIntegral byte 97 where byte = ind `mod` 256 98 99-- | Decrypts the given ByteString using the given Key 100decrypt :: ByteArray ba 101 => Twofish -- ^ The key to use 102 -> ba -- ^ The data to decrypt 103 -> ba 104decrypt cipher = mapBlocks (decryptBlock cipher) 105 106{- decryption for 128 bits blocks -} 107decryptBlock :: ByteArray ba => Twofish -> ba -> ba 108decryptBlock Twofish { s = (s1, s2, s3, s4), k = ks } message = store32ls ixs 109 where (a, b, c, d) = load32ls message 110 a' = c `xor` arrayRead32 ks 6 111 b' = d `xor` arrayRead32 ks 7 112 c' = a `xor` arrayRead32 ks 4 113 d' = b `xor` arrayRead32 ks 5 114 (!a'', !b'', !c'', !d'') = foldl' unshuffle (a', b', c', d') [8, 7..1] 115 ixs = (a'' `xor` arrayRead32 ks 0, b'' `xor` arrayRead32 ks 1, c'' `xor` arrayRead32 ks 2, d'' `xor` arrayRead32 ks 3) 116 117 unshuffle :: (Word32, Word32, Word32, Word32) -> Int -> (Word32, Word32, Word32, Word32) 118 unshuffle (!retA, !retB, !retC, !retD) ind = (retA', retB', retC', retD') 119 where [k0, k1, k2, k3] = fmap (\offset -> arrayRead32 ks $ (4 + 4 * ind) + offset) [0..3] 120 t2 = byteIndex s2 retD `xor` byteIndex s3 (shiftR retD 8) `xor` byteIndex s4 (shiftR retD 16) `xor` byteIndex s1 (shiftR retD 24) 121 t1 = (byteIndex s1 retC `xor` byteIndex s2 (shiftR retC 8) `xor` byteIndex s3 (shiftR retC 16) `xor` byteIndex s4 (shiftR retC 24)) + t2 122 retA' = rotateL retA 1 `xor` (t1 + k2) 123 retB' = rotateR (retB `xor` (t2 + t1 + k3)) 1 124 t2' = byteIndex s2 retB' `xor` byteIndex s3 (shiftR retB' 8) `xor` byteIndex s4 (shiftR retB' 16) `xor` byteIndex s1 (shiftR retB' 24) 125 t1' = (byteIndex s1 retA' `xor` byteIndex s2 (shiftR retA' 8) `xor` byteIndex s3 (shiftR retA' 16) `xor` byteIndex s4 (shiftR retA' 24)) + t2' 126 retC' = rotateL retC 1 `xor` (t1' + k0) 127 retD' = rotateR (retD `xor` (t2' + t1' + k1)) 1 128 129sbox0 :: Int -> Word8 130sbox0 = arrayRead8 t 131 where t = array8 132 "\xa9\x67\xb3\xe8\x04\xfd\xa3\x76\x9a\x92\x80\x78\xe4\xdd\xd1\x38\ 133 \\x0d\xc6\x35\x98\x18\xf7\xec\x6c\x43\x75\x37\x26\xfa\x13\x94\x48\ 134 \\xf2\xd0\x8b\x30\x84\x54\xdf\x23\x19\x5b\x3d\x59\xf3\xae\xa2\x82\ 135 \\x63\x01\x83\x2e\xd9\x51\x9b\x7c\xa6\xeb\xa5\xbe\x16\x0c\xe3\x61\ 136 \\xc0\x8c\x3a\xf5\x73\x2c\x25\x0b\xbb\x4e\x89\x6b\x53\x6a\xb4\xf1\ 137 \\xe1\xe6\xbd\x45\xe2\xf4\xb6\x66\xcc\x95\x03\x56\xd4\x1c\x1e\xd7\ 138 \\xfb\xc3\x8e\xb5\xe9\xcf\xbf\xba\xea\x77\x39\xaf\x33\xc9\x62\x71\ 139 \\x81\x79\x09\xad\x24\xcd\xf9\xd8\xe5\xc5\xb9\x4d\x44\x08\x86\xe7\ 140 \\xa1\x1d\xaa\xed\x06\x70\xb2\xd2\x41\x7b\xa0\x11\x31\xc2\x27\x90\ 141 \\x20\xf6\x60\xff\x96\x5c\xb1\xab\x9e\x9c\x52\x1b\x5f\x93\x0a\xef\ 142 \\x91\x85\x49\xee\x2d\x4f\x8f\x3b\x47\x87\x6d\x46\xd6\x3e\x69\x64\ 143 \\x2a\xce\xcb\x2f\xfc\x97\x05\x7a\xac\x7f\xd5\x1a\x4b\x0e\xa7\x5a\ 144 \\x28\x14\x3f\x29\x88\x3c\x4c\x02\xb8\xda\xb0\x17\x55\x1f\x8a\x7d\ 145 \\x57\xc7\x8d\x74\xb7\xc4\x9f\x72\x7e\x15\x22\x12\x58\x07\x99\x34\ 146 \\x6e\x50\xde\x68\x65\xbc\xdb\xf8\xc8\xa8\x2b\x40\xdc\xfe\x32\xa4\ 147 \\xca\x10\x21\xf0\xd3\x5d\x0f\x00\x6f\x9d\x36\x42\x4a\x5e\xc1\xe0"# 148 149sbox1 :: Int -> Word8 150sbox1 = arrayRead8 t 151 where t = array8 152 "\x75\xf3\xc6\xf4\xdb\x7b\xfb\xc8\x4a\xd3\xe6\x6b\x45\x7d\xe8\x4b\ 153 \\xd6\x32\xd8\xfd\x37\x71\xf1\xe1\x30\x0f\xf8\x1b\x87\xfa\x06\x3f\ 154 \\x5e\xba\xae\x5b\x8a\x00\xbc\x9d\x6d\xc1\xb1\x0e\x80\x5d\xd2\xd5\ 155 \\xa0\x84\x07\x14\xb5\x90\x2c\xa3\xb2\x73\x4c\x54\x92\x74\x36\x51\ 156 \\x38\xb0\xbd\x5a\xfc\x60\x62\x96\x6c\x42\xf7\x10\x7c\x28\x27\x8c\ 157 \\x13\x95\x9c\xc7\x24\x46\x3b\x70\xca\xe3\x85\xcb\x11\xd0\x93\xb8\ 158 \\xa6\x83\x20\xff\x9f\x77\xc3\xcc\x03\x6f\x08\xbf\x40\xe7\x2b\xe2\ 159 \\x79\x0c\xaa\x82\x41\x3a\xea\xb9\xe4\x9a\xa4\x97\x7e\xda\x7a\x17\ 160 \\x66\x94\xa1\x1d\x3d\xf0\xde\xb3\x0b\x72\xa7\x1c\xef\xd1\x53\x3e\ 161 \\x8f\x33\x26\x5f\xec\x76\x2a\x49\x81\x88\xee\x21\xc4\x1a\xeb\xd9\ 162 \\xc5\x39\x99\xcd\xad\x31\x8b\x01\x18\x23\xdd\x1f\x4e\x2d\xf9\x48\ 163 \\x4f\xf2\x65\x8e\x78\x5c\x58\x19\x8d\xe5\x98\x57\x67\x7f\x05\x64\ 164 \\xaf\x63\xb6\xfe\xf5\xb7\x3c\xa5\xce\xe9\x68\x44\xe0\x4d\x43\x69\ 165 \\x29\x2e\xac\x15\x59\xa8\x0a\x9e\x6e\x47\xdf\x34\x35\x6a\xcf\xdc\ 166 \\x22\xc9\xc0\x9b\x89\xd4\xed\xab\x12\xa2\x0d\x52\xbb\x02\x2f\xa9\ 167 \\xd7\x61\x1e\xb4\x50\x04\xf6\xc2\x16\x25\x86\x56\x55\x09\xbe\x91"# 168 169rs :: [[Word8]] 170rs = [ [0x01, 0xA4, 0x55, 0x87, 0x5A, 0x58, 0xDB, 0x9E] 171 , [0xA4, 0x56, 0x82, 0xF3, 0x1E, 0xC6, 0x68, 0xE5] 172 , [0x02, 0xA1, 0xFC, 0xC1, 0x47, 0xAE, 0x3D, 0x19] 173 , [0xA4, 0x55, 0x87, 0x5A, 0x58, 0xDB, 0x9E, 0x03] ] 174 175 176 177load32ls :: ByteArray ba => ba -> (Word32, Word32, Word32, Word32) 178load32ls message = (intify q1, intify q2, intify q3, intify q4) 179 where (half1, half2) = B.splitAt 8 message 180 (q1, q2) = B.splitAt 4 half1 181 (q3, q4) = B.splitAt 4 half2 182 183 intify :: ByteArray ba => ba -> Word32 184 intify bytes = foldl' (\int (!word, !ind) -> int .|. shiftL (fromIntegral word) (ind * 8) ) 0 (zip (B.unpack bytes) [0..]) 185 186store32ls :: ByteArray ba => (Word32, Word32, Word32, Word32) -> ba 187store32ls (a, b, c, d) = B.pack $ concatMap splitWordl [a, b, c, d] 188 where splitWordl :: Word32 -> [Word8] 189 splitWordl w = fmap (\ind -> fromIntegral $ shiftR w (8 * ind)) [0..3] 190 191 192-- Create S words 193sWords :: ByteArray ba => ba -> [Word8] 194sWords key = sWord 195 where word64Count = B.length key `div` 2 196 sWord = concatMap (\wordIndex -> 197 map (\rsRow -> 198 foldl' (\acc (!rsVal, !colIndex) -> 199 acc `xor` gfMult rsPolynomial (B.index key $ 8 * wordIndex + colIndex) rsVal 200 ) 0 (zip rsRow [0..]) 201 ) rs 202 ) [0..word64Count - 1] 203 204data Column = Zero | One | Two | Three deriving (Show, Eq, Enum, Bounded) 205 206genSboxes :: KeyPackage ba -> [Word8] -> (Array32, Array32, Array32, Array32) 207genSboxes keyPackage ws = (mkArray b0', mkArray b1', mkArray b2', mkArray b3') 208 where range = [0..255] 209 mkArray = array32 256 210 [w0, w1, w2, w3, w4, w5, w6, w7, w8, w9, w10, w11, w12, w13, w14, w15] = take 16 ws 211 (b0', b1', b2', b3') = sboxBySize $ byteSize keyPackage 212 213 sboxBySize :: ByteSize -> ([Word32], [Word32], [Word32], [Word32]) 214 sboxBySize Bytes16 = (b0, b1, b2, b3) 215 where !b0 = fmap mapper range 216 where mapper :: Int -> Word32 217 mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` w0) `xor` w4)) Zero 218 !b1 = fmap mapper range 219 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox0 . fromIntegral $ sbox1 byte `xor` w1) `xor` w5)) One 220 !b2 = fmap mapper range 221 where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox1 . fromIntegral $ sbox0 byte `xor` w2) `xor` w6)) Two 222 !b3 = fmap mapper range 223 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox1 . fromIntegral $ sbox1 byte `xor` w3) `xor` w7)) Three 224 225 sboxBySize Bytes24 = (b0, b1, b2, b3) 226 where !b0 = fmap mapper range 227 where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral) ((sbox0 . fromIntegral $ sbox1 byte `xor` w0) `xor` w4) `xor` w8)) Zero 228 !b1 = fmap mapper range 229 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox0 . fromIntegral) ((sbox1 . fromIntegral $ sbox1 byte `xor` w1) `xor` w5) `xor` w9)) One 230 !b2 = fmap mapper range 231 where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` w2) `xor` w6) `xor` w10)) Two 232 !b3 = fmap mapper range 233 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox1 . fromIntegral) ((sbox1 . fromIntegral $ sbox0 byte `xor` w3) `xor` w7) `xor` w11)) Three 234 235 sboxBySize Bytes32 = (b0, b1, b2, b3) 236 where !b0 = fmap mapper range 237 where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox0 . fromIntegral) ((sbox0 . fromIntegral) ((sbox1 . fromIntegral $ sbox1 byte `xor` w0) `xor` w4) `xor` w8) `xor` w12)) Zero 238 !b1 = fmap mapper range 239 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox0 . fromIntegral) ((sbox1 . fromIntegral) ((sbox1 . fromIntegral $ sbox0 byte `xor` w1) `xor` w5) `xor` w9) `xor` w13)) One 240 !b2 = fmap mapper range 241 where mapper byte = mdsColumnMult ((sbox1 . fromIntegral) ((sbox1 . fromIntegral) ((sbox0 . fromIntegral) ((sbox0 . fromIntegral $ sbox0 byte `xor` w2) `xor` w6) `xor` w10) `xor` w14)) Two 242 !b3 = fmap mapper range 243 where mapper byte = mdsColumnMult ((sbox0 . fromIntegral) ((sbox1 . fromIntegral) ((sbox1 . fromIntegral) ((sbox0 . fromIntegral $ sbox1 byte `xor` w3) `xor` w7) `xor` w11) `xor` w15)) Three 244 245genK :: (ByteArray ba) => KeyPackage ba -> [Word32] 246genK keyPackage = concatMap makeTuple [0..19] 247 where makeTuple :: Word8 -> [Word32] 248 makeTuple idx = [a + b', rotateL (2 * b' + a) 9] 249 where tmp1 = replicate 4 $ 2 * idx 250 tmp2 = fmap (+1) tmp1 251 a = h tmp1 keyPackage 0 252 b = h tmp2 keyPackage 1 253 b' = rotateL b 8 254 255h :: (ByteArray ba) => [Word8] -> KeyPackage ba -> Int -> Word32 256h input keyPackage offset = foldl' xorMdsColMult 0 $ zip [y0f, y1f, y2f, y3f] $ enumFrom Zero 257 where key = rawKeyBytes keyPackage 258 [y0, y1, y2, y3] = take 4 input 259 (!y0f, !y1f, !y2f, !y3f) = run (y0, y1, y2, y3) $ byteSize keyPackage 260 261 run :: (Word8, Word8, Word8, Word8) -> ByteSize -> (Word8, Word8, Word8, Word8) 262 run (!y0'', !y1'', !y2'', !y3'') Bytes32 = run (y0', y1', y2', y3') Bytes24 263 where y0' = sbox1 (fromIntegral y0'') `xor` B.index key (4 * (6 + offset) + 0) 264 y1' = sbox0 (fromIntegral y1'') `xor` B.index key (4 * (6 + offset) + 1) 265 y2' = sbox0 (fromIntegral y2'') `xor` B.index key (4 * (6 + offset) + 2) 266 y3' = sbox1 (fromIntegral y3'') `xor` B.index key (4 * (6 + offset) + 3) 267 268 run (!y0'', !y1'', !y2'', !y3'') Bytes24 = run (y0', y1', y2', y3') Bytes16 269 where y0' = sbox1 (fromIntegral y0'') `xor` B.index key (4 * (4 + offset) + 0) 270 y1' = sbox1 (fromIntegral y1'') `xor` B.index key (4 * (4 + offset) + 1) 271 y2' = sbox0 (fromIntegral y2'') `xor` B.index key (4 * (4 + offset) + 2) 272 y3' = sbox0 (fromIntegral y3'') `xor` B.index key (4 * (4 + offset) + 3) 273 274 run (!y0'', !y1'', !y2'', !y3'') Bytes16 = (y0', y1', y2', y3') 275 where y0' = sbox1 . fromIntegral $ (sbox0 . fromIntegral $ (sbox0 (fromIntegral y0'') `xor` B.index key (4 * (2 + offset) + 0))) `xor` B.index key (4 * (0 + offset) + 0) 276 y1' = sbox0 . fromIntegral $ (sbox0 . fromIntegral $ (sbox1 (fromIntegral y1'') `xor` B.index key (4 * (2 + offset) + 1))) `xor` B.index key (4 * (0 + offset) + 1) 277 y2' = sbox1 . fromIntegral $ (sbox1 . fromIntegral $ (sbox0 (fromIntegral y2'') `xor` B.index key (4 * (2 + offset) + 2))) `xor` B.index key (4 * (0 + offset) + 2) 278 y3' = sbox0 . fromIntegral $ (sbox1 . fromIntegral $ (sbox1 (fromIntegral y3'') `xor` B.index key (4 * (2 + offset) + 3))) `xor` B.index key (4 * (0 + offset) + 3) 279 280 xorMdsColMult :: Word32 -> (Word8, Column) -> Word32 281 xorMdsColMult acc wordAndIndex = acc `xor` uncurry mdsColumnMult wordAndIndex 282 283mdsColumnMult :: Word8 -> Column -> Word32 284mdsColumnMult !byte !col = 285 case col of Zero -> input .|. rotateL mul5B 8 .|. rotateL mulEF 16 .|. rotateL mulEF 24 286 One -> mulEF .|. rotateL mulEF 8 .|. rotateL mul5B 16 .|. rotateL input 24 287 Two -> mul5B .|. rotateL mulEF 8 .|. rotateL input 16 .|. rotateL mulEF 24 288 Three -> mul5B .|. rotateL input 8 .|. rotateL mulEF 16 .|. rotateL mul5B 24 289 where input = fromIntegral byte 290 mul5B = fromIntegral $ gfMult mdsPolynomial byte 0x5B 291 mulEF = fromIntegral $ gfMult mdsPolynomial byte 0xEF 292 293tupInd :: (Bits b) => b -> (a, a) -> a 294tupInd b 295 | testBit b 0 = snd 296 | otherwise = fst 297 298gfMult :: Word32 -> Word8 -> Word8 -> Word8 299gfMult p a b = fromIntegral $ run a b' p' result 0 300 where b' = (0, fromIntegral b) 301 p' = (0, p) 302 result = 0 303 304 run :: Word8 -> (Word32, Word32) -> (Word32, Word32) -> Word32 -> Int -> Word32 305 run a' b'' p'' result' count = 306 if count == 7 307 then result'' 308 else run a'' b''' p'' result'' (count + 1) 309 where result'' = result' `xor` tupInd (a' .&. 1) b'' 310 a'' = shiftR a' 1 311 b''' = (fst b'', tupInd (shiftR (snd b'') 7) p'' `xor` shiftL (snd b'') 1) 312