1-- |
2-- Module      : Data.ASN1.Prim
3-- License     : BSD-style
4-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
5-- Stability   : experimental
6-- Portability : unknown
7--
8-- Tools to read ASN1 primitive (e.g. boolean, int)
9--
10
11{-# LANGUAGE CPP #-}
12{-# LANGUAGE ViewPatterns #-}
13module Data.ASN1.Prim
14    (
15    -- * ASN1 high level algebraic type
16      ASN1(..)
17    , ASN1ConstructionType(..)
18
19    , encodeHeader
20    , encodePrimitiveHeader
21    , encodePrimitive
22    , decodePrimitive
23    , encodeConstructed
24    , encodeList
25    , encodeOne
26    , mkSmallestLength
27
28    -- * marshall an ASN1 type from a val struct or a bytestring
29    , getBoolean
30    , getInteger
31    , getDouble
32    , getBitString
33    , getOctetString
34    , getNull
35    , getOID
36    , getTime
37
38    -- * marshall an ASN1 type to a bytestring
39    , putTime
40    , putInteger
41    , putDouble
42    , putBitString
43    , putString
44    , putOID
45    ) where
46
47import Data.ASN1.Internal
48import Data.ASN1.Stream
49import Data.ASN1.BitArray
50import Data.ASN1.Types
51import Data.ASN1.Types.Lowlevel
52import Data.ASN1.Error
53import Data.ASN1.Serialize
54import Data.Bits
55import Data.Monoid
56import Data.Word
57import Data.List (unfoldr)
58import Data.ByteString (ByteString)
59import Data.Char (ord, isDigit)
60import qualified Data.ByteString as B
61import qualified Data.ByteString.Char8 as BC
62import qualified Data.ByteString.Unsafe as B
63import Data.Hourglass
64import Control.Arrow (first)
65import Control.Applicative
66import Control.Monad
67
68encodeHeader :: Bool -> ASN1Length -> ASN1 -> ASN1Header
69encodeHeader pc len (Boolean _)                = ASN1Header Universal 0x1 pc len
70encodeHeader pc len (IntVal _)                 = ASN1Header Universal 0x2 pc len
71encodeHeader pc len (BitString _)              = ASN1Header Universal 0x3 pc len
72encodeHeader pc len (OctetString _)            = ASN1Header Universal 0x4 pc len
73encodeHeader pc len Null                       = ASN1Header Universal 0x5 pc len
74encodeHeader pc len (OID _)                    = ASN1Header Universal 0x6 pc len
75encodeHeader pc len (Real _)                   = ASN1Header Universal 0x9 pc len
76encodeHeader pc len (Enumerated _)             = ASN1Header Universal 0xa pc len
77encodeHeader pc len (ASN1String cs)            = ASN1Header Universal (characterStringType $ characterEncoding cs) pc len
78  where characterStringType UTF8      = 0xc
79        characterStringType Numeric   = 0x12
80        characterStringType Printable = 0x13
81        characterStringType T61       = 0x14
82        characterStringType VideoTex  = 0x15
83        characterStringType IA5       = 0x16
84        characterStringType Graphic   = 0x19
85        characterStringType Visible   = 0x1a
86        characterStringType General   = 0x1b
87        characterStringType UTF32     = 0x1c
88        characterStringType Character = 0x1d
89        characterStringType BMP       = 0x1e
90encodeHeader pc len (ASN1Time TimeUTC _ _)     = ASN1Header Universal 0x17 pc len
91encodeHeader pc len (ASN1Time TimeGeneralized _ _) = ASN1Header Universal 0x18 pc len
92encodeHeader pc len (Start Sequence)           = ASN1Header Universal 0x10 pc len
93encodeHeader pc len (Start Set)                = ASN1Header Universal 0x11 pc len
94encodeHeader pc len (Start (Container tc tag)) = ASN1Header tc tag pc len
95encodeHeader pc len (Other tc tag _)           = ASN1Header tc tag pc len
96encodeHeader _ _ (End _)                       = error "this should not happen"
97
98encodePrimitiveHeader :: ASN1Length -> ASN1 -> ASN1Header
99encodePrimitiveHeader = encodeHeader False
100
101encodePrimitiveData :: ASN1 -> ByteString
102encodePrimitiveData (Boolean b)         = B.singleton (if b then 0xff else 0)
103encodePrimitiveData (IntVal i)          = putInteger i
104encodePrimitiveData (BitString bits)    = putBitString bits
105encodePrimitiveData (OctetString b)     = putString b
106encodePrimitiveData Null                = B.empty
107encodePrimitiveData (OID oidv)          = putOID oidv
108encodePrimitiveData (Real d)            = putDouble d
109encodePrimitiveData (Enumerated i)      = putInteger $ fromIntegral i
110encodePrimitiveData (ASN1String cs)     = getCharacterStringRawData cs
111encodePrimitiveData (ASN1Time ty ti tz) = putTime ty ti tz
112encodePrimitiveData (Other _ _ b)       = b
113encodePrimitiveData o                   = error ("not a primitive " ++ show o)
114
115encodePrimitive :: ASN1 -> (Int, [ASN1Event])
116encodePrimitive a =
117    let b = encodePrimitiveData a
118        blen = B.length b
119        len = makeLength blen
120        hdr = encodePrimitiveHeader len a
121     in (B.length (putHeader hdr) + blen, [Header hdr, Primitive b])
122  where
123        makeLength len
124            | len < 0x80 = LenShort len
125            | otherwise  = LenLong (nbBytes len) len
126        nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1
127
128encodeOne :: ASN1 -> (Int, [ASN1Event])
129encodeOne (Start _) = error "encode one cannot do start"
130encodeOne t         = encodePrimitive t
131
132encodeList :: [ASN1] -> (Int, [ASN1Event])
133encodeList []               = (0, [])
134encodeList (End _:xs)       = encodeList xs
135encodeList (t@(Start _):xs) =
136    let (ys, zs)    = getConstructedEnd 0 xs
137        (llen, lev) = encodeList zs
138        (len, ev)   = encodeConstructed t ys
139     in (llen + len, ev ++ lev)
140
141encodeList (x:xs)           =
142    let (llen, lev) = encodeList xs
143        (len, ev)   = encodeOne x
144     in (llen + len, ev ++ lev)
145
146encodeConstructed :: ASN1 -> [ASN1] -> (Int, [ASN1Event])
147encodeConstructed c@(Start _) children =
148    (tlen, Header h : ConstructionBegin : events ++ [ConstructionEnd])
149  where (clen, events) = encodeList children
150        len  = mkSmallestLength clen
151        h    = encodeHeader True len c
152        tlen = B.length (putHeader h) + clen
153
154encodeConstructed _ _ = error "not a start node"
155
156mkSmallestLength :: Int -> ASN1Length
157mkSmallestLength i
158    | i < 0x80  = LenShort i
159    | otherwise = LenLong (nbBytes i) i
160        where nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1
161
162type ASN1Ret = Either ASN1Error ASN1
163
164decodePrimitive :: ASN1Header -> B.ByteString -> ASN1Ret
165decodePrimitive (ASN1Header Universal 0x1 _ _) p   = getBoolean False p
166decodePrimitive (ASN1Header Universal 0x2 _ _) p   = getInteger p
167decodePrimitive (ASN1Header Universal 0x3 _ _) p   = getBitString p
168decodePrimitive (ASN1Header Universal 0x4 _ _) p   = getOctetString p
169decodePrimitive (ASN1Header Universal 0x5 _ _) p   = getNull p
170decodePrimitive (ASN1Header Universal 0x6 _ _) p   = getOID p
171decodePrimitive (ASN1Header Universal 0x7 _ _) _   = Left $ TypeNotImplemented "Object Descriptor"
172decodePrimitive (ASN1Header Universal 0x8 _ _) _   = Left $ TypeNotImplemented "External"
173decodePrimitive (ASN1Header Universal 0x9 _ _) p   = getDouble p
174decodePrimitive (ASN1Header Universal 0xa _ _) p   = getEnumerated p
175decodePrimitive (ASN1Header Universal 0xb _ _) _   = Left $ TypeNotImplemented "EMBEDDED PDV"
176decodePrimitive (ASN1Header Universal 0xc _ _) p   = getCharacterString UTF8 p
177decodePrimitive (ASN1Header Universal 0xd _ _) _   = Left $ TypeNotImplemented "RELATIVE-OID"
178decodePrimitive (ASN1Header Universal 0x10 _ _) _  = Left $ TypePrimitiveInvalid "sequence"
179decodePrimitive (ASN1Header Universal 0x11 _ _) _  = Left $ TypePrimitiveInvalid "set"
180decodePrimitive (ASN1Header Universal 0x12 _ _) p  = getCharacterString Numeric p
181decodePrimitive (ASN1Header Universal 0x13 _ _) p  = getCharacterString Printable p
182decodePrimitive (ASN1Header Universal 0x14 _ _) p  = getCharacterString T61 p
183decodePrimitive (ASN1Header Universal 0x15 _ _) p  = getCharacterString VideoTex p
184decodePrimitive (ASN1Header Universal 0x16 _ _) p  = getCharacterString IA5 p
185decodePrimitive (ASN1Header Universal 0x17 _ _) p  = getTime TimeUTC p
186decodePrimitive (ASN1Header Universal 0x18 _ _) p  = getTime TimeGeneralized p
187decodePrimitive (ASN1Header Universal 0x19 _ _) p  = getCharacterString Graphic p
188decodePrimitive (ASN1Header Universal 0x1a _ _) p  = getCharacterString Visible p
189decodePrimitive (ASN1Header Universal 0x1b _ _) p  = getCharacterString General p
190decodePrimitive (ASN1Header Universal 0x1c _ _) p  = getCharacterString UTF32 p
191decodePrimitive (ASN1Header Universal 0x1d _ _) p  = getCharacterString Character p
192decodePrimitive (ASN1Header Universal 0x1e _ _) p  = getCharacterString BMP p
193decodePrimitive (ASN1Header tc        tag  _ _) p  = Right $ Other tc tag p
194
195
196getBoolean :: Bool -> ByteString -> Either ASN1Error ASN1
197getBoolean isDer s =
198    if B.length s == 1
199        then case B.head s of
200            0    -> Right (Boolean False)
201            0xff -> Right (Boolean True)
202            _    -> if isDer then Left $ PolicyFailed "DER" "boolean value not canonical" else Right (Boolean True)
203        else Left $ TypeDecodingFailed "boolean: length not within bound"
204
205{- | getInteger, parse a value bytestring and get the integer out of the two complement encoded bytes -}
206getInteger :: ByteString -> Either ASN1Error ASN1
207{-# INLINE getInteger #-}
208getInteger s = IntVal <$> getIntegerRaw "integer" s
209
210{- | getEnumerated, parse an enumerated value the same way that integer values are parsed. -}
211getEnumerated :: ByteString -> Either ASN1Error ASN1
212{-# INLINE getEnumerated #-}
213getEnumerated s = Enumerated <$> getIntegerRaw "enumerated" s
214
215{- | According to X.690 section 8.4 integer and enumerated values should be encoded the same way. -}
216getIntegerRaw :: String -> ByteString -> Either ASN1Error Integer
217getIntegerRaw typestr s
218    | B.length s == 0 = Left . TypeDecodingFailed $ typestr ++ ": null encoding"
219    | B.length s == 1 = Right $ snd $ intOfBytes s
220    | otherwise       =
221        if (v1 == 0xff && testBit v2 7) || (v1 == 0x0 && (not $ testBit v2 7))
222            then Left . TypeDecodingFailed $ typestr ++ ": not shortest encoding"
223            else Right $ snd $ intOfBytes s
224        where
225            v1 = s `B.index` 0
226            v2 = s `B.index` 1
227
228getDouble :: ByteString -> Either ASN1Error ASN1
229getDouble s = Real <$> getDoubleRaw s
230
231getDoubleRaw :: ByteString -> Either ASN1Error Double
232getDoubleRaw s
233  | B.null s  = Right 0
234getDoubleRaw s@(B.unsafeHead -> h)
235  | h == 0x40 = Right $! (1/0)  -- Infinity
236  | h == 0x41 = Right $! (-1/0) -- -Infinity
237  | h == 0x42 = Right $! (0/0)  -- NaN
238  | otherwise = do
239      let len = B.length s
240      base <- case (h `testBit` 5, h `testBit` 4) of
241                -- extract bits 5,4 for the base
242                (False, False) -> return 2
243                (False, True)  -> return 8
244                (True,  False) -> return 16
245                _              -> Left . TypeDecodingFailed $ "real: invalid base detected"
246      -- check bit 6 for the sign
247      let mkSigned = if h `testBit` 6 then negate else id
248      -- extract bits 3,2 for the scaling factor
249      let scaleFactor = (h .&. 0x0c) `shiftR` 2
250      expLength <- getExponentLength len h s
251      -- 1 byte for the header, expLength for the exponent, and at least 1 byte for the mantissa
252      unless (len > 1 + fromIntegral expLength) $
253        Left . TypeDecodingFailed $ "real: not enough input for exponent and mantissa"
254      let (_, exp'') = intOfBytes $ B.unsafeTake (fromIntegral expLength) $ B.unsafeDrop 1 s
255      let exp' = case base :: Int of
256                   2 -> exp''
257                   8 -> 3 * exp''
258                   _ -> 4 * exp'' -- must be 16
259          exponent = exp' - fromIntegral scaleFactor
260          -- whatever is leftover is the mantissa, unsigned
261          (_, mantissa) = uintOfBytes $ B.unsafeDrop (1 + fromIntegral expLength) s
262      Right $! encodeFloat (mkSigned $ toInteger mantissa) (fromIntegral exponent)
263
264getExponentLength :: Int -> Word8 -> ByteString -> Either ASN1Error Word8
265getExponentLength len h s =
266  case h .&. 0x03 of
267    l | l == 0x03 -> do
268          unless (len > 1) $ Left . TypeDecodingFailed $ "real: not enough input to decode exponent length"
269          return $ B.unsafeIndex s 1
270      | otherwise -> return $ l + 1
271
272getBitString :: ByteString -> Either ASN1Error ASN1
273getBitString s =
274    let toSkip = B.head s in
275    let toSkip' = if toSkip >= 48 && toSkip <= 48 + 7 then toSkip - (fromIntegral $ ord '0') else toSkip in
276    let xs = B.tail s in
277    if toSkip' >= 0 && toSkip' <= 7
278        then Right $ BitString $ toBitArray xs (fromIntegral toSkip')
279        else Left $ TypeDecodingFailed ("bitstring: skip number not within bound " ++ show toSkip' ++ " " ++  show s)
280
281getCharacterString :: ASN1StringEncoding -> ByteString -> Either ASN1Error ASN1
282getCharacterString encoding bs = Right $ ASN1String (ASN1CharacterString encoding bs)
283
284getOctetString :: ByteString -> Either ASN1Error ASN1
285getOctetString = Right . OctetString
286
287getNull :: ByteString -> Either ASN1Error ASN1
288getNull s
289    | B.length s == 0 = Right Null
290    | otherwise       = Left $ TypeDecodingFailed "Null: data length not within bound"
291
292{- | return an OID -}
293getOID :: ByteString -> Either ASN1Error ASN1
294getOID s = Right $ OID $ (fromIntegral (x `div` 40) : fromIntegral (x `mod` 40) : groupOID xs)
295  where
296        (x:xs) = B.unpack s
297
298        groupOID :: [Word8] -> [Integer]
299        groupOID = map (foldl (\acc n -> (acc `shiftL` 7) + fromIntegral n) 0) . groupSubOID
300
301        groupSubOIDHelper [] = Nothing
302        groupSubOIDHelper l  = Just $ spanSubOIDbound l
303
304        groupSubOID :: [Word8] -> [[Word8]]
305        groupSubOID = unfoldr groupSubOIDHelper
306
307        spanSubOIDbound [] = ([], [])
308        spanSubOIDbound (a:as) = if testBit a 7 then (clearBit a 7 : ys, zs) else ([a], as)
309            where (ys, zs) = spanSubOIDbound as
310
311getTime :: ASN1TimeType -> ByteString -> Either ASN1Error ASN1
312getTime timeType bs
313    | hasNonASCII bs = decodingError "contains non ASCII characters"
314    | otherwise      =
315        case timeParseE format (BC.unpack bs) of -- BC.unpack is safe as we check ASCIIness first
316            Left _  ->
317                case timeParseE formatNoSeconds (BC.unpack bs) of
318                    Left _  -> decodingError ("cannot convert string " ++ BC.unpack bs)
319                    Right r -> parseRemaining r
320            Right r -> parseRemaining r
321  where
322        parseRemaining r =
323            case parseTimezone $ parseMs $ first adjustUTC r of
324                Left err        -> decodingError err
325                Right (dt', tz) -> Right $ ASN1Time timeType dt' tz
326
327        adjustUTC dt@(DateTime (Date y m d) tod)
328            | timeType == TimeGeneralized = dt
329            | y > 2050                    = DateTime (Date (y - 100) m d) tod
330            | otherwise                   = dt
331        formatNoSeconds = init format
332        format | timeType == TimeGeneralized = 'Y':'Y':baseFormat
333               | otherwise                   = baseFormat
334        baseFormat = "YYMMDDHMIS"
335
336        parseMs (dt,s) =
337            case s of
338                '.':s' -> let (ns, r) = first toNano $ spanToLength 3 isDigit s'
339                           in (dt { dtTime = (dtTime dt) { todNSec = ns } }, r)
340                _      -> (dt,s)
341        parseTimezone (dt,s) =
342            case s of
343                '+':s' -> Right (dt, parseTimezoneFormat id s')
344                '-':s' -> Right (dt, parseTimezoneFormat ((-1) *) s')
345                'Z':[] -> Right (dt, Just timezone_UTC)
346                ""     -> Right (dt, Nothing)
347                _      -> Left ("unknown timezone format: " ++ s)
348
349        parseTimezoneFormat transform s
350            | length s == 4  = Just $ toTz $ toInt $ fst $ spanToLength 4 isDigit s
351            | otherwise      = Nothing
352          where toTz z = let (h,m) = z `divMod` 100 in TimezoneOffset $ transform (h * 60 + m)
353
354        toNano :: String -> NanoSeconds
355        toNano l = fromIntegral (toInt l * order * 1000000)
356          where len   = length l
357                order = case len of
358                            1 -> 100
359                            2 -> 10
360                            3 -> 1
361                            _ -> 1
362
363        spanToLength :: Int -> (Char -> Bool) -> String -> (String, String)
364        spanToLength len p l = loop 0 l
365          where loop i z
366                    | i >= len  = ([], z)
367                    | otherwise = case z of
368                                    []   -> ([], [])
369                                    x:xs -> if p x
370                                                then let (r1,r2) = loop (i+1) xs
371                                                      in (x:r1, r2)
372                                                else ([], z)
373
374        toInt :: String -> Int
375        toInt = foldl (\acc w -> acc * 10 + (ord w - ord '0')) 0
376
377        decodingError reason = Left $ TypeDecodingFailed ("time format invalid for " ++ show timeType ++ " : " ++ reason)
378        hasNonASCII = maybe False (const True) . B.find (\c -> c > 0x7f)
379
380-- FIXME need msec printed
381putTime :: ASN1TimeType -> DateTime -> Maybe TimezoneOffset -> ByteString
382putTime ty dt mtz = BC.pack etime
383  where
384        etime
385            | ty == TimeUTC = timePrint "YYMMDDHMIS" dt ++ tzStr
386            | otherwise     = timePrint "YYYYMMDDHMIS" dt ++ msecStr ++ tzStr
387        msecStr = []
388        tzStr = case mtz of
389                     Nothing                      -> ""
390                     Just tz | tz == timezone_UTC -> "Z"
391                             | otherwise          -> show tz
392
393putInteger :: Integer -> ByteString
394putInteger i = B.pack $ bytesOfInt i
395
396putBitString :: BitArray -> ByteString
397putBitString (BitArray n bits) =
398    B.concat [B.singleton (fromIntegral i),bits]
399  where i = (8 - (n `mod` 8)) .&. 0x7
400
401putString :: ByteString -> ByteString
402putString l = l
403
404{- no enforce check that oid1 is between [0..2] and oid2 is between [0..39] -}
405putOID :: [Integer] -> ByteString
406putOID oids = case oids of
407    (oid1:oid2:suboids) ->
408        let eoidclass = fromIntegral (oid1 * 40 + oid2)
409            subeoids  = B.concat $ map encode suboids
410         in B.cons eoidclass subeoids
411    _                   -> error ("invalid OID format " ++ show oids)
412  where
413        encode x | x == 0    = B.singleton 0
414                 | otherwise = putVarEncodingIntegral x
415
416putDouble :: Double -> ByteString
417putDouble d
418  | d == 0 = B.pack []
419  | d == (1/0) = B.pack [0x40]
420  | d == negate (1/0) = B.pack [0x41]
421  | isNaN d = B.pack [0x42]
422  | otherwise = B.cons (header .|. (expLen - 1)) -- encode length of exponent
423                (expBS <> manBS)
424  where
425  (mkUnsigned, header)
426    | d < 0     = (negate, bINARY_NEGATIVE_NUMBER_ID)
427    | otherwise = (id, bINARY_POSITIVE_NUMBER_ID)
428  (man, exp) = decodeFloat d
429  (mantissa, exponent) = normalize (fromIntegral $ mkUnsigned man, exp)
430  expBS = putInteger (fromIntegral exponent)
431  expLen = fromIntegral (B.length expBS)
432  manBS = putInteger (fromIntegral mantissa)
433
434-- | Normalize the mantissa and adjust the exponent.
435--
436-- DER requires the mantissa to either be 0 or odd, so we right-shift it
437-- until the LSB is 1, and then add the shift amount to the exponent.
438--
439-- TODO: handle denormal numbers
440normalize :: (Word64, Int) -> (Word64, Int)
441normalize (mantissa, exponent) = (mantissa `shiftR` sh, exponent + sh)
442  where
443    sh = countTrailingZeros mantissa
444
445#if !(MIN_VERSION_base(4,8,0))
446    countTrailingZeros :: FiniteBits b => b -> Int
447    countTrailingZeros x = go 0
448      where
449        go i | i >= w      = i
450             | testBit x i = i
451             | otherwise   = go (i+1)
452        w = finiteBitSize x
453#endif
454
455bINARY_POSITIVE_NUMBER_ID, bINARY_NEGATIVE_NUMBER_ID :: Word8
456bINARY_POSITIVE_NUMBER_ID = 0x80
457bINARY_NEGATIVE_NUMBER_ID = 0xc0
458