1{-# OPTIONS_GHC -fno-warn-orphans #-}
2module Marshalling
3    ( someWords8
4    , prop_header_marshalling_id
5    , prop_handshake_marshalling_id
6    , prop_handshake13_marshalling_id
7    ) where
8
9import Control.Monad
10import Control.Applicative
11import Test.Tasty.QuickCheck
12import Network.TLS.Internal
13import Network.TLS
14
15import qualified Data.ByteString as B
16import Data.Word
17import Data.X509 (CertificateChain(..))
18import Certificate
19
20genByteString :: Int -> Gen B.ByteString
21genByteString i = B.pack <$> vector i
22
23instance Arbitrary Version where
24    arbitrary = elements [ SSL2, SSL3, TLS10, TLS11, TLS12, TLS13 ]
25
26instance Arbitrary ProtocolType where
27    arbitrary = elements
28            [ ProtocolType_ChangeCipherSpec
29            , ProtocolType_Alert
30            , ProtocolType_Handshake
31            , ProtocolType_AppData ]
32
33instance Arbitrary Header where
34    arbitrary = Header <$> arbitrary <*> arbitrary <*> arbitrary
35
36instance Arbitrary ClientRandom where
37    arbitrary = ClientRandom <$> genByteString 32
38
39instance Arbitrary ServerRandom where
40    arbitrary = ServerRandom <$> genByteString 32
41
42instance Arbitrary Session where
43    arbitrary = do
44        i <- choose (1,2) :: Gen Int
45        case i of
46            2 -> Session . Just <$> genByteString 32
47            _ -> return $ Session Nothing
48
49instance Arbitrary HashAlgorithm where
50    arbitrary = elements
51        [ Network.TLS.HashNone
52        , Network.TLS.HashMD5
53        , Network.TLS.HashSHA1
54        , Network.TLS.HashSHA224
55        , Network.TLS.HashSHA256
56        , Network.TLS.HashSHA384
57        , Network.TLS.HashSHA512
58        , Network.TLS.HashIntrinsic
59        ]
60
61instance Arbitrary SignatureAlgorithm where
62    arbitrary = elements
63        [ SignatureAnonymous
64        , SignatureRSA
65        , SignatureDSS
66        , SignatureECDSA
67        , SignatureRSApssRSAeSHA256
68        , SignatureRSApssRSAeSHA384
69        , SignatureRSApssRSAeSHA512
70        , SignatureEd25519
71        , SignatureEd448
72        , SignatureRSApsspssSHA256
73        , SignatureRSApsspssSHA384
74        , SignatureRSApsspssSHA512
75        ]
76
77instance Arbitrary DigitallySigned where
78    arbitrary = DigitallySigned Nothing <$> genByteString 32
79
80arbitraryCiphersIDs :: Gen [Word16]
81arbitraryCiphersIDs = choose (0,200) >>= vector
82
83arbitraryCompressionIDs :: Gen [Word8]
84arbitraryCompressionIDs = choose (0,200) >>= vector
85
86someWords8 :: Int -> Gen [Word8]
87someWords8 = vector
88
89instance Arbitrary ExtensionRaw where
90    arbitrary =
91        let arbitraryContent = choose (0,40) >>= genByteString
92         in ExtensionRaw <$> arbitrary <*> arbitraryContent
93
94arbitraryHelloExtensions :: Version -> Gen [ExtensionRaw]
95arbitraryHelloExtensions ver
96    | ver >= SSL3 = arbitrary
97    | otherwise   = return []  -- no hello extension with SSLv2
98
99instance Arbitrary CertificateType where
100    arbitrary = elements
101            [ CertificateType_RSA_Sign, CertificateType_DSS_Sign
102            , CertificateType_RSA_Fixed_DH, CertificateType_DSS_Fixed_DH
103            , CertificateType_RSA_Ephemeral_DH, CertificateType_DSS_Ephemeral_DH
104            , CertificateType_fortezza_dms ]
105
106instance Arbitrary Handshake where
107    arbitrary = oneof
108            [ arbitrary >>= \ver -> ClientHello ver
109                <$> arbitrary
110                <*> arbitrary
111                <*> arbitraryCiphersIDs
112                <*> arbitraryCompressionIDs
113                <*> arbitraryHelloExtensions ver
114                <*> return Nothing
115            , arbitrary >>= \ver -> ServerHello ver
116                <$> arbitrary
117                <*> arbitrary
118                <*> arbitrary
119                <*> arbitrary
120                <*> arbitraryHelloExtensions ver
121            , Certificates . CertificateChain <$> resize 2 (listOf arbitraryX509)
122            , pure HelloRequest
123            , pure ServerHelloDone
124            , ClientKeyXchg . CKX_RSA <$> genByteString 48
125            --, liftM  ServerKeyXchg
126            , liftM3 CertRequest arbitrary (return Nothing) (listOf arbitraryDN)
127            , CertVerify <$> arbitrary
128            , Finished <$> genByteString 12
129            ]
130
131arbitraryCertReqContext :: Gen B.ByteString
132arbitraryCertReqContext = oneof [ return B.empty, genByteString 32 ]
133
134instance Arbitrary Handshake13 where
135    arbitrary = oneof
136            [ arbitrary >>= \ver -> ClientHello13 ver
137                <$> arbitrary
138                <*> arbitrary
139                <*> arbitraryCiphersIDs
140                <*> arbitraryHelloExtensions ver
141            , arbitrary >>= \ver -> ServerHello13
142                <$> arbitrary
143                <*> arbitrary
144                <*> arbitrary
145                <*> arbitraryHelloExtensions ver
146            , NewSessionTicket13
147                <$> arbitrary
148                <*> arbitrary
149                <*> genByteString 32 -- nonce
150                <*> genByteString 32 -- session ID
151                <*> arbitrary
152            , pure EndOfEarlyData13
153            , EncryptedExtensions13 <$> arbitrary
154            , CertRequest13
155                <$> arbitraryCertReqContext
156                <*> arbitrary
157            , resize 2 (listOf arbitraryX509) >>= \certs -> Certificate13
158                <$> arbitraryCertReqContext
159                <*> return (CertificateChain certs)
160                <*> replicateM (length certs) arbitrary
161            , CertVerify13 <$> arbitrary <*> genByteString 32
162            , Finished13 <$> genByteString 12
163            , KeyUpdate13 <$> elements [ UpdateNotRequested, UpdateRequested ]
164            ]
165
166{- quickcheck property -}
167
168prop_header_marshalling_id :: Header -> Bool
169prop_header_marshalling_id x = decodeHeader (encodeHeader x) == Right x
170
171prop_handshake_marshalling_id :: Handshake -> Bool
172prop_handshake_marshalling_id x = decodeHs (encodeHandshake x) == Right x
173  where decodeHs b = verifyResult (decodeHandshake cp) $ decodeHandshakeRecord b
174        cp = CurrentParams { cParamsVersion = TLS10, cParamsKeyXchgType = Just CipherKeyExchange_RSA }
175
176prop_handshake13_marshalling_id :: Handshake13 -> Bool
177prop_handshake13_marshalling_id x = decodeHs (encodeHandshake13 x) == Right x
178  where decodeHs b = verifyResult decodeHandshake13 $ decodeHandshakeRecord13 b
179
180verifyResult :: (t -> b -> r) -> GetResult (t, b) -> r
181verifyResult fn result =
182    case result of
183        GotPartial _ -> error "got partial"
184        GotError e   -> error ("got error: " ++ show e)
185        GotSuccessRemaining _ _ -> error "got remaining byte left"
186        GotSuccess (ty, content) -> fn ty content
187