1{-# LANGUAGE OverloadedStrings #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE DeriveDataTypeable #-}
4module Network.Wai.Test
5    ( -- * Session
6      Session
7    , runSession
8      -- * Client Cookies
9    , ClientCookies
10    , getClientCookies
11    , modifyClientCookies
12    , setClientCookie
13    , deleteClientCookie
14      -- * Requests
15    , request
16    , srequest
17    , SRequest (..)
18    , SResponse (..)
19    , defaultRequest
20    , setPath
21    , setRawPathInfo
22      -- * Assertions
23    , assertStatus
24    , assertContentType
25    , assertBody
26    , assertBodyContains
27    , assertHeader
28    , assertNoHeader
29    , assertClientCookieExists
30    , assertNoClientCookieExists
31    , assertClientCookieValue
32    , WaiTestFailure (..)
33    ) where
34
35#if __GLASGOW_HASKELL__ < 710
36import Control.Applicative ((<$>))
37import Data.Monoid (mempty, mappend)
38#endif
39
40import Network.Wai
41import Network.Wai.Internal (ResponseReceived (ResponseReceived))
42import Network.Wai.Test.Internal
43import Control.Monad.IO.Class (liftIO)
44import Control.Monad.Trans.Class (lift)
45import qualified Control.Monad.Trans.State as ST
46import Control.Monad.Trans.Reader (runReaderT, ask)
47import Control.Monad (unless)
48import Control.DeepSeq (deepseq)
49import Control.Exception (throwIO, Exception)
50import Data.Typeable (Typeable)
51import qualified Data.Map as Map
52import qualified Web.Cookie as Cookie
53import Data.ByteString (ByteString)
54import qualified Data.ByteString.Char8 as S8
55import Data.ByteString.Builder (toLazyByteString)
56import qualified Data.ByteString.Lazy as L
57import qualified Data.ByteString.Lazy.Char8 as L8
58import qualified Network.HTTP.Types as H
59import Data.CaseInsensitive (CI)
60import qualified Data.ByteString as S
61import qualified Data.Text as T
62import qualified Data.Text.Encoding as TE
63import Data.IORef
64import Data.Time.Clock (getCurrentTime)
65
66-- |
67--
68-- Since 3.0.6
69getClientCookies :: Session ClientCookies
70getClientCookies = clientCookies <$> lift ST.get
71
72-- |
73--
74-- Since 3.0.6
75modifyClientCookies :: (ClientCookies -> ClientCookies) -> Session ()
76modifyClientCookies f =
77  lift (ST.modify (\cs -> cs { clientCookies = f $ clientCookies cs }))
78
79-- |
80--
81-- Since 3.0.6
82setClientCookie :: Cookie.SetCookie -> Session ()
83setClientCookie c =
84  modifyClientCookies
85    (Map.insert (Cookie.setCookieName c) c)
86
87-- |
88--
89-- Since 3.0.6
90deleteClientCookie :: ByteString -> Session ()
91deleteClientCookie cookieName =
92  modifyClientCookies
93    (Map.delete cookieName)
94
95-- | See also: 'runSessionWith'.
96runSession :: Session a -> Application -> IO a
97runSession session app = ST.evalStateT (runReaderT session app) initState
98
99data SRequest = SRequest
100    { simpleRequest :: Request
101    , simpleRequestBody :: L.ByteString
102    -- ^ Request body that will override the one set in 'simpleRequest'.
103    --
104    -- This is usually simpler than setting the body as a stateful IO-action
105    -- in 'simpleRequest'.
106    }
107data SResponse = SResponse
108    { simpleStatus :: H.Status
109    , simpleHeaders :: H.ResponseHeaders
110    , simpleBody :: L.ByteString
111    }
112    deriving (Show, Eq)
113
114request :: Request -> Session SResponse
115request req = do
116    app <- ask
117    req' <- addCookiesToRequest req
118    response <- liftIO $ do
119        ref <- newIORef $ error "runResponse gave no result"
120        ResponseReceived <- app req' (runResponse ref)
121        readIORef ref
122    extractSetCookieFromSResponse response
123
124-- | Set whole path (request path + query string).
125setPath :: Request -> S8.ByteString -> Request
126setPath req path = req {
127    pathInfo = segments
128  , rawPathInfo = (L8.toStrict . toLazyByteString) (H.encodePathSegments segments)
129  , queryString = query
130  , rawQueryString = (H.renderQuery True query)
131  }
132  where
133    (segments, query) = H.decodePath path
134
135setRawPathInfo :: Request -> S8.ByteString -> Request
136setRawPathInfo r rawPinfo =
137    let pInfo = dropFrontSlash $ T.split (== '/') $ TE.decodeUtf8 rawPinfo
138    in  r { rawPathInfo = rawPinfo, pathInfo = pInfo }
139  where
140    dropFrontSlash ("":"":[]) = [] -- homepage, a single slash
141    dropFrontSlash ("":path) = path
142    dropFrontSlash path = path
143
144addCookiesToRequest :: Request -> Session Request
145addCookiesToRequest req = do
146  oldClientCookies <- getClientCookies
147  let requestPath = "/" `T.append` T.intercalate "/" (pathInfo req)
148  currentUTCTime <- liftIO getCurrentTime
149  let cookiesForRequest =
150        Map.filter
151          (\c -> checkCookieTime currentUTCTime c
152              && checkCookiePath requestPath c)
153          oldClientCookies
154  let cookiePairs = [ (Cookie.setCookieName c, Cookie.setCookieValue c)
155                    | c <- map snd $ Map.toList cookiesForRequest
156                    ]
157  let cookieValue = L8.toStrict . toLazyByteString $ Cookie.renderCookies cookiePairs
158      addCookieHeader rest
159        | null cookiePairs = rest
160        | otherwise = ("Cookie", cookieValue) : rest
161  return $ req { requestHeaders = addCookieHeader $ requestHeaders req }
162    where checkCookieTime t c =
163            case Cookie.setCookieExpires c of
164              Nothing -> True
165              Just t' -> t < t'
166          checkCookiePath p c =
167            case Cookie.setCookiePath c of
168              Nothing -> True
169              Just p' -> p' `S8.isPrefixOf` TE.encodeUtf8 p
170
171extractSetCookieFromSResponse :: SResponse -> Session SResponse
172extractSetCookieFromSResponse response = do
173  let setCookieHeaders =
174        filter (("Set-Cookie"==) . fst) $ simpleHeaders response
175  let newClientCookies = map (Cookie.parseSetCookie . snd) setCookieHeaders
176  modifyClientCookies
177    (Map.union
178       (Map.fromList [(Cookie.setCookieName c, c) | c <- newClientCookies ]))
179  return response
180
181-- | Similar to 'request', but allows setting the request body as a plain
182-- 'L.ByteString'.
183srequest :: SRequest -> Session SResponse
184srequest (SRequest req bod) = do
185    refChunks <- liftIO $ newIORef $ L.toChunks bod
186    request $
187      req
188        { requestBody = atomicModifyIORef refChunks $ \bss ->
189            case bss of
190                [] -> ([], S.empty)
191                x:y -> (y, x)
192        }
193
194runResponse :: IORef SResponse -> Response -> IO ResponseReceived
195runResponse ref res = do
196    refBuilder <- newIORef mempty
197    let add y = atomicModifyIORef refBuilder $ \x -> (x `mappend` y, ())
198    withBody $ \body -> body add (return ())
199    builder <- readIORef refBuilder
200    let lbs = toLazyByteString builder
201        len = L.length lbs
202    -- Force evaluation of the body to have exceptions thrown at the right
203    -- time.
204    seq len $ writeIORef ref $ SResponse s h $ toLazyByteString builder
205    return ResponseReceived
206  where
207    (s, h, withBody) = responseToStream res
208
209assertBool :: String -> Bool -> Session ()
210assertBool s b = unless b $ assertFailure s
211
212assertString :: String -> Session ()
213assertString s = unless (null s) $ assertFailure s
214
215assertFailure :: String -> Session ()
216assertFailure msg = msg `deepseq` liftIO (throwIO (WaiTestFailure msg))
217
218data WaiTestFailure = WaiTestFailure String
219    deriving (Show, Eq, Typeable)
220instance Exception WaiTestFailure
221
222assertContentType :: ByteString -> SResponse -> Session ()
223assertContentType ct SResponse{simpleHeaders = h} =
224    case lookup "content-type" h of
225        Nothing -> assertString $ concat
226            [ "Expected content type "
227            , show ct
228            , ", but no content type provided"
229            ]
230        Just ct' -> assertBool (concat
231            [ "Expected content type "
232            , show ct
233            , ", but received "
234            , show ct'
235            ]) (go ct == go ct')
236  where
237    go = S8.takeWhile (/= ';')
238
239assertStatus :: Int -> SResponse -> Session ()
240assertStatus i SResponse{simpleStatus = s} = assertBool (concat
241    [ "Expected status code "
242    , show i
243    , ", but received "
244    , show sc
245    ]) $ i == sc
246  where
247    sc = H.statusCode s
248
249assertBody :: L.ByteString -> SResponse -> Session ()
250assertBody lbs SResponse{simpleBody = lbs'} = assertBool (concat
251    [ "Expected response body "
252    , show $ L8.unpack lbs
253    , ", but received "
254    , show $ L8.unpack lbs'
255    ]) $ lbs == lbs'
256
257assertBodyContains :: L.ByteString -> SResponse -> Session ()
258assertBodyContains lbs SResponse{simpleBody = lbs'} = assertBool (concat
259    [ "Expected response body to contain "
260    , show $ L8.unpack lbs
261    , ", but received "
262    , show $ L8.unpack lbs'
263    ]) $ strict lbs `S.isInfixOf` strict lbs'
264  where
265    strict = S.concat . L.toChunks
266
267assertHeader :: CI ByteString -> ByteString -> SResponse -> Session ()
268assertHeader header value SResponse{simpleHeaders = h} =
269    case lookup header h of
270        Nothing -> assertString $ concat
271            [ "Expected header "
272            , show header
273            , " to be "
274            , show value
275            , ", but it was not present"
276            ]
277        Just value' -> assertBool (concat
278            [ "Expected header "
279            , show header
280            , " to be "
281            , show value
282            , ", but received "
283            , show value'
284            ]) (value == value')
285
286assertNoHeader :: CI ByteString -> SResponse -> Session ()
287assertNoHeader header SResponse{simpleHeaders = h} =
288    case lookup header h of
289        Nothing -> return ()
290        Just s -> assertString $ concat
291            [ "Unexpected header "
292            , show header
293            , " containing "
294            , show s
295            ]
296
297-- |
298--
299-- Since 3.0.6
300assertClientCookieExists :: String -> ByteString -> Session ()
301assertClientCookieExists s cookieName = do
302  cookies <- getClientCookies
303  assertBool s $ Map.member cookieName cookies
304
305-- |
306--
307-- Since 3.0.6
308assertNoClientCookieExists :: String -> ByteString -> Session ()
309assertNoClientCookieExists s cookieName = do
310  cookies <- getClientCookies
311  assertBool s $ not $ Map.member cookieName cookies
312
313-- |
314--
315-- Since 3.0.6
316assertClientCookieValue :: String -> ByteString -> ByteString -> Session ()
317assertClientCookieValue s cookieName cookieValue = do
318  cookies <- getClientCookies
319  case Map.lookup cookieName cookies of
320    Nothing ->
321      assertFailure (s ++ " (cookie does not exist)")
322    Just c  ->
323      assertBool
324        (concat
325          [ s
326          , " (actual value "
327          , show $ Cookie.setCookieValue c
328          , " expected value "
329          , show cookieValue
330          , ")"
331          ]
332        )
333        (Cookie.setCookieValue c == cookieValue)
334