1{-# LANGUAGE ScopedTypeVariables #-} 2module Network.TLS.Util 3 ( sub 4 , takelast 5 , partition3 6 , partition6 7 , fromJust 8 , (&&!) 9 , bytesEq 10 , fmapEither 11 , catchException 12 , forEitherM 13 , mapChunks_ 14 , getChunks 15 , Saved 16 , saveMVar 17 , restoreMVar 18 ) where 19 20import qualified Data.ByteArray as BA 21import qualified Data.ByteString as B 22import Network.TLS.Imports 23 24import Control.Exception (SomeException) 25import Control.Concurrent.Async 26import Control.Concurrent.MVar 27 28sub :: ByteString -> Int -> Int -> Maybe ByteString 29sub b offset len 30 | B.length b < offset + len = Nothing 31 | otherwise = Just $ B.take len $ snd $ B.splitAt offset b 32 33takelast :: Int -> ByteString -> Maybe ByteString 34takelast i b 35 | B.length b >= i = sub b (B.length b - i) i 36 | otherwise = Nothing 37 38partition3 :: ByteString -> (Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString) 39partition3 bytes (d1,d2,d3) 40 | any (< 0) l = Nothing 41 | sum l /= B.length bytes = Nothing 42 | otherwise = Just (p1,p2,p3) 43 where l = [d1,d2,d3] 44 (p1, r1) = B.splitAt d1 bytes 45 (p2, r2) = B.splitAt d2 r1 46 (p3, _) = B.splitAt d3 r2 47 48partition6 :: ByteString -> (Int,Int,Int,Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString) 49partition6 bytes (d1,d2,d3,d4,d5,d6) = if B.length bytes < s then Nothing else Just (p1,p2,p3,p4,p5,p6) 50 where s = sum [d1,d2,d3,d4,d5,d6] 51 (p1, r1) = B.splitAt d1 bytes 52 (p2, r2) = B.splitAt d2 r1 53 (p3, r3) = B.splitAt d3 r2 54 (p4, r4) = B.splitAt d4 r3 55 (p5, r5) = B.splitAt d5 r4 56 (p6, _) = B.splitAt d6 r5 57 58fromJust :: String -> Maybe a -> a 59fromJust what Nothing = error ("fromJust " ++ what ++ ": Nothing") -- yuck 60fromJust _ (Just x) = x 61 62-- | This is a strict version of &&. 63(&&!) :: Bool -> Bool -> Bool 64True &&! True = True 65True &&! False = False 66False &&! True = False 67False &&! False = False 68 69-- | verify that 2 bytestrings are equals. 70-- it's a non lazy version, that will compare every bytes. 71-- arguments with different length will bail out early 72bytesEq :: ByteString -> ByteString -> Bool 73bytesEq = BA.constEq 74 75fmapEither :: (a -> b) -> Either l a -> Either l b 76fmapEither f = fmap f 77 78catchException :: IO a -> (SomeException -> IO a) -> IO a 79catchException action handler = withAsync action waitCatch >>= either handler return 80 81forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b]) 82forEitherM [] _ = return (pure []) 83forEitherM (x:xs) f = f x >>= doTail 84 where 85 doTail (Right b) = fmap (b :) <$> forEitherM xs f 86 doTail (Left e) = return (Left e) 87 88mapChunks_ :: Monad m 89 => Maybe Int -> (B.ByteString -> m a) -> B.ByteString -> m () 90mapChunks_ len f = mapM_ f . getChunks len 91 92getChunks :: Maybe Int -> B.ByteString -> [B.ByteString] 93getChunks Nothing = (: []) 94getChunks (Just len) = go 95 where 96 go bs | B.length bs > len = 97 let (chunk, remain) = B.splitAt len bs 98 in chunk : go remain 99 | otherwise = [bs] 100 101-- | An opaque newtype wrapper to prevent from poking inside content that has 102-- been saved. 103newtype Saved a = Saved a 104 105-- | Save the content of an 'MVar' to restore it later. 106saveMVar :: MVar a -> IO (Saved a) 107saveMVar ref = Saved <$> readMVar ref 108 109-- | Restore the content of an 'MVar' to a previous saved value and return the 110-- content that has just been replaced. 111restoreMVar :: MVar a -> Saved a -> IO (Saved a) 112restoreMVar ref (Saved val) = Saved <$> swapMVar ref val 113