1{-# LANGUAGE OverloadedStrings #-}
2
3-- |
4-- Module      : Network.TLS.Handshake.State13
5-- License     : BSD-style
6-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
7-- Stability   : experimental
8-- Portability : unknown
9--
10module Network.TLS.Handshake.State13
11       ( CryptLevel ( CryptEarlySecret
12                    , CryptHandshakeSecret
13                    , CryptApplicationSecret
14                    )
15       , TrafficSecret
16       , getTxState
17       , getRxState
18       , setTxState
19       , setRxState
20       , clearTxState
21       , clearRxState
22       , setHelloParameters13
23       , transcriptHash
24       , wrapAsMessageHash13
25       , PendingAction(..)
26       , setPendingActions
27       , popPendingAction
28       ) where
29
30import Control.Concurrent.MVar
31import Control.Monad.State
32import qualified Data.ByteString as B
33import Data.IORef
34import Network.TLS.Cipher
35import Network.TLS.Compression
36import Network.TLS.Context.Internal
37import Network.TLS.Crypto
38import Network.TLS.Handshake.State
39import Network.TLS.KeySchedule (hkdfExpandLabel)
40import Network.TLS.Record.State
41import Network.TLS.Struct
42import Network.TLS.Imports
43import Network.TLS.Types
44import Network.TLS.Util
45
46getTxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
47getTxState ctx = getXState ctx ctxTxState
48
49getRxState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
50getRxState ctx = getXState ctx ctxRxState
51
52getXState :: Context
53          -> (Context -> MVar RecordState)
54          -> IO (Hash, Cipher, CryptLevel, ByteString)
55getXState ctx func = do
56    tx <- readMVar (func ctx)
57    let Just usedCipher = stCipher tx
58        usedHash = cipherHash usedCipher
59        level = stCryptLevel tx
60        secret = cstMacSecret $ stCryptState tx
61    return (usedHash, usedCipher, level, secret)
62
63class TrafficSecret ty where
64    fromTrafficSecret :: ty -> (CryptLevel, ByteString)
65
66instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
67    fromTrafficSecret prx@(AnyTrafficSecret s) = (getCryptLevel prx, s)
68
69instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
70    fromTrafficSecret prx@(ClientTrafficSecret s) = (getCryptLevel prx, s)
71
72instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
73    fromTrafficSecret prx@(ServerTrafficSecret s) = (getCryptLevel prx, s)
74
75setTxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
76setTxState = setXState ctxTxState BulkEncrypt
77
78setRxState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
79setRxState = setXState ctxRxState BulkDecrypt
80
81setXState :: TrafficSecret ty
82          => (Context -> MVar RecordState) -> BulkDirection
83          -> Context -> Hash -> Cipher -> ty
84          -> IO ()
85setXState func encOrDec ctx h cipher ts =
86    let (lvl, secret) = fromTrafficSecret ts
87     in setXState' func encOrDec ctx h cipher lvl secret
88
89setXState' :: (Context -> MVar RecordState) -> BulkDirection
90          -> Context -> Hash -> Cipher -> CryptLevel -> ByteString
91          -> IO ()
92setXState' func encOrDec ctx h cipher lvl secret =
93    modifyMVar_ (func ctx) (\_ -> return rt)
94  where
95    bulk    = cipherBulk cipher
96    keySize = bulkKeySize bulk
97    ivSize  = max 8 (bulkIVSize bulk + bulkExplicitIV bulk)
98    key = hkdfExpandLabel h secret "key" "" keySize
99    iv  = hkdfExpandLabel h secret "iv"  "" ivSize
100    cst = CryptState {
101        cstKey       = bulkInit bulk encOrDec key
102      , cstIV        = iv
103      , cstMacSecret = secret
104      }
105    rt = RecordState {
106        stCryptState  = cst
107      , stMacState    = MacState { msSequence = 0 }
108      , stCryptLevel  = lvl
109      , stCipher      = Just cipher
110      , stCompression = nullCompression
111      }
112
113clearTxState :: Context -> IO ()
114clearTxState = clearXState ctxTxState
115
116clearRxState :: Context -> IO ()
117clearRxState = clearXState ctxRxState
118
119clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
120clearXState func ctx =
121    modifyMVar_ (func ctx) (\rt -> return rt { stCipher = Nothing })
122
123setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
124setHelloParameters13 cipher = do
125    hst <- get
126    case hstPendingCipher hst of
127        Nothing -> do
128            put hst {
129                  hstPendingCipher      = Just cipher
130                , hstPendingCompression = nullCompression
131                , hstHandshakeDigest    = updateDigest $ hstHandshakeDigest hst
132                }
133            return $ Right ()
134        Just oldcipher
135            | cipher == oldcipher -> return $ Right ()
136            | otherwise -> return $ Left $ Error_Protocol ("TLS 1.3 cipher changed after hello retry", True, IllegalParameter)
137  where
138    hashAlg = cipherHash cipher
139    updateDigest (HandshakeMessages bytes)  = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
140    updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"
141
142-- When a HelloRetryRequest is sent or received, the existing transcript must be
143-- wrapped in a "message_hash" construct.  See RFC 8446 section 4.4.1.  This
144-- applies to key-schedule computations as well as the ones for PSK binders.
145wrapAsMessageHash13 :: HandshakeM ()
146wrapAsMessageHash13 = do
147    cipher <- getPendingCipher
148    foldHandshakeDigest (cipherHash cipher) foldFunc
149  where
150    foldFunc dig = B.concat [ "\254\0\0"
151                            , B.singleton (fromIntegral $ B.length dig)
152                            , dig
153                            ]
154
155transcriptHash :: MonadIO m => Context -> m ByteString
156transcriptHash ctx = do
157    hst <- fromJust "HState" <$> getHState ctx
158    case hstHandshakeDigest hst of
159      HandshakeDigestContext hashCtx -> return $ hashFinal hashCtx
160      HandshakeMessages      _       -> error "un-initialized handshake digest"
161
162setPendingActions :: Context -> [PendingAction] -> IO ()
163setPendingActions ctx = writeIORef (ctxPendingActions ctx)
164
165popPendingAction :: Context -> IO (Maybe PendingAction)
166popPendingAction ctx = do
167    let ref = ctxPendingActions ctx
168    actions <- readIORef ref
169    case actions of
170        bs:bss -> writeIORef ref bss >> return (Just bs)
171        []     -> return Nothing
172