1{-# LANGUAGE ExistentialQuantification #-}
2{-# LANGUAGE OverloadedStrings #-}
3{-# LANGUAGE RecordWildCards #-}
4-- |
5-- Module      : Network.TLS.Context.Internal
6-- License     : BSD-style
7-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
8-- Stability   : experimental
9-- Portability : unknown
10--
11module Network.TLS.Context.Internal
12    (
13    -- * Context configuration
14      ClientParams(..)
15    , ServerParams(..)
16    , defaultParamsClient
17    , SessionID
18    , SessionData(..)
19    , MaxFragmentEnum(..)
20    , Measurement(..)
21
22    -- * Context object and accessor
23    , Context(..)
24    , Hooks(..)
25    , Established(..)
26    , PendingAction(..)
27    , ctxEOF
28    , ctxHasSSLv2ClientHello
29    , ctxDisableSSLv2ClientHello
30    , ctxEstablished
31    , withLog
32    , ctxWithHooks
33    , contextModifyHooks
34    , setEOF
35    , setEstablished
36    , contextFlush
37    , contextClose
38    , contextSend
39    , contextRecv
40    , updateRecordLayer
41    , updateMeasure
42    , withMeasure
43    , withReadLock
44    , withWriteLock
45    , withStateLock
46    , withRWLock
47
48    -- * information
49    , Information(..)
50    , contextGetInformation
51
52    -- * Using context states
53    , throwCore
54    , failOnEitherError
55    , usingState
56    , usingState_
57    , runTxState
58    , runRxState
59    , usingHState
60    , getHState
61    , saveHState
62    , restoreHState
63    , getStateRNG
64    , tls13orLater
65    , addCertRequest13
66    , getCertRequest13
67    , decideRecordVersion
68
69    -- * Misc
70    , HandshakeSync(..)
71    ) where
72
73import Network.TLS.Backend
74import Network.TLS.Cipher
75import Network.TLS.Compression (Compression)
76import Network.TLS.Extension
77import Network.TLS.Handshake.Control
78import Network.TLS.Handshake.State
79import Network.TLS.Hooks
80import Network.TLS.Imports
81import Network.TLS.Measurement
82import Network.TLS.Parameters
83import Network.TLS.Record.Layer
84import Network.TLS.Record.State
85import Network.TLS.State
86import Network.TLS.Struct
87import Network.TLS.Struct13
88import Network.TLS.Types
89import Network.TLS.Util
90
91import Control.Concurrent.MVar
92import Control.Exception (throwIO)
93import Control.Monad.State.Strict
94import qualified Data.ByteString as B
95import Data.IORef
96import Data.Tuple
97
98-- | Information related to a running context, e.g. current cipher
99data Information = Information
100    { infoVersion      :: Version
101    , infoCipher       :: Cipher
102    , infoCompression  :: Compression
103    , infoMasterSecret :: Maybe ByteString
104    , infoExtendedMasterSec   :: Bool
105    , infoClientRandom :: Maybe ClientRandom
106    , infoServerRandom :: Maybe ServerRandom
107    , infoNegotiatedGroup     :: Maybe Group
108    , infoTLS13HandshakeMode  :: Maybe HandshakeMode13
109    , infoIsEarlyDataAccepted :: Bool
110    } deriving (Show,Eq)
111
112-- | A TLS Context keep tls specific state, parameters and backend information.
113data Context = forall bytes . Monoid bytes => Context
114    { ctxConnection       :: Backend   -- ^ return the backend object associated with this context
115    , ctxSupported        :: Supported
116    , ctxShared           :: Shared
117    , ctxState            :: MVar TLSState
118    , ctxMeasurement      :: IORef Measurement
119    , ctxEOF_             :: IORef Bool    -- ^ has the handle EOFed or not.
120    , ctxEstablished_     :: IORef Established -- ^ has the handshake been done and been successful.
121    , ctxNeedEmptyPacket  :: IORef Bool    -- ^ empty packet workaround for CBC guessability.
122    , ctxSSLv2ClientHello :: IORef Bool    -- ^ enable the reception of compatibility SSLv2 client hello.
123                                           -- the flag will be set to false regardless of its initial value
124                                           -- after the first packet received.
125    , ctxFragmentSize     :: Maybe Int        -- ^ maximum size of plaintext fragments
126    , ctxTxState          :: MVar RecordState -- ^ current tx state
127    , ctxRxState          :: MVar RecordState -- ^ current rx state
128    , ctxHandshake        :: MVar (Maybe HandshakeState) -- ^ optional handshake state
129    , ctxDoHandshake      :: Context -> IO ()
130    , ctxDoHandshakeWith  :: Context -> Handshake -> IO ()
131    , ctxDoRequestCertificate :: Context -> IO Bool
132    , ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
133    , ctxHooks            :: IORef Hooks   -- ^ hooks for this context
134    , ctxLockWrite        :: MVar ()       -- ^ lock to use for writing data (including updating the state)
135    , ctxLockRead         :: MVar ()       -- ^ lock to use for reading data (including updating the state)
136    , ctxLockState        :: MVar ()       -- ^ lock used during read/write when receiving and sending packet.
137                                           -- it is usually nested in a write or read lock.
138    , ctxPendingActions   :: IORef [PendingAction]
139    , ctxCertRequests     :: IORef [Handshake13]  -- ^ pending PHA requests
140    , ctxKeyLogger        :: String -> IO ()
141    , ctxRecordLayer      :: RecordLayer bytes
142    , ctxHandshakeSync    :: HandshakeSync
143    , ctxQUICMode         :: Bool
144    }
145
146data HandshakeSync = HandshakeSync (Context -> ClientState -> IO ())
147                                   (Context -> ServerState -> IO ())
148
149updateRecordLayer :: Monoid bytes => RecordLayer bytes -> Context -> Context
150updateRecordLayer recordLayer Context{..} =
151    Context { ctxRecordLayer = recordLayer, .. }
152
153data Established = NotEstablished
154                 | EarlyDataAllowed Int    -- remaining 0-RTT bytes allowed
155                 | EarlyDataNotAllowed Int -- remaining 0-RTT packets allowed to skip
156                 | Established
157                 deriving (Eq, Show)
158
159data PendingAction
160    = PendingAction Bool (Handshake13 -> IO ())
161      -- ^ simple pending action
162    | PendingActionHash Bool (ByteString -> Handshake13 -> IO ())
163      -- ^ pending action taking transcript hash up to preceding message
164
165updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
166updateMeasure ctx = modifyIORef' (ctxMeasurement ctx)
167
168withMeasure :: Context -> (Measurement -> IO a) -> IO a
169withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f
170
171-- | A shortcut for 'backendFlush . ctxConnection'.
172contextFlush :: Context -> IO ()
173contextFlush = backendFlush . ctxConnection
174
175-- | A shortcut for 'backendClose . ctxConnection'.
176contextClose :: Context -> IO ()
177contextClose = backendClose . ctxConnection
178
179-- | Information about the current context
180contextGetInformation :: Context -> IO (Maybe Information)
181contextGetInformation ctx = do
182    ver    <- usingState_ ctx $ gets stVersion
183    hstate <- getHState ctx
184    let (ms, ems, cr, sr, hm13, grp) =
185            case hstate of
186                Just st -> (hstMasterSecret st,
187                            hstExtendedMasterSec st,
188                            Just (hstClientRandom st),
189                            hstServerRandom st,
190                            if ver == Just TLS13 then Just (hstTLS13HandshakeMode st) else Nothing,
191                            hstNegotiatedGroup st)
192                Nothing -> (Nothing, False, Nothing, Nothing, Nothing, Nothing)
193    (cipher,comp) <- readMVar (ctxRxState ctx) <&> \st -> (stCipher st, stCompression st)
194    let accepted = case hstate of
195            Just st -> hstTLS13RTT0Status st == RTT0Accepted
196            Nothing -> False
197    case (ver, cipher) of
198        (Just v, Just c) -> return $ Just $ Information v c comp ms ems cr sr grp hm13 accepted
199        _                -> return Nothing
200
201contextSend :: Context -> ByteString -> IO ()
202contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
203
204contextRecv :: Context -> Int -> IO ByteString
205contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
206
207ctxEOF :: Context -> IO Bool
208ctxEOF ctx = readIORef $ ctxEOF_ ctx
209
210ctxHasSSLv2ClientHello :: Context -> IO Bool
211ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx
212
213ctxDisableSSLv2ClientHello :: Context -> IO ()
214ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False
215
216setEOF :: Context -> IO ()
217setEOF ctx = writeIORef (ctxEOF_ ctx) True
218
219ctxEstablished :: Context -> IO Established
220ctxEstablished ctx = readIORef $ ctxEstablished_ ctx
221
222ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
223ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f
224
225contextModifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
226contextModifyHooks ctx = modifyIORef (ctxHooks ctx)
227
228setEstablished :: Context -> Established -> IO ()
229setEstablished ctx = writeIORef (ctxEstablished_ ctx)
230
231withLog :: Context -> (Logging -> IO ()) -> IO ()
232withLog ctx f = ctxWithHooks ctx (f . hookLogging)
233
234throwCore :: MonadIO m => TLSError -> m a
235throwCore = liftIO . throwIO
236
237failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
238failOnEitherError f = do
239    ret <- f
240    case ret of
241        Left err -> throwCore err
242        Right r  -> return r
243
244usingState :: Context -> TLSSt a -> IO (Either TLSError a)
245usingState ctx f =
246    modifyMVar (ctxState ctx) $ \st ->
247            let (a, newst) = runTLSState f st
248             in newst `seq` return (newst, a)
249
250usingState_ :: Context -> TLSSt a -> IO a
251usingState_ ctx f = failOnEitherError $ usingState ctx f
252
253usingHState :: MonadIO m => Context -> HandshakeM a -> m a
254usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst ->
255    case mst of
256        Nothing -> throwCore $ Error_Misc "missing handshake"
257        Just st -> return $ swap (Just <$> runHandshake st f)
258
259getHState :: MonadIO m => Context -> m (Maybe HandshakeState)
260getHState ctx = liftIO $ readMVar (ctxHandshake ctx)
261
262saveHState :: Context -> IO (Saved (Maybe HandshakeState))
263saveHState ctx = saveMVar (ctxHandshake ctx)
264
265restoreHState :: Context
266              -> Saved (Maybe HandshakeState)
267              -> IO (Saved (Maybe HandshakeState))
268restoreHState ctx = restoreMVar (ctxHandshake ctx)
269
270decideRecordVersion :: Context -> IO (Version, Bool)
271decideRecordVersion ctx = usingState_ ctx $ do
272    ver <- getVersionWithDefault (maximum $ supportedVersions $ ctxSupported ctx)
273    hrr <- getTLS13HRR
274    -- For TLS 1.3, ver' is only used in ClientHello.
275    -- The record version of the first ClientHello SHOULD be TLS 1.0.
276    -- The record version of the second ClientHello MUST be TLS 1.2.
277    let ver'
278         | ver >= TLS13 = if hrr then TLS12 else TLS10
279         | otherwise    = ver
280    return (ver', ver >= TLS13)
281
282runTxState :: Context -> RecordM a -> IO (Either TLSError a)
283runTxState ctx f = do
284    (ver, tls13) <- decideRecordVersion ctx
285    let opt = RecordOptions { recordVersion = ver
286                            , recordTLS13   = tls13
287                            }
288    modifyMVar (ctxTxState ctx) $ \st ->
289        case runRecordM f opt st of
290            Left err         -> return (st, Left err)
291            Right (a, newSt) -> return (newSt, Right a)
292
293runRxState :: Context -> RecordM a -> IO (Either TLSError a)
294runRxState ctx f = do
295    ver <- usingState_ ctx getVersion
296    -- For 1.3, ver is just ignored. So, it is not necessary to convert ver.
297    let opt = RecordOptions { recordVersion = ver
298                            , recordTLS13   = ver >= TLS13
299                            }
300    modifyMVar (ctxRxState ctx) $ \st ->
301        case runRecordM f opt st of
302            Left err         -> return (st, Left err)
303            Right (a, newSt) -> return (newSt, Right a)
304
305getStateRNG :: Context -> Int -> IO ByteString
306getStateRNG ctx n = usingState_ ctx $ genRandom n
307
308withReadLock :: Context -> IO a -> IO a
309withReadLock ctx f = withMVar (ctxLockRead ctx) (const f)
310
311withWriteLock :: Context -> IO a -> IO a
312withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f)
313
314withRWLock :: Context -> IO a -> IO a
315withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f
316
317withStateLock :: Context -> IO a -> IO a
318withStateLock ctx f = withMVar (ctxLockState ctx) (const f)
319
320tls13orLater :: MonadIO m => Context -> m Bool
321tls13orLater ctx = do
322    ev <- liftIO $ usingState ctx $ getVersionWithDefault TLS10 -- fixme
323    return $ case ev of
324               Left  _ -> False
325               Right v -> v >= TLS13
326
327addCertRequest13 :: Context -> Handshake13 -> IO ()
328addCertRequest13 ctx certReq = modifyIORef (ctxCertRequests ctx) (certReq:)
329
330getCertRequest13 :: Context -> CertReqContext -> IO (Maybe Handshake13)
331getCertRequest13 ctx context = do
332    let ref = ctxCertRequests ctx
333    l <- readIORef ref
334    let (matched, others) = partition (\(CertRequest13 c _) -> context == c) l
335    case matched of
336        []          -> return Nothing
337        (certReq:_) -> writeIORef ref others >> return (Just certReq)
338