1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE OverloadedStrings #-}
3{-# LANGUAGE FlexibleContexts #-}
4
5-- |
6-- Module      : Network.TLS.Packet13
7-- License     : BSD-style
8-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
9-- Stability   : experimental
10-- Portability : unknown
11--
12module Network.TLS.Packet13
13       ( encodeHandshake13
14       , getHandshakeType13
15       , decodeHandshakeRecord13
16       , decodeHandshake13
17       , decodeHandshakes13
18       ) where
19
20import qualified Data.ByteString as B
21import Network.TLS.Struct
22import Network.TLS.Struct13
23import Network.TLS.Packet
24import Network.TLS.Wire
25import Network.TLS.Imports
26import Data.X509 (CertificateChainRaw(..), encodeCertificateChain, decodeCertificateChain)
27import Network.TLS.ErrT
28
29encodeHandshake13 :: Handshake13 -> ByteString
30encodeHandshake13 hdsk = pkt
31  where
32    !tp = typeOfHandshake13 hdsk
33    !content = encodeHandshake13' hdsk
34    !len = B.length content
35    !header = encodeHandshakeHeader13 tp len
36    !pkt = B.concat [header, content]
37
38-- TLS 1.3 does not use "select (extensions_present)".
39putExtensions :: [ExtensionRaw] -> Put
40putExtensions es = putOpaque16 (runPut $ mapM_ putExtension es)
41
42encodeHandshake13' :: Handshake13 -> ByteString
43encodeHandshake13' (ClientHello13 version random session cipherIDs exts) = runPut $ do
44    putBinaryVersion version
45    putClientRandom32 random
46    putSession session
47    putWords16 cipherIDs
48    putWords8 [0]
49    putExtensions exts
50encodeHandshake13' (ServerHello13 random session cipherId exts) = runPut $ do
51    putBinaryVersion TLS12
52    putServerRandom32 random
53    putSession session
54    putWord16 cipherId
55    putWord8 0 -- compressionID nullCompression
56    putExtensions exts
57encodeHandshake13' (EncryptedExtensions13 exts) = runPut $ putExtensions exts
58encodeHandshake13' (CertRequest13 reqctx exts) = runPut $ do
59    putOpaque8 reqctx
60    putExtensions exts
61encodeHandshake13' (Certificate13 reqctx cc ess) = runPut $ do
62    putOpaque8 reqctx
63    putOpaque24 (runPut $ mapM_ putCert $ zip certs ess)
64  where
65    CertificateChainRaw certs = encodeCertificateChain cc
66    putCert (certRaw,exts) = do
67        putOpaque24 certRaw
68        putExtensions exts
69encodeHandshake13' (CertVerify13 hs signature) = runPut $ do
70    putSignatureHashAlgorithm hs
71    putOpaque16 signature
72encodeHandshake13' (Finished13 dat) = runPut $ putBytes dat
73encodeHandshake13' (NewSessionTicket13 life ageadd nonce label exts) = runPut $ do
74    putWord32 life
75    putWord32 ageadd
76    putOpaque8 nonce
77    putOpaque16 label
78    putExtensions exts
79encodeHandshake13' EndOfEarlyData13 = ""
80encodeHandshake13' (KeyUpdate13 UpdateNotRequested) = runPut $ putWord8 0
81encodeHandshake13' (KeyUpdate13 UpdateRequested)    = runPut $ putWord8 1
82
83encodeHandshakeHeader13 :: HandshakeType13 -> Int -> ByteString
84encodeHandshakeHeader13 ty len = runPut $ do
85    putWord8 (valOfType ty)
86    putWord24 len
87
88decodeHandshakes13 :: MonadError TLSError m => ByteString -> m [Handshake13]
89decodeHandshakes13 bs = case decodeHandshakeRecord13 bs of
90  GotError err                -> throwError err
91  GotPartial _cont            -> error "decodeHandshakes13"
92  GotSuccess (ty,content)     -> case decodeHandshake13 ty content of
93    Left  e -> throwError e
94    Right h -> return [h]
95  GotSuccessRemaining (ty,content) left -> case decodeHandshake13 ty content of
96    Left  e -> throwError e
97    Right h -> (h:) <$> decodeHandshakes13 left
98
99{- decode and encode HANDSHAKE -}
100getHandshakeType13 :: Get HandshakeType13
101getHandshakeType13 = do
102    ty <- getWord8
103    case valToType ty of
104        Nothing -> fail ("invalid handshake type: " ++ show ty)
105        Just t  -> return t
106
107decodeHandshakeRecord13 :: ByteString -> GetResult (HandshakeType13, ByteString)
108decodeHandshakeRecord13 = runGet "handshake-record" $ do
109    ty      <- getHandshakeType13
110    content <- getOpaque24
111    return (ty, content)
112
113decodeHandshake13 :: HandshakeType13 -> ByteString -> Either TLSError Handshake13
114decodeHandshake13 ty = runGetErr ("handshake[" ++ show ty ++ "]") $ case ty of
115    HandshakeType_ClientHello13         -> decodeClientHello13
116    HandshakeType_ServerHello13         -> decodeServerHello13
117    HandshakeType_Finished13            -> decodeFinished13
118    HandshakeType_EncryptedExtensions13 -> decodeEncryptedExtensions13
119    HandshakeType_CertRequest13         -> decodeCertRequest13
120    HandshakeType_Certificate13         -> decodeCertificate13
121    HandshakeType_CertVerify13          -> decodeCertVerify13
122    HandshakeType_NewSessionTicket13    -> decodeNewSessionTicket13
123    HandshakeType_EndOfEarlyData13      -> return EndOfEarlyData13
124    HandshakeType_KeyUpdate13           -> decodeKeyUpdate13
125
126decodeClientHello13 :: Get Handshake13
127decodeClientHello13 = do
128    Just ver <- getBinaryVersion
129    random   <- getClientRandom32
130    session  <- getSession
131    ciphers  <- getWords16
132    _comp    <- getWords8
133    exts     <- fromIntegral <$> getWord16 >>= getExtensions
134    return $ ClientHello13 ver random session ciphers exts
135
136decodeServerHello13 :: Get Handshake13
137decodeServerHello13 = do
138    Just _ver <- getBinaryVersion
139    random    <- getServerRandom32
140    session   <- getSession
141    cipherid  <- getWord16
142    _comp     <- getWord8
143    exts      <- fromIntegral <$> getWord16 >>= getExtensions
144    return $ ServerHello13 random session cipherid exts
145
146decodeFinished13 :: Get Handshake13
147decodeFinished13 = Finished13 <$> (remaining >>= getBytes)
148
149decodeEncryptedExtensions13 :: Get Handshake13
150decodeEncryptedExtensions13 = EncryptedExtensions13 <$> do
151    len <- fromIntegral <$> getWord16
152    getExtensions len
153
154decodeCertRequest13 :: Get Handshake13
155decodeCertRequest13 = do
156    reqctx <- getOpaque8
157    len <- fromIntegral <$> getWord16
158    exts <- getExtensions len
159    return $ CertRequest13 reqctx exts
160
161decodeCertificate13 :: Get Handshake13
162decodeCertificate13 = do
163    reqctx <- getOpaque8
164    len <- fromIntegral <$> getWord24
165    (certRaws, ess) <- unzip <$> getList len getCert
166    case decodeCertificateChain $ CertificateChainRaw certRaws of
167        Left (i, s) -> fail ("error certificate parsing " ++ show i ++ ":" ++ s)
168        Right cc    -> return $ Certificate13 reqctx cc ess
169  where
170    getCert = do
171        l <- fromIntegral <$> getWord24
172        cert <- getBytes l
173        len <- fromIntegral <$> getWord16
174        exts <- getExtensions len
175        return (3 + l + 2 + len, (cert, exts))
176
177decodeCertVerify13 :: Get Handshake13
178decodeCertVerify13 = CertVerify13 <$> getSignatureHashAlgorithm <*> getOpaque16
179
180decodeNewSessionTicket13 :: Get Handshake13
181decodeNewSessionTicket13 = do
182    life   <- getWord32
183    ageadd <- getWord32
184    nonce  <- getOpaque8
185    label  <- getOpaque16
186    len    <- fromIntegral <$> getWord16
187    exts   <- getExtensions len
188    return $ NewSessionTicket13 life ageadd nonce label exts
189
190decodeKeyUpdate13 :: Get Handshake13
191decodeKeyUpdate13 = do
192    ru <- getWord8
193    case ru of
194        0 -> return $ KeyUpdate13 UpdateNotRequested
195        1 -> return $ KeyUpdate13 UpdateRequested
196        x -> fail $ "Unknown request_update: " ++ show x
197