1-- Disable this warning so we can still test deprecated functionality. 2{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-} 3module Connection 4 ( newPairContext 5 , arbitraryCiphers 6 , arbitraryVersions 7 , arbitraryHashSignatures 8 , arbitraryGroups 9 , arbitraryKeyUsage 10 , arbitraryPairParams 11 , arbitraryPairParams13 12 , arbitraryPairParamsWithVersionsAndCiphers 13 , arbitraryClientCredential 14 , arbitraryCredentialsOfEachCurve 15 , arbitraryRSACredentialWithUsage 16 , dhParamsGroup 17 , getConnectVersion 18 , isVersionEnabled 19 , isCustomDHParams 20 , isLeafRSA 21 , isCredentialDSA 22 , arbitraryEMSMode 23 , setEMSMode 24 , readClientSessionRef 25 , twoSessionRefs 26 , twoSessionManagers 27 , setPairParamsSessionManagers 28 , setPairParamsSessionResuming 29 , withDataPipe 30 , initiateDataPipe 31 , byeBye 32 ) where 33 34import Test.Tasty.QuickCheck 35import Certificate 36import PubKey 37import PipeChan 38import Network.TLS as TLS 39import Network.TLS.Extra 40import Data.X509 41import Data.Default.Class 42import Data.IORef 43import Control.Applicative 44import Control.Concurrent.Async 45import Control.Concurrent.Chan 46import Control.Concurrent 47import qualified Control.Exception as E 48import Control.Monad (unless, when) 49import Data.List (intersect, isInfixOf) 50 51import qualified Data.ByteString as B 52 53debug :: Bool 54debug = False 55 56knownCiphers :: [Cipher] 57knownCiphers = ciphersuite_all ++ ciphersuite_weak 58 where 59 ciphersuite_weak = [ 60 cipher_DHE_DSS_RC4_SHA1 61 , cipher_RC4_128_MD5 62 , cipher_null_MD5 63 , cipher_null_SHA1 64 ] 65 66arbitraryCiphers :: Gen [Cipher] 67arbitraryCiphers = listOf1 $ elements knownCiphers 68 69knownVersions :: [Version] 70knownVersions = [TLS13,TLS12,TLS11,TLS10,SSL3] 71 72arbitraryVersions :: Gen [Version] 73arbitraryVersions = sublistOf knownVersions 74 75-- for performance reason ecdsa_secp521r1_sha512 is not tested 76knownHashSignatures :: [HashAndSignatureAlgorithm] 77knownHashSignatures = [(TLS.HashIntrinsic, SignatureRSApssRSAeSHA512) 78 ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA384) 79 ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA256) 80 ,(TLS.HashIntrinsic, SignatureEd25519) 81 ,(TLS.HashIntrinsic, SignatureEd448) 82 ,(TLS.HashSHA512, SignatureRSA) 83 ,(TLS.HashSHA384, SignatureRSA) 84 ,(TLS.HashSHA384, SignatureECDSA) 85 ,(TLS.HashSHA256, SignatureRSA) 86 ,(TLS.HashSHA256, SignatureECDSA) 87 ,(TLS.HashSHA1, SignatureRSA) 88 ,(TLS.HashSHA1, SignatureDSS) 89 ] 90 91knownHashSignatures13 :: [HashAndSignatureAlgorithm] 92knownHashSignatures13 = filter compat knownHashSignatures 93 where 94 compat (h,s) = h /= TLS.HashSHA1 && s /= SignatureDSS && s /= SignatureRSA 95 96arbitraryHashSignatures :: Version -> Gen [HashAndSignatureAlgorithm] 97arbitraryHashSignatures v = sublistOf l 98 where l = if v < TLS13 then knownHashSignatures else knownHashSignatures13 99 100-- for performance reason P521, FFDHE6144, FFDHE8192 are not tested 101knownGroups, knownECGroups, knownFFGroups :: [Group] 102knownECGroups = [P256,P384,X25519,X448] 103knownFFGroups = [FFDHE2048,FFDHE3072,FFDHE4096] 104knownGroups = knownECGroups ++ knownFFGroups 105 106defaultECGroup :: Group 107defaultECGroup = P256 -- same as defaultECCurve 108 109otherKnownECGroups :: [Group] 110otherKnownECGroups = filter (/= defaultECGroup) knownECGroups 111 112arbitraryGroups :: Gen [Group] 113arbitraryGroups = scale (min 5) $ listOf1 $ elements knownGroups 114 115isCredentialDSA :: (CertificateChain, PrivKey) -> Bool 116isCredentialDSA (_, PrivKeyDSA _) = True 117isCredentialDSA _ = False 118 119arbitraryCredentialsOfEachType :: Gen [(CertificateChain, PrivKey)] 120arbitraryCredentialsOfEachType = arbitraryCredentialsOfEachType' >>= shuffle 121 122arbitraryCredentialsOfEachType' :: Gen [(CertificateChain, PrivKey)] 123arbitraryCredentialsOfEachType' = do 124 let (pubKey, privKey) = getGlobalRSAPair 125 curveName = defaultECCurve 126 (dsaPub, dsaPriv) <- arbitraryDSAPair 127 (ecdsaPub, ecdsaPriv) <- arbitraryECDSAPair curveName 128 (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair 129 (ed448Pub, ed448Priv) <- arbitraryEd448Pair 130 mapM (\(pub, priv) -> do 131 cert <- arbitraryX509WithKey (pub, priv) 132 return (CertificateChain [cert], priv) 133 ) [ (PubKeyRSA pubKey, PrivKeyRSA privKey) 134 , (PubKeyDSA dsaPub, PrivKeyDSA dsaPriv) 135 , (toPubKeyEC curveName ecdsaPub, toPrivKeyEC curveName ecdsaPriv) 136 , (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv) 137 , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv) 138 ] 139 140arbitraryCredentialsOfEachCurve :: Gen [(CertificateChain, PrivKey)] 141arbitraryCredentialsOfEachCurve = arbitraryCredentialsOfEachCurve' >>= shuffle 142 143arbitraryCredentialsOfEachCurve' :: Gen [(CertificateChain, PrivKey)] 144arbitraryCredentialsOfEachCurve' = do 145 ecdsaPairs <- 146 mapM (\curveName -> do 147 (ecdsaPub, ecdsaPriv) <- arbitraryECDSAPair curveName 148 return (toPubKeyEC curveName ecdsaPub, toPrivKeyEC curveName ecdsaPriv) 149 ) knownECCurves 150 (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair 151 (ed448Pub, ed448Priv) <- arbitraryEd448Pair 152 mapM (\(pub, priv) -> do 153 cert <- arbitraryX509WithKey (pub, priv) 154 return (CertificateChain [cert], priv) 155 ) $ [ (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv) 156 , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv) 157 ] ++ ecdsaPairs 158 159dhParamsGroup :: DHParams -> Maybe Group 160dhParamsGroup params 161 | params == ffdhe2048 = Just FFDHE2048 162 | params == ffdhe3072 = Just FFDHE3072 163 | otherwise = Nothing 164 165isCustomDHParams :: DHParams -> Bool 166isCustomDHParams params = params == dhParams512 167 168leafPublicKey :: CertificateChain -> Maybe PubKey 169leafPublicKey (CertificateChain []) = Nothing 170leafPublicKey (CertificateChain (leaf:_)) = Just (certPubKey $ getCertificate leaf) 171 172isLeafRSA :: Maybe CertificateChain -> Bool 173isLeafRSA chain = case chain >>= leafPublicKey of 174 Just (PubKeyRSA _) -> True 175 _ -> False 176 177arbitraryCipherPair :: Version -> Gen ([Cipher], [Cipher]) 178arbitraryCipherPair connectVersion = do 179 serverCiphers <- arbitraryCiphers `suchThat` 180 (\cs -> or [cipherAllowedForVersion connectVersion x | x <- cs]) 181 clientCiphers <- arbitraryCiphers `suchThat` 182 (\cs -> or [x `elem` serverCiphers && 183 cipherAllowedForVersion connectVersion x | x <- cs]) 184 return (clientCiphers, serverCiphers) 185 186arbitraryPairParams :: Gen (ClientParams, ServerParams) 187arbitraryPairParams = elements knownVersions >>= arbitraryPairParamsAt 188 189-- Pair of groups so that at least the default EC group P256 and one FF group 190-- are in common. This makes DHE and ECDHE ciphers always compatible with 191-- extension "Supported Elliptic Curves" / "Supported Groups". 192arbitraryGroupPair :: Gen ([Group], [Group]) 193arbitraryGroupPair = do 194 (serverECGroups, clientECGroups) <- arbitraryGroupPairWith defaultECGroup otherKnownECGroups 195 (serverFFGroups, clientFFGroups) <- arbitraryGroupPairFrom knownFFGroups 196 serverGroups <- shuffle (serverECGroups ++ serverFFGroups) 197 clientGroups <- shuffle (clientECGroups ++ clientFFGroups) 198 return (clientGroups, serverGroups) 199 where 200 arbitraryGroupPairFrom list = elements list >>= \e -> 201 arbitraryGroupPairWith e (filter (/= e) list) 202 arbitraryGroupPairWith e es = do 203 s <- sublistOf es 204 c <- sublistOf es 205 return (e : s, e : c) 206 207arbitraryPairParams13 :: Gen (ClientParams, ServerParams) 208arbitraryPairParams13 = arbitraryPairParamsAt TLS13 209 210arbitraryPairParamsAt :: Version -> Gen (ClientParams, ServerParams) 211arbitraryPairParamsAt connectVersion = do 212 (clientCiphers, serverCiphers) <- arbitraryCipherPair connectVersion 213 -- Select version lists containing connectVersion, as well as some other 214 -- versions for which we have compatible ciphers. Criteria about cipher 215 -- ensure we can test version downgrade. 216 let allowedVersions = [ v | v <- knownVersions, 217 or [ x `elem` serverCiphers && 218 cipherAllowedForVersion v x | x <- clientCiphers ]] 219 allowedVersionsFiltered = filter (<= connectVersion) allowedVersions 220 -- Server or client is allowed to have versions > connectVersion, but not 221 -- both simultaneously. 222 filterSrv <- arbitrary 223 let (clientAllowedVersions, serverAllowedVersions) 224 | filterSrv = (allowedVersions, allowedVersionsFiltered) 225 | otherwise = (allowedVersionsFiltered, allowedVersions) 226 -- Generate version lists containing less than 127 elements, otherwise the 227 -- "supported_versions" extension cannot be correctly serialized 228 clientVersions <- listWithOthers connectVersion 126 clientAllowedVersions 229 serverVersions <- listWithOthers connectVersion 126 serverAllowedVersions 230 arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers) 231 where 232 listWithOthers :: a -> Int -> [a] -> Gen [a] 233 listWithOthers fixedElement maxOthers others 234 | maxOthers < 1 = return [fixedElement] 235 | otherwise = sized $ \n -> do 236 num <- choose (0, min n maxOthers) 237 pos <- choose (0, num) 238 prefix <- vectorOf pos $ elements others 239 suffix <- vectorOf (num - pos) $ elements others 240 return $ prefix ++ (fixedElement : suffix) 241 242getConnectVersion :: (ClientParams, ServerParams) -> Version 243getConnectVersion (cparams, sparams) = maximum (cver `intersect` sver) 244 where 245 sver = supportedVersions (serverSupported sparams) 246 cver = supportedVersions (clientSupported cparams) 247 248isVersionEnabled :: Version -> (ClientParams, ServerParams) -> Bool 249isVersionEnabled ver (cparams, sparams) = 250 (ver `elem` supportedVersions (serverSupported sparams)) && 251 (ver `elem` supportedVersions (clientSupported cparams)) 252 253arbitraryHashSignaturePair :: Gen ([HashAndSignatureAlgorithm], [HashAndSignatureAlgorithm]) 254arbitraryHashSignaturePair = do 255 serverHashSignatures <- shuffle knownHashSignatures 256 clientHashSignatures <- shuffle knownHashSignatures 257 return (clientHashSignatures, serverHashSignatures) 258 259arbitraryPairParamsWithVersionsAndCiphers :: ([Version], [Version]) 260 -> ([Cipher], [Cipher]) 261 -> Gen (ClientParams, ServerParams) 262arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers) = do 263 secNeg <- arbitrary 264 dhparams <- elements [dhParams512,ffdhe2048,ffdhe3072] 265 266 creds <- arbitraryCredentialsOfEachType 267 (clientGroups, serverGroups) <- arbitraryGroupPair 268 (clientHashSignatures, serverHashSignatures) <- arbitraryHashSignaturePair 269 let serverState = def 270 { serverSupported = def { supportedCiphers = serverCiphers 271 , supportedVersions = serverVersions 272 , supportedSecureRenegotiation = secNeg 273 , supportedGroups = serverGroups 274 , supportedHashSignatures = serverHashSignatures 275 } 276 , serverDHEParams = Just dhparams 277 , serverShared = def { sharedCredentials = Credentials creds } 278 } 279 let clientState = (defaultParamsClient "" B.empty) 280 { clientSupported = def { supportedCiphers = clientCiphers 281 , supportedVersions = clientVersions 282 , supportedSecureRenegotiation = secNeg 283 , supportedGroups = clientGroups 284 , supportedHashSignatures = clientHashSignatures 285 } 286 , clientShared = def { sharedValidationCache = ValidationCache 287 { cacheAdd = \_ _ _ -> return () 288 , cacheQuery = \_ _ _ -> return ValidationCachePass 289 } 290 } 291 } 292 return (clientState, serverState) 293 294arbitraryClientCredential :: Version -> Gen Credential 295arbitraryClientCredential SSL3 = do 296 -- for SSL3 there is no EC but only RSA/DSA 297 creds <- arbitraryCredentialsOfEachType' 298 elements (take 2 creds) -- RSA and DSA, but not ECDSA, Ed25519 and Ed448 299arbitraryClientCredential v | v < TLS12 = do 300 -- for TLS10 and TLS11 there is no EdDSA but only RSA/DSA/ECDSA 301 creds <- arbitraryCredentialsOfEachType' 302 elements (take 3 creds) -- RSA, DSA and ECDSA, but not EdDSA 303arbitraryClientCredential _ = arbitraryCredentialsOfEachType' >>= elements 304 305arbitraryRSACredentialWithUsage :: [ExtKeyUsageFlag] -> Gen (CertificateChain, PrivKey) 306arbitraryRSACredentialWithUsage usageFlags = do 307 let (pubKey, privKey) = getGlobalRSAPair 308 cert <- arbitraryX509WithKeyAndUsage usageFlags (PubKeyRSA pubKey, ()) 309 return (CertificateChain [cert], PrivKeyRSA privKey) 310 311arbitraryEMSMode :: Gen (EMSMode, EMSMode) 312arbitraryEMSMode = (,) <$> gen <*> gen 313 where gen = elements [ NoEMS, AllowEMS, RequireEMS ] 314 315setEMSMode :: (EMSMode, EMSMode) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams) 316setEMSMode (cems, sems) (clientParam, serverParam) = (clientParam', serverParam') 317 where 318 clientParam' = clientParam { clientSupported = (clientSupported clientParam) 319 { supportedExtendedMasterSec = cems } 320 } 321 serverParam' = serverParam { serverSupported = (serverSupported serverParam) 322 { supportedExtendedMasterSec = sems } 323 } 324 325readClientSessionRef :: (IORef mclient, IORef mserver) -> IO mclient 326readClientSessionRef refs = readIORef (fst refs) 327 328twoSessionRefs :: IO (IORef (Maybe client), IORef (Maybe server)) 329twoSessionRefs = (,) <$> newIORef Nothing <*> newIORef Nothing 330 331-- | simple session manager to store one session id and session data for a single thread. 332-- a Real concurrent session manager would use an MVar and have multiples items. 333oneSessionManager :: IORef (Maybe (SessionID, SessionData)) -> SessionManager 334oneSessionManager ref = SessionManager 335 { sessionResume = \myId -> readIORef ref >>= maybeResume False myId 336 , sessionResumeOnlyOnce = \myId -> readIORef ref >>= maybeResume True myId 337 , sessionEstablish = \myId dat -> writeIORef ref $ Just (myId, dat) 338 , sessionInvalidate = \_ -> return () 339 } 340 where 341 maybeResume onlyOnce myId (Just (sid, sdata)) 342 | sid == myId = when onlyOnce (writeIORef ref Nothing) >> return (Just sdata) 343 maybeResume _ _ _ = return Nothing 344 345twoSessionManagers :: (IORef (Maybe (SessionID, SessionData)), IORef (Maybe (SessionID, SessionData))) -> (SessionManager, SessionManager) 346twoSessionManagers (cRef, sRef) = (oneSessionManager cRef, oneSessionManager sRef) 347 348setPairParamsSessionManagers :: (SessionManager, SessionManager) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams) 349setPairParamsSessionManagers (clientManager, serverManager) (clientState, serverState) = (nc,ns) 350 where nc = clientState { clientShared = updateSessionManager clientManager $ clientShared clientState } 351 ns = serverState { serverShared = updateSessionManager serverManager $ serverShared serverState } 352 updateSessionManager manager shared = shared { sharedSessionManager = manager } 353 354setPairParamsSessionResuming :: (SessionID, SessionData) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams) 355setPairParamsSessionResuming sessionStuff (clientState, serverState) = 356 ( clientState { clientWantSessionResume = Just sessionStuff } 357 , serverState) 358 359newPairContext :: PipeChan -> (ClientParams, ServerParams) -> IO (Context, Context) 360newPairContext pipe (cParams, sParams) = do 361 let noFlush = return () 362 let noClose = return () 363 364 let cBackend = Backend noFlush noClose (writePipeA pipe) (readPipeA pipe) 365 let sBackend = Backend noFlush noClose (writePipeB pipe) (readPipeB pipe) 366 cCtx' <- contextNew cBackend cParams 367 sCtx' <- contextNew sBackend sParams 368 369 contextHookSetLogging cCtx' (logging "client: ") 370 contextHookSetLogging sCtx' (logging "server: ") 371 372 return (cCtx', sCtx') 373 where 374 logging pre = 375 if debug 376 then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) 377 , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) } 378 else def 379 380withDataPipe :: (ClientParams, ServerParams) -> (Context -> Chan result -> IO ()) -> (Chan start -> Context -> IO ()) -> ((start -> IO (), IO result) -> IO a) -> IO a 381withDataPipe params tlsServer tlsClient cont = do 382 -- initial setup 383 pipe <- newPipe 384 _ <- runPipe pipe 385 startQueue <- newChan 386 resultQueue <- newChan 387 388 (cCtx, sCtx) <- newPairContext pipe params 389 390 withAsync (E.catch (tlsServer sCtx resultQueue) 391 (printAndRaise "server" (serverSupported $ snd params))) $ \sAsync -> withAsync (E.catch (tlsClient startQueue cCtx) 392 (printAndRaise "client" (clientSupported $ fst params))) $ \cAsync -> do 393 let readResult = waitBoth cAsync sAsync >> readChan resultQueue 394 cont (writeChan startQueue, readResult) 395 396 where 397 printAndRaise :: String -> Supported -> E.SomeException -> IO () 398 printAndRaise s supported e = do 399 putStrLn $ s ++ " exception: " ++ show e ++ 400 ", supported: " ++ show supported 401 E.throwIO e 402 403initiateDataPipe :: (ClientParams, ServerParams) -> (Context -> IO a1) -> (Context -> IO a) -> IO (Either E.SomeException a, Either E.SomeException a1) 404initiateDataPipe params tlsServer tlsClient = do 405 -- initial setup 406 pipe <- newPipe 407 _ <- runPipe pipe 408 409 (cCtx, sCtx) <- newPairContext pipe params 410 411 async (tlsServer sCtx) >>= \sAsync -> 412 async (tlsClient cCtx) >>= \cAsync -> do 413 sRes <- waitCatch sAsync 414 cRes <- waitCatch cAsync 415 return (cRes, sRes) 416 417-- Terminate the write direction and wait to receive the peer EOF. This is 418-- necessary in situations where we want to confirm the peer status, or to make 419-- sure to receive late messages like session tickets. In the test suite this 420-- is used each time application code ends the connection without prior call to 421-- 'recvData'. 422byeBye :: Context -> IO () 423byeBye ctx = do 424 bye ctx 425 bs <- recvData ctx 426 unless (B.null bs) $ fail "byeBye: unexpected application data" 427