1-- Disable this warning so we can still test deprecated functionality.
2{-# OPTIONS_GHC -fno-warn-warnings-deprecations #-}
3module Connection
4    ( newPairContext
5    , arbitraryCiphers
6    , arbitraryVersions
7    , arbitraryHashSignatures
8    , arbitraryGroups
9    , arbitraryKeyUsage
10    , arbitraryPairParams
11    , arbitraryPairParams13
12    , arbitraryPairParamsWithVersionsAndCiphers
13    , arbitraryClientCredential
14    , arbitraryCredentialsOfEachCurve
15    , arbitraryRSACredentialWithUsage
16    , dhParamsGroup
17    , getConnectVersion
18    , isVersionEnabled
19    , isCustomDHParams
20    , isLeafRSA
21    , isCredentialDSA
22    , arbitraryEMSMode
23    , setEMSMode
24    , readClientSessionRef
25    , twoSessionRefs
26    , twoSessionManagers
27    , setPairParamsSessionManagers
28    , setPairParamsSessionResuming
29    , withDataPipe
30    , initiateDataPipe
31    , byeBye
32    ) where
33
34import Test.Tasty.QuickCheck
35import Certificate
36import PubKey
37import PipeChan
38import Network.TLS as TLS
39import Network.TLS.Extra
40import Data.X509
41import Data.Default.Class
42import Data.IORef
43import Control.Applicative
44import Control.Concurrent.Async
45import Control.Concurrent.Chan
46import Control.Concurrent
47import qualified Control.Exception as E
48import Control.Monad (unless, when)
49import Data.List (intersect, isInfixOf)
50
51import qualified Data.ByteString as B
52
53debug :: Bool
54debug = False
55
56knownCiphers :: [Cipher]
57knownCiphers = ciphersuite_all ++ ciphersuite_weak
58  where
59    ciphersuite_weak = [
60        cipher_DHE_DSS_RC4_SHA1
61      , cipher_RC4_128_MD5
62      , cipher_null_MD5
63      , cipher_null_SHA1
64      ]
65
66-- local restriction: EdDSA credentials are not usable before TLS12 so without
67-- ECDSA support it is not possible for ECDHE_ECDSA to be successful with TLS10
68-- and TLS11
69cipherAllowedForVersion' :: Version -> Cipher -> Bool
70cipherAllowedForVersion' connectVersion x =
71    cipherAllowedForVersion connectVersion x &&
72    (connectVersion >= TLS12 || cipherKeyExchange x /= CipherKeyExchange_ECDHE_ECDSA)
73
74arbitraryCiphers :: Gen [Cipher]
75arbitraryCiphers = listOf1 $ elements knownCiphers
76
77knownVersions :: [Version]
78knownVersions = [TLS13,TLS12,TLS11,TLS10,SSL3]
79
80arbitraryVersions :: Gen [Version]
81arbitraryVersions = sublistOf knownVersions
82
83knownHashSignatures :: [HashAndSignatureAlgorithm]
84knownHashSignatures = filter nonECDSA availableHashSignatures
85  where
86    availableHashSignatures = [(TLS.HashIntrinsic, SignatureRSApssRSAeSHA512)
87                              ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA384)
88                              ,(TLS.HashIntrinsic, SignatureRSApssRSAeSHA256)
89                              ,(TLS.HashIntrinsic, SignatureEd25519)
90                              ,(TLS.HashIntrinsic, SignatureEd448)
91                              ,(TLS.HashSHA512, SignatureRSA)
92                              ,(TLS.HashSHA512, SignatureECDSA)
93                              ,(TLS.HashSHA384, SignatureRSA)
94                              ,(TLS.HashSHA384, SignatureECDSA)
95                              ,(TLS.HashSHA256, SignatureRSA)
96                              ,(TLS.HashSHA256, SignatureECDSA)
97                              ,(TLS.HashSHA1,   SignatureRSA)
98                              ,(TLS.HashSHA1,   SignatureDSS)
99                              ]
100    -- arbitraryCredentialsOfEachType cannot generate ECDSA
101    nonECDSA (_,s) = s /= SignatureECDSA
102
103knownHashSignatures13 :: [HashAndSignatureAlgorithm]
104knownHashSignatures13 = filter compat knownHashSignatures
105  where
106    compat (h,s) = h /= TLS.HashSHA1 && s /= SignatureDSS && s /= SignatureRSA
107
108arbitraryHashSignatures :: Version -> Gen [HashAndSignatureAlgorithm]
109arbitraryHashSignatures v = sublistOf l
110    where l = if v < TLS13 then knownHashSignatures else knownHashSignatures13
111
112-- for performance reason P521, FFDHE6144, FFDHE8192 are not tested
113knownGroups, knownECGroups, knownFFGroups :: [Group]
114knownECGroups = [P256,P384,X25519,X448]
115knownFFGroups = [FFDHE2048,FFDHE3072,FFDHE4096]
116knownGroups   = knownECGroups ++ knownFFGroups
117
118arbitraryGroups :: Gen [Group]
119arbitraryGroups = scale (min 5) $ listOf1 $ elements knownGroups
120
121isCredentialDSA :: (CertificateChain, PrivKey) -> Bool
122isCredentialDSA (_, PrivKeyDSA _) = True
123isCredentialDSA _                 = False
124
125arbitraryCredentialsOfEachType :: Gen [(CertificateChain, PrivKey)]
126arbitraryCredentialsOfEachType = do
127    let (pubKey, privKey) = getGlobalRSAPair
128    (dsaPub, dsaPriv) <- arbitraryDSAPair
129    (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair
130    (ed448Pub, ed448Priv) <- arbitraryEd448Pair
131    mapM (\(pub, priv) -> do
132              cert <- arbitraryX509WithKey (pub, priv)
133              return (CertificateChain [cert], priv)
134         ) [ (PubKeyRSA pubKey, PrivKeyRSA privKey)
135           , (PubKeyDSA dsaPub, PrivKeyDSA dsaPriv)
136           , (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv)
137           , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv)
138           ]
139
140arbitraryCredentialsOfEachCurve :: Gen [(CertificateChain, PrivKey)]
141arbitraryCredentialsOfEachCurve = do
142    (ed25519Pub, ed25519Priv) <- arbitraryEd25519Pair
143    (ed448Pub, ed448Priv) <- arbitraryEd448Pair
144    mapM (\(pub, priv) -> do
145              cert <- arbitraryX509WithKey (pub, priv)
146              return (CertificateChain [cert], priv)
147         ) [ (PubKeyEd25519 ed25519Pub, PrivKeyEd25519 ed25519Priv)
148           , (PubKeyEd448 ed448Pub, PrivKeyEd448 ed448Priv)
149           ]
150
151dhParamsGroup :: DHParams -> Maybe Group
152dhParamsGroup params
153    | params == ffdhe2048 = Just FFDHE2048
154    | params == ffdhe3072 = Just FFDHE3072
155    | otherwise           = Nothing
156
157isCustomDHParams :: DHParams -> Bool
158isCustomDHParams params = params == dhParams512
159
160leafPublicKey :: CertificateChain -> Maybe PubKey
161leafPublicKey (CertificateChain [])       = Nothing
162leafPublicKey (CertificateChain (leaf:_)) = Just (certPubKey $ getCertificate leaf)
163
164isLeafRSA :: Maybe CertificateChain -> Bool
165isLeafRSA chain = case chain >>= leafPublicKey of
166                        Just (PubKeyRSA _) -> True
167                        _                  -> False
168
169arbitraryCipherPair :: Version -> Gen ([Cipher], [Cipher])
170arbitraryCipherPair connectVersion = do
171    serverCiphers      <- arbitraryCiphers `suchThat`
172                                (\cs -> or [cipherAllowedForVersion' connectVersion x | x <- cs])
173    clientCiphers      <- arbitraryCiphers `suchThat`
174                                (\cs -> or [x `elem` serverCiphers &&
175                                            cipherAllowedForVersion' connectVersion x | x <- cs])
176    return (clientCiphers, serverCiphers)
177
178arbitraryPairParams :: Gen (ClientParams, ServerParams)
179arbitraryPairParams = elements knownVersions >>= arbitraryPairParamsAt
180
181-- pair of groups so that at least one EC and one FF group are in common
182arbitraryGroupPair :: Gen ([Group], [Group])
183arbitraryGroupPair = do
184    (serverECGroups, clientECGroups) <- arbitraryGroupPairFrom knownECGroups
185    (serverFFGroups, clientFFGroups) <- arbitraryGroupPairFrom knownFFGroups
186    serverGroups <- shuffle (serverECGroups ++ serverFFGroups)
187    clientGroups <- shuffle (clientECGroups ++ clientFFGroups)
188    return (clientGroups, serverGroups)
189  where
190    arbitraryGroupPairFrom list = do
191        s <- arbitraryGroupsFrom list
192        c <- arbitraryGroupsFrom list `suchThat` any (`elem` s)
193        return (c, s)
194    arbitraryGroupsFrom list = listOf1 $ elements list
195
196arbitraryPairParams13 :: Gen (ClientParams, ServerParams)
197arbitraryPairParams13 = arbitraryPairParamsAt TLS13
198
199arbitraryPairParamsAt :: Version -> Gen (ClientParams, ServerParams)
200arbitraryPairParamsAt connectVersion = do
201    (clientCiphers, serverCiphers) <- arbitraryCipherPair connectVersion
202    -- Select version lists containing connectVersion, as well as some other
203    -- versions for which we have compatible ciphers.  Criteria about cipher
204    -- ensure we can test version downgrade.
205    let allowedVersions = [ v | v <- knownVersions,
206                                or [ x `elem` serverCiphers &&
207                                     cipherAllowedForVersion' v x | x <- clientCiphers ]]
208        allowedVersionsFiltered = filter (<= connectVersion) allowedVersions
209    -- Server or client is allowed to have versions > connectVersion, but not
210    -- both simultaneously.
211    filterSrv <- arbitrary
212    let (clientAllowedVersions, serverAllowedVersions)
213            | filterSrv = (allowedVersions, allowedVersionsFiltered)
214otherwise = (allowedVersionsFiltered, allowedVersions)
215    -- Generate version lists containing less than 127 elements, otherwise the
216    -- "supported_versions" extension cannot be correctly serialized
217    clientVersions <- listWithOthers connectVersion 126 clientAllowedVersions
218    serverVersions <- listWithOthers connectVersion 126 serverAllowedVersions
219    arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers)
220  where
221    listWithOthers :: a -> Int -> [a] -> Gen [a]
222    listWithOthers fixedElement maxOthers others
223        | maxOthers < 1 = return [fixedElement]
224        | otherwise     = sized $ \n -> do
225            num <- choose (0, min n maxOthers)
226            pos <- choose (0, num)
227            prefix <- vectorOf pos $ elements others
228            suffix <- vectorOf (num - pos) $ elements others
229            return $ prefix ++ (fixedElement : suffix)
230
231getConnectVersion :: (ClientParams, ServerParams) -> Version
232getConnectVersion (cparams, sparams) = maximum (cver `intersect` sver)
233  where
234    sver = supportedVersions (serverSupported sparams)
235    cver = supportedVersions (clientSupported cparams)
236
237isVersionEnabled :: Version -> (ClientParams, ServerParams) -> Bool
238isVersionEnabled ver (cparams, sparams) =
239    (ver `elem` supportedVersions (serverSupported sparams)) &&
240    (ver `elem` supportedVersions (clientSupported cparams))
241
242arbitraryHashSignaturePair :: Gen ([HashAndSignatureAlgorithm], [HashAndSignatureAlgorithm])
243arbitraryHashSignaturePair = do
244    serverHashSignatures <- shuffle knownHashSignatures
245    clientHashSignatures <- shuffle knownHashSignatures
246    return (clientHashSignatures, serverHashSignatures)
247
248arbitraryPairParamsWithVersionsAndCiphers :: ([Version], [Version])
249                                          -> ([Cipher], [Cipher])
250                                          -> Gen (ClientParams, ServerParams)
251arbitraryPairParamsWithVersionsAndCiphers (clientVersions, serverVersions) (clientCiphers, serverCiphers) = do
252    secNeg             <- arbitrary
253    dhparams           <- elements [dhParams512,ffdhe2048,ffdhe3072]
254
255    creds              <- arbitraryCredentialsOfEachType
256    (clientGroups, serverGroups) <- arbitraryGroupPair
257    (clientHashSignatures, serverHashSignatures) <- arbitraryHashSignaturePair
258    let serverState = def
259            { serverSupported = def { supportedCiphers  = serverCiphers
260                                    , supportedVersions = serverVersions
261                                    , supportedSecureRenegotiation = secNeg
262                                    , supportedGroups   = serverGroups
263                                    , supportedHashSignatures = serverHashSignatures
264                                    }
265            , serverDHEParams = Just dhparams
266            , serverShared = def { sharedCredentials = Credentials creds }
267            }
268    let clientState = (defaultParamsClient "" B.empty)
269            { clientSupported = def { supportedCiphers  = clientCiphers
270                                    , supportedVersions = clientVersions
271                                    , supportedSecureRenegotiation = secNeg
272                                    , supportedGroups   = clientGroups
273                                    , supportedHashSignatures = clientHashSignatures
274                                    }
275            , clientShared = def { sharedValidationCache = ValidationCache
276                                        { cacheAdd = \_ _ _ -> return ()
277                                        , cacheQuery = \_ _ _ -> return ValidationCachePass
278                                        }
279                                }
280            }
281    return (clientState, serverState)
282
283arbitraryClientCredential :: Version -> Gen Credential
284arbitraryClientCredential SSL3 = do
285    -- for SSL3 there is no EC but only RSA/DSA
286    creds <- arbitraryCredentialsOfEachType
287    elements (take 2 creds) -- RSA and DSA, but not Ed25519 and Ed448
288arbitraryClientCredential v | v < TLS12 = do
289    -- for TLS10 and TLS11 there is no EdDSA but only RSA/DSA/ECDSA
290    creds <- arbitraryCredentialsOfEachType
291    elements (take 2 creds) -- RSA and DSA (ECDSA later), but not EdDSA
292arbitraryClientCredential _    = arbitraryCredentialsOfEachType >>= elements
293
294arbitraryRSACredentialWithUsage :: [ExtKeyUsageFlag] -> Gen (CertificateChain, PrivKey)
295arbitraryRSACredentialWithUsage usageFlags = do
296    let (pubKey, privKey) = getGlobalRSAPair
297    cert <- arbitraryX509WithKeyAndUsage usageFlags (PubKeyRSA pubKey, ())
298    return (CertificateChain [cert], PrivKeyRSA privKey)
299
300arbitraryEMSMode :: Gen (EMSMode, EMSMode)
301arbitraryEMSMode = (,) <$> gen <*> gen
302  where gen = elements [ NoEMS, AllowEMS, RequireEMS ]
303
304setEMSMode :: (EMSMode, EMSMode) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
305setEMSMode (cems, sems) (clientParam, serverParam) = (clientParam', serverParam')
306  where
307    clientParam' = clientParam { clientSupported = (clientSupported clientParam)
308                                   { supportedExtendedMasterSec = cems }
309                               }
310    serverParam' = serverParam { serverSupported = (serverSupported serverParam)
311                                   { supportedExtendedMasterSec = sems }
312                               }
313
314readClientSessionRef :: (IORef mclient, IORef mserver) -> IO mclient
315readClientSessionRef refs = readIORef (fst refs)
316
317twoSessionRefs :: IO (IORef (Maybe client), IORef (Maybe server))
318twoSessionRefs = (,) <$> newIORef Nothing <*> newIORef Nothing
319
320-- | simple session manager to store one session id and session data for a single thread.
321-- a Real concurrent session manager would use an MVar and have multiples items.
322oneSessionManager :: IORef (Maybe (SessionID, SessionData)) -> SessionManager
323oneSessionManager ref = SessionManager
324    { sessionResume         = \myId     -> readIORef ref >>= maybeResume False myId
325    , sessionResumeOnlyOnce = \myId     -> readIORef ref >>= maybeResume True myId
326    , sessionEstablish      = \myId dat -> writeIORef ref $ Just (myId, dat)
327    , sessionInvalidate     = \_        -> return ()
328    }
329  where
330    maybeResume onlyOnce myId (Just (sid, sdata))
331        | sid == myId = when onlyOnce (writeIORef ref Nothing) >> return (Just sdata)
332    maybeResume _ _ _ = return Nothing
333
334twoSessionManagers :: (IORef (Maybe (SessionID, SessionData)), IORef (Maybe (SessionID, SessionData))) -> (SessionManager, SessionManager)
335twoSessionManagers (cRef, sRef) = (oneSessionManager cRef, oneSessionManager sRef)
336
337setPairParamsSessionManagers :: (SessionManager, SessionManager) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
338setPairParamsSessionManagers (clientManager, serverManager) (clientState, serverState) = (nc,ns)
339  where nc = clientState { clientShared = updateSessionManager clientManager $ clientShared clientState }
340        ns = serverState { serverShared = updateSessionManager serverManager $ serverShared serverState }
341        updateSessionManager manager shared = shared { sharedSessionManager = manager }
342
343setPairParamsSessionResuming :: (SessionID, SessionData) -> (ClientParams, ServerParams) -> (ClientParams, ServerParams)
344setPairParamsSessionResuming sessionStuff (clientState, serverState) =
345    ( clientState { clientWantSessionResume = Just sessionStuff }
346    , serverState)
347
348newPairContext :: PipeChan -> (ClientParams, ServerParams) -> IO (Context, Context)
349newPairContext pipe (cParams, sParams) = do
350    let noFlush = return ()
351    let noClose = return ()
352
353    let cBackend = Backend noFlush noClose (writePipeA pipe) (readPipeA pipe)
354    let sBackend = Backend noFlush noClose (writePipeB pipe) (readPipeB pipe)
355    cCtx' <- contextNew cBackend cParams
356    sCtx' <- contextNew sBackend sParams
357
358    contextHookSetLogging cCtx' (logging "client: ")
359    contextHookSetLogging sCtx' (logging "server: ")
360
361    return (cCtx', sCtx')
362  where
363        logging pre =
364            if debug
365                then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++)
366                                    , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) }
367                else def
368
369withDataPipe :: (ClientParams, ServerParams) -> (Context -> Chan result -> IO ()) -> (Chan start -> Context -> IO ()) -> ((start -> IO (), IO result) -> IO a) -> IO a
370withDataPipe params tlsServer tlsClient cont = do
371    -- initial setup
372    pipe        <- newPipe
373    _           <- runPipe pipe
374    startQueue  <- newChan
375    resultQueue <- newChan
376
377    (cCtx, sCtx) <- newPairContext pipe params
378
379    withAsync (E.catch (tlsServer sCtx resultQueue)
380                       (printAndRaise "server" (serverSupported $ snd params))) $ \sAsync -> do
381    withAsync (E.catch (tlsClient startQueue cCtx)
382                       (printAndRaise "client" (clientSupported $ fst params))) $ \cAsync -> do
383
384      let readResult = waitBoth cAsync sAsync >> readChan resultQueue
385      cont (writeChan startQueue, readResult)
386
387  where
388        printAndRaise :: String -> Supported -> E.SomeException -> IO ()
389        printAndRaise s supported e = do
390            putStrLn $ s ++ " exception: " ++ show e ++
391                            ", supported: " ++ show supported
392            E.throwIO e
393
394initiateDataPipe :: (ClientParams, ServerParams) -> (Context -> IO a1) -> (Context -> IO a) -> IO (Either E.SomeException a, Either E.SomeException a1)
395initiateDataPipe params tlsServer tlsClient = do
396    -- initial setup
397    pipe        <- newPipe
398    _           <- runPipe pipe
399
400    (cCtx, sCtx) <- newPairContext pipe params
401
402    async (tlsServer sCtx) >>= \sAsync ->
403        async (tlsClient cCtx) >>= \cAsync -> do
404            sRes <- waitCatch sAsync
405            cRes <- waitCatch cAsync
406            return (cRes, sRes)
407
408-- Terminate the write direction and wait to receive the peer EOF.  This is
409-- necessary in situations where we want to confirm the peer status, or to make
410-- sure to receive late messages like session tickets.  In the test suite this
411-- is used each time application code ends the connection without prior call to
412-- 'recvData'.
413byeBye :: Context -> IO ()
414byeBye ctx = do
415    bye ctx
416    bs <- recvData ctx
417    unless (B.null bs) $ fail "byeBye: unexpected application data"
418