1{-# LANGUAGE ScopedTypeVariables #-}
2module Network.TLS.Util
3        ( sub
4        , takelast
5        , partition3
6        , partition6
7        , fromJust
8        , (&&!)
9        , bytesEq
10        , fmapEither
11        , catchException
12        , forEitherM
13        , mapChunks_
14        , getChunks
15        , Saved
16        , saveMVar
17        , restoreMVar
18        ) where
19
20import qualified Data.ByteArray as BA
21import qualified Data.ByteString as B
22import Network.TLS.Imports
23
24import Control.Exception (SomeException)
25import Control.Concurrent.Async
26import Control.Concurrent.MVar
27
28sub :: ByteString -> Int -> Int -> Maybe ByteString
29sub b offset len
30    | B.length b < offset + len = Nothing
31    | otherwise                 = Just $ B.take len $ snd $ B.splitAt offset b
32
33takelast :: Int -> ByteString -> Maybe ByteString
34takelast i b
35    | B.length b >= i = sub b (B.length b - i) i
36    | otherwise       = Nothing
37
38partition3 :: ByteString -> (Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString)
39partition3 bytes (d1,d2,d3)
40    | any (< 0) l             = Nothing
41    | sum l /= B.length bytes = Nothing
42    | otherwise               = Just (p1,p2,p3)
43        where l        = [d1,d2,d3]
44              (p1, r1) = B.splitAt d1 bytes
45              (p2, r2) = B.splitAt d2 r1
46              (p3, _)  = B.splitAt d3 r2
47
48partition6 :: ByteString -> (Int,Int,Int,Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString)
49partition6 bytes (d1,d2,d3,d4,d5,d6) = if B.length bytes < s then Nothing else Just (p1,p2,p3,p4,p5,p6)
50  where s        = sum [d1,d2,d3,d4,d5,d6]
51        (p1, r1) = B.splitAt d1 bytes
52        (p2, r2) = B.splitAt d2 r1
53        (p3, r3) = B.splitAt d3 r2
54        (p4, r4) = B.splitAt d4 r3
55        (p5, r5) = B.splitAt d5 r4
56        (p6, _)  = B.splitAt d6 r5
57
58fromJust :: String -> Maybe a -> a
59fromJust what Nothing  = error ("fromJust " ++ what ++ ": Nothing") -- yuck
60fromJust _    (Just x) = x
61
62-- | This is a strict version of &&.
63(&&!) :: Bool -> Bool -> Bool
64True  &&! True  = True
65True  &&! False = False
66False &&! True  = False
67False &&! False = False
68
69-- | verify that 2 bytestrings are equals.
70-- it's a non lazy version, that will compare every bytes.
71-- arguments with different length will bail out early
72bytesEq :: ByteString -> ByteString -> Bool
73bytesEq = BA.constEq
74
75fmapEither :: (a -> b) -> Either l a -> Either l b
76fmapEither f = fmap f
77
78catchException :: IO a -> (SomeException -> IO a) -> IO a
79catchException action handler = withAsync action waitCatch >>= either handler return
80
81forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b])
82forEitherM []     _ = return (pure [])
83forEitherM (x:xs) f = f x >>= doTail
84  where
85    doTail (Right b) = fmap (b :) <$> forEitherM xs f
86    doTail (Left e)  = return (Left e)
87
88mapChunks_ :: Monad m
89           => Maybe Int -> (B.ByteString -> m a) -> B.ByteString -> m ()
90mapChunks_ len f = mapM_ f . getChunks len
91
92getChunks :: Maybe Int -> B.ByteString -> [B.ByteString]
93getChunks Nothing    = (: [])
94getChunks (Just len) = go
95  where
96    go bs | B.length bs > len =
97              let (chunk, remain) = B.splitAt len bs
98               in chunk : go remain
99          | otherwise = [bs]
100
101-- | An opaque newtype wrapper to prevent from poking inside content that has
102-- been saved.
103newtype Saved a = Saved a
104
105-- | Save the content of an 'MVar' to restore it later.
106saveMVar :: MVar a -> IO (Saved a)
107saveMVar ref = Saved <$> readMVar ref
108
109-- | Restore the content of an 'MVar' to a previous saved value and return the
110-- content that has just been replaced.
111restoreMVar :: MVar a -> Saved a -> IO (Saved a)
112restoreMVar ref (Saved val) = Saved <$> swapMVar ref val
113