1-- Disable this warning so we can still test deprecated functionality.
2{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
3module Connection
4    ( newPairContext
5    , arbitraryCiphers
6    , arbitraryVersions
7    , arbitraryHashSignatures
8    , arbitraryGroups
9    , arbitraryKeyUsage
10    , arbitraryPairParams
11    , arbitraryPairParams13
12    , arbitraryPairParamsWithVersionsAndCiphers
13    , arbitraryClientCredential
14    , arbitraryCredentialsOfEachCurve
15    , arbitraryRSACredentialWithUsage
16    , dhParamsGroup
17    , getConnectVersion
18    , isVersionEnabled
19    , isCustomDHParams
20    , isLeafRSA
21    , isCredentialDSA
22    , arbitraryEMSMode
23    , setEMSMode
24    , readClientSessionRef
25    , twoSessionRefs
26    , twoSessionManagers
27    , setPairParamsSessionManagers
28    , setPairParamsSessionResuming
29    , withDataPipe
30    , initiateDataPipe
31    , byeBye
32    ) where
33
34import Test.Tasty.QuickCheck
35import Certificate
36import PubKey
37import PipeChan
38import Network.TLS as TLS
39import Network.TLS.Extra
40import Data.X509
41import Data.Default.Class
42import Data.IORef
43import Control.Applicative
44import Control.Concurrent.Async
45import Control.Concurrent.Chan
46import Control.Concurrent
47import qualified Control.Exception as E
48import Control.Monad (unless, when)
49import Data.List (intersect, isInfixOf)
50
51import qualified Data.ByteString as B
52
53debug :: Bool
54debug = False
55
56knownCiphers :: [Cipher]
57knownCiphers = ciphersuite_all ++ ciphersuite_weak
58  where
59    ciphersuite_weak = [
60        cipher_DHE_DSS_RC4_SHA1
61      , cipher_RC4_128_MD5
62      , cipher_null_MD5
63      , cipher_null_SHA1
64      ]
65
66arbitraryCiphers :: Gen [Cipher]
67arbitraryCiphers = listOf1 $ elements knownCiphers
68
69knownVersions :: [Version]
70knownVersions = [TLS13,TLS12,TLS11,TLS10,SSL3]
71
72arbitraryVersions :: Gen [Version]
73arbitraryVersions = sublistOf knownVersions
74
75-- for performance reason ecdsa_secp521r1_sha512 is not tested
76knownHashSignatures :: [HashAndSignatureAlgorithm]
77knownHashSignatures =         [(TLS.HashIntrinsic, SignatureRSApssRSAeSHA512)
78                              ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA384)
79                              ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA256)
80                              ,(TLS.HashIntrinsic, SignatureEd25519)
81                              ,(TLS.HashIntrinsic, SignatureEd448)
82                              ,(TLS.HashSHA512, SignatureRSA)
83                              ,(TLS.HashSHA384, SignatureRSA)
84                              ,(TLS.HashSHA384, SignatureECDSA)
85                              ,(TLS.HashSHA256, SignatureRSA)
86                              ,(TLS.HashSHA256, SignatureECDSA)
87                              ,(TLS.HashSHA1,   SignatureRSA)
88                              ,(TLS.HashSHA1,   SignatureDSS)
89                              ]
90
91knownHashSignatures13 :: [HashAndSignatureAlgorithm]
92knownHashSignatures13 = filter compat knownHashSignatures
93  where
94    compat (h,s) = h /= TLS.HashSHA1 && s /= SignatureDSS && s /= SignatureRSA
95
96arbitraryHashSignatures :: Version -> Gen [HashAndSignatureAlgorithm]
97arbitraryHashSignatures v = sublistOf l
98    where l = if v < TLS13 then knownHashSignatures else knownHashSignatures13
99
100-- for performance reason P521, FFDHE6144, FFDHE8192 are not tested
101knownGroups, knownECGroups, knownFFGroups :: [Group]
102knownECGroups = [P256,P384,X25519,X448]
103knownFFGroups = [FFDHE2048,FFDHE3072,FFDHE4096]
104knownGroups   = knownECGroups ++ knownFFGroups
105
106defaultECGroup :: Group
107defaultECGroup = P256  -- same as defaultECCurve
108
109otherKnownECGroups :: [Group]
110otherKnownECGroups = filter (/= defaultECGroup) knownECGroups
111
112arbitraryGroups :: Gen [Group]
113arbitraryGroups = scale (min 5) $ listOf1 $ elements knownGroups
114
115isCredentialDSA :: (CertificateChain, PrivKey) -> Bool
116isCredentialDSA (_, PrivKeyDSA _) = True
117isCredentialDSA _                 = False
118
119arbitraryCredentialsOfEachType :: Gen [(CertificateChain, PrivKey)]
120arbitraryCredentialsOfEachType = arbitraryCredentialsOfEachType' >>= shuffle
121
122arbitraryCredentialsOfEachType' :: Gen [(CertificateChain, PrivKey)]
123arbitraryCredentialsOfEachType' = do
124    let (pubKey, privKey) = getGlobalRSAPair
125        curveName = defaultECCurve
126    (dsaPub, dsaPriv) <- arbitraryDSAPair
127    (ecdsaPub, ecdsaPriv) <- arbitraryECDSAPair curveName
128    (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair
129    (ed448Pub, ed448Priv) <- arbitraryEd448Pair
130    mapM (\(pub, priv) -> do
131              cert <- arbitraryX509WithKey (pub, priv)
132              return (CertificateChain [cert], priv)
133         ) [ (PubKeyRSA pubKey, PrivKeyRSA privKey)
134           , (PubKeyDSA dsaPub, PrivKeyDSA dsaPriv)
135           , (toPubKeyEC curveName ecdsaPub, toPrivKeyEC curveName ecdsaPriv)
136           , (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv)
137           , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv)
138           ]
139
140arbitraryCredentialsOfEachCurve :: Gen [(CertificateChain, PrivKey)]
141arbitraryCredentialsOfEachCurve = arbitraryCredentialsOfEachCurve' >>= shuffle
142
143arbitraryCredentialsOfEachCurve' :: Gen [(CertificateChain, PrivKey)]
144arbitraryCredentialsOfEachCurve' = do
145    ecdsaPairs <-
146        mapM (\curveName -> do
147                 (ecdsaPub, ecdsaPriv) <- arbitraryECDSAPair curveName
148                 return (toPubKeyEC curveName ecdsaPub, toPrivKeyEC curveName ecdsaPriv)
149             ) knownECCurves
150    (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair
151    (ed448Pub, ed448Priv) <- arbitraryEd448Pair
152    mapM (\(pub, priv) -> do
153              cert <- arbitraryX509WithKey (pub, priv)
154              return (CertificateChain [cert], priv)
155         ) $ [ (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv)
156             , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv)
157             ] ++ ecdsaPairs
158
159dhParamsGroup :: DHParams -> Maybe Group
160dhParamsGroup params
161    | params == ffdhe2048 = Just FFDHE2048
162    | params == ffdhe3072 = Just FFDHE3072
163    | otherwise           = Nothing
164
165isCustomDHParams :: DHParams -> Bool
166isCustomDHParams params = params == dhParams512
167
168leafPublicKey :: CertificateChain -> Maybe PubKey
169leafPublicKey (CertificateChain [])       = Nothing
170leafPublicKey (CertificateChain (leaf:_)) = Just (certPubKey $ getCertificate leaf)
171
172isLeafRSA :: Maybe CertificateChain -> Bool
173isLeafRSA chain = case chain >>= leafPublicKey of
174                        Just (PubKeyRSA _) -> True
175                        _                  -> False
176
177arbitraryCipherPair :: Version -> Gen ([Cipher], [Cipher])
178arbitraryCipherPair connectVersion = do
179    serverCiphers      <- arbitraryCiphers `suchThat`
180                                (\cs -> or [cipherAllowedForVersion connectVersion x | x <- cs])
181    clientCiphers      <- arbitraryCiphers `suchThat`
182                                (\cs -> or [x `elem` serverCiphers &&
183                                            cipherAllowedForVersion connectVersion x | x <- cs])
184    return (clientCiphers, serverCiphers)
185
186arbitraryPairParams :: Gen (ClientParams, ServerParams)
187arbitraryPairParams = elements knownVersions >>= arbitraryPairParamsAt
188
189-- Pair of groups so that at least the default EC group P256 and one FF group
190-- are in common.  This makes DHE and ECDHE ciphers always compatible with
191-- extension "Supported Elliptic Curves" / "Supported Groups".
192arbitraryGroupPair :: Gen ([Group], [Group])
193arbitraryGroupPair = do
194    (serverECGroups, clientECGroups) <- arbitraryGroupPairWith defaultECGroup otherKnownECGroups
195    (serverFFGroups, clientFFGroups) <- arbitraryGroupPairFrom knownFFGroups
196    serverGroups <- shuffle (serverECGroups ++ serverFFGroups)
197    clientGroups <- shuffle (clientECGroups ++ clientFFGroups)
198    return (clientGroups, serverGroups)
199  where
200    arbitraryGroupPairFrom list = elements list >>= \e ->
201        arbitraryGroupPairWith e (filter (/= e) list)
202    arbitraryGroupPairWith e es = do
203        s <- sublistOf es
204        c <- sublistOf es
205        return (e : s, e : c)
206
207arbitraryPairParams13 :: Gen (ClientParams, ServerParams)
208arbitraryPairParams13 = arbitraryPairParamsAt TLS13
209
210arbitraryPairParamsAt :: Version -> Gen (ClientParams, ServerParams)
211arbitraryPairParamsAt connectVersion = do
212    (clientCiphers, serverCiphers) <- arbitraryCipherPair connectVersion
213    -- Select version lists containing connectVersion, as well as some other
214    -- versions for which we have compatible ciphers.  Criteria about cipher
215    -- ensure we can test version downgrade.
216    let allowedVersions = [ v | v <- knownVersions,
217                                or [ x `elem` serverCiphers &&
218                                     cipherAllowedForVersion v x | x <- clientCiphers ]]
219        allowedVersionsFiltered = filter (<= connectVersion) allowedVersions
220    -- Server or client is allowed to have versions > connectVersion, but not
221    -- both simultaneously.
222    filterSrv <- arbitrary
223    let (clientAllowedVersions, serverAllowedVersions)
224            | filterSrv = (allowedVersions, allowedVersionsFiltered)
225otherwise = (allowedVersionsFiltered, allowedVersions)
226    -- Generate version lists containing less than 127 elements, otherwise the
227    -- "supported_versions" extension cannot be correctly serialized
228    clientVersions <- listWithOthers connectVersion 126 clientAllowedVersions
229    serverVersions <- listWithOthers connectVersion 126 serverAllowedVersions
230    arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers)
231  where
232    listWithOthers :: a -> Int -> [a] -> Gen [a]
233    listWithOthers fixedElement maxOthers others
234        | maxOthers < 1 = return [fixedElement]
235        | otherwise     = sized $ \n -> do
236            num <- choose (0, min n maxOthers)
237            pos <- choose (0, num)
238            prefix <- vectorOf pos $ elements others
239            suffix <- vectorOf (num - pos) $ elements others
240            return $ prefix ++ (fixedElement : suffix)
241
242getConnectVersion :: (ClientParams, ServerParams) -> Version
243getConnectVersion (cparams, sparams) = maximum (cver `intersect` sver)
244  where
245    sver = supportedVersions (serverSupported sparams)
246    cver = supportedVersions (clientSupported cparams)
247
248isVersionEnabled :: Version -> (ClientParams, ServerParams) -> Bool
249isVersionEnabled ver (cparams, sparams) =
250    (ver `elem` supportedVersions (serverSupported sparams)) &&
251    (ver `elem` supportedVersions (clientSupported cparams))
252
253arbitraryHashSignaturePair :: Gen ([HashAndSignatureAlgorithm], [HashAndSignatureAlgorithm])
254arbitraryHashSignaturePair = do
255    serverHashSignatures <- shuffle knownHashSignatures
256    clientHashSignatures <- shuffle knownHashSignatures
257    return (clientHashSignatures, serverHashSignatures)
258
259arbitraryPairParamsWithVersionsAndCiphers :: ([Version], [Version])
260                                          -> ([Cipher], [Cipher])
261                                          -> Gen (ClientParams, ServerParams)
262arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers) = do
263    secNeg             <- arbitrary
264    dhparams           <- elements [dhParams512,ffdhe2048,ffdhe3072]
265
266    creds              <- arbitraryCredentialsOfEachType
267    (clientGroups, serverGroups) <- arbitraryGroupPair
268    (clientHashSignatures, serverHashSignatures) <- arbitraryHashSignaturePair
269    let serverState = def
270            { serverSupported = def { supportedCiphers  = serverCiphers
271                                    , supportedVersions = serverVersions
272                                    , supportedSecureRenegotiation = secNeg
273                                    , supportedGroups   = serverGroups
274                                    , supportedHashSignatures = serverHashSignatures
275                                    }
276            , serverDHEParams = Just dhparams
277            , serverShared = def { sharedCredentials = Credentials creds }
278            }
279    let clientState = (defaultParamsClient "" B.empty)
280            { clientSupported = def { supportedCiphers  = clientCiphers
281                                    , supportedVersions = clientVersions
282                                    , supportedSecureRenegotiation = secNeg
283                                    , supportedGroups   = clientGroups
284                                    , supportedHashSignatures = clientHashSignatures
285                                    }
286            , clientShared = def { sharedValidationCache = ValidationCache
287                                        { cacheAdd = \_ _ _ -> return ()
288                                        , cacheQuery = \_ _ _ -> return ValidationCachePass
289                                        }
290                                }
291            }
292    return (clientState, serverState)
293
294arbitraryClientCredential :: Version -> Gen Credential
295arbitraryClientCredential SSL3 = do
296    -- for SSL3 there is no EC but only RSA/DSA
297    creds <- arbitraryCredentialsOfEachType'
298    elements (take 2 creds) -- RSA and DSA, but not ECDSA, Ed25519 and Ed448
299arbitraryClientCredential v | v < TLS12 = do
300    -- for TLS10 and TLS11 there is no EdDSA but only RSA/DSA/ECDSA
301    creds <- arbitraryCredentialsOfEachType'
302    elements (take 3 creds) -- RSA, DSA and ECDSA, but not EdDSA
303arbitraryClientCredential _    = arbitraryCredentialsOfEachType' >>= elements
304
305arbitraryRSACredentialWithUsage :: [ExtKeyUsageFlag] -> Gen (CertificateChain, PrivKey)
306arbitraryRSACredentialWithUsage usageFlags = do
307    let (pubKey, privKey) = getGlobalRSAPair
308    cert <- arbitraryX509WithKeyAndUsage usageFlags (PubKeyRSA pubKey, ())
309    return (CertificateChain [cert], PrivKeyRSA privKey)
310
311arbitraryEMSMode :: Gen (EMSMode, EMSMode)
312arbitraryEMSMode = (,) <$> gen <*> gen
313  where gen = elements [ NoEMS, AllowEMS, RequireEMS ]
314
315setEMSMode :: (EMSMode, EMSMode) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
316setEMSMode (cems, sems) (clientParam, serverParam) = (clientParam', serverParam')
317  where
318    clientParam' = clientParam { clientSupported = (clientSupported clientParam)
319                                   { supportedExtendedMasterSec = cems }
320                               }
321    serverParam' = serverParam { serverSupported = (serverSupported serverParam)
322                                   { supportedExtendedMasterSec = sems }
323                               }
324
325readClientSessionRef :: (IORef mclient, IORef mserver) -> IO mclient
326readClientSessionRef refs = readIORef (fst refs)
327
328twoSessionRefs :: IO (IORef (Maybe client), IORef (Maybe server))
329twoSessionRefs = (,) <$> newIORef Nothing <*> newIORef Nothing
330
331-- | simple session manager to store one session id and session data for a single thread.
332-- a Real concurrent session manager would use an MVar and have multiples items.
333oneSessionManager :: IORef (Maybe (SessionID, SessionData)) -> SessionManager
334oneSessionManager ref = SessionManager
335    { sessionResume         = \myId     -> readIORef ref >>= maybeResume False myId
336    , sessionResumeOnlyOnce = \myId     -> readIORef ref >>= maybeResume True myId
337    , sessionEstablish      = \myId dat -> writeIORef ref $ Just (myId, dat)
338    , sessionInvalidate     = \_        -> return ()
339    }
340  where
341    maybeResume onlyOnce myId (Just (sid, sdata))
342        | sid == myId = when onlyOnce (writeIORef ref Nothing) >> return (Just sdata)
343    maybeResume _ _ _ = return Nothing
344
345twoSessionManagers :: (IORef (Maybe (SessionID, SessionData)), IORef (Maybe (SessionID, SessionData))) -> (SessionManager, SessionManager)
346twoSessionManagers (cRef, sRef) = (oneSessionManager cRef, oneSessionManager sRef)
347
348setPairParamsSessionManagers :: (SessionManager, SessionManager) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
349setPairParamsSessionManagers (clientManager, serverManager) (clientState, serverState) = (nc,ns)
350  where nc = clientState { clientShared = updateSessionManager clientManager $ clientShared clientState }
351        ns = serverState { serverShared = updateSessionManager serverManager $ serverShared serverState }
352        updateSessionManager manager shared = shared { sharedSessionManager = manager }
353
354setPairParamsSessionResuming :: (SessionID, SessionData) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
355setPairParamsSessionResuming sessionStuff (clientState, serverState) =
356    ( clientState { clientWantSessionResume = Just sessionStuff }
357    , serverState)
358
359newPairContext :: PipeChan -> (ClientParams, ServerParams) -> IO (Context, Context)
360newPairContext pipe (cParams, sParams) = do
361    let noFlush = return ()
362    let noClose = return ()
363
364    let cBackend = Backend noFlush noClose (writePipeA pipe) (readPipeA pipe)
365    let sBackend = Backend noFlush noClose (writePipeB pipe) (readPipeB pipe)
366    cCtx' <- contextNew cBackend cParams
367    sCtx' <- contextNew sBackend sParams
368
369    contextHookSetLogging cCtx' (logging "client: ")
370    contextHookSetLogging sCtx' (logging "server: ")
371
372    return (cCtx', sCtx')
373  where
374        logging pre =
375            if debug
376                then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++)
377                                    , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) }
378                else def
379
380withDataPipe :: (ClientParams, ServerParams) -> (Context -> Chan result -> IO ()) -> (Chan start -> Context -> IO ()) -> ((start -> IO (), IO result) -> IO a) -> IO a
381withDataPipe params tlsServer tlsClient cont = do
382    -- initial setup
383    pipe        <- newPipe
384    _           <- runPipe pipe
385    startQueue  <- newChan
386    resultQueue <- newChan
387
388    (cCtx, sCtx) <- newPairContext pipe params
389
390    withAsync (E.catch (tlsServer sCtx resultQueue)
391                       (printAndRaise "server" (serverSupported $ snd params))) $ \sAsync -> withAsync (E.catch (tlsClient startQueue cCtx)
392                                (printAndRaise "client" (clientSupported $ fst params))) $ \cAsync -> do
393      let readResult = waitBoth cAsync sAsync >> readChan resultQueue
394      cont (writeChan startQueue, readResult)
395
396  where
397        printAndRaise :: String -> Supported -> E.SomeException -> IO ()
398        printAndRaise s supported e = do
399            putStrLn $ s ++ " exception: " ++ show e ++
400                            ", supported: " ++ show supported
401            E.throwIO e
402
403initiateDataPipe :: (ClientParams, ServerParams) -> (Context -> IO a1) -> (Context -> IO a) -> IO (Either E.SomeException a, Either E.SomeException a1)
404initiateDataPipe params tlsServer tlsClient = do
405    -- initial setup
406    pipe        <- newPipe
407    _           <- runPipe pipe
408
409    (cCtx, sCtx) <- newPairContext pipe params
410
411    async (tlsServer sCtx) >>= \sAsync ->
412        async (tlsClient cCtx) >>= \cAsync -> do
413            sRes <- waitCatch sAsync
414            cRes <- waitCatch cAsync
415            return (cRes, sRes)
416
417-- Terminate the write direction and wait to receive the peer EOF.  This is
418-- necessary in situations where we want to confirm the peer status, or to make
419-- sure to receive late messages like session tickets.  In the test suite this
420-- is used each time application code ends the connection without prior call to
421-- 'recvData'.
422byeBye :: Context -> IO ()
423byeBye ctx = do
424    bye ctx
425    bs <- recvData ctx
426    unless (B.null bs) $ fail "byeBye: unexpected application data"
427