1{-# LANGUAGE CPP #-}
2{-# LANGUAGE Rank2Types #-}
3{-# LANGUAGE MultiParamTypeClasses #-}
4{-# LANGUAGE GeneralizedNewtypeDeriving #-}
5-- |
6-- Module      : Network.TLS.State
7-- License     : BSD-style
8-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
9-- Stability   : experimental
10-- Portability : unknown
11--
12-- the State module contains calls related to state initialization/manipulation
13-- which is use by the Receiving module and the Sending module.
14--
15module Network.TLS.State
16    ( TLSState(..)
17    , TLSSt
18    , runTLSState
19    , newTLSState
20    , withTLSRNG
21    , updateVerifiedData
22    , finishHandshakeTypeMaterial
23    , finishHandshakeMaterial
24    , certVerifyHandshakeTypeMaterial
25    , certVerifyHandshakeMaterial
26    , setVersion
27    , setVersionIfUnset
28    , getVersion
29    , getVersionWithDefault
30    , setSecureRenegotiation
31    , getSecureRenegotiation
32    , setExtensionALPN
33    , getExtensionALPN
34    , setNegotiatedProtocol
35    , getNegotiatedProtocol
36    , setClientALPNSuggest
37    , getClientALPNSuggest
38    , setClientEcPointFormatSuggest
39    , getClientEcPointFormatSuggest
40    , getClientCertificateChain
41    , setClientCertificateChain
42    , setClientSNI
43    , getClientSNI
44    , getVerifiedData
45    , setSession
46    , getSession
47    , isSessionResuming
48    , isClientContext
49    , setExporterMasterSecret
50    , getExporterMasterSecret
51    , setTLS13KeyShare
52    , getTLS13KeyShare
53    , setTLS13PreSharedKey
54    , getTLS13PreSharedKey
55    , setTLS13HRR
56    , getTLS13HRR
57    , setTLS13Cookie
58    , getTLS13Cookie
59    , setClientSupportsPHA
60    , getClientSupportsPHA
61    -- * random
62    , genRandom
63    , withRNG
64    ) where
65
66import Network.TLS.Imports
67import Network.TLS.Struct
68import Network.TLS.Struct13
69import Network.TLS.RNG
70import Network.TLS.Types (Role(..), HostName)
71import Network.TLS.Wire (GetContinuation)
72import Network.TLS.Extension
73import qualified Data.ByteString as B
74import Control.Monad.State.Strict
75import Network.TLS.ErrT
76import Crypto.Random
77import Data.X509 (CertificateChain)
78
79data TLSState = TLSState
80    { stSession             :: Session
81    , stSessionResuming     :: Bool
82    , stSecureRenegotiation :: Bool  -- RFC 5746
83    , stClientVerifiedData  :: ByteString -- RFC 5746
84    , stServerVerifiedData  :: ByteString -- RFC 5746
85    , stExtensionALPN       :: Bool  -- RFC 7301
86    , stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
87    , stNegotiatedProtocol  :: Maybe B.ByteString -- ALPN protocol
88    , stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType13, ByteString))
89    , stClientALPNSuggest   :: Maybe [B.ByteString]
90    , stClientGroupSuggest  :: Maybe [Group]
91    , stClientEcPointFormatSuggest :: Maybe [EcPointFormat]
92    , stClientCertificateChain :: Maybe CertificateChain
93    , stClientSNI           :: Maybe HostName
94    , stRandomGen           :: StateRNG
95    , stVersion             :: Maybe Version
96    , stClientContext       :: Role
97    , stTLS13KeyShare       :: Maybe KeyShare
98    , stTLS13PreSharedKey   :: Maybe PreSharedKey
99    , stTLS13HRR            :: !Bool
100    , stTLS13Cookie         :: Maybe Cookie
101    , stExporterMasterSecret :: Maybe ByteString -- TLS 1.3
102    , stClientSupportsPHA   :: !Bool -- Post-Handshake Authentication (TLS 1.3)
103    }
104
105newtype TLSSt a = TLSSt { runTLSSt :: ErrT TLSError (State TLSState) a }
106    deriving (Monad, MonadError TLSError, Functor, Applicative)
107
108instance MonadState TLSState TLSSt where
109    put x = TLSSt (lift $ put x)
110    get   = TLSSt (lift get)
111#if MIN_VERSION_mtl(2,1,0)
112    state f = TLSSt (lift $ state f)
113#endif
114
115runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
116runTLSState f st = runState (runErrT (runTLSSt f)) st
117
118newTLSState :: StateRNG -> Role -> TLSState
119newTLSState rng clientContext = TLSState
120    { stSession             = Session Nothing
121    , stSessionResuming     = False
122    , stSecureRenegotiation = False
123    , stClientVerifiedData  = B.empty
124    , stServerVerifiedData  = B.empty
125    , stExtensionALPN       = False
126    , stHandshakeRecordCont = Nothing
127    , stHandshakeRecordCont13 = Nothing
128    , stNegotiatedProtocol  = Nothing
129    , stClientALPNSuggest   = Nothing
130    , stClientGroupSuggest  = Nothing
131    , stClientEcPointFormatSuggest = Nothing
132    , stClientCertificateChain = Nothing
133    , stClientSNI           = Nothing
134    , stRandomGen           = rng
135    , stVersion             = Nothing
136    , stClientContext       = clientContext
137    , stTLS13KeyShare       = Nothing
138    , stTLS13PreSharedKey   = Nothing
139    , stTLS13HRR            = False
140    , stTLS13Cookie         = Nothing
141    , stExporterMasterSecret = Nothing
142    , stClientSupportsPHA   = False
143    }
144
145updateVerifiedData :: Role -> ByteString -> TLSSt ()
146updateVerifiedData sending bs = do
147    cc <- isClientContext
148    if cc /= sending
149        then modify (\st -> st { stServerVerifiedData = bs })
150        else modify (\st -> st { stClientVerifiedData = bs })
151
152finishHandshakeTypeMaterial :: HandshakeType -> Bool
153finishHandshakeTypeMaterial HandshakeType_ClientHello     = True
154finishHandshakeTypeMaterial HandshakeType_ServerHello     = True
155finishHandshakeTypeMaterial HandshakeType_Certificate     = True
156finishHandshakeTypeMaterial HandshakeType_HelloRequest    = False
157finishHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
158finishHandshakeTypeMaterial HandshakeType_ClientKeyXchg   = True
159finishHandshakeTypeMaterial HandshakeType_ServerKeyXchg   = True
160finishHandshakeTypeMaterial HandshakeType_CertRequest     = True
161finishHandshakeTypeMaterial HandshakeType_CertVerify      = True
162finishHandshakeTypeMaterial HandshakeType_Finished        = True
163
164finishHandshakeMaterial :: Handshake -> Bool
165finishHandshakeMaterial = finishHandshakeTypeMaterial . typeOfHandshake
166
167certVerifyHandshakeTypeMaterial :: HandshakeType -> Bool
168certVerifyHandshakeTypeMaterial HandshakeType_ClientHello     = True
169certVerifyHandshakeTypeMaterial HandshakeType_ServerHello     = True
170certVerifyHandshakeTypeMaterial HandshakeType_Certificate     = True
171certVerifyHandshakeTypeMaterial HandshakeType_HelloRequest    = False
172certVerifyHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
173certVerifyHandshakeTypeMaterial HandshakeType_ClientKeyXchg   = True
174certVerifyHandshakeTypeMaterial HandshakeType_ServerKeyXchg   = True
175certVerifyHandshakeTypeMaterial HandshakeType_CertRequest     = True
176certVerifyHandshakeTypeMaterial HandshakeType_CertVerify      = False
177certVerifyHandshakeTypeMaterial HandshakeType_Finished        = False
178
179certVerifyHandshakeMaterial :: Handshake -> Bool
180certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
181
182setSession :: Session -> Bool -> TLSSt ()
183setSession session resuming = modify (\st -> st { stSession = session, stSessionResuming = resuming })
184
185getSession :: TLSSt Session
186getSession = gets stSession
187
188isSessionResuming :: TLSSt Bool
189isSessionResuming = gets stSessionResuming
190
191setVersion :: Version -> TLSSt ()
192setVersion ver = modify (\st -> st { stVersion = Just ver })
193
194setVersionIfUnset :: Version -> TLSSt ()
195setVersionIfUnset ver = modify maybeSet
196  where maybeSet st = case stVersion st of
197                           Nothing -> st { stVersion = Just ver }
198                           Just _  -> st
199
200getVersion :: TLSSt Version
201getVersion = fromMaybe (error "internal error: version hasn't been set yet") <$> gets stVersion
202
203getVersionWithDefault :: Version -> TLSSt Version
204getVersionWithDefault defaultVer = fromMaybe defaultVer <$> gets stVersion
205
206setSecureRenegotiation :: Bool -> TLSSt ()
207setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
208
209getSecureRenegotiation :: TLSSt Bool
210getSecureRenegotiation = gets stSecureRenegotiation
211
212setExtensionALPN :: Bool -> TLSSt ()
213setExtensionALPN b = modify (\st -> st { stExtensionALPN = b })
214
215getExtensionALPN :: TLSSt Bool
216getExtensionALPN = gets stExtensionALPN
217
218setNegotiatedProtocol :: B.ByteString -> TLSSt ()
219setNegotiatedProtocol s = modify (\st -> st { stNegotiatedProtocol = Just s })
220
221getNegotiatedProtocol :: TLSSt (Maybe B.ByteString)
222getNegotiatedProtocol = gets stNegotiatedProtocol
223
224setClientALPNSuggest :: [B.ByteString] -> TLSSt ()
225setClientALPNSuggest ps = modify (\st -> st { stClientALPNSuggest = Just ps})
226
227getClientALPNSuggest :: TLSSt (Maybe [B.ByteString])
228getClientALPNSuggest = gets stClientALPNSuggest
229
230setClientEcPointFormatSuggest :: [EcPointFormat] -> TLSSt ()
231setClientEcPointFormatSuggest epf = modify (\st -> st { stClientEcPointFormatSuggest = Just epf})
232
233getClientEcPointFormatSuggest :: TLSSt (Maybe [EcPointFormat])
234getClientEcPointFormatSuggest = gets stClientEcPointFormatSuggest
235
236setClientCertificateChain :: CertificateChain -> TLSSt ()
237setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Just s })
238
239getClientCertificateChain :: TLSSt (Maybe CertificateChain)
240getClientCertificateChain = gets stClientCertificateChain
241
242setClientSNI :: HostName -> TLSSt ()
243setClientSNI hn = modify (\st -> st { stClientSNI = Just hn })
244
245getClientSNI :: TLSSt (Maybe HostName)
246getClientSNI = gets stClientSNI
247
248getVerifiedData :: Role -> TLSSt ByteString
249getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData else stServerVerifiedData)
250
251isClientContext :: TLSSt Role
252isClientContext = gets stClientContext
253
254genRandom :: Int -> TLSSt ByteString
255genRandom n = do
256    withRNG (getRandomBytes n)
257
258withRNG :: MonadPseudoRandom StateRNG a -> TLSSt a
259withRNG f = do
260    st <- get
261    let (a,rng') = withTLSRNG (stRandomGen st) f
262    put (st { stRandomGen = rng' })
263    return a
264
265setExporterMasterSecret :: ByteString -> TLSSt ()
266setExporterMasterSecret key = modify (\st -> st { stExporterMasterSecret = Just key })
267
268getExporterMasterSecret :: TLSSt (Maybe ByteString)
269getExporterMasterSecret = gets stExporterMasterSecret
270
271setTLS13KeyShare :: Maybe KeyShare -> TLSSt ()
272setTLS13KeyShare mks = modify (\st -> st { stTLS13KeyShare = mks })
273
274getTLS13KeyShare :: TLSSt (Maybe KeyShare)
275getTLS13KeyShare = gets stTLS13KeyShare
276
277setTLS13PreSharedKey :: Maybe PreSharedKey -> TLSSt ()
278setTLS13PreSharedKey mpsk = modify (\st -> st { stTLS13PreSharedKey = mpsk })
279
280getTLS13PreSharedKey :: TLSSt (Maybe PreSharedKey)
281getTLS13PreSharedKey = gets stTLS13PreSharedKey
282
283setTLS13HRR :: Bool -> TLSSt ()
284setTLS13HRR b = modify (\st -> st { stTLS13HRR = b })
285
286getTLS13HRR :: TLSSt Bool
287getTLS13HRR = gets stTLS13HRR
288
289setTLS13Cookie :: Maybe Cookie -> TLSSt ()
290setTLS13Cookie mcookie = modify (\st -> st { stTLS13Cookie = mcookie })
291
292getTLS13Cookie :: TLSSt (Maybe Cookie)
293getTLS13Cookie = gets stTLS13Cookie
294
295setClientSupportsPHA :: Bool -> TLSSt ()
296setClientSupportsPHA b = modify (\st -> st { stClientSupportsPHA = b })
297
298getClientSupportsPHA :: TLSSt Bool
299getClientSupportsPHA = gets stClientSupportsPHA
300