1--
2-- Licensed to the Apache Software Foundation (ASF) under one
3-- or more contributor license agreements. See the NOTICE file
4-- distributed with this work for additional information
5-- regarding copyright ownership. The ASF licenses this file
6-- to you under the Apache License, Version 2.0 (the
7-- "License"); you may not use this file except in compliance
8-- with the License. You may obtain a copy of the License at
9--
10--   http://www.apache.org/licenses/LICENSE-2.0
11--
12-- Unless required by applicable law or agreed to in writing,
13-- software distributed under the License is distributed on an
14-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15-- KIND, either express or implied. See the License for the
16-- specific language governing permissions and limitations
17-- under the License.
18--
19
20{-# LANGUAGE CPP #-}
21{-# LANGUAGE ExistentialQuantification #-}
22{-# LANGUAGE OverloadedStrings #-}
23{-# LANGUAGE ScopedTypeVariables #-}
24
25module Thrift.Protocol.Compact
26    ( module Thrift.Protocol
27    , CompactProtocol(..)
28    , parseVarint
29    , buildVarint
30    ) where
31
32import Control.Applicative
33import Control.Monad
34import Data.Attoparsec.ByteString as P
35import Data.Attoparsec.ByteString.Lazy as LP
36import Data.Bits
37import Data.ByteString.Lazy.Builder as B
38import Data.Int
39import Data.List as List
40import Data.Monoid
41import Data.Word
42import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
43
44import Thrift.Protocol
45import Thrift.Transport
46import Thrift.Types
47
48import qualified Data.ByteString as BS
49import qualified Data.ByteString.Lazy as LBS
50import qualified Data.HashMap.Strict as Map
51import qualified Data.Text.Lazy as LT
52
53-- | the Compact Protocol implements the standard Thrift 'TCompactProcotol'
54-- which is similar to the 'TBinaryProtocol', but takes less space on the wire.
55-- Integral types are encoded using as varints.
56data CompactProtocol a = CompactProtocol a
57                         -- ^ Constuct a 'CompactProtocol' with a 'Transport'
58
59protocolID, version, versionMask, typeMask, typeBits :: Word8
60protocolID  = 0x82 -- 1000 0010
61version     = 0x01
62versionMask = 0x1f -- 0001 1111
63typeMask    = 0xe0 -- 1110 0000
64typeBits    = 0x07 -- 0000 0111
65typeShiftAmount :: Int
66typeShiftAmount = 5
67
68getTransport :: Transport t => CompactProtocol t -> t
69getTransport (CompactProtocol t) = t
70
71instance Transport t => Protocol (CompactProtocol t) where
72    readByte p = tReadAll (getTransport p) 1
73    writeMessage p (n, t, s) f = do
74      tWrite (getTransport p) messageBegin
75      f
76      tFlush $ getTransport p
77      where
78        messageBegin = toLazyByteString $
79          B.word8 protocolID <>
80          B.word8 ((version .&. versionMask) .|.
81                  (((fromIntegral $ fromEnum t) `shiftL`
82                  typeShiftAmount) .&. typeMask)) <>
83          buildVarint (i32ToZigZag s) <>
84          buildCompactValue (TString $ encodeUtf8 n)
85
86    readMessage p f = readMessageBegin >>= f
87      where
88        readMessageBegin = runParser p $ do
89          pid <- fromIntegral <$> P.anyWord8
90          when (pid /= protocolID) $ error "Bad Protocol ID"
91          w <- fromIntegral <$> P.anyWord8
92          let ver = w .&. versionMask
93          when (ver /= version) $ error "Bad Protocol version"
94          let typ = (w `shiftR` typeShiftAmount) .&. typeBits
95          seqId <- parseVarint zigZagToI32
96          TString name <- parseCompactValue T_STRING
97          return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
98
99    writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
100    readVal p ty = runParser p $ parseCompactValue ty
101
102instance Transport t => StatelessProtocol (CompactProtocol t) where
103    serializeVal _ = toLazyByteString . buildCompactValue
104    deserializeVal _ ty bs =
105      case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
106        Left s -> error s
107        Right val -> val
108
109-- | Writing Functions
110buildCompactValue :: ThriftVal -> Builder
111buildCompactValue (TStruct fields) = buildCompactStruct fields
112buildCompactValue (TMap kt vt entries) =
113  let len = fromIntegral $ length entries :: Word32 in
114  if len == 0
115  then B.word8 0x00
116  else buildVarint len <>
117       B.word8 (fromTType kt `shiftL` 4 .|. fromTType vt) <>
118       buildCompactMap entries
119buildCompactValue (TList ty entries) =
120  let len = length entries in
121  (if len < 15
122   then B.word8 $ (fromIntegral len `shiftL` 4) .|. fromTType ty
123   else B.word8 (0xF0 .|. fromTType ty) <>
124        buildVarint (fromIntegral len :: Word32)) <>
125  buildCompactList entries
126buildCompactValue (TSet ty entries) = buildCompactValue (TList ty entries)
127buildCompactValue (TBool b) =
128  B.word8 $ toEnum $ if b then 1 else 0
129buildCompactValue (TByte b) = int8 b
130buildCompactValue (TI16 i) = buildVarint $ i16ToZigZag i
131buildCompactValue (TI32 i) = buildVarint $ i32ToZigZag i
132buildCompactValue (TI64 i) = buildVarint $ i64ToZigZag i
133buildCompactValue (TDouble d) = doubleLE d
134buildCompactValue (TString s) = buildVarint len <> lazyByteString s
135  where
136    len = fromIntegral (LBS.length s) :: Word32
137buildCompactValue (TBinary s) = buildCompactValue (TString s)
138
139buildCompactStruct :: Map.HashMap Int16 (LT.Text, ThriftVal) -> Builder
140buildCompactStruct = flip (loop 0) mempty . Map.toList
141  where
142    loop _ [] acc = acc <> B.word8 (fromTType T_STOP)
143    loop lastId ((fid, (_,val)) : fields) acc = loop fid fields $ acc <>
144      (if fid > lastId && fid - lastId <= 15
145       then B.word8 $ fromIntegral ((fid - lastId) `shiftL` 4) .|. typeOf val
146       else B.word8 (typeOf val) <> buildVarint (i16ToZigZag fid)) <>
147      (if typeOf val > 0x02 -- Not a T_BOOL
148       then buildCompactValue val
149       else mempty) -- T_BOOLs are encoded in the type
150buildCompactMap :: [(ThriftVal, ThriftVal)] -> Builder
151buildCompactMap = foldl combine mempty
152  where
153    combine s (key, val) = buildCompactValue key <> buildCompactValue val <> s
154
155buildCompactList :: [ThriftVal] -> Builder
156buildCompactList = foldr (mappend . buildCompactValue) mempty
157
158-- | Reading Functions
159parseCompactValue :: ThriftType -> Parser ThriftVal
160parseCompactValue (T_STRUCT tmap) = TStruct <$> parseCompactStruct tmap
161parseCompactValue (T_MAP kt' vt') = do
162  n <- parseVarint id
163  if n == 0
164    then return $ TMap kt' vt' []
165    else do
166    w <- P.anyWord8
167    let kt = typeFrom $ w `shiftR` 4
168        vt = typeFrom $ w .&. 0x0F
169    TMap kt vt <$> parseCompactMap kt vt n
170parseCompactValue (T_LIST ty) = TList ty <$> parseCompactList
171parseCompactValue (T_SET ty) = TSet ty <$> parseCompactList
172parseCompactValue T_BOOL = TBool . (/=0) <$> P.anyWord8
173parseCompactValue T_BYTE = TByte . fromIntegral <$> P.anyWord8
174parseCompactValue T_I16 = TI16 <$> parseVarint zigZagToI16
175parseCompactValue T_I32 = TI32 <$> parseVarint zigZagToI32
176parseCompactValue T_I64 = TI64 <$> parseVarint zigZagToI64
177parseCompactValue T_DOUBLE = TDouble . bsToDoubleLE <$> P.take 8
178parseCompactValue T_STRING = parseCompactString TString
179parseCompactValue T_BINARY = parseCompactString TBinary
180parseCompactValue ty = error $ "Cannot read value of type " ++ show ty
181
182parseCompactString ty = do
183  len :: Word32 <- parseVarint id
184  ty . LBS.fromStrict <$> P.take (fromIntegral len)
185
186parseCompactStruct :: TypeMap -> Parser (Map.HashMap Int16 (LT.Text, ThriftVal))
187parseCompactStruct tmap = Map.fromList <$> parseFields 0
188  where
189    parseFields :: Int16 -> Parser [(Int16, (LT.Text, ThriftVal))]
190    parseFields lastId = do
191      w <- P.anyWord8
192      if w == 0x00
193        then return []
194        else do
195          let ty = typeFrom (w .&. 0x0F)
196              modifier = (w .&. 0xF0) `shiftR` 4
197          fid <- if modifier /= 0
198                 then return (lastId + fromIntegral modifier)
199                 else parseVarint zigZagToI16
200          val <- if ty == T_BOOL
201                 then return (TBool $ (w .&. 0x0F) == 0x01)
202                 else case (ty, Map.lookup fid tmap) of
203                        (T_STRING, Just (_, T_BINARY)) -> parseCompactValue T_BINARY
204                        _ -> parseCompactValue ty
205          ((fid, (LT.empty, val)) : ) <$> parseFields fid
206
207parseCompactMap :: ThriftType -> ThriftType -> Int32 ->
208                   Parser [(ThriftVal, ThriftVal)]
209parseCompactMap kt vt n | n <= 0 = return []
210                        | otherwise = do
211  k <- parseCompactValue kt
212  v <- parseCompactValue vt
213  ((k,v) :) <$> parseCompactMap kt vt (n-1)
214
215parseCompactList :: Parser [ThriftVal]
216parseCompactList = do
217  w <- P.anyWord8
218  let ty = typeFrom $ w .&. 0x0F
219      lsize = w `shiftR` 4
220  size <- if lsize == 0xF
221          then parseVarint id
222          else return $ fromIntegral lsize
223  loop ty size
224  where
225    loop :: ThriftType -> Int32 -> Parser [ThriftVal]
226    loop ty n | n <= 0 = return []
227              | otherwise = liftM2 (:) (parseCompactValue ty)
228                            (loop ty (n-1))
229
230-- Signed numbers must be converted to "Zig Zag" format before they can be
231-- serialized in the Varint format
232i16ToZigZag :: Int16 -> Word16
233i16ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 15)
234
235zigZagToI16 :: Word16 -> Int16
236zigZagToI16 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1)
237
238i32ToZigZag :: Int32 -> Word32
239i32ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 31)
240
241zigZagToI32 :: Word32 -> Int32
242zigZagToI32 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1)
243
244i64ToZigZag :: Int64 -> Word64
245i64ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 63)
246
247zigZagToI64 :: Word64 -> Int64
248zigZagToI64 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1)
249
250buildVarint :: (Bits a, Integral a)  => a -> Builder
251buildVarint n | n .&. complement 0x7F == 0 = B.word8 $ fromIntegral n
252              | otherwise = B.word8 (0x80 .|. (fromIntegral n .&. 0x7F)) <>
253                            buildVarint (n `shiftR` 7)
254
255parseVarint :: (Bits a, Integral a, Ord a) => (a -> b) -> Parser b
256parseVarint fromZigZag = do
257  bytestemp <- BS.unpack <$> P.takeTill (not . flip testBit 7)
258  lsb <- P.anyWord8
259  let bytes = lsb : List.reverse bytestemp
260  return $ fromZigZag $ List.foldl' combine 0x00 bytes
261  where combine a b = (a `shiftL` 7) .|. (fromIntegral b .&. 0x7f)
262
263-- | Compute the Compact Type
264fromTType :: ThriftType -> Word8
265fromTType ty = case ty of
266  T_STOP -> 0x00
267  T_BOOL -> 0x01
268  T_BYTE -> 0x03
269  T_I16 -> 0x04
270  T_I32 -> 0x05
271  T_I64 -> 0x06
272  T_DOUBLE -> 0x07
273  T_STRING -> 0x08
274  T_BINARY -> 0x08
275  T_LIST{} -> 0x09
276  T_SET{} -> 0x0A
277  T_MAP{} -> 0x0B
278  T_STRUCT{} -> 0x0C
279  T_VOID -> error "No Compact type for T_VOID"
280
281typeOf :: ThriftVal -> Word8
282typeOf v = case v of
283  TBool True -> 0x01
284  TBool False -> 0x02
285  TByte _ -> 0x03
286  TI16 _ -> 0x04
287  TI32 _ -> 0x05
288  TI64 _ -> 0x06
289  TDouble _ -> 0x07
290  TString _ -> 0x08
291  TBinary _ -> 0x08
292  TList{} -> 0x09
293  TSet{} -> 0x0A
294  TMap{} -> 0x0B
295  TStruct{} -> 0x0C
296
297typeFrom :: Word8 -> ThriftType
298typeFrom w = case w of
299  0x01 -> T_BOOL
300  0x02 -> T_BOOL
301  0x03 -> T_BYTE
302  0x04 -> T_I16
303  0x05 -> T_I32
304  0x06 -> T_I64
305  0x07 -> T_DOUBLE
306  0x08 -> T_STRING
307  0x09 -> T_LIST T_VOID
308  0x0A -> T_SET T_VOID
309  0x0B -> T_MAP T_VOID T_VOID
310  0x0C -> T_STRUCT Map.empty
311  n -> error $ "typeFrom: " ++ show n ++ " is not a compact type"
312