1// Copyright 2009 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package qtls 6 7import ( 8 "bytes" 9 "math/rand" 10 "reflect" 11 "strings" 12 "testing" 13 "testing/quick" 14 "time" 15) 16 17var tests = []interface{}{ 18 &clientHelloMsg{}, 19 &serverHelloMsg{}, 20 &finishedMsg{}, 21 22 &certificateMsg{}, 23 &certificateRequestMsg{}, 24 &certificateVerifyMsg{ 25 hasSignatureAlgorithm: true, 26 }, 27 &certificateStatusMsg{}, 28 &clientKeyExchangeMsg{}, 29 &newSessionTicketMsg{}, 30 &sessionState{}, 31 &sessionStateTLS13{}, 32 &encryptedExtensionsMsg{}, 33 &endOfEarlyDataMsg{}, 34 &keyUpdateMsg{}, 35 &newSessionTicketMsgTLS13{}, 36 &certificateRequestMsgTLS13{}, 37 &certificateMsgTLS13{}, 38} 39 40func TestMarshalUnmarshal(t *testing.T) { 41 rand := rand.New(rand.NewSource(time.Now().UnixNano())) 42 43 for i, iface := range tests { 44 ty := reflect.ValueOf(iface).Type() 45 46 n := 100 47 if testing.Short() { 48 n = 5 49 } 50 for j := 0; j < n; j++ { 51 v, ok := quick.Value(ty, rand) 52 if !ok { 53 t.Errorf("#%d: failed to create value", i) 54 break 55 } 56 57 m1 := v.Interface().(handshakeMessage) 58 marshaled := m1.marshal() 59 m2 := iface.(handshakeMessage) 60 if !m2.unmarshal(marshaled) { 61 t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled) 62 break 63 } 64 m2.marshal() // to fill any marshal cache in the message 65 66 if !reflect.DeepEqual(m1, m2) { 67 t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled) 68 break 69 } 70 71 if i >= 3 { 72 // The first three message types (ClientHello, 73 // ServerHello and Finished) are allowed to 74 // have parsable prefixes because the extension 75 // data is optional and the length of the 76 // Finished varies across versions. 77 for j := 0; j < len(marshaled); j++ { 78 if m2.unmarshal(marshaled[0:j]) { 79 t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) 80 break 81 } 82 } 83 } 84 } 85 } 86} 87 88func TestFuzz(t *testing.T) { 89 rand := rand.New(rand.NewSource(0)) 90 for _, iface := range tests { 91 m := iface.(handshakeMessage) 92 93 for j := 0; j < 1000; j++ { 94 len := rand.Intn(100) 95 bytes := randomBytes(len, rand) 96 // This just looks for crashes due to bounds errors etc. 97 m.unmarshal(bytes) 98 } 99 } 100} 101 102func randomBytes(n int, rand *rand.Rand) []byte { 103 r := make([]byte, n) 104 if _, err := rand.Read(r); err != nil { 105 panic("rand.Read failed: " + err.Error()) 106 } 107 return r 108} 109 110func randomString(n int, rand *rand.Rand) string { 111 b := randomBytes(n, rand) 112 return string(b) 113} 114 115func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 116 m := &clientHelloMsg{} 117 m.vers = uint16(rand.Intn(65536)) 118 m.random = randomBytes(32, rand) 119 m.sessionId = randomBytes(rand.Intn(32), rand) 120 m.cipherSuites = make([]uint16, rand.Intn(63)+1) 121 for i := 0; i < len(m.cipherSuites); i++ { 122 cs := uint16(rand.Int31()) 123 if cs == scsvRenegotiation { 124 cs += 1 125 } 126 m.cipherSuites[i] = cs 127 } 128 m.compressionMethods = randomBytes(rand.Intn(63)+1, rand) 129 if rand.Intn(10) > 5 { 130 m.serverName = randomString(rand.Intn(255), rand) 131 for strings.HasSuffix(m.serverName, ".") { 132 m.serverName = m.serverName[:len(m.serverName)-1] 133 } 134 } 135 m.ocspStapling = rand.Intn(10) > 5 136 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 137 m.supportedCurves = make([]CurveID, rand.Intn(5)+1) 138 for i := range m.supportedCurves { 139 m.supportedCurves[i] = CurveID(rand.Intn(30000) + 1) 140 } 141 if rand.Intn(10) > 5 { 142 m.ticketSupported = true 143 if rand.Intn(10) > 5 { 144 m.sessionTicket = randomBytes(rand.Intn(300), rand) 145 } else { 146 m.sessionTicket = make([]byte, 0) 147 } 148 } 149 if rand.Intn(10) > 5 { 150 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 151 } 152 if rand.Intn(10) > 5 { 153 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 154 } 155 for i := 0; i < rand.Intn(5); i++ { 156 m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand)) 157 } 158 if rand.Intn(10) > 5 { 159 m.scts = true 160 } 161 if rand.Intn(10) > 5 { 162 m.secureRenegotiationSupported = true 163 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 164 } 165 for i := 0; i < rand.Intn(5); i++ { 166 m.supportedVersions = append(m.supportedVersions, uint16(rand.Intn(0xffff)+1)) 167 } 168 if rand.Intn(10) > 5 { 169 m.cookie = randomBytes(rand.Intn(500)+1, rand) 170 } 171 for i := 0; i < rand.Intn(5); i++ { 172 var ks keyShare 173 ks.group = CurveID(rand.Intn(30000) + 1) 174 ks.data = randomBytes(rand.Intn(200)+1, rand) 175 m.keyShares = append(m.keyShares, ks) 176 } 177 switch rand.Intn(3) { 178 case 1: 179 m.pskModes = []uint8{pskModeDHE} 180 case 2: 181 m.pskModes = []uint8{pskModeDHE, pskModePlain} 182 } 183 for i := 0; i < rand.Intn(5); i++ { 184 var psk pskIdentity 185 psk.obfuscatedTicketAge = uint32(rand.Intn(500000)) 186 psk.label = randomBytes(rand.Intn(500)+1, rand) 187 m.pskIdentities = append(m.pskIdentities, psk) 188 m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) 189 } 190 if rand.Intn(10) > 5 { 191 m.earlyData = true 192 } 193 if numExt := rand.Intn(10); numExt > 0 { 194 extType := 1000 + uint16(rand.Intn(5000)) 195 length := rand.Intn(50) 196 m.additionalExtensions = append(m.additionalExtensions, 197 Extension{Type: extType, Data: randomBytes(length, rand)}) 198 } 199 200 return reflect.ValueOf(m) 201} 202 203func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { 204 m := &serverHelloMsg{} 205 m.vers = uint16(rand.Intn(65536)) 206 m.random = randomBytes(32, rand) 207 m.sessionId = randomBytes(rand.Intn(32), rand) 208 m.cipherSuite = uint16(rand.Int31()) 209 m.compressionMethod = uint8(rand.Intn(256)) 210 m.supportedPoints = randomBytes(rand.Intn(5)+1, rand) 211 212 if rand.Intn(10) > 5 { 213 m.ocspStapling = true 214 } 215 if rand.Intn(10) > 5 { 216 m.ticketSupported = true 217 } 218 if rand.Intn(10) > 5 { 219 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 220 } 221 222 for i := 0; i < rand.Intn(4); i++ { 223 m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) 224 } 225 226 if rand.Intn(10) > 5 { 227 m.secureRenegotiationSupported = true 228 m.secureRenegotiation = randomBytes(rand.Intn(50)+1, rand) 229 } 230 if rand.Intn(10) > 5 { 231 m.supportedVersion = uint16(rand.Intn(0xffff) + 1) 232 } 233 if rand.Intn(10) > 5 { 234 m.cookie = randomBytes(rand.Intn(500)+1, rand) 235 } 236 if rand.Intn(10) > 5 { 237 for i := 0; i < rand.Intn(5); i++ { 238 m.serverShare.group = CurveID(rand.Intn(30000) + 1) 239 m.serverShare.data = randomBytes(rand.Intn(200)+1, rand) 240 } 241 } else if rand.Intn(10) > 5 { 242 m.selectedGroup = CurveID(rand.Intn(30000) + 1) 243 } 244 if rand.Intn(10) > 5 { 245 m.selectedIdentityPresent = true 246 m.selectedIdentity = uint16(rand.Intn(0xffff)) 247 } 248 249 return reflect.ValueOf(m) 250} 251 252func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { 253 m := &encryptedExtensionsMsg{} 254 255 if rand.Intn(10) > 5 { 256 m.alpnProtocol = randomString(rand.Intn(32)+1, rand) 257 } 258 if rand.Intn(10) > 5 { 259 m.earlyData = true 260 } 261 262 if numExt := rand.Intn(4); numExt > 0 { 263 for i := 0; i < numExt; i++ { 264 extType := 1000 + uint16(rand.Intn(5000)) 265 length := rand.Intn(50) 266 m.additionalExtensions = append(m.additionalExtensions, 267 Extension{Type: extType, Data: randomBytes(length, rand)}) 268 } 269 } 270 271 return reflect.ValueOf(m) 272} 273 274func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 275 m := &certificateMsg{} 276 numCerts := rand.Intn(20) 277 m.certificates = make([][]byte, numCerts) 278 for i := 0; i < numCerts; i++ { 279 m.certificates[i] = randomBytes(rand.Intn(10)+1, rand) 280 } 281 return reflect.ValueOf(m) 282} 283 284func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value { 285 m := &certificateRequestMsg{} 286 m.certificateTypes = randomBytes(rand.Intn(5)+1, rand) 287 for i := 0; i < rand.Intn(100); i++ { 288 m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand)) 289 } 290 return reflect.ValueOf(m) 291} 292 293func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { 294 m := &certificateVerifyMsg{} 295 m.hasSignatureAlgorithm = true 296 m.signatureAlgorithm = SignatureScheme(rand.Intn(30000)) 297 m.signature = randomBytes(rand.Intn(15)+1, rand) 298 return reflect.ValueOf(m) 299} 300 301func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { 302 m := &certificateStatusMsg{} 303 m.response = randomBytes(rand.Intn(10)+1, rand) 304 return reflect.ValueOf(m) 305} 306 307func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value { 308 m := &clientKeyExchangeMsg{} 309 m.ciphertext = randomBytes(rand.Intn(1000)+1, rand) 310 return reflect.ValueOf(m) 311} 312 313func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value { 314 m := &finishedMsg{} 315 m.verifyData = randomBytes(12, rand) 316 return reflect.ValueOf(m) 317} 318 319func (*newSessionTicketMsg) Generate(rand *rand.Rand, size int) reflect.Value { 320 m := &newSessionTicketMsg{} 321 m.ticket = randomBytes(rand.Intn(4), rand) 322 return reflect.ValueOf(m) 323} 324 325func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { 326 s := &sessionState{} 327 s.vers = uint16(rand.Intn(10000)) 328 s.cipherSuite = uint16(rand.Intn(10000)) 329 s.masterSecret = randomBytes(rand.Intn(100)+1, rand) 330 s.createdAt = uint64(rand.Int63()) 331 for i := 0; i < rand.Intn(20); i++ { 332 s.certificates = append(s.certificates, randomBytes(rand.Intn(500)+1, rand)) 333 } 334 return reflect.ValueOf(s) 335} 336 337func (*sessionStateTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 338 s := &sessionStateTLS13{} 339 s.cipherSuite = uint16(rand.Intn(10000)) 340 s.resumptionSecret = randomBytes(rand.Intn(100)+1, rand) 341 s.createdAt = uint64(rand.Int63()) 342 s.maxEarlyData = uint32(rand.Int31()) 343 s.appData = randomBytes(rand.Intn(100)+1, rand) 344 for i := 0; i < rand.Intn(2)+1; i++ { 345 s.certificate.Certificate = append( 346 s.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 347 } 348 if rand.Intn(10) > 5 { 349 s.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 350 } 351 if rand.Intn(10) > 5 { 352 for i := 0; i < rand.Intn(2)+1; i++ { 353 s.certificate.SignedCertificateTimestamps = append( 354 s.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 355 } 356 } 357 s.alpn = randomString(6, rand) 358 return reflect.ValueOf(s) 359} 360 361func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { 362 m := &endOfEarlyDataMsg{} 363 return reflect.ValueOf(m) 364} 365 366func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { 367 m := &keyUpdateMsg{} 368 m.updateRequested = rand.Intn(10) > 5 369 return reflect.ValueOf(m) 370} 371 372func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 373 m := &newSessionTicketMsgTLS13{} 374 m.lifetime = uint32(rand.Intn(500000)) 375 m.ageAdd = uint32(rand.Intn(500000)) 376 m.nonce = randomBytes(rand.Intn(100), rand) 377 m.label = randomBytes(rand.Intn(1000), rand) 378 if rand.Intn(10) > 5 { 379 m.maxEarlyData = uint32(rand.Intn(500000)) 380 } 381 return reflect.ValueOf(m) 382} 383 384func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 385 m := &certificateRequestMsgTLS13{} 386 if rand.Intn(10) > 5 { 387 m.ocspStapling = true 388 } 389 if rand.Intn(10) > 5 { 390 m.scts = true 391 } 392 if rand.Intn(10) > 5 { 393 m.supportedSignatureAlgorithms = supportedSignatureAlgorithms 394 } 395 if rand.Intn(10) > 5 { 396 m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms 397 } 398 if rand.Intn(10) > 5 { 399 m.certificateAuthorities = make([][]byte, 3) 400 for i := 0; i < 3; i++ { 401 m.certificateAuthorities[i] = randomBytes(rand.Intn(10)+1, rand) 402 } 403 } 404 return reflect.ValueOf(m) 405} 406 407func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { 408 m := &certificateMsgTLS13{} 409 for i := 0; i < rand.Intn(2)+1; i++ { 410 m.certificate.Certificate = append( 411 m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) 412 } 413 if rand.Intn(10) > 5 { 414 m.ocspStapling = true 415 m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) 416 } 417 if rand.Intn(10) > 5 { 418 m.scts = true 419 for i := 0; i < rand.Intn(2)+1; i++ { 420 m.certificate.SignedCertificateTimestamps = append( 421 m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) 422 } 423 } 424 return reflect.ValueOf(m) 425} 426 427func TestRejectEmptySCTList(t *testing.T) { 428 // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. 429 430 var random [32]byte 431 sct := []byte{0x42, 0x42, 0x42, 0x42} 432 serverHello := serverHelloMsg{ 433 vers: VersionTLS12, 434 random: random[:], 435 scts: [][]byte{sct}, 436 } 437 serverHelloBytes := serverHello.marshal() 438 439 var serverHelloCopy serverHelloMsg 440 if !serverHelloCopy.unmarshal(serverHelloBytes) { 441 t.Fatal("Failed to unmarshal initial message") 442 } 443 444 // Change serverHelloBytes so that the SCT list is empty 445 i := bytes.Index(serverHelloBytes, sct) 446 if i < 0 { 447 t.Fatal("Cannot find SCT in ServerHello") 448 } 449 450 var serverHelloEmptySCT []byte 451 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) 452 // Append the extension length and SCT list length for an empty list. 453 serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) 454 serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) 455 456 // Update the handshake message length. 457 serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) 458 serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) 459 serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) 460 461 // Update the extensions length 462 serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) 463 serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) 464 465 if serverHelloCopy.unmarshal(serverHelloEmptySCT) { 466 t.Fatal("Unmarshaled ServerHello with empty SCT list") 467 } 468} 469 470func TestRejectEmptySCT(t *testing.T) { 471 // Not only must the SCT list be non-empty, but the SCT elements must 472 // not be zero length. 473 474 var random [32]byte 475 serverHello := serverHelloMsg{ 476 vers: VersionTLS12, 477 random: random[:], 478 scts: [][]byte{nil}, 479 } 480 serverHelloBytes := serverHello.marshal() 481 482 var serverHelloCopy serverHelloMsg 483 if serverHelloCopy.unmarshal(serverHelloBytes) { 484 t.Fatal("Unmarshaled ServerHello with zero-length SCT") 485 } 486} 487