1{-# LANGUAGE BangPatterns #-}
2
3module Network.Wai.Handler.Warp.RequestHeader (
4      parseHeaderLines
5    ) where
6
7import Control.Exception (throwIO)
8import qualified Data.ByteString as S
9import qualified Data.ByteString.Char8 as C8 (unpack)
10import Data.ByteString.Internal (memchr)
11import qualified Data.CaseInsensitive as CI
12import Foreign.ForeignPtr (withForeignPtr)
13import Foreign.Ptr (Ptr, plusPtr, minusPtr, nullPtr)
14import Foreign.Storable (peek)
15import qualified Network.HTTP.Types as H
16
17import Network.Wai.Handler.Warp.Imports
18import Network.Wai.Handler.Warp.Types
19
20-- $setup
21-- >>> :set -XOverloadedStrings
22
23----------------------------------------------------------------
24
25parseHeaderLines :: [ByteString]
26                 -> IO (H.Method
27                       ,ByteString  --  Path
28                       ,ByteString  --  Path, parsed
29                       ,ByteString  --  Query
30                       ,H.HttpVersion
31                       ,H.RequestHeaders
32                       )
33parseHeaderLines [] = throwIO $ NotEnoughLines []
34parseHeaderLines (firstLine:otherLines) = do
35    (method, path', query, httpversion) <- parseRequestLine firstLine
36    let path = H.extractPath path'
37        hdr = map parseHeader otherLines
38    return (method, path', path, query, httpversion, hdr)
39
40----------------------------------------------------------------
41
42-- |
43--
44-- >>> parseRequestLine "GET / HTTP/1.1"
45-- ("GET","/","",HTTP/1.1)
46-- >>> parseRequestLine "POST /cgi/search.cgi?key=foo HTTP/1.0"
47-- ("POST","/cgi/search.cgi","?key=foo",HTTP/1.0)
48-- >>> parseRequestLine "GET "
49-- *** Exception: Warp: Invalid first line of request: "GET "
50-- >>> parseRequestLine "GET /NotHTTP UNKNOWN/1.1"
51-- *** Exception: Warp: Request line specified a non-HTTP request
52-- >>> parseRequestLine "PRI * HTTP/2.0"
53-- ("PRI","*","",HTTP/2.0)
54parseRequestLine :: ByteString
55                 -> IO (H.Method
56                       ,ByteString -- Path
57                       ,ByteString -- Query
58                       ,H.HttpVersion)
59parseRequestLine requestLine@(PS fptr off len) = withForeignPtr fptr $ \ptr -> do
60    when (len < 14) $ throwIO baderr
61    let methodptr = ptr `plusPtr` off
62        limptr = methodptr `plusPtr` len
63        lim0 = fromIntegral len
64
65    pathptr0 <- memchr methodptr 32 lim0 -- ' '
66    when (pathptr0 == nullPtr || (limptr `minusPtr` pathptr0) < 11) $
67        throwIO baderr
68    let pathptr = pathptr0 `plusPtr` 1
69        lim1 = fromIntegral (limptr `minusPtr` pathptr0)
70
71    httpptr0 <- memchr pathptr 32 lim1 -- ' '
72    when (httpptr0 == nullPtr || (limptr `minusPtr` httpptr0) < 9) $
73        throwIO baderr
74    let httpptr = httpptr0 `plusPtr` 1
75        lim2 = fromIntegral (httpptr0 `minusPtr` pathptr)
76
77    checkHTTP httpptr
78    !hv <- httpVersion httpptr
79    queryptr <- memchr pathptr 63 lim2 -- '?'
80
81    let !method = bs ptr methodptr pathptr0
82        !path
83          | queryptr == nullPtr = bs ptr pathptr httpptr0
84          | otherwise           = bs ptr pathptr queryptr
85        !query
86          | queryptr == nullPtr = S.empty
87          | otherwise           = bs ptr queryptr httpptr0
88
89    return (method,path,query,hv)
90  where
91    baderr = BadFirstLine $ C8.unpack requestLine
92    check :: Ptr Word8 -> Int -> Word8 -> IO ()
93    check p n w = do
94        w0 <- peek $ p `plusPtr` n
95        when (w0 /= w) $ throwIO NonHttp
96    checkHTTP httpptr = do
97        check httpptr 0 72 -- 'H'
98        check httpptr 1 84 -- 'T'
99        check httpptr 2 84 -- 'T'
100        check httpptr 3 80 -- 'P'
101        check httpptr 4 47 -- '/'
102        check httpptr 6 46 -- '.'
103    httpVersion httpptr = do
104        major <- peek (httpptr `plusPtr` 5) :: IO Word8
105        minor <- peek (httpptr `plusPtr` 7) :: IO Word8
106        let version
107              | major == 49 = if minor == 49 then H.http11 else H.http10
108              | major == 50 && minor == 48 = H.HttpVersion 2 0
109              | otherwise   = H.http10
110        return version
111    bs ptr p0 p1 = PS fptr o l
112      where
113        o = p0 `minusPtr` ptr
114        l = p1 `minusPtr` p0
115
116----------------------------------------------------------------
117
118-- |
119--
120-- >>> parseHeader "Content-Length:47"
121-- ("Content-Length","47")
122-- >>> parseHeader "Accept-Ranges: bytes"
123-- ("Accept-Ranges","bytes")
124-- >>> parseHeader "Host:  example.com:8080"
125-- ("Host","example.com:8080")
126-- >>> parseHeader "NoSemiColon"
127-- ("NoSemiColon","")
128
129parseHeader :: ByteString -> H.Header
130parseHeader s =
131    let (k, rest) = S.break (== 58) s -- ':'
132        rest' = S.dropWhile (\c -> c == 32 || c == 9) $ S.drop 1 rest
133     in (CI.mk k, rest')
134