1-- |
2-- Module      : Data.ASN1.Get
3-- License     : BSD-style
4-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
5-- Stability   : experimental
6-- Portability : unknown
7--
8-- Simple get module with really simple accessor for ASN1.
9--
10-- Original code is pulled from the Get module from cereal
11-- which is covered by:
12-- Copyright   : Lennart Kolmodin, Galois Inc. 2009
13-- License     : BSD3-style (see LICENSE)
14--
15-- The original code has been tailored and reduced to only cover the useful
16-- case for asn1 and augmented by a position.
17--
18{-# LANGUAGE Rank2Types #-}
19{-# LANGUAGE CPP #-}
20module Data.ASN1.Get
21    ( Result(..)
22    , Input
23    , Get
24    , runGetPos
25    , runGet
26    , getBytes
27    , getBytesCopy
28    , getWord8
29    ) where
30
31import Control.Applicative (Applicative(..),Alternative(..))
32import Control.Monad (ap,MonadPlus(..))
33import Data.Maybe (fromMaybe)
34import Foreign
35
36import qualified Data.ByteString          as B
37
38-- | The result of a parse.
39data Result r = Fail String
40              -- ^ The parse failed. The 'String' is the
41              --   message describing the error, if any.
42              | Partial (B.ByteString -> Result r)
43              -- ^ Supply this continuation with more input so that
44              --   the parser can resume. To indicate that no more
45              --   input is available, use an 'B.empty' string.
46              | Done r Position B.ByteString
47              -- ^ The parse succeeded.  The 'B.ByteString' is the
48              --   input that had not yet been consumed (if any) when
49              --   the parse succeeded.
50
51instance Show r => Show (Result r) where
52    show (Fail msg)  = "Fail " ++ show msg
53    show (Partial _) = "Partial _"
54    show (Done r pos bs) = "Done " ++ show r ++ " " ++ show pos ++ " " ++ show bs
55
56instance Functor Result where
57    fmap _ (Fail msg)  = Fail msg
58    fmap f (Partial k) = Partial (fmap f . k)
59    fmap f (Done r p bs) = Done (f r) p bs
60
61type Input  = B.ByteString
62type Buffer = Maybe B.ByteString
63
64type Failure   r = Input -> Buffer -> More -> Position -> String -> Result r
65type Success a r = Input -> Buffer -> More -> Position -> a      -> Result r
66type Position    = Word64
67
68-- | Have we read all available input?
69data More = Complete
70          | Incomplete (Maybe Int)
71          deriving (Eq)
72
73-- | The Get monad is an Exception and State monad.
74newtype Get a = Get
75    { unGet :: forall r. Input -> Buffer -> More -> Position -> Failure r -> Success a r -> Result r }
76
77append :: Buffer -> Buffer -> Buffer
78append l r = B.append `fmap` l <*> r
79{-# INLINE append #-}
80
81bufferBytes :: Buffer -> B.ByteString
82bufferBytes  = fromMaybe B.empty
83{-# INLINE bufferBytes #-}
84
85instance Functor Get where
86    fmap p m =
87      Get $ \s0 b0 m0 p0 kf ks ->
88        let ks' s1 b1 m1 p1 a = ks s1 b1 m1 p1 (p a)
89         in unGet m s0 b0 m0 p0 kf ks'
90
91instance Applicative Get where
92    pure  = return
93    (<*>) = ap
94
95instance Alternative Get where
96    empty = failDesc "empty"
97    (<|>) = mplus
98
99-- Definition directly from Control.Monad.State.Strict
100instance Monad Get where
101    return a = Get $ \ s0 b0 m0 p0 _ ks -> ks s0 b0 m0 p0 a
102
103    m >>= g  = Get $ \s0 b0 m0 p0 kf ks ->
104        let ks' s1 b1 m1 p1 a = unGet (g a) s1 b1 m1 p1 kf ks
105         in unGet m s0 b0 m0 p0 kf ks'
106
107#if MIN_VERSION_base(4,13,0)
108instance MonadFail Get where
109#endif
110    fail = failDesc
111
112instance MonadPlus Get where
113    mzero     = failDesc "mzero"
114    mplus a b =
115      Get $ \s0 b0 m0 p0 kf ks ->
116        let kf' _ b1 m1 p1 _ = unGet b (s0 `B.append` bufferBytes b1)
117                                       (b0 `append` b1) m1 p1 kf ks
118         in unGet a s0 (Just B.empty) m0 p0 kf' ks
119
120------------------------------------------------------------------------
121
122put :: Position -> B.ByteString -> Get ()
123put pos s = Get (\_ b0 m p0 _ k -> k s b0 m (p0+pos) ())
124{-# INLINE put #-}
125
126finalK :: B.ByteString -> t -> t1 -> Position -> r -> Result r
127finalK s _ _ p a = Done a p s
128
129failK :: Failure a
130failK _ _ _ p s = Fail (show p ++ ":" ++ s)
131
132-- | Run the Get monad applies a 'get'-based parser on the input ByteString
133runGetPos :: Position -> Get a -> B.ByteString -> Result a
134runGetPos pos m str = unGet m str Nothing (Incomplete Nothing) pos failK finalK
135{-# INLINE runGetPos #-}
136
137runGet :: Get a -> B.ByteString -> Result a
138runGet = runGetPos 0
139{-# INLINE runGet #-}
140
141-- | If at least @n@ bytes of input are available, return the current
142--   input, otherwise fail.
143ensure :: Int -> Get B.ByteString
144ensure n = n `seq` Get $ \ s0 b0 m0 p0 kf ks ->
145    if B.length s0 >= n
146    then ks s0 b0 m0 p0 s0
147    else unGet (demandInput >> ensureRec n) s0 b0 m0 p0 kf ks
148{-# INLINE ensure #-}
149
150-- | If at least @n@ bytes of input are available, return the current
151--   input, otherwise fail.
152ensureRec :: Int -> Get B.ByteString
153ensureRec n = Get $ \s0 b0 m0 p0 kf ks ->
154    if B.length s0 >= n
155    then ks s0 b0 m0 p0 s0
156    else unGet (demandInput >> ensureRec n) s0 b0 m0 p0 kf ks
157
158-- | Immediately demand more input via a 'Partial' continuation
159--   result.
160demandInput :: Get ()
161demandInput = Get $ \s0 b0 m0 p0 kf ks ->
162  case m0 of
163    Complete      -> kf s0 b0 m0 p0 "too few bytes"
164    Incomplete mb -> Partial $ \s ->
165      if B.null s
166      then kf s0 b0 m0 p0 "too few bytes"
167      else let update l = l - B.length s
168               s1 = s0 `B.append` s
169               b1 = b0 `append` Just s
170            in ks s1 b1 (Incomplete (update `fmap` mb)) p0 ()
171
172failDesc :: String -> Get a
173failDesc err = Get (\s0 b0 m0 p0 kf _ -> kf s0 b0 m0 p0 ("Failed reading: " ++ err))
174
175------------------------------------------------------------------------
176-- Utility with ByteStrings
177
178-- | An efficient 'get' method for strict ByteStrings. Fails if fewer
179-- than @n@ bytes are left in the input. This function creates a fresh
180-- copy of the underlying bytes.
181getBytesCopy :: Int -> Get B.ByteString
182getBytesCopy n = do
183  bs <- getBytes n
184  return $! B.copy bs
185
186------------------------------------------------------------------------
187-- Helpers
188
189-- | Pull @n@ bytes from the input, as a strict ByteString.
190getBytes :: Int -> Get B.ByteString
191getBytes n
192  | n <= 0    = return B.empty
193  | otherwise = do
194    s <- ensure n
195    let (b1, b2) = B.splitAt n s
196    put (fromIntegral n) b2
197    return b1
198
199getWord8 :: Get Word8
200getWord8 = do
201    s <- ensure 1
202    case B.uncons s of
203        Nothing     -> error "getWord8: ensure internal error"
204        Just (h,b2) -> put 1 b2 >> return h
205