1--
2-- Licensed to the Apache Software Foundation (ASF) under one
3-- or more contributor license agreements. See the NOTICE file
4-- distributed with this work for additional information
5-- regarding copyright ownership. The ASF licenses this file
6-- to you under the Apache License, Version 2.0 (the
7-- "License"); you may not use this file except in compliance
8-- with the License. You may obtain a copy of the License at
9--
10--   http://www.apache.org/licenses/LICENSE-2.0
11--
12-- Unless required by applicable law or agreed to in writing,
13-- software distributed under the License is distributed on an
14-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15-- KIND, either express or implied. See the License for the
16-- specific language governing permissions and limitations
17-- under the License.
18--
19
20module Thrift.Transport.Header
21  ( module Thrift.Transport
22  , HeaderTransport(..)
23  , openHeaderTransport
24  , ProtocolType(..)
25  , TransformType(..)
26  , ClientType(..)
27  , tResetProtocol
28  , tSetProtocol
29  ) where
30
31import Thrift.Transport
32import Thrift.Protocol.Compact
33import Control.Applicative
34import Control.Exception ( throw )
35import Control.Monad
36import Data.Bits
37import Data.IORef
38import Data.Int
39import Data.Monoid
40import Data.Word
41
42import qualified Data.Attoparsec.ByteString as P
43import qualified Data.Binary as Binary
44import qualified Data.ByteString as BS
45import qualified Data.ByteString.Char8 as C
46import qualified Data.ByteString.Lazy as LBS
47import qualified Data.ByteString.Lazy.Builder as B
48import qualified Data.Map as Map
49
50data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq)
51data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq)
52
53infoIdKeyValue = 1
54
55type Headers = Map.Map String String
56
57data TransformType = ZlibTransform deriving (Enum, Eq)
58
59fromTransportType :: TransformType -> Int16
60fromTransportType ZlibTransform = 1
61
62toTransportType :: Int16 -> TransformType
63toTransportType 1 = ZlibTransform
64toTransportType _ =  throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN
65
66data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport
67    { readBuffer :: IORef LBS.ByteString
68    , writeBuffer :: IORef B.Builder
69    , inTrans :: i
70    , outTrans :: o
71    , clientType :: IORef ClientType
72    , protocolType :: IORef ProtocolType
73    , headers :: IORef [(String, String)]
74    , writeHeaders :: Headers
75    , transforms :: IORef [TransformType]
76    , writeTransforms :: [TransformType]
77    }
78
79openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o)
80openHeaderTransport i o = do
81  pid <- newIORef TCompact
82  rBuf <- newIORef LBS.empty
83  wBuf <- newIORef mempty
84  cType <- newIORef HeaderClient
85  h <- newIORef []
86  trans <- newIORef []
87  return HeaderTransport
88      { readBuffer = rBuf
89      , writeBuffer = wBuf
90      , inTrans = i
91      , outTrans = o
92      , clientType = cType
93      , protocolType = pid
94      , headers = h
95      , writeHeaders = Map.empty
96      , transforms = trans
97      , writeTransforms = []
98      }
99
100isFramed t = (/= Unframed) <$> readIORef (clientType t)
101
102readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
103readFrame t = do
104  let input = inTrans t
105  let rBuf = readBuffer t
106  let cType = clientType t
107  lsz <- tRead input 4
108  let sz = LBS.toStrict lsz
109  case P.parseOnly P.endOfInput sz of
110    Right _ -> do return False
111    Left _ -> do
112      case parseBinaryMagic sz of
113        Right _ -> do
114          writeIORef rBuf $ lsz
115          writeIORef cType Unframed
116          writeIORef (protocolType t) TBinary
117          return True
118        Left _ -> do
119          case parseCompactMagic sz of
120            Right _ -> do
121              writeIORef rBuf $ lsz
122              writeIORef cType Unframed
123              writeIORef (protocolType t) TCompact
124              return True
125            Left _ -> do
126              let len = Binary.decode lsz :: Int32
127              lbuf <- tReadAll input $ fromIntegral len
128              let buf = LBS.toStrict lbuf
129              case parseBinaryMagic buf of
130                Right _ -> do
131                  writeIORef cType Framed
132                  writeIORef (protocolType t) TBinary
133                  writeIORef rBuf lbuf
134                  return True
135                Left _ -> do
136                  case parseCompactMagic buf of
137                    Right _ -> do
138                      writeIORef cType Framed
139                      writeIORef (protocolType t) TCompact
140                      writeIORef rBuf lbuf
141                      return True
142                    Left _ -> do
143                      case parseHeaderMagic buf of
144                        Right flags -> do
145                          let (flags, seqNum, header, body) = extractHeader buf
146                          writeIORef cType HeaderClient
147                          handleHeader t header
148                          payload <- untransform t body
149                          writeIORef rBuf $ LBS.fromStrict $ payload
150                          return True
151                        Left _ ->
152                          throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN
153
154parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8
155parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01)
156parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8)
157
158parseI32 :: P.Parser Int32
159parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4
160parseI16 :: P.Parser Int16
161parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2
162
163extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString)
164extractHeader bs =
165  case P.parse extractHeader_ bs of
166    P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain)
167    _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
168  where
169    extractHeader_ = do
170      magic <- P.word8 0x0f *> P.word8 0xff
171      flags <- parseI16
172      seqNum <- parseI32
173      (headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16
174      header <- P.take headerSize
175      return (flags, seqNum, header)
176
177handleHeader t header =
178  case P.parseOnly parseHeader header of
179    Right (pType, trans, info) -> do
180      writeIORef (protocolType t) pType
181      writeIORef (transforms t) trans
182      writeIORef (headers t) info
183    _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
184
185
186iw16 :: Int16 -> Word16
187iw16 = fromIntegral
188iw32 :: Int32 -> Word32
189iw32 = fromIntegral
190wi16 :: Word16 -> Int16
191wi16 = fromIntegral
192wi32 :: Word32 -> Int32
193wi32 = fromIntegral
194
195parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)])
196parseHeader = do
197  protocolType <- toProtocolType <$> parseVarint wi16
198  numTrans <- fromIntegral <$> parseVarint wi16
199  trans <- replicateM numTrans parseTransform
200  info <- parseInfo
201  return (protocolType, trans, info)
202
203toProtocolType :: Int16 -> ProtocolType
204toProtocolType 0 = TBinary
205toProtocolType 1 = TJSON
206toProtocolType 2 = TCompact
207
208fromProtocolType :: ProtocolType -> Int16
209fromProtocolType TBinary = 0
210fromProtocolType TJSON = 1
211fromProtocolType TCompact = 2
212
213parseTransform :: P.Parser TransformType
214parseTransform = toTransportType <$> parseVarint wi16
215
216parseInfo :: P.Parser [(String, String)]
217parseInfo = do
218  n <- P.eitherP P.endOfInput (parseVarint wi32)
219  case n of
220    Left _ -> return []
221    Right n0 ->
222      replicateM (fromIntegral n0) $ do
223        klen <- parseVarint wi16
224        k <- P.take $ fromIntegral klen
225        vlen <- parseVarint wi16
226        v <- P.take $ fromIntegral vlen
227        return (C.unpack k, C.unpack v)
228
229parseString :: P.Parser BS.ByteString
230parseString = parseVarint wi32 >>= (P.take . fromIntegral)
231
232buildHeader :: HeaderTransport i o -> IO B.Builder
233buildHeader t = do
234  pType <- readIORef $ protocolType t
235  let pId = buildVarint $ iw16 $ fromProtocolType pType
236  let headerContent = pId <> (buildTransforms t) <> (buildInfo t)
237  let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent
238  -- TODO: length limit check
239  let padding = mconcat $ replicate (mod len 4) $ B.word8 0
240  let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1)
241  let flags = 0
242  let seqNum = 0
243  return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding
244
245buildTransforms :: HeaderTransport i o -> B.Builder
246-- TODO: check length limit
247buildTransforms t =
248  let trans = writeTransforms t in
249  (buildVarint $ iw16 $ fromIntegral $ length trans) <>
250  (mconcat $ map (buildVarint . iw16 . fromTransportType) trans)
251
252buildInfo :: HeaderTransport i o -> B.Builder
253buildInfo t =
254  let h = Map.assocs $ writeHeaders t in
255  -- TODO: check length limit
256  case length h of
257    0 -> mempty
258    len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h)
259  where
260    buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v
261    -- TODO: check length limit
262    buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s
263
264tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
265tResetProtocol t = do
266  rBuf <- readIORef $ readBuffer t
267  writeIORef (clientType t) HeaderClient
268  readFrame t
269
270tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO ()
271tSetProtocol t = writeIORef (protocolType t)
272
273transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString
274transform t bs =
275  foldr applyTransform bs $ writeTransforms t
276  where
277    -- applyTransform bs ZlibTransform =
278    --   throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform  " TE_UNKNOWN
279    applyTransform bs _ =
280      throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
281
282untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString
283untransform t bs = do
284  trans <- readIORef $ transforms t
285  return $ foldl unapplyTransform bs trans
286  where
287    -- unapplyTransform bs ZlibTransform =
288    --   throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform  " TE_UNKNOWN
289    unapplyTransform bs _ =
290      throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
291
292instance (Transport i, Transport o) => Transport (HeaderTransport i o) where
293  tIsOpen t = do
294    tIsOpen (inTrans t)
295    tIsOpen (outTrans t)
296
297  tClose t = do
298    tClose(outTrans t)
299    tClose(inTrans t)
300
301  tRead t len = do
302    rBuf <- readIORef $ readBuffer t
303    if not $ LBS.null rBuf
304      then do
305        let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf
306        writeIORef (readBuffer t) remain
307        return consumed
308      else do
309        framed <- isFramed t
310        if not framed
311          then tRead (inTrans t) len
312          else do
313            ok <- readFrame t
314            if ok
315              then tRead t len
316              else return LBS.empty
317
318  tPeek t = do
319    rBuf <- readIORef (readBuffer t)
320    if not $ LBS.null rBuf
321      then return $ Just $ LBS.head rBuf
322      else do
323        framed <- isFramed t
324        if not framed
325          then tPeek (inTrans t)
326          else do
327            ok <- readFrame t
328            if ok
329              then tPeek t
330              else return Nothing
331
332  tWrite t buf = do
333    let wBuf = writeBuffer t
334    framed <- isFramed t
335    if framed
336      then modifyIORef wBuf (<> B.lazyByteString buf)
337      else
338        -- TODO: what should we do when switched to unframed in the middle ?
339        tWrite(outTrans t) buf
340
341  tFlush t = do
342    cType <- readIORef $ clientType t
343    case cType of
344      Unframed -> tFlush $ outTrans t
345      Framed -> flushBuffer t id mempty
346      HeaderClient -> buildHeader t >>= flushBuffer t (transform t)
347    where
348      flushBuffer t f header = do
349        wBuf <- readIORef $ writeBuffer t
350        writeIORef (writeBuffer t) mempty
351        let payload = B.toLazyByteString (header <> wBuf)
352        tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32)
353        tWrite (outTrans t) $ f payload
354        tFlush (outTrans t)
355