1-- | 2-- Module : Data.ASN1.Prim 3-- License : BSD-style 4-- Maintainer : Vincent Hanquez <vincent@snarc.org> 5-- Stability : experimental 6-- Portability : unknown 7-- 8-- Tools to read ASN1 primitive (e.g. boolean, int) 9-- 10 11{-# LANGUAGE CPP #-} 12{-# LANGUAGE ViewPatterns #-} 13module Data.ASN1.Prim 14 ( 15 -- * ASN1 high level algebraic type 16 ASN1(..) 17 , ASN1ConstructionType(..) 18 19 , encodeHeader 20 , encodePrimitiveHeader 21 , encodePrimitive 22 , decodePrimitive 23 , encodeConstructed 24 , encodeList 25 , encodeOne 26 , mkSmallestLength 27 28 -- * marshall an ASN1 type from a val struct or a bytestring 29 , getBoolean 30 , getInteger 31 , getDouble 32 , getBitString 33 , getOctetString 34 , getNull 35 , getOID 36 , getTime 37 38 -- * marshall an ASN1 type to a bytestring 39 , putTime 40 , putInteger 41 , putDouble 42 , putBitString 43 , putString 44 , putOID 45 ) where 46 47import Data.ASN1.Internal 48import Data.ASN1.Stream 49import Data.ASN1.BitArray 50import Data.ASN1.Types 51import Data.ASN1.Types.Lowlevel 52import Data.ASN1.Error 53import Data.ASN1.Serialize 54import Data.Bits 55import Data.Monoid 56import Data.Word 57import Data.List (unfoldr) 58import Data.ByteString (ByteString) 59import Data.Char (ord, isDigit) 60import qualified Data.ByteString as B 61import qualified Data.ByteString.Char8 as BC 62import qualified Data.ByteString.Unsafe as B 63import Data.Hourglass 64import Control.Arrow (first) 65import Control.Applicative 66import Control.Monad 67 68encodeHeader :: Bool -> ASN1Length -> ASN1 -> ASN1Header 69encodeHeader pc len (Boolean _) = ASN1Header Universal 0x1 pc len 70encodeHeader pc len (IntVal _) = ASN1Header Universal 0x2 pc len 71encodeHeader pc len (BitString _) = ASN1Header Universal 0x3 pc len 72encodeHeader pc len (OctetString _) = ASN1Header Universal 0x4 pc len 73encodeHeader pc len Null = ASN1Header Universal 0x5 pc len 74encodeHeader pc len (OID _) = ASN1Header Universal 0x6 pc len 75encodeHeader pc len (Real _) = ASN1Header Universal 0x9 pc len 76encodeHeader pc len (Enumerated _) = ASN1Header Universal 0xa pc len 77encodeHeader pc len (ASN1String cs) = ASN1Header Universal (characterStringType $ characterEncoding cs) pc len 78 where characterStringType UTF8 = 0xc 79 characterStringType Numeric = 0x12 80 characterStringType Printable = 0x13 81 characterStringType T61 = 0x14 82 characterStringType VideoTex = 0x15 83 characterStringType IA5 = 0x16 84 characterStringType Graphic = 0x19 85 characterStringType Visible = 0x1a 86 characterStringType General = 0x1b 87 characterStringType UTF32 = 0x1c 88 characterStringType Character = 0x1d 89 characterStringType BMP = 0x1e 90encodeHeader pc len (ASN1Time TimeUTC _ _) = ASN1Header Universal 0x17 pc len 91encodeHeader pc len (ASN1Time TimeGeneralized _ _) = ASN1Header Universal 0x18 pc len 92encodeHeader pc len (Start Sequence) = ASN1Header Universal 0x10 pc len 93encodeHeader pc len (Start Set) = ASN1Header Universal 0x11 pc len 94encodeHeader pc len (Start (Container tc tag)) = ASN1Header tc tag pc len 95encodeHeader pc len (Other tc tag _) = ASN1Header tc tag pc len 96encodeHeader _ _ (End _) = error "this should not happen" 97 98encodePrimitiveHeader :: ASN1Length -> ASN1 -> ASN1Header 99encodePrimitiveHeader = encodeHeader False 100 101encodePrimitiveData :: ASN1 -> ByteString 102encodePrimitiveData (Boolean b) = B.singleton (if b then 0xff else 0) 103encodePrimitiveData (IntVal i) = putInteger i 104encodePrimitiveData (BitString bits) = putBitString bits 105encodePrimitiveData (OctetString b) = putString b 106encodePrimitiveData Null = B.empty 107encodePrimitiveData (OID oidv) = putOID oidv 108encodePrimitiveData (Real d) = putDouble d 109encodePrimitiveData (Enumerated i) = putInteger $ fromIntegral i 110encodePrimitiveData (ASN1String cs) = getCharacterStringRawData cs 111encodePrimitiveData (ASN1Time ty ti tz) = putTime ty ti tz 112encodePrimitiveData (Other _ _ b) = b 113encodePrimitiveData o = error ("not a primitive " ++ show o) 114 115encodePrimitive :: ASN1 -> (Int, [ASN1Event]) 116encodePrimitive a = 117 let b = encodePrimitiveData a 118 blen = B.length b 119 len = makeLength blen 120 hdr = encodePrimitiveHeader len a 121 in (B.length (putHeader hdr) + blen, [Header hdr, Primitive b]) 122 where 123 makeLength len 124 | len < 0x80 = LenShort len 125 | otherwise = LenLong (nbBytes len) len 126 nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 127 128encodeOne :: ASN1 -> (Int, [ASN1Event]) 129encodeOne (Start _) = error "encode one cannot do start" 130encodeOne t = encodePrimitive t 131 132encodeList :: [ASN1] -> (Int, [ASN1Event]) 133encodeList [] = (0, []) 134encodeList (End _:xs) = encodeList xs 135encodeList (t@(Start _):xs) = 136 let (ys, zs) = getConstructedEnd 0 xs 137 (llen, lev) = encodeList zs 138 (len, ev) = encodeConstructed t ys 139 in (llen + len, ev ++ lev) 140 141encodeList (x:xs) = 142 let (llen, lev) = encodeList xs 143 (len, ev) = encodeOne x 144 in (llen + len, ev ++ lev) 145 146encodeConstructed :: ASN1 -> [ASN1] -> (Int, [ASN1Event]) 147encodeConstructed c@(Start _) children = 148 (tlen, Header h : ConstructionBegin : events ++ [ConstructionEnd]) 149 where (clen, events) = encodeList children 150 len = mkSmallestLength clen 151 h = encodeHeader True len c 152 tlen = B.length (putHeader h) + clen 153 154encodeConstructed _ _ = error "not a start node" 155 156mkSmallestLength :: Int -> ASN1Length 157mkSmallestLength i 158 | i < 0x80 = LenShort i 159 | otherwise = LenLong (nbBytes i) i 160 where nbBytes nb = if nb > 255 then 1 + nbBytes (nb `div` 256) else 1 161 162type ASN1Ret = Either ASN1Error ASN1 163 164decodePrimitive :: ASN1Header -> B.ByteString -> ASN1Ret 165decodePrimitive (ASN1Header Universal 0x1 _ _) p = getBoolean False p 166decodePrimitive (ASN1Header Universal 0x2 _ _) p = getInteger p 167decodePrimitive (ASN1Header Universal 0x3 _ _) p = getBitString p 168decodePrimitive (ASN1Header Universal 0x4 _ _) p = getOctetString p 169decodePrimitive (ASN1Header Universal 0x5 _ _) p = getNull p 170decodePrimitive (ASN1Header Universal 0x6 _ _) p = getOID p 171decodePrimitive (ASN1Header Universal 0x7 _ _) _ = Left $ TypeNotImplemented "Object Descriptor" 172decodePrimitive (ASN1Header Universal 0x8 _ _) _ = Left $ TypeNotImplemented "External" 173decodePrimitive (ASN1Header Universal 0x9 _ _) p = getDouble p 174decodePrimitive (ASN1Header Universal 0xa _ _) p = getEnumerated p 175decodePrimitive (ASN1Header Universal 0xb _ _) _ = Left $ TypeNotImplemented "EMBEDDED PDV" 176decodePrimitive (ASN1Header Universal 0xc _ _) p = getCharacterString UTF8 p 177decodePrimitive (ASN1Header Universal 0xd _ _) _ = Left $ TypeNotImplemented "RELATIVE-OID" 178decodePrimitive (ASN1Header Universal 0x10 _ _) _ = Left $ TypePrimitiveInvalid "sequence" 179decodePrimitive (ASN1Header Universal 0x11 _ _) _ = Left $ TypePrimitiveInvalid "set" 180decodePrimitive (ASN1Header Universal 0x12 _ _) p = getCharacterString Numeric p 181decodePrimitive (ASN1Header Universal 0x13 _ _) p = getCharacterString Printable p 182decodePrimitive (ASN1Header Universal 0x14 _ _) p = getCharacterString T61 p 183decodePrimitive (ASN1Header Universal 0x15 _ _) p = getCharacterString VideoTex p 184decodePrimitive (ASN1Header Universal 0x16 _ _) p = getCharacterString IA5 p 185decodePrimitive (ASN1Header Universal 0x17 _ _) p = getTime TimeUTC p 186decodePrimitive (ASN1Header Universal 0x18 _ _) p = getTime TimeGeneralized p 187decodePrimitive (ASN1Header Universal 0x19 _ _) p = getCharacterString Graphic p 188decodePrimitive (ASN1Header Universal 0x1a _ _) p = getCharacterString Visible p 189decodePrimitive (ASN1Header Universal 0x1b _ _) p = getCharacterString General p 190decodePrimitive (ASN1Header Universal 0x1c _ _) p = getCharacterString UTF32 p 191decodePrimitive (ASN1Header Universal 0x1d _ _) p = getCharacterString Character p 192decodePrimitive (ASN1Header Universal 0x1e _ _) p = getCharacterString BMP p 193decodePrimitive (ASN1Header tc tag _ _) p = Right $ Other tc tag p 194 195 196getBoolean :: Bool -> ByteString -> Either ASN1Error ASN1 197getBoolean isDer s = 198 if B.length s == 1 199 then case B.head s of 200 0 -> Right (Boolean False) 201 0xff -> Right (Boolean True) 202 _ -> if isDer then Left $ PolicyFailed "DER" "boolean value not canonical" else Right (Boolean True) 203 else Left $ TypeDecodingFailed "boolean: length not within bound" 204 205{- | getInteger, parse a value bytestring and get the integer out of the two complement encoded bytes -} 206getInteger :: ByteString -> Either ASN1Error ASN1 207{-# INLINE getInteger #-} 208getInteger s = IntVal <$> getIntegerRaw "integer" s 209 210{- | getEnumerated, parse an enumerated value the same way that integer values are parsed. -} 211getEnumerated :: ByteString -> Either ASN1Error ASN1 212{-# INLINE getEnumerated #-} 213getEnumerated s = Enumerated <$> getIntegerRaw "enumerated" s 214 215{- | According to X.690 section 8.4 integer and enumerated values should be encoded the same way. -} 216getIntegerRaw :: String -> ByteString -> Either ASN1Error Integer 217getIntegerRaw typestr s 218 | B.length s == 0 = Left . TypeDecodingFailed $ typestr ++ ": null encoding" 219 | B.length s == 1 = Right $ snd $ intOfBytes s 220 | otherwise = 221 if (v1 == 0xff && testBit v2 7) || (v1 == 0x0 && (not $ testBit v2 7)) 222 then Left . TypeDecodingFailed $ typestr ++ ": not shortest encoding" 223 else Right $ snd $ intOfBytes s 224 where 225 v1 = s `B.index` 0 226 v2 = s `B.index` 1 227 228getDouble :: ByteString -> Either ASN1Error ASN1 229getDouble s = Real <$> getDoubleRaw s 230 231getDoubleRaw :: ByteString -> Either ASN1Error Double 232getDoubleRaw s 233 | B.null s = Right 0 234getDoubleRaw s@(B.unsafeHead -> h) 235 | h == 0x40 = Right $! (1/0) -- Infinity 236 | h == 0x41 = Right $! (-1/0) -- -Infinity 237 | h == 0x42 = Right $! (0/0) -- NaN 238 | otherwise = do 239 let len = B.length s 240 base <- case (h `testBit` 5, h `testBit` 4) of 241 -- extract bits 5,4 for the base 242 (False, False) -> return 2 243 (False, True) -> return 8 244 (True, False) -> return 16 245 _ -> Left . TypeDecodingFailed $ "real: invalid base detected" 246 -- check bit 6 for the sign 247 let mkSigned = if h `testBit` 6 then negate else id 248 -- extract bits 3,2 for the scaling factor 249 let scaleFactor = (h .&. 0x0c) `shiftR` 2 250 expLength <- getExponentLength len h s 251 -- 1 byte for the header, expLength for the exponent, and at least 1 byte for the mantissa 252 unless (len > 1 + fromIntegral expLength) $ 253 Left . TypeDecodingFailed $ "real: not enough input for exponent and mantissa" 254 let (_, exp'') = intOfBytes $ B.unsafeTake (fromIntegral expLength) $ B.unsafeDrop 1 s 255 let exp' = case base :: Int of 256 2 -> exp'' 257 8 -> 3 * exp'' 258 _ -> 4 * exp'' -- must be 16 259 exponent = exp' - fromIntegral scaleFactor 260 -- whatever is leftover is the mantissa, unsigned 261 (_, mantissa) = uintOfBytes $ B.unsafeDrop (1 + fromIntegral expLength) s 262 Right $! encodeFloat (mkSigned $ toInteger mantissa) (fromIntegral exponent) 263 264getExponentLength :: Int -> Word8 -> ByteString -> Either ASN1Error Word8 265getExponentLength len h s = 266 case h .&. 0x03 of 267 l | l == 0x03 -> do 268 unless (len > 1) $ Left . TypeDecodingFailed $ "real: not enough input to decode exponent length" 269 return $ B.unsafeIndex s 1 270 | otherwise -> return $ l + 1 271 272getBitString :: ByteString -> Either ASN1Error ASN1 273getBitString s = 274 let toSkip = B.head s in 275 let toSkip' = if toSkip >= 48 && toSkip <= 48 + 7 then toSkip - (fromIntegral $ ord '0') else toSkip in 276 let xs = B.tail s in 277 if toSkip' >= 0 && toSkip' <= 7 278 then Right $ BitString $ toBitArray xs (fromIntegral toSkip') 279 else Left $ TypeDecodingFailed ("bitstring: skip number not within bound " ++ show toSkip' ++ " " ++ show s) 280 281getCharacterString :: ASN1StringEncoding -> ByteString -> Either ASN1Error ASN1 282getCharacterString encoding bs = Right $ ASN1String (ASN1CharacterString encoding bs) 283 284getOctetString :: ByteString -> Either ASN1Error ASN1 285getOctetString = Right . OctetString 286 287getNull :: ByteString -> Either ASN1Error ASN1 288getNull s 289 | B.length s == 0 = Right Null 290 | otherwise = Left $ TypeDecodingFailed "Null: data length not within bound" 291 292{- | return an OID -} 293getOID :: ByteString -> Either ASN1Error ASN1 294getOID s = Right $ OID $ (fromIntegral (x `div` 40) : fromIntegral (x `mod` 40) : groupOID xs) 295 where 296 (x:xs) = B.unpack s 297 298 groupOID :: [Word8] -> [Integer] 299 groupOID = map (foldl (\acc n -> (acc `shiftL` 7) + fromIntegral n) 0) . groupSubOID 300 301 groupSubOIDHelper [] = Nothing 302 groupSubOIDHelper l = Just $ spanSubOIDbound l 303 304 groupSubOID :: [Word8] -> [[Word8]] 305 groupSubOID = unfoldr groupSubOIDHelper 306 307 spanSubOIDbound [] = ([], []) 308 spanSubOIDbound (a:as) = if testBit a 7 then (clearBit a 7 : ys, zs) else ([a], as) 309 where (ys, zs) = spanSubOIDbound as 310 311getTime :: ASN1TimeType -> ByteString -> Either ASN1Error ASN1 312getTime timeType bs 313 | hasNonASCII bs = decodingError "contains non ASCII characters" 314 | otherwise = 315 case timeParseE format (BC.unpack bs) of -- BC.unpack is safe as we check ASCIIness first 316 Left _ -> 317 case timeParseE formatNoSeconds (BC.unpack bs) of 318 Left _ -> decodingError ("cannot convert string " ++ BC.unpack bs) 319 Right r -> parseRemaining r 320 Right r -> parseRemaining r 321 where 322 parseRemaining r = 323 case parseTimezone $ parseMs $ first adjustUTC r of 324 Left err -> decodingError err 325 Right (dt', tz) -> Right $ ASN1Time timeType dt' tz 326 327 adjustUTC dt@(DateTime (Date y m d) tod) 328 | timeType == TimeGeneralized = dt 329 | y > 2050 = DateTime (Date (y - 100) m d) tod 330 | otherwise = dt 331 formatNoSeconds = init format 332 format | timeType == TimeGeneralized = 'Y':'Y':baseFormat 333 | otherwise = baseFormat 334 baseFormat = "YYMMDDHMIS" 335 336 parseMs (dt,s) = 337 case s of 338 '.':s' -> let (ns, r) = first toNano $ spanToLength 3 isDigit s' 339 in (dt { dtTime = (dtTime dt) { todNSec = ns } }, r) 340 _ -> (dt,s) 341 parseTimezone (dt,s) = 342 case s of 343 '+':s' -> Right (dt, parseTimezoneFormat id s') 344 '-':s' -> Right (dt, parseTimezoneFormat ((-1) *) s') 345 'Z':[] -> Right (dt, Just timezone_UTC) 346 "" -> Right (dt, Nothing) 347 _ -> Left ("unknown timezone format: " ++ s) 348 349 parseTimezoneFormat transform s 350 | length s == 4 = Just $ toTz $ toInt $ fst $ spanToLength 4 isDigit s 351 | otherwise = Nothing 352 where toTz z = let (h,m) = z `divMod` 100 in TimezoneOffset $ transform (h * 60 + m) 353 354 toNano :: String -> NanoSeconds 355 toNano l = fromIntegral (toInt l * order * 1000000) 356 where len = length l 357 order = case len of 358 1 -> 100 359 2 -> 10 360 3 -> 1 361 _ -> 1 362 363 spanToLength :: Int -> (Char -> Bool) -> String -> (String, String) 364 spanToLength len p l = loop 0 l 365 where loop i z 366 | i >= len = ([], z) 367 | otherwise = case z of 368 [] -> ([], []) 369 x:xs -> if p x 370 then let (r1,r2) = loop (i+1) xs 371 in (x:r1, r2) 372 else ([], z) 373 374 toInt :: String -> Int 375 toInt = foldl (\acc w -> acc * 10 + (ord w - ord '0')) 0 376 377 decodingError reason = Left $ TypeDecodingFailed ("time format invalid for " ++ show timeType ++ " : " ++ reason) 378 hasNonASCII = maybe False (const True) . B.find (\c -> c > 0x7f) 379 380-- FIXME need msec printed 381putTime :: ASN1TimeType -> DateTime -> Maybe TimezoneOffset -> ByteString 382putTime ty dt mtz = BC.pack etime 383 where 384 etime 385 | ty == TimeUTC = timePrint "YYMMDDHMIS" dt ++ tzStr 386 | otherwise = timePrint "YYYYMMDDHMIS" dt ++ msecStr ++ tzStr 387 msecStr = [] 388 tzStr = case mtz of 389 Nothing -> "" 390 Just tz | tz == timezone_UTC -> "Z" 391 | otherwise -> show tz 392 393putInteger :: Integer -> ByteString 394putInteger i = B.pack $ bytesOfInt i 395 396putBitString :: BitArray -> ByteString 397putBitString (BitArray n bits) = 398 B.concat [B.singleton (fromIntegral i),bits] 399 where i = (8 - (n `mod` 8)) .&. 0x7 400 401putString :: ByteString -> ByteString 402putString l = l 403 404{- no enforce check that oid1 is between [0..2] and oid2 is between [0..39] -} 405putOID :: [Integer] -> ByteString 406putOID oids = case oids of 407 (oid1:oid2:suboids) -> 408 let eoidclass = fromIntegral (oid1 * 40 + oid2) 409 subeoids = B.concat $ map encode suboids 410 in B.cons eoidclass subeoids 411 _ -> error ("invalid OID format " ++ show oids) 412 where 413 encode x | x == 0 = B.singleton 0 414 | otherwise = putVarEncodingIntegral x 415 416putDouble :: Double -> ByteString 417putDouble d 418 | d == 0 = B.pack [] 419 | d == (1/0) = B.pack [0x40] 420 | d == negate (1/0) = B.pack [0x41] 421 | isNaN d = B.pack [0x42] 422 | otherwise = B.cons (header .|. (expLen - 1)) -- encode length of exponent 423 (expBS <> manBS) 424 where 425 (mkUnsigned, header) 426 | d < 0 = (negate, bINARY_NEGATIVE_NUMBER_ID) 427 | otherwise = (id, bINARY_POSITIVE_NUMBER_ID) 428 (man, exp) = decodeFloat d 429 (mantissa, exponent) = normalize (fromIntegral $ mkUnsigned man, exp) 430 expBS = putInteger (fromIntegral exponent) 431 expLen = fromIntegral (B.length expBS) 432 manBS = putInteger (fromIntegral mantissa) 433 434-- | Normalize the mantissa and adjust the exponent. 435-- 436-- DER requires the mantissa to either be 0 or odd, so we right-shift it 437-- until the LSB is 1, and then add the shift amount to the exponent. 438-- 439-- TODO: handle denormal numbers 440normalize :: (Word64, Int) -> (Word64, Int) 441normalize (mantissa, exponent) = (mantissa `shiftR` sh, exponent + sh) 442 where 443 sh = countTrailingZeros mantissa 444 445#if !(MIN_VERSION_base(4,8,0)) 446 countTrailingZeros :: FiniteBits b => b -> Int 447 countTrailingZeros x = go 0 448 where 449 go i | i >= w = i 450 | testBit x i = i 451 | otherwise = go (i+1) 452 w = finiteBitSize x 453#endif 454 455bINARY_POSITIVE_NUMBER_ID, bINARY_NEGATIVE_NUMBER_ID :: Word8 456bINARY_POSITIVE_NUMBER_ID = 0x80 457bINARY_NEGATIVE_NUMBER_ID = 0xc0 458