1-- Copyright (C) 2010-2012 John Millikin <john@john-millikin.com>
2--
3-- Licensed under the Apache License, Version 2.0 (the "License");
4-- you may not use this file except in compliance with the License.
5-- You may obtain a copy of the License at
6--
7--     http://www.apache.org/licenses/LICENSE-2.0
8--
9-- Unless required by applicable law or agreed to in writing, software
10-- distributed under the License is distributed on an "AS IS" BASIS,
11-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-- See the License for the specific language governing permissions and
13-- limitations under the License.
14
15module DBusTests.Util
16    ( assertVariant
17    , assertValue
18    , assertAtom
19    , assertException
20    , assertThrows
21
22    , getTempPath
23    , listenRandomUnixPath
24    , listenRandomUnixAbstract
25    , listenRandomIPv4
26    , listenRandomIPv6
27    , noIPv6
28    , forkVar
29
30    , withEnv
31    , countFileDescriptors
32
33    , dropWhileEnd
34
35    , halfSized
36    , clampedSize
37    , smallListOf
38    , smallListOf1
39
40    , DBusTests.Util.requireLeft
41    , DBusTests.Util.requireRight
42    ) where
43
44import Control.Concurrent
45import Control.Exception (Exception, IOException, try, bracket, bracket_)
46import Control.Monad.IO.Class (MonadIO, liftIO)
47import Control.Monad.Trans.Resource
48import Data.Bits ((.&.))
49import Data.Char (chr)
50import System.Directory (getTemporaryDirectory, removeFile)
51import System.FilePath ((</>))
52import Test.QuickCheck hiding ((.&.))
53import Test.Tasty.HUnit
54import qualified Data.ByteString
55import qualified Data.ByteString.Lazy
56import qualified Data.Map as Map
57import qualified Data.Text as T
58import qualified Network.Socket as NS
59import qualified System.Posix as Posix
60
61import DBus
62import DBus.Internal.Types
63
64assertVariant :: (Eq a, Show a, IsVariant a) => Type -> a -> Test.Tasty.HUnit.Assertion
65assertVariant t a = do
66    t @=? variantType (toVariant a)
67    Just a @=? fromVariant (toVariant a)
68    toVariant a @=? toVariant a
69
70assertValue :: (Eq a, Show a, IsValue a) => Type -> a -> Test.Tasty.HUnit.Assertion
71assertValue t a = do
72    t @=? DBus.typeOf a
73    t @=? DBus.Internal.Types.typeOf a
74    t @=? valueType (toValue a)
75    fromValue (toValue a) @?= Just a
76    toValue a @=? toValue a
77    assertVariant t a
78
79assertAtom :: (Eq a, Show a, IsAtom a) => Type -> a -> Test.Tasty.HUnit.Assertion
80assertAtom t a = do
81    t @=? (atomType (toAtom a))
82    fromAtom (toAtom a) @?= (Just a)
83    toAtom a @=? toAtom a
84    assertValue t a
85
86getTempPath :: IO String
87getTempPath = do
88    tmp <- getTemporaryDirectory
89    uuid <- randomUUID
90    return (tmp </> formatUUID uuid)
91
92listenRandomUnixPath :: MonadResource m => m Address
93listenRandomUnixPath = do
94    path <- liftIO getTempPath
95
96    let sockAddr = NS.SockAddrUnix path
97    (_, sock) <- allocate
98        (NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol)
99        NS.close
100    liftIO (NS.bind sock sockAddr)
101    liftIO (NS.listen sock 1)
102    _ <- register (removeFile path)
103
104    let Just addr = address "unix" (Map.fromList
105            [ ("path", path)
106            ])
107    return addr
108
109listenRandomUnixAbstract :: MonadResource m => m (Address, ReleaseKey)
110listenRandomUnixAbstract = do
111    uuid <- liftIO randomUUID
112    let sockAddr = NS.SockAddrUnix ('\x00' : formatUUID uuid)
113
114    (key, sock) <- allocate
115        (NS.socket NS.AF_UNIX NS.Stream NS.defaultProtocol)
116        NS.close
117
118    liftIO $ NS.bind sock sockAddr
119    liftIO $ NS.listen sock 1
120
121    let Just addr = address "unix" (Map.fromList
122            [ ("abstract", formatUUID uuid)
123            ])
124    return (addr, key)
125
126listenRandomIPv4 :: MonadResource m => m (Address, NS.Socket, ReleaseKey)
127listenRandomIPv4 = do
128    let hints = NS.defaultHints
129            { NS.addrFlags = [NS.AI_NUMERICHOST]
130            , NS.addrFamily = NS.AF_INET
131            , NS.addrSocketType = NS.Stream
132            }
133    hostAddr <- liftIO $ NS.getAddrInfo (Just hints) (Just "127.0.0.1") Nothing
134    let sockAddr = NS.addrAddress $ head hostAddr
135
136    (key, sock) <- allocate
137        (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
138        NS.close
139    liftIO $ NS.bind sock sockAddr
140    liftIO $ NS.listen sock 1
141
142    sockPort <- liftIO $ NS.socketPort sock
143    let Just addr = address "tcp" (Map.fromList
144            [ ("family", "ipv4")
145            , ("host", "localhost")
146            , ("port", show (toInteger sockPort))
147            ])
148    return (addr, sock, key)
149
150listenRandomIPv6 :: MonadResource m => m Address
151listenRandomIPv6 = do
152    addrs <- liftIO $ NS.getAddrInfo Nothing (Just "::1") Nothing
153    let sockAddr = case addrs of
154            [] -> error "listenRandomIPv6: no address for localhost?"
155            a:_ -> NS.addrAddress a
156
157    (_, sock) <- allocate
158        (NS.socket NS.AF_INET6 NS.Stream NS.defaultProtocol)
159        NS.close
160    liftIO $ NS.bind sock sockAddr
161    liftIO $ NS.listen sock 1
162
163    sockPort <- liftIO $ NS.socketPort sock
164    let Just addr = address "tcp" (Map.fromList
165            [ ("family", "ipv6")
166            , ("host", "::1")
167            , ("port", show (toInteger sockPort))
168            ])
169    return addr
170
171noIPv6 :: IO Bool
172noIPv6 = do
173    tried <- try (NS.getAddrInfo Nothing (Just "::1") Nothing)
174    case (tried :: Either IOException [NS.AddrInfo]) of
175        Left _ -> return True
176        Right addrs -> return (null addrs)
177
178forkVar :: MonadIO m => IO a -> m (MVar a)
179forkVar io = liftIO $ do
180    var <- newEmptyMVar
181    _ <- forkIO (io >>= putMVar var)
182    return var
183
184withEnv :: MonadIO m => String -> Maybe String -> IO a -> m a
185withEnv name value io = liftIO $ do
186    let set val = case val of
187            Just x -> Posix.setEnv name x True
188            Nothing -> Posix.unsetEnv name
189    old <- Posix.getEnv name
190    bracket_ (set value) (set old) io
191
192countFileDescriptors :: MonadIO m => m Int
193countFileDescriptors = liftIO io where
194    io = do
195        pid <- Posix.getProcessID
196        let fdDir = "/proc/" ++ show pid ++ "/fd"
197        bracket (Posix.openDirStream fdDir) Posix.closeDirStream countDirEntries
198    countDirEntries dir = loop 0 where
199        loop n = do
200            name <- Posix.readDirStream dir
201            if null name
202                then return n
203                else loop (n + 1)
204
205halfSized :: Gen a -> Gen a
206halfSized gen = sized (\n -> if n > 0
207    then resize (div n 2) gen
208    else gen)
209
210smallListOf :: Gen a -> Gen [a]
211smallListOf gen = clampedSize 10 (listOf gen)
212
213smallListOf1 :: Gen a -> Gen [a]
214smallListOf1 gen = clampedSize 10 (listOf1 gen)
215
216clampedSize :: Int -> Gen a -> Gen a
217clampedSize maxN gen = sized (\n -> resize (min n maxN) gen)
218
219instance Arbitrary T.Text where
220    arbitrary = fmap T.pack genUnicode
221
222genUnicode :: Gen [Char]
223genUnicode = string where
224    string = sized $ \n -> do
225        k <- choose (0,n)
226        sequence [ char | _ <- [1..k] ]
227
228    excluding :: [a -> Bool] -> Gen a -> Gen a
229    excluding bad gen = loop where
230        loop = do
231            x <- gen
232            if or (map ($ x) bad)
233                then loop
234                else return x
235
236    reserved = [lowSurrogate, highSurrogate, noncharacter]
237    lowSurrogate c = c >= 0xDC00 && c <= 0xDFFF
238    highSurrogate c = c >= 0xD800 && c <= 0xDBFF
239    noncharacter c = masked == 0xFFFE || masked == 0xFFFF where
240        masked = c .&. 0xFFFF
241
242    ascii = choose (0x20, 0x7F)
243    plane0 = choose (0xF0, 0xFFFF)
244    plane1 = oneof [ choose (0x10000, 0x10FFF)
245                   , choose (0x11000, 0x11FFF)
246                   , choose (0x12000, 0x12FFF)
247                   , choose (0x13000, 0x13FFF)
248                   , choose (0x1D000, 0x1DFFF)
249                   , choose (0x1F000, 0x1FFFF)
250                   ]
251    plane2 = oneof [ choose (0x20000, 0x20FFF)
252                   , choose (0x21000, 0x21FFF)
253                   , choose (0x22000, 0x22FFF)
254                   , choose (0x23000, 0x23FFF)
255                   , choose (0x24000, 0x24FFF)
256                   , choose (0x25000, 0x25FFF)
257                   , choose (0x26000, 0x26FFF)
258                   , choose (0x27000, 0x27FFF)
259                   , choose (0x28000, 0x28FFF)
260                   , choose (0x29000, 0x29FFF)
261                   , choose (0x2A000, 0x2AFFF)
262                   , choose (0x2B000, 0x2BFFF)
263                   , choose (0x2F000, 0x2FFFF)
264                   ]
265    plane14 = choose (0xE0000, 0xE0FFF)
266    planes = [ascii, plane0, plane1, plane2, plane14]
267
268    char = chr `fmap` excluding reserved (oneof planes)
269
270instance Arbitrary Data.ByteString.ByteString where
271    arbitrary = fmap Data.ByteString.pack arbitrary
272
273instance Arbitrary Data.ByteString.Lazy.ByteString where
274    arbitrary = fmap Data.ByteString.Lazy.fromChunks arbitrary
275
276dropWhileEnd :: (Char -> Bool) -> String -> String
277dropWhileEnd p = T.unpack . T.dropWhileEnd p . T.pack
278
279requireLeft :: Show b => Either a b -> IO a
280requireLeft (Left a) = return a
281requireLeft (Right b) = assertFailure ("Right " ++ show b ++ " is not Left") >> undefined
282
283requireRight :: Show a => Either a b -> IO b
284requireRight (Right b) = return b
285requireRight (Left a) = assertFailure ("Left " ++ show a ++ " is not Right") >> undefined
286
287assertException :: (Eq e, Exception e) => e -> IO a -> Test.Tasty.HUnit.Assertion
288assertException e f = do
289    result <- try f
290    case result of
291        Left ex -> ex @?= e
292        Right _ -> assertFailure "expected exception not thrown"
293
294assertThrows :: Exception e => (e -> Bool) -> IO a -> Test.Tasty.HUnit.Assertion
295assertThrows check f = do
296    result <- try f
297    case result of
298        Left ex -> assertBool ("unexpected exception " ++ show ex) (check ex)
299        Right _ -> assertFailure "expected exception not thrown"
300