1# Copyright 2005 Divmod, Inc. See LICENSE file for details 2# Copyright (c) Twisted Matrix Laboratories. 3# See LICENSE for details. 4 5""" 6Tests for L{twisted.internet._sslverify}. 7""" 8 9from __future__ import division, absolute_import 10 11import itertools 12 13from zope.interface import implementer 14 15try: 16 from OpenSSL import SSL 17 from OpenSSL.crypto import PKey, X509 18 from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM 19 skipSSL = False 20 if hasattr(SSL.Context, "set_tlsext_servername_callback"): 21 skipSNI = False 22 else: 23 skipSNI = "PyOpenSSL 0.13 or greater required for SNI support." 24except ImportError: 25 skipSSL = "OpenSSL is required for SSL tests." 26 skipSNI = skipSSL 27 28from twisted.test.iosim import connectedServerAndClient 29 30from twisted.internet.error import ConnectionClosed 31from twisted.python.compat import nativeString 32from twisted.python.constants import NamedConstant, Names 33from twisted.python.filepath import FilePath 34 35from twisted.trial import unittest 36from twisted.internet import protocol, defer, reactor 37 38from twisted.internet.error import CertificateError, ConnectionLost 39from twisted.internet import interfaces 40 41if not skipSSL: 42 from twisted.internet.ssl import platformTrust, VerificationError 43 from twisted.internet import _sslverify as sslverify 44 from twisted.protocols.tls import TLSMemoryBIOFactory 45 46# A couple of static PEM-format certificates to be used by various tests. 47A_HOST_CERTIFICATE_PEM = """ 48-----BEGIN CERTIFICATE----- 49 MIIC2jCCAkMCAjA5MA0GCSqGSIb3DQEBBAUAMIG0MQswCQYDVQQGEwJVUzEiMCAG 50 A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u 51 MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo 52 dXNldHRzMScwJQYJKoZIhvcNAQkBFhhub2JvZHlAdHdpc3RlZG1hdHJpeC5jb20x 53 ETAPBgNVBAsTCFNlY3VyaXR5MB4XDTA2MDgxNjAxMDEwOFoXDTA3MDgxNjAxMDEw 54 OFowgbQxCzAJBgNVBAYTAlVTMSIwIAYDVQQDExlleGFtcGxlLnR3aXN0ZWRtYXRy 55 aXguY29tMQ8wDQYDVQQHEwZCb3N0b24xHDAaBgNVBAoTE1R3aXN0ZWQgTWF0cml4 56 IExhYnMxFjAUBgNVBAgTDU1hc3NhY2h1c2V0dHMxJzAlBgkqhkiG9w0BCQEWGG5v 57 Ym9keUB0d2lzdGVkbWF0cml4LmNvbTERMA8GA1UECxMIU2VjdXJpdHkwgZ8wDQYJ 58 KoZIhvcNAQEBBQADgY0AMIGJAoGBAMzH8CDF/U91y/bdbdbJKnLgnyvQ9Ig9ZNZp 59 8hpsu4huil60zF03+Lexg2l1FIfURScjBuaJMR6HiMYTMjhzLuByRZ17KW4wYkGi 60 KXstz03VIKy4Tjc+v4aXFI4XdRw10gGMGQlGGscXF/RSoN84VoDKBfOMWdXeConJ 61 VyC4w3iJAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAviMT4lBoxOgQy32LIgZ4lVCj 62 JNOiZYg8GMQ6y0ugp86X80UjOvkGtNf/R7YgED/giKRN/q/XJiLJDEhzknkocwmO 63 S+4b2XpiaZYxRyKWwL221O7CGmtWYyZl2+92YYmmCiNzWQPfP6BOMlfax0AGLHls 64 fXzCWdG0O/3Lk2SRM0I= 65-----END CERTIFICATE----- 66""" 67 68A_PEER_CERTIFICATE_PEM = """ 69-----BEGIN CERTIFICATE----- 70 MIIC3jCCAkcCAjA6MA0GCSqGSIb3DQEBBAUAMIG2MQswCQYDVQQGEwJVUzEiMCAG 71 A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u 72 MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo 73 dXNldHRzMSkwJwYJKoZIhvcNAQkBFhpzb21lYm9keUB0d2lzdGVkbWF0cml4LmNv 74 bTERMA8GA1UECxMIU2VjdXJpdHkwHhcNMDYwODE2MDEwMTU2WhcNMDcwODE2MDEw 75 MTU2WjCBtjELMAkGA1UEBhMCVVMxIjAgBgNVBAMTGWV4YW1wbGUudHdpc3RlZG1h 76 dHJpeC5jb20xDzANBgNVBAcTBkJvc3RvbjEcMBoGA1UEChMTVHdpc3RlZCBNYXRy 77 aXggTGFiczEWMBQGA1UECBMNTWFzc2FjaHVzZXR0czEpMCcGCSqGSIb3DQEJARYa 78 c29tZWJvZHlAdHdpc3RlZG1hdHJpeC5jb20xETAPBgNVBAsTCFNlY3VyaXR5MIGf 79 MA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCnm+WBlgFNbMlHehib9ePGGDXF+Nz4 80 CjGuUmVBaXCRCiVjg3kSDecwqfb0fqTksBZ+oQ1UBjMcSh7OcvFXJZnUesBikGWE 81 JE4V8Bjh+RmbJ1ZAlUPZ40bAkww0OpyIRAGMvKG+4yLFTO4WDxKmfDcrOb6ID8WJ 82 e1u+i3XGkIf/5QIDAQABMA0GCSqGSIb3DQEBBAUAA4GBAD4Oukm3YYkhedUepBEA 83 vvXIQhVDqL7mk6OqYdXmNj6R7ZMC8WWvGZxrzDI1bZuB+4aIxxd1FXC3UOHiR/xg 84 i9cDl1y8P/qRp4aEBNF6rI0D4AxTbfnHQx4ERDAOShJdYZs/2zifPJ6va6YvrEyr 85 yqDtGhklsWW3ZwBzEh5VEOUp 86-----END CERTIFICATE----- 87""" 88 89 90 91def counter(counter=itertools.count()): 92 """ 93 Each time we're called, return the next integer in the natural numbers. 94 """ 95 return next(counter) 96 97 98 99def makeCertificate(**kw): 100 keypair = PKey() 101 keypair.generate_key(TYPE_RSA, 768) 102 103 certificate = X509() 104 certificate.gmtime_adj_notBefore(0) 105 certificate.gmtime_adj_notAfter(60 * 60 * 24 * 365) # One year 106 for xname in certificate.get_issuer(), certificate.get_subject(): 107 for (k, v) in kw.items(): 108 setattr(xname, k, nativeString(v)) 109 110 certificate.set_serial_number(counter()) 111 certificate.set_pubkey(keypair) 112 certificate.sign(keypair, "md5") 113 114 return keypair, certificate 115 116 117 118def certificatesForAuthorityAndServer(commonName=b'example.com'): 119 """ 120 Create a self-signed CA certificate and server certificate signed by the 121 CA. 122 123 @param commonName: The C{commonName} to embed in the certificate. 124 @type commonName: L{bytes} 125 126 @return: a 2-tuple of C{(certificate_authority_certificate, 127 server_certificate)} 128 @rtype: L{tuple} of (L{sslverify.Certificate}, 129 L{sslverify.PrivateCertificate}) 130 """ 131 serverDN = sslverify.DistinguishedName(commonName=commonName) 132 serverKey = sslverify.KeyPair.generate() 133 serverCertReq = serverKey.certificateRequest(serverDN) 134 135 caDN = sslverify.DistinguishedName(commonName=b'CA') 136 caKey= sslverify.KeyPair.generate() 137 caCertReq = caKey.certificateRequest(caDN) 138 caSelfCertData = caKey.signCertificateRequest( 139 caDN, caCertReq, lambda dn: True, 516) 140 caSelfCert = caKey.newCertificate(caSelfCertData) 141 142 serverCertData = caKey.signCertificateRequest( 143 caDN, serverCertReq, lambda dn: True, 516) 144 serverCert = serverKey.newCertificate(serverCertData) 145 return caSelfCert, serverCert 146 147 148 149def loopbackTLSConnection(trustRoot, privateKeyFile, chainedCertFile=None): 150 """ 151 Create a loopback TLS connection with the given trust and keys. 152 153 @param trustRoot: the C{trustRoot} argument for the client connection's 154 context. 155 @type trustRoot: L{sslverify.IOpenSSLTrustRoot} 156 157 @param privateKeyFile: The name of the file containing the private key. 158 @type privateKeyFile: L{str} (native string; file name) 159 160 @param chainedCertFile: The name of the chained certificate file. 161 @type chainedCertFile: L{str} (native string; file name) 162 163 @return: 3-tuple of server-protocol, client-protocol, and L{IOPump} 164 @rtype: L{tuple} 165 """ 166 class ContextFactory(object): 167 def getContext(self): 168 """ 169 Create a context for the server side of the connection. 170 171 @return: an SSL context using a certificate and key. 172 @rtype: C{OpenSSL.SSL.Context} 173 """ 174 ctx = SSL.Context(SSL.TLSv1_METHOD) 175 if chainedCertFile is not None: 176 ctx.use_certificate_chain_file(chainedCertFile) 177 ctx.use_privatekey_file(privateKeyFile) 178 # Let the test author know if they screwed something up. 179 ctx.check_privatekey() 180 return ctx 181 182 class GreetingServer(protocol.Protocol): 183 greeting = b"greetings!" 184 def connectionMade(self): 185 self.transport.write(self.greeting) 186 187 class ListeningClient(protocol.Protocol): 188 data = b'' 189 lostReason = None 190 def dataReceived(self, data): 191 self.data += data 192 def connectionLost(self, reason): 193 self.lostReason = reason 194 195 serverOpts = ContextFactory() 196 clientOpts = sslverify.OpenSSLCertificateOptions(trustRoot=trustRoot) 197 198 clientFactory = TLSMemoryBIOFactory( 199 clientOpts, isClient=True, 200 wrappedFactory=protocol.Factory.forProtocol(GreetingServer) 201 ) 202 serverFactory = TLSMemoryBIOFactory( 203 serverOpts, isClient=False, 204 wrappedFactory=protocol.Factory.forProtocol(ListeningClient) 205 ) 206 207 sProto, cProto, pump = connectedServerAndClient( 208 lambda: serverFactory.buildProtocol(None), 209 lambda: clientFactory.buildProtocol(None) 210 ) 211 return sProto, cProto, pump 212 213 214 215def pathContainingDumpOf(testCase, *dumpables): 216 """ 217 Create a temporary file to store some serializable-as-PEM objects in, and 218 return its name. 219 220 @param testCase: a test case to use for generating a temporary directory. 221 @type testCase: L{twisted.trial.unittest.TestCase} 222 223 @param dumpables: arguments are objects from pyOpenSSL with a C{dump} 224 method, taking a pyOpenSSL file-type constant, such as 225 L{OpenSSL.crypto.FILETYPE_PEM} or L{OpenSSL.crypto.FILETYPE_ASN1}. 226 @type dumpables: L{tuple} of L{object} with C{dump} method taking L{int} 227 returning L{bytes} 228 229 @return: the path to a file where all of the dumpables were dumped in PEM 230 format. 231 @rtype: L{str} 232 """ 233 fname = testCase.mktemp() 234 with open(fname, "wb") as f: 235 for dumpable in dumpables: 236 f.write(dumpable.dump(FILETYPE_PEM)) 237 return fname 238 239 240 241class DataCallbackProtocol(protocol.Protocol): 242 def dataReceived(self, data): 243 d, self.factory.onData = self.factory.onData, None 244 if d is not None: 245 d.callback(data) 246 247 def connectionLost(self, reason): 248 d, self.factory.onLost = self.factory.onLost, None 249 if d is not None: 250 d.errback(reason) 251 252class WritingProtocol(protocol.Protocol): 253 byte = b'x' 254 def connectionMade(self): 255 self.transport.write(self.byte) 256 257 def connectionLost(self, reason): 258 self.factory.onLost.errback(reason) 259 260 261 262class FakeContext(object): 263 """ 264 Introspectable fake of an C{OpenSSL.SSL.Context}. 265 266 Saves call arguments for later introspection. 267 268 Necessary because C{Context} offers poor introspection. cf. this 269 U{pyOpenSSL bug<https://bugs.launchpad.net/pyopenssl/+bug/1173899>}. 270 271 @ivar _method: See C{method} parameter of L{__init__}. 272 273 @ivar _options: C{int} of C{OR}ed values from calls of L{set_options}. 274 275 @ivar _certificate: Set by L{use_certificate}. 276 277 @ivar _privateKey: Set by L{use_privatekey}. 278 279 @ivar _verify: Set by L{set_verify}. 280 281 @ivar _verifyDepth: Set by L{set_verify_depth}. 282 283 @ivar _sessionID: Set by L{set_session_id}. 284 285 @ivar _extraCertChain: Accumulated C{list} of all extra certificates added 286 by L{add_extra_chain_cert}. 287 288 @ivar _cipherList: Set by L{set_cipher_list}. 289 290 @ivar _dhFilename: Set by L{load_tmp_dh}. 291 292 @ivar _defaultVerifyPathsSet: Set by L{set_default_verify_paths} 293 """ 294 _options = 0 295 296 def __init__(self, method): 297 self._method = method 298 self._extraCertChain = [] 299 self._defaultVerifyPathsSet = False 300 301 302 def set_options(self, options): 303 self._options |= options 304 305 306 def use_certificate(self, certificate): 307 self._certificate = certificate 308 309 310 def use_privatekey(self, privateKey): 311 self._privateKey = privateKey 312 313 314 def check_privatekey(self): 315 return None 316 317 318 def set_verify(self, flags, callback): 319 self._verify = flags, callback 320 321 322 def set_verify_depth(self, depth): 323 self._verifyDepth = depth 324 325 326 def set_session_id(self, sessionID): 327 self._sessionID = sessionID 328 329 330 def add_extra_chain_cert(self, cert): 331 self._extraCertChain.append(cert) 332 333 334 def set_cipher_list(self, cipherList): 335 self._cipherList = cipherList 336 337 338 def load_tmp_dh(self, dhfilename): 339 self._dhFilename = dhfilename 340 341 342 def set_default_verify_paths(self): 343 """ 344 Set the default paths for the platform. 345 """ 346 self._defaultVerifyPathsSet = True 347 348 349 350class ClientOptions(unittest.SynchronousTestCase): 351 """ 352 Tests for L{sslverify.optionsForClientTLS}. 353 """ 354 if skipSSL: 355 skip = skipSSL 356 357 def test_extraKeywords(self): 358 """ 359 When passed a keyword parameter other than C{extraCertificateOptions}, 360 L{sslverify.optionsForClientTLS} raises an exception just like a 361 normal Python function would. 362 """ 363 error = self.assertRaises( 364 TypeError, 365 sslverify.optionsForClientTLS, 366 hostname=u'alpha', someRandomThing=u'beta', 367 ) 368 self.assertEqual( 369 str(error), 370 "optionsForClientTLS() got an unexpected keyword argument " 371 "'someRandomThing'" 372 ) 373 374 375 def test_bytesFailFast(self): 376 """ 377 If you pass L{bytes} as the hostname to 378 L{sslverify.optionsForClientTLS} it immediately raises a L{TypeError}. 379 """ 380 error = self.assertRaises( 381 TypeError, 382 sslverify.optionsForClientTLS, b'not-actually-a-hostname.com' 383 ) 384 expectedText = ( 385 "optionsForClientTLS requires text for host names, not " + 386 bytes.__name__ 387 ) 388 self.assertEqual(str(error), expectedText) 389 390 391 392class OpenSSLOptions(unittest.TestCase): 393 if skipSSL: 394 skip = skipSSL 395 396 serverPort = clientConn = None 397 onServerLost = onClientLost = None 398 399 sKey = None 400 sCert = None 401 cKey = None 402 cCert = None 403 404 def setUp(self): 405 """ 406 Create class variables of client and server certificates. 407 """ 408 self.sKey, self.sCert = makeCertificate( 409 O=b"Server Test Certificate", 410 CN=b"server") 411 self.cKey, self.cCert = makeCertificate( 412 O=b"Client Test Certificate", 413 CN=b"client") 414 self.caCert1 = makeCertificate( 415 O=b"CA Test Certificate 1", 416 CN=b"ca1")[1] 417 self.caCert2 = makeCertificate( 418 O=b"CA Test Certificate", 419 CN=b"ca2")[1] 420 self.caCerts = [self.caCert1, self.caCert2] 421 self.extraCertChain = self.caCerts 422 423 424 def tearDown(self): 425 if self.serverPort is not None: 426 self.serverPort.stopListening() 427 if self.clientConn is not None: 428 self.clientConn.disconnect() 429 430 L = [] 431 if self.onServerLost is not None: 432 L.append(self.onServerLost) 433 if self.onClientLost is not None: 434 L.append(self.onClientLost) 435 436 return defer.DeferredList(L, consumeErrors=True) 437 438 def loopback(self, serverCertOpts, clientCertOpts, 439 onServerLost=None, onClientLost=None, onData=None): 440 if onServerLost is None: 441 self.onServerLost = onServerLost = defer.Deferred() 442 if onClientLost is None: 443 self.onClientLost = onClientLost = defer.Deferred() 444 if onData is None: 445 onData = defer.Deferred() 446 447 serverFactory = protocol.ServerFactory() 448 serverFactory.protocol = DataCallbackProtocol 449 serverFactory.onLost = onServerLost 450 serverFactory.onData = onData 451 452 clientFactory = protocol.ClientFactory() 453 clientFactory.protocol = WritingProtocol 454 clientFactory.onLost = onClientLost 455 456 self.serverPort = reactor.listenSSL(0, serverFactory, serverCertOpts) 457 self.clientConn = reactor.connectSSL('127.0.0.1', 458 self.serverPort.getHost().port, clientFactory, clientCertOpts) 459 460 461 def test_constructorWithOnlyPrivateKey(self): 462 """ 463 C{privateKey} and C{certificate} make only sense if both are set. 464 """ 465 self.assertRaises( 466 ValueError, 467 sslverify.OpenSSLCertificateOptions, privateKey=self.sKey 468 ) 469 470 471 def test_constructorWithOnlyCertificate(self): 472 """ 473 C{privateKey} and C{certificate} make only sense if both are set. 474 """ 475 self.assertRaises( 476 ValueError, 477 sslverify.OpenSSLCertificateOptions, certificate=self.sCert 478 ) 479 480 481 def test_constructorWithCertificateAndPrivateKey(self): 482 """ 483 Specifying C{privateKey} and C{certificate} initializes correctly. 484 """ 485 opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 486 certificate=self.sCert) 487 self.assertEqual(opts.privateKey, self.sKey) 488 self.assertEqual(opts.certificate, self.sCert) 489 self.assertEqual(opts.extraCertChain, []) 490 491 492 def test_constructorDoesNotAllowVerifyWithoutCACerts(self): 493 """ 494 C{verify} must not be C{True} without specifying C{caCerts}. 495 """ 496 self.assertRaises( 497 ValueError, 498 sslverify.OpenSSLCertificateOptions, 499 privateKey=self.sKey, certificate=self.sCert, verify=True 500 ) 501 502 503 def test_constructorDoesNotAllowLegacyWithTrustRoot(self): 504 """ 505 C{verify}, C{requireCertificate}, and C{caCerts} must not be specified 506 by the caller (to be I{any} value, even the default!) when specifying 507 C{trustRoot}. 508 """ 509 self.assertRaises( 510 TypeError, 511 sslverify.OpenSSLCertificateOptions, 512 privateKey=self.sKey, certificate=self.sCert, 513 verify=True, trustRoot=None, caCerts=self.caCerts, 514 ) 515 self.assertRaises( 516 TypeError, 517 sslverify.OpenSSLCertificateOptions, 518 privateKey=self.sKey, certificate=self.sCert, 519 trustRoot=None, requireCertificate=True, 520 ) 521 522 523 def test_constructorAllowsCACertsWithoutVerify(self): 524 """ 525 It's currently a NOP, but valid. 526 """ 527 opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 528 certificate=self.sCert, 529 caCerts=self.caCerts) 530 self.assertFalse(opts.verify) 531 self.assertEqual(self.caCerts, opts.caCerts) 532 533 534 def test_constructorWithVerifyAndCACerts(self): 535 """ 536 Specifying C{verify} and C{caCerts} initializes correctly. 537 """ 538 opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 539 certificate=self.sCert, 540 verify=True, 541 caCerts=self.caCerts) 542 self.assertTrue(opts.verify) 543 self.assertEqual(self.caCerts, opts.caCerts) 544 545 546 def test_constructorSetsExtraChain(self): 547 """ 548 Setting C{extraCertChain} works if C{certificate} and C{privateKey} are 549 set along with it. 550 """ 551 opts = sslverify.OpenSSLCertificateOptions( 552 privateKey=self.sKey, 553 certificate=self.sCert, 554 extraCertChain=self.extraCertChain, 555 ) 556 self.assertEqual(self.extraCertChain, opts.extraCertChain) 557 558 559 def test_constructorDoesNotAllowExtraChainWithoutPrivateKey(self): 560 """ 561 A C{extraCertChain} without C{privateKey} doesn't make sense and is 562 thus rejected. 563 """ 564 self.assertRaises( 565 ValueError, 566 sslverify.OpenSSLCertificateOptions, 567 certificate=self.sCert, 568 extraCertChain=self.extraCertChain, 569 ) 570 571 572 def test_constructorDoesNotAllowExtraChainWithOutPrivateKey(self): 573 """ 574 A C{extraCertChain} without C{certificate} doesn't make sense and is 575 thus rejected. 576 """ 577 self.assertRaises( 578 ValueError, 579 sslverify.OpenSSLCertificateOptions, 580 privateKey=self.sKey, 581 extraCertChain=self.extraCertChain, 582 ) 583 584 585 def test_extraChainFilesAreAddedIfSupplied(self): 586 """ 587 If C{extraCertChain} is set and all prerequisites are met, the 588 specified chain certificates are added to C{Context}s that get 589 created. 590 """ 591 opts = sslverify.OpenSSLCertificateOptions( 592 privateKey=self.sKey, 593 certificate=self.sCert, 594 extraCertChain=self.extraCertChain, 595 ) 596 opts._contextFactory = FakeContext 597 ctx = opts.getContext() 598 self.assertEqual(self.sKey, ctx._privateKey) 599 self.assertEqual(self.sCert, ctx._certificate) 600 self.assertEqual(self.extraCertChain, ctx._extraCertChain) 601 602 603 def test_extraChainDoesNotBreakPyOpenSSL(self): 604 """ 605 C{extraCertChain} doesn't break C{OpenSSL.SSL.Context} creation. 606 """ 607 opts = sslverify.OpenSSLCertificateOptions( 608 privateKey=self.sKey, 609 certificate=self.sCert, 610 extraCertChain=self.extraCertChain, 611 ) 612 ctx = opts.getContext() 613 self.assertIsInstance(ctx, SSL.Context) 614 615 616 def test_acceptableCiphersAreAlwaysSet(self): 617 """ 618 If the user doesn't supply custom acceptable ciphers, a shipped secure 619 default is used. We can't check directly for it because the effective 620 cipher string we set varies with platforms. 621 """ 622 opts = sslverify.OpenSSLCertificateOptions( 623 privateKey=self.sKey, 624 certificate=self.sCert, 625 ) 626 opts._contextFactory = FakeContext 627 ctx = opts.getContext() 628 self.assertEqual(opts._cipherString, ctx._cipherList) 629 630 631 def test_givesMeaningfulErrorMessageIfNoCipherMatches(self): 632 """ 633 If there is no valid cipher that matches the user's wishes, 634 a L{ValueError} is raised. 635 """ 636 self.assertRaises( 637 ValueError, 638 sslverify.OpenSSLCertificateOptions, 639 privateKey=self.sKey, 640 certificate=self.sCert, 641 acceptableCiphers= 642 sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString('') 643 ) 644 645 646 def test_honorsAcceptableCiphersArgument(self): 647 """ 648 If acceptable ciphers are passed, they are used. 649 """ 650 @implementer(interfaces.IAcceptableCiphers) 651 class FakeAcceptableCiphers(object): 652 def selectCiphers(self, _): 653 return [sslverify.OpenSSLCipher(u'sentinel')] 654 655 opts = sslverify.OpenSSLCertificateOptions( 656 privateKey=self.sKey, 657 certificate=self.sCert, 658 acceptableCiphers=FakeAcceptableCiphers(), 659 ) 660 opts._contextFactory = FakeContext 661 ctx = opts.getContext() 662 self.assertEqual(u'sentinel', ctx._cipherList) 663 664 665 def test_basicSecurityOptionsAreSet(self): 666 """ 667 Every context must have C{OP_NO_SSLv2}, C{OP_NO_COMPRESSION}, and 668 C{OP_CIPHER_SERVER_PREFERENCE} set. 669 """ 670 opts = sslverify.OpenSSLCertificateOptions( 671 privateKey=self.sKey, 672 certificate=self.sCert, 673 ) 674 opts._contextFactory = FakeContext 675 ctx = opts.getContext() 676 options = (SSL.OP_NO_SSLv2 | opts._OP_NO_COMPRESSION | 677 opts._OP_CIPHER_SERVER_PREFERENCE) 678 self.assertEqual(options, ctx._options & options) 679 680 681 def test_singleUseKeys(self): 682 """ 683 If C{singleUseKeys} is set, every context must have 684 C{OP_SINGLE_DH_USE} and C{OP_SINGLE_ECDH_USE} set. 685 """ 686 opts = sslverify.OpenSSLCertificateOptions( 687 privateKey=self.sKey, 688 certificate=self.sCert, 689 enableSingleUseKeys=True, 690 ) 691 opts._contextFactory = FakeContext 692 ctx = opts.getContext() 693 options = SSL.OP_SINGLE_DH_USE | opts._OP_SINGLE_ECDH_USE 694 self.assertEqual(options, ctx._options & options) 695 696 697 def test_dhParams(self): 698 """ 699 If C{dhParams} is set, they are loaded into each new context. 700 """ 701 class FakeDiffieHellmanParameters(object): 702 _dhFile = FilePath(b'dh.params') 703 704 dhParams = FakeDiffieHellmanParameters() 705 opts = sslverify.OpenSSLCertificateOptions( 706 privateKey=self.sKey, 707 certificate=self.sCert, 708 dhParameters=dhParams, 709 ) 710 opts._contextFactory = FakeContext 711 ctx = opts.getContext() 712 self.assertEqual( 713 FakeDiffieHellmanParameters._dhFile.path, 714 ctx._dhFilename 715 ) 716 717 718 def test_ecDoesNotBreakConstructor(self): 719 """ 720 Missing ECC does not break the constructor and sets C{_ecCurve} to 721 L{None}. 722 """ 723 def raiser(self): 724 raise NotImplementedError 725 self.patch(sslverify._OpenSSLECCurve, "_getBinding", raiser) 726 727 opts = sslverify.OpenSSLCertificateOptions( 728 privateKey=self.sKey, 729 certificate=self.sCert, 730 ) 731 self.assertIs(None, opts._ecCurve) 732 733 734 def test_ecNeverBreaksGetContext(self): 735 """ 736 ECDHE support is best effort only and errors are ignored. 737 """ 738 opts = sslverify.OpenSSLCertificateOptions( 739 privateKey=self.sKey, 740 certificate=self.sCert, 741 ) 742 opts._ecCurve = object() 743 ctx = opts.getContext() 744 self.assertIsInstance(ctx, SSL.Context) 745 746 747 def test_ecSuccessWithRealBindings(self): 748 """ 749 Integration test that checks the positive code path to ensure that we 750 use the API properly. 751 """ 752 try: 753 defaultCurve = sslverify._OpenSSLECCurve( 754 sslverify._defaultCurveName 755 ) 756 except NotImplementedError: 757 raise unittest.SkipTest( 758 "Underlying pyOpenSSL is not based on cryptography." 759 ) 760 opts = sslverify.OpenSSLCertificateOptions( 761 privateKey=self.sKey, 762 certificate=self.sCert, 763 ) 764 self.assertEqual(defaultCurve, opts._ecCurve) 765 # Exercise positive code path. getContext swallows errors so we do it 766 # explicitly by hand. 767 opts._ecCurve.addECKeyToContext(opts.getContext()) 768 769 770 def test_abbreviatingDistinguishedNames(self): 771 """ 772 Check that abbreviations used in certificates correctly map to 773 complete names. 774 """ 775 self.assertEqual( 776 sslverify.DN(CN=b'a', OU=b'hello'), 777 sslverify.DistinguishedName(commonName=b'a', 778 organizationalUnitName=b'hello')) 779 self.assertNotEquals( 780 sslverify.DN(CN=b'a', OU=b'hello'), 781 sslverify.DN(CN=b'a', OU=b'hello', emailAddress=b'xxx')) 782 dn = sslverify.DN(CN=b'abcdefg') 783 self.assertRaises(AttributeError, setattr, dn, 'Cn', b'x') 784 self.assertEqual(dn.CN, dn.commonName) 785 dn.CN = b'bcdefga' 786 self.assertEqual(dn.CN, dn.commonName) 787 788 789 def testInspectDistinguishedName(self): 790 n = sslverify.DN(commonName=b'common name', 791 organizationName=b'organization name', 792 organizationalUnitName=b'organizational unit name', 793 localityName=b'locality name', 794 stateOrProvinceName=b'state or province name', 795 countryName=b'country name', 796 emailAddress=b'email address') 797 s = n.inspect() 798 for k in [ 799 'common name', 800 'organization name', 801 'organizational unit name', 802 'locality name', 803 'state or province name', 804 'country name', 805 'email address']: 806 self.assertIn(k, s, "%r was not in inspect output." % (k,)) 807 self.assertIn(k.title(), s, "%r was not in inspect output." % (k,)) 808 809 810 def testInspectDistinguishedNameWithoutAllFields(self): 811 n = sslverify.DN(localityName=b'locality name') 812 s = n.inspect() 813 for k in [ 814 'common name', 815 'organization name', 816 'organizational unit name', 817 'state or province name', 818 'country name', 819 'email address']: 820 self.assertNotIn(k, s, "%r was in inspect output." % (k,)) 821 self.assertNotIn(k.title(), s, "%r was in inspect output." % (k,)) 822 self.assertIn('locality name', s) 823 self.assertIn('Locality Name', s) 824 825 826 def test_inspectCertificate(self): 827 """ 828 Test that the C{inspect} method of L{sslverify.Certificate} returns 829 a human-readable string containing some basic information about the 830 certificate. 831 """ 832 c = sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM) 833 self.assertEqual( 834 c.inspect().split('\n'), 835 ["Certificate For Subject:", 836 " Common Name: example.twistedmatrix.com", 837 " Country Name: US", 838 " Email Address: nobody@twistedmatrix.com", 839 " Locality Name: Boston", 840 " Organization Name: Twisted Matrix Labs", 841 " Organizational Unit Name: Security", 842 " State Or Province Name: Massachusetts", 843 "", 844 "Issuer:", 845 " Common Name: example.twistedmatrix.com", 846 " Country Name: US", 847 " Email Address: nobody@twistedmatrix.com", 848 " Locality Name: Boston", 849 " Organization Name: Twisted Matrix Labs", 850 " Organizational Unit Name: Security", 851 " State Or Province Name: Massachusetts", 852 "", 853 "Serial Number: 12345", 854 "Digest: C4:96:11:00:30:C3:EC:EE:A3:55:AA:ED:8C:84:85:18", 855 "Public Key with Hash: ff33994c80812aa95a79cdb85362d054"]) 856 857 858 def test_certificateOptionsSerialization(self): 859 """ 860 Test that __setstate__(__getstate__()) round-trips properly. 861 """ 862 firstOpts = sslverify.OpenSSLCertificateOptions( 863 privateKey=self.sKey, 864 certificate=self.sCert, 865 method=SSL.SSLv3_METHOD, 866 verify=True, 867 caCerts=[self.sCert], 868 verifyDepth=2, 869 requireCertificate=False, 870 verifyOnce=False, 871 enableSingleUseKeys=False, 872 enableSessions=False, 873 fixBrokenPeers=True, 874 enableSessionTickets=True) 875 context = firstOpts.getContext() 876 self.assertIdentical(context, firstOpts._context) 877 self.assertNotIdentical(context, None) 878 state = firstOpts.__getstate__() 879 self.assertNotIn("_context", state) 880 881 opts = sslverify.OpenSSLCertificateOptions() 882 opts.__setstate__(state) 883 self.assertEqual(opts.privateKey, self.sKey) 884 self.assertEqual(opts.certificate, self.sCert) 885 self.assertEqual(opts.method, SSL.SSLv3_METHOD) 886 self.assertEqual(opts.verify, True) 887 self.assertEqual(opts.caCerts, [self.sCert]) 888 self.assertEqual(opts.verifyDepth, 2) 889 self.assertEqual(opts.requireCertificate, False) 890 self.assertEqual(opts.verifyOnce, False) 891 self.assertEqual(opts.enableSingleUseKeys, False) 892 self.assertEqual(opts.enableSessions, False) 893 self.assertEqual(opts.fixBrokenPeers, True) 894 self.assertEqual(opts.enableSessionTickets, True) 895 896 897 def test_certificateOptionsSessionTickets(self): 898 """ 899 Enabling session tickets should not set the OP_NO_TICKET option. 900 """ 901 opts = sslverify.OpenSSLCertificateOptions(enableSessionTickets=True) 902 ctx = opts.getContext() 903 self.assertEqual(0, ctx.set_options(0) & 0x00004000) 904 905 906 def test_certificateOptionsSessionTicketsDisabled(self): 907 """ 908 Enabling session tickets should set the OP_NO_TICKET option. 909 """ 910 opts = sslverify.OpenSSLCertificateOptions(enableSessionTickets=False) 911 ctx = opts.getContext() 912 self.assertEqual(0x00004000, ctx.set_options(0) & 0x00004000) 913 914 915 def test_allowedAnonymousClientConnection(self): 916 """ 917 Check that anonymous connections are allowed when certificates aren't 918 required on the server. 919 """ 920 onData = defer.Deferred() 921 self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 922 certificate=self.sCert, requireCertificate=False), 923 sslverify.OpenSSLCertificateOptions( 924 requireCertificate=False), 925 onData=onData) 926 927 return onData.addCallback( 928 lambda result: self.assertEqual(result, WritingProtocol.byte)) 929 930 931 def test_refusedAnonymousClientConnection(self): 932 """ 933 Check that anonymous connections are refused when certificates are 934 required on the server. 935 """ 936 onServerLost = defer.Deferred() 937 onClientLost = defer.Deferred() 938 self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 939 certificate=self.sCert, verify=True, 940 caCerts=[self.sCert], requireCertificate=True), 941 sslverify.OpenSSLCertificateOptions( 942 requireCertificate=False), 943 onServerLost=onServerLost, 944 onClientLost=onClientLost) 945 946 d = defer.DeferredList([onClientLost, onServerLost], 947 consumeErrors=True) 948 949 950 def afterLost(result): 951 ((cSuccess, cResult), (sSuccess, sResult)) = result 952 self.failIf(cSuccess) 953 self.failIf(sSuccess) 954 # Win32 fails to report the SSL Error, and report a connection lost 955 # instead: there is a race condition so that's not totally 956 # surprising (see ticket #2877 in the tracker) 957 self.assertIsInstance(cResult.value, (SSL.Error, ConnectionLost)) 958 self.assertIsInstance(sResult.value, SSL.Error) 959 960 return d.addCallback(afterLost) 961 962 def test_failedCertificateVerification(self): 963 """ 964 Check that connecting with a certificate not accepted by the server CA 965 fails. 966 """ 967 onServerLost = defer.Deferred() 968 onClientLost = defer.Deferred() 969 self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 970 certificate=self.sCert, verify=False, 971 requireCertificate=False), 972 sslverify.OpenSSLCertificateOptions(verify=True, 973 requireCertificate=False, caCerts=[self.cCert]), 974 onServerLost=onServerLost, 975 onClientLost=onClientLost) 976 977 d = defer.DeferredList([onClientLost, onServerLost], 978 consumeErrors=True) 979 def afterLost(result): 980 ((cSuccess, cResult), (sSuccess, sResult)) = result 981 self.failIf(cSuccess) 982 self.failIf(sSuccess) 983 984 return d.addCallback(afterLost) 985 986 def test_successfulCertificateVerification(self): 987 """ 988 Test a successful connection with client certificate validation on 989 server side. 990 """ 991 onData = defer.Deferred() 992 self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 993 certificate=self.sCert, verify=False, 994 requireCertificate=False), 995 sslverify.OpenSSLCertificateOptions(verify=True, 996 requireCertificate=True, caCerts=[self.sCert]), 997 onData=onData) 998 999 return onData.addCallback( 1000 lambda result: self.assertEqual(result, WritingProtocol.byte)) 1001 1002 def test_successfulSymmetricSelfSignedCertificateVerification(self): 1003 """ 1004 Test a successful connection with validation on both server and client 1005 sides. 1006 """ 1007 onData = defer.Deferred() 1008 self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey, 1009 certificate=self.sCert, verify=True, 1010 requireCertificate=True, caCerts=[self.cCert]), 1011 sslverify.OpenSSLCertificateOptions(privateKey=self.cKey, 1012 certificate=self.cCert, verify=True, 1013 requireCertificate=True, caCerts=[self.sCert]), 1014 onData=onData) 1015 1016 return onData.addCallback( 1017 lambda result: self.assertEqual(result, WritingProtocol.byte)) 1018 1019 def test_verification(self): 1020 """ 1021 Check certificates verification building custom certificates data. 1022 """ 1023 clientDN = sslverify.DistinguishedName(commonName='client') 1024 clientKey = sslverify.KeyPair.generate() 1025 clientCertReq = clientKey.certificateRequest(clientDN) 1026 1027 serverDN = sslverify.DistinguishedName(commonName='server') 1028 serverKey = sslverify.KeyPair.generate() 1029 serverCertReq = serverKey.certificateRequest(serverDN) 1030 1031 clientSelfCertReq = clientKey.certificateRequest(clientDN) 1032 clientSelfCertData = clientKey.signCertificateRequest( 1033 clientDN, clientSelfCertReq, lambda dn: True, 132) 1034 clientSelfCert = clientKey.newCertificate(clientSelfCertData) 1035 1036 serverSelfCertReq = serverKey.certificateRequest(serverDN) 1037 serverSelfCertData = serverKey.signCertificateRequest( 1038 serverDN, serverSelfCertReq, lambda dn: True, 516) 1039 serverSelfCert = serverKey.newCertificate(serverSelfCertData) 1040 1041 clientCertData = serverKey.signCertificateRequest( 1042 serverDN, clientCertReq, lambda dn: True, 7) 1043 clientCert = clientKey.newCertificate(clientCertData) 1044 1045 serverCertData = clientKey.signCertificateRequest( 1046 clientDN, serverCertReq, lambda dn: True, 42) 1047 serverCert = serverKey.newCertificate(serverCertData) 1048 1049 onData = defer.Deferred() 1050 1051 serverOpts = serverCert.options(serverSelfCert) 1052 clientOpts = clientCert.options(clientSelfCert) 1053 1054 self.loopback(serverOpts, 1055 clientOpts, 1056 onData=onData) 1057 1058 return onData.addCallback( 1059 lambda result: self.assertEqual(result, WritingProtocol.byte)) 1060 1061 1062 1063class ProtocolVersion(Names): 1064 """ 1065 L{ProtocolVersion} provides constants representing each version of the 1066 SSL/TLS protocol. 1067 """ 1068 SSLv2 = NamedConstant() 1069 SSLv3 = NamedConstant() 1070 TLSv1_0 = NamedConstant() 1071 TLSv1_1 = NamedConstant() 1072 TLSv1_2 = NamedConstant() 1073 1074 1075 1076class ProtocolVersionTests(unittest.TestCase): 1077 """ 1078 Tests for L{sslverify.OpenSSLCertificateOptions}'s SSL/TLS version 1079 selection features. 1080 """ 1081 if skipSSL: 1082 skip = skipSSL 1083 else: 1084 _METHOD_TO_PROTOCOL = { 1085 SSL.SSLv2_METHOD: set([ProtocolVersion.SSLv2]), 1086 SSL.SSLv3_METHOD: set([ProtocolVersion.SSLv3]), 1087 SSL.TLSv1_METHOD: set([ProtocolVersion.TLSv1_0]), 1088 getattr(SSL, "TLSv1_1_METHOD", object()): 1089 set([ProtocolVersion.TLSv1_1]), 1090 getattr(SSL, "TLSv1_2_METHOD", object()): 1091 set([ProtocolVersion.TLSv1_2]), 1092 1093 # Presently, SSLv23_METHOD means (SSLv2, SSLv3, TLSv1.0, TLSv1.1, 1094 # TLSv1.2) (excluding any protocol versions not implemented by the 1095 # underlying version of OpenSSL). 1096 SSL.SSLv23_METHOD: set(ProtocolVersion.iterconstants()), 1097 } 1098 1099 _EXCLUSION_OPS = { 1100 SSL.OP_NO_SSLv2: ProtocolVersion.SSLv2, 1101 SSL.OP_NO_SSLv3: ProtocolVersion.SSLv3, 1102 SSL.OP_NO_TLSv1: ProtocolVersion.TLSv1_0, 1103 getattr(SSL, "OP_NO_TLSv1_1", 0): ProtocolVersion.TLSv1_1, 1104 getattr(SSL, "OP_NO_TLSv1_2", 0): ProtocolVersion.TLSv1_2, 1105 } 1106 1107 1108 def _protocols(self, opts): 1109 """ 1110 Determine which SSL/TLS protocol versions are allowed by C{opts}. 1111 1112 @param opts: An L{sslverify.OpenSSLCertificateOptions} instance to 1113 inspect. 1114 1115 @return: A L{set} of L{NamedConstant}s from L{ProtocolVersion} 1116 indicating which SSL/TLS protocol versions connections negotiated 1117 using C{opts} will allow. 1118 """ 1119 protocols = self._METHOD_TO_PROTOCOL[opts.method].copy() 1120 context = opts.getContext() 1121 options = context.set_options(0) 1122 if opts.method == SSL.SSLv23_METHOD: 1123 # Exclusions apply only to SSLv23_METHOD and no others. 1124 for opt, exclude in self._EXCLUSION_OPS.items(): 1125 if options & opt: 1126 protocols.discard(exclude) 1127 return protocols 1128 1129 1130 def test_default(self): 1131 """ 1132 When L{sslverify.OpenSSLCertificateOptions} is initialized with no 1133 specific protocol versions all versions of TLS are allowed and no 1134 versions of SSL are allowed. 1135 """ 1136 self.assertEqual( 1137 set([ProtocolVersion.TLSv1_0, 1138 ProtocolVersion.TLSv1_1, 1139 ProtocolVersion.TLSv1_2]), 1140 self._protocols(sslverify.OpenSSLCertificateOptions())) 1141 1142 1143 def test_SSLv23(self): 1144 """ 1145 When L{sslverify.OpenSSLCertificateOptions} is initialized with 1146 C{SSLv23_METHOD} all versions of TLS and SSLv3 are allowed. 1147 """ 1148 self.assertEqual( 1149 set([ProtocolVersion.SSLv3, 1150 ProtocolVersion.TLSv1_0, 1151 ProtocolVersion.TLSv1_1, 1152 ProtocolVersion.TLSv1_2]), 1153 self._protocols(sslverify.OpenSSLCertificateOptions( 1154 method=SSL.SSLv23_METHOD))) 1155 1156 1157 1158class TrustRootTests(unittest.TestCase): 1159 """ 1160 Tests for L{sslverify.OpenSSLCertificateOptions}' C{trustRoot} argument, 1161 L{sslverify.platformTrust}, and their interactions. 1162 """ 1163 if skipSSL: 1164 skip = skipSSL 1165 1166 def test_caCertsPlatformDefaults(self): 1167 """ 1168 Specifying a C{trustRoot} of L{sslverify.OpenSSLDefaultPaths} when 1169 initializing L{sslverify.OpenSSLCertificateOptions} loads the 1170 platform-provided trusted certificates via C{set_default_verify_paths}. 1171 """ 1172 opts = sslverify.OpenSSLCertificateOptions( 1173 trustRoot=sslverify.OpenSSLDefaultPaths(), 1174 ) 1175 fc = FakeContext(SSL.TLSv1_METHOD) 1176 opts._contextFactory = lambda method: fc 1177 opts.getContext() 1178 self.assertTrue(fc._defaultVerifyPathsSet) 1179 1180 1181 def test_trustRootPlatformRejectsUntrustedCA(self): 1182 """ 1183 Specifying a C{trustRoot} of L{platformTrust} when initializing 1184 L{sslverify.OpenSSLCertificateOptions} causes certificates issued by a 1185 newly created CA to be rejected by an SSL connection using these 1186 options. 1187 1188 Note that this test should I{always} pass, even on platforms where the 1189 CA certificates are not installed, as long as L{platformTrust} rejects 1190 completely invalid / unknown root CA certificates. This is simply a 1191 smoke test to make sure that verification is happening at all. 1192 """ 1193 caSelfCert, serverCert = certificatesForAuthorityAndServer() 1194 chainedCert = pathContainingDumpOf(self, serverCert, caSelfCert) 1195 privateKey = pathContainingDumpOf(self, serverCert.privateKey) 1196 1197 sProto, cProto, pump = loopbackTLSConnection( 1198 trustRoot=platformTrust(), 1199 privateKeyFile=privateKey, 1200 chainedCertFile=chainedCert, 1201 ) 1202 # No data was received. 1203 self.assertEqual(cProto.wrappedProtocol.data, b'') 1204 1205 # It was an L{SSL.Error}. 1206 self.assertEqual(cProto.wrappedProtocol.lostReason.type, SSL.Error) 1207 1208 # Some combination of OpenSSL and PyOpenSSL is bad at reporting errors. 1209 err = cProto.wrappedProtocol.lostReason.value 1210 self.assertEqual(err.args[0][0][2], 'tlsv1 alert unknown ca') 1211 1212 1213 def test_trustRootSpecificCertificate(self): 1214 """ 1215 Specifying a L{Certificate} object for L{trustRoot} will result in that 1216 certificate being the only trust root for a client. 1217 """ 1218 caCert, serverCert = certificatesForAuthorityAndServer() 1219 otherCa, otherServer = certificatesForAuthorityAndServer() 1220 sProto, cProto, pump = loopbackTLSConnection( 1221 trustRoot=caCert, 1222 privateKeyFile=pathContainingDumpOf(self, serverCert.privateKey), 1223 chainedCertFile=pathContainingDumpOf(self, serverCert), 1224 ) 1225 pump.flush() 1226 self.assertIs(cProto.wrappedProtocol.lostReason, None) 1227 self.assertEqual(cProto.wrappedProtocol.data, 1228 sProto.wrappedProtocol.greeting) 1229 1230 1231 1232class ServiceIdentityTests(unittest.SynchronousTestCase): 1233 """ 1234 Tests for the verification of the peer's service's identity via the 1235 C{hostname} argument to L{sslverify.OpenSSLCertificateOptions}. 1236 """ 1237 1238 if skipSSL: 1239 skip = skipSSL 1240 1241 def serviceIdentitySetup(self, clientHostname, serverHostname, 1242 serverContextSetup=lambda ctx: None, 1243 validCertificate=True, 1244 clientPresentsCertificate=False, 1245 validClientCertificate=True, 1246 serverVerifies=False, 1247 buggyInfoCallback=False, 1248 fakePlatformTrust=False, 1249 useDefaultTrust=False): 1250 """ 1251 Connect a server and a client. 1252 1253 @param clientHostname: The I{client's idea} of the server's hostname; 1254 passed as the C{hostname} to the 1255 L{sslverify.OpenSSLCertificateOptions} instance. 1256 @type clientHostname: L{unicode} 1257 1258 @param serverHostname: The I{server's own idea} of the server's 1259 hostname; present in the certificate presented by the server. 1260 @type serverHostname: L{unicode} 1261 1262 @param serverContextSetup: a 1-argument callable invoked with the 1263 L{OpenSSL.SSL.Context} after it's produced. 1264 @type serverContextSetup: L{callable} taking L{OpenSSL.SSL.Context} 1265 returning L{NoneType}. 1266 1267 @param validCertificate: Is the server's certificate valid? L{True} if 1268 so, L{False} otherwise. 1269 @type validCertificate: L{bool} 1270 1271 @param clientPresentsCertificate: Should the client present a 1272 certificate to the server? Defaults to 'no'. 1273 @type clientPresentsCertificate: L{bool} 1274 1275 @param validClientCertificate: If the client presents a certificate, 1276 should it actually be a valid one, i.e. signed by the same CA that 1277 the server is checking? Defaults to 'yes'. 1278 @type validClientCertificate: L{bool} 1279 1280 @param serverVerifies: Should the server verify the client's 1281 certificate? Defaults to 'no'. 1282 @type serverVerifies: L{bool} 1283 1284 @param buggyInfoCallback: Should we patch the implementation so that 1285 the C{info_callback} passed to OpenSSL to have a bug and raise an 1286 exception (L{ZeroDivisionError})? Defaults to 'no'. 1287 @type buggyInfoCallback: L{bool} 1288 1289 @param fakePlatformTrust: Should we fake the platformTrust to be the 1290 same as our fake server certificate authority, so that we can test 1291 it's being used? Defaults to 'no' and we just pass platform trust. 1292 @type fakePlatformTrust: L{bool} 1293 1294 @param useDefaultTrust: Should we avoid passing the C{trustRoot} to 1295 L{ssl.optionsForClientTLS}? Defaults to 'no'. 1296 @type useDefaultTrust: L{bool} 1297 1298 @return: see L{connectedServerAndClient}. 1299 @rtype: see L{connectedServerAndClient}. 1300 """ 1301 serverIDNA = sslverify._idnaBytes(serverHostname) 1302 serverCA, serverCert = certificatesForAuthorityAndServer(serverIDNA) 1303 other = {} 1304 passClientCert = None 1305 clientCA, clientCert = certificatesForAuthorityAndServer(u'client') 1306 if serverVerifies: 1307 other.update(trustRoot=clientCA) 1308 1309 if clientPresentsCertificate: 1310 if validClientCertificate: 1311 passClientCert = clientCert 1312 else: 1313 bogusCA, bogus = certificatesForAuthorityAndServer(u'client') 1314 passClientCert = bogus 1315 1316 serverOpts = sslverify.OpenSSLCertificateOptions( 1317 privateKey=serverCert.privateKey.original, 1318 certificate=serverCert.original, 1319 **other 1320 ) 1321 serverContextSetup(serverOpts.getContext()) 1322 if not validCertificate: 1323 serverCA, otherServer = certificatesForAuthorityAndServer( 1324 serverIDNA 1325 ) 1326 if buggyInfoCallback: 1327 def broken(*a, **k): 1328 """ 1329 Raise an exception. 1330 1331 @param a: Arguments for an C{info_callback} 1332 1333 @param k: Keyword arguments for an C{info_callback} 1334 """ 1335 1 / 0 1336 self.patch( 1337 sslverify.ClientTLSOptions, "_identityVerifyingInfoCallback", 1338 broken, 1339 ) 1340 1341 signature = {'hostname': clientHostname} 1342 if passClientCert: 1343 signature.update(clientCertificate=passClientCert) 1344 if not useDefaultTrust: 1345 signature.update(trustRoot=serverCA) 1346 if fakePlatformTrust: 1347 self.patch(sslverify, "platformTrust", lambda: serverCA) 1348 1349 clientOpts = sslverify.optionsForClientTLS(**signature) 1350 1351 class GreetingServer(protocol.Protocol): 1352 greeting = b"greetings!" 1353 lostReason = None 1354 data = b'' 1355 def connectionMade(self): 1356 self.transport.write(self.greeting) 1357 def dataReceived(self, data): 1358 self.data += data 1359 def connectionLost(self, reason): 1360 self.lostReason = reason 1361 1362 class GreetingClient(protocol.Protocol): 1363 greeting = b'cheerio!' 1364 data = b'' 1365 lostReason = None 1366 def connectionMade(self): 1367 self.transport.write(self.greeting) 1368 def dataReceived(self, data): 1369 self.data += data 1370 def connectionLost(self, reason): 1371 self.lostReason = reason 1372 1373 self.serverOpts = serverOpts 1374 self.clientOpts = clientOpts 1375 1376 clientFactory = TLSMemoryBIOFactory( 1377 clientOpts, isClient=True, 1378 wrappedFactory=protocol.Factory.forProtocol(GreetingClient) 1379 ) 1380 serverFactory = TLSMemoryBIOFactory( 1381 serverOpts, isClient=False, 1382 wrappedFactory=protocol.Factory.forProtocol(GreetingServer) 1383 ) 1384 return connectedServerAndClient( 1385 lambda: serverFactory.buildProtocol(None), 1386 lambda: clientFactory.buildProtocol(None), 1387 ) 1388 1389 1390 def test_invalidHostname(self): 1391 """ 1392 When a certificate containing an invalid hostname is received from the 1393 server, the connection is immediately dropped. 1394 """ 1395 cProto, sProto, pump = self.serviceIdentitySetup( 1396 u"wrong-host.example.com", 1397 u"correct-host.example.com", 1398 ) 1399 self.assertEqual(cProto.wrappedProtocol.data, b'') 1400 self.assertEqual(sProto.wrappedProtocol.data, b'') 1401 1402 cErr = cProto.wrappedProtocol.lostReason.value 1403 sErr = sProto.wrappedProtocol.lostReason.value 1404 1405 self.assertIsInstance(cErr, VerificationError) 1406 self.assertIsInstance(sErr, ConnectionClosed) 1407 1408 1409 def test_validHostname(self): 1410 """ 1411 Whenever a valid certificate containing a valid hostname is received, 1412 connection proceeds normally. 1413 """ 1414 cProto, sProto, pump = self.serviceIdentitySetup( 1415 u"valid.example.com", 1416 u"valid.example.com", 1417 ) 1418 self.assertEqual(cProto.wrappedProtocol.data, 1419 b'greetings!') 1420 1421 cErr = cProto.wrappedProtocol.lostReason 1422 sErr = sProto.wrappedProtocol.lostReason 1423 self.assertIdentical(cErr, None) 1424 self.assertIdentical(sErr, None) 1425 1426 1427 def test_validHostnameInvalidCertificate(self): 1428 """ 1429 When an invalid certificate containing a perfectly valid hostname is 1430 received, the connection is aborted with an OpenSSL error. 1431 """ 1432 cProto, sProto, pump = self.serviceIdentitySetup( 1433 u"valid.example.com", 1434 u"valid.example.com", 1435 validCertificate=False, 1436 ) 1437 1438 self.assertEqual(cProto.wrappedProtocol.data, b'') 1439 self.assertEqual(sProto.wrappedProtocol.data, b'') 1440 1441 cErr = cProto.wrappedProtocol.lostReason.value 1442 sErr = sProto.wrappedProtocol.lostReason.value 1443 1444 self.assertIsInstance(cErr, SSL.Error) 1445 self.assertIsInstance(sErr, SSL.Error) 1446 1447 1448 def test_realCAsBetterNotSignOurBogusTestCerts(self): 1449 """ 1450 If we use the default trust from the platform, our dinky certificate 1451 should I{really} fail. 1452 """ 1453 cProto, sProto, pump = self.serviceIdentitySetup( 1454 u"valid.example.com", 1455 u"valid.example.com", 1456 validCertificate=False, 1457 useDefaultTrust=True, 1458 ) 1459 1460 self.assertEqual(cProto.wrappedProtocol.data, b'') 1461 self.assertEqual(sProto.wrappedProtocol.data, b'') 1462 1463 cErr = cProto.wrappedProtocol.lostReason.value 1464 sErr = sProto.wrappedProtocol.lostReason.value 1465 1466 self.assertIsInstance(cErr, SSL.Error) 1467 self.assertIsInstance(sErr, SSL.Error) 1468 1469 1470 def test_butIfTheyDidItWouldWork(self): 1471 """ 1472 L{ssl.optionsForClientTLS} should be using L{ssl.platformTrust} by 1473 default, so if we fake that out then it should trust ourselves again. 1474 """ 1475 cProto, sProto, pump = self.serviceIdentitySetup( 1476 u"valid.example.com", 1477 u"valid.example.com", 1478 useDefaultTrust=True, 1479 fakePlatformTrust=True, 1480 ) 1481 self.assertEqual(cProto.wrappedProtocol.data, 1482 b'greetings!') 1483 1484 cErr = cProto.wrappedProtocol.lostReason 1485 sErr = sProto.wrappedProtocol.lostReason 1486 self.assertIdentical(cErr, None) 1487 self.assertIdentical(sErr, None) 1488 1489 1490 def test_clientPresentsCertificate(self): 1491 """ 1492 When the server verifies and the client presents a valid certificate 1493 for that verification by passing it to 1494 L{sslverify.optionsForClientTLS}, communication proceeds. 1495 """ 1496 cProto, sProto, pump = self.serviceIdentitySetup( 1497 u"valid.example.com", 1498 u"valid.example.com", 1499 validCertificate=True, 1500 serverVerifies=True, 1501 clientPresentsCertificate=True, 1502 ) 1503 1504 self.assertEqual(cProto.wrappedProtocol.data, 1505 b'greetings!') 1506 1507 cErr = cProto.wrappedProtocol.lostReason 1508 sErr = sProto.wrappedProtocol.lostReason 1509 self.assertIdentical(cErr, None) 1510 self.assertIdentical(sErr, None) 1511 1512 1513 def test_clientPresentsBadCertificate(self): 1514 """ 1515 When the server verifies and the client presents an invalid certificate 1516 for that verification by passing it to 1517 L{sslverify.optionsForClientTLS}, the connection cannot be established 1518 with an SSL error. 1519 """ 1520 cProto, sProto, pump = self.serviceIdentitySetup( 1521 u"valid.example.com", 1522 u"valid.example.com", 1523 validCertificate=True, 1524 serverVerifies=True, 1525 validClientCertificate=False, 1526 clientPresentsCertificate=True, 1527 ) 1528 1529 self.assertEqual(cProto.wrappedProtocol.data, 1530 b'') 1531 1532 cErr = cProto.wrappedProtocol.lostReason.value 1533 sErr = sProto.wrappedProtocol.lostReason.value 1534 1535 self.assertIsInstance(cErr, SSL.Error) 1536 self.assertIsInstance(sErr, SSL.Error) 1537 1538 1539 def test_hostnameIsIndicated(self): 1540 """ 1541 Specifying the C{hostname} argument to L{CertificateOptions} also sets 1542 the U{Server Name Extension 1543 <https://en.wikipedia.org/wiki/Server_Name_Indication>} TLS indication 1544 field to the correct value. 1545 """ 1546 names = [] 1547 def setupServerContext(ctx): 1548 def servername_received(conn): 1549 names.append(conn.get_servername().decode("ascii")) 1550 ctx.set_tlsext_servername_callback(servername_received) 1551 cProto, sProto, pump = self.serviceIdentitySetup( 1552 u"valid.example.com", 1553 u"valid.example.com", 1554 setupServerContext 1555 ) 1556 self.assertEqual(names, [u"valid.example.com"]) 1557 1558 test_hostnameIsIndicated.skip = skipSNI 1559 1560 1561 def test_hostnameEncoding(self): 1562 """ 1563 Hostnames are encoded as IDNA. 1564 """ 1565 names = [] 1566 hello = u"h\N{LATIN SMALL LETTER A WITH ACUTE}llo.example.com" 1567 def setupServerContext(ctx): 1568 def servername_received(conn): 1569 serverIDNA = sslverify._idnaText(conn.get_servername()) 1570 names.append(serverIDNA) 1571 ctx.set_tlsext_servername_callback(servername_received) 1572 cProto, sProto, pump = self.serviceIdentitySetup( 1573 hello, hello, setupServerContext 1574 ) 1575 self.assertEqual(names, [hello]) 1576 self.assertEqual(cProto.wrappedProtocol.data, 1577 b'greetings!') 1578 1579 cErr = cProto.wrappedProtocol.lostReason 1580 sErr = sProto.wrappedProtocol.lostReason 1581 self.assertIdentical(cErr, None) 1582 self.assertIdentical(sErr, None) 1583 1584 test_hostnameEncoding.skip = skipSNI 1585 1586 1587 def test_fallback(self): 1588 """ 1589 L{sslverify.simpleVerifyHostname} checks string equality on the 1590 commonName of a connection's certificate's subject, doing nothing if it 1591 matches and raising L{VerificationError} if it doesn't. 1592 """ 1593 name = 'something.example.com' 1594 class Connection(object): 1595 def get_peer_certificate(self): 1596 """ 1597 Fake of L{OpenSSL.SSL.Connection.get_peer_certificate}. 1598 1599 @return: A certificate with a known common name. 1600 @rtype: L{OpenSSL.crypto.X509} 1601 """ 1602 cert = X509() 1603 cert.get_subject().commonName = name 1604 return cert 1605 conn = Connection() 1606 self.assertIdentical( 1607 sslverify.simpleVerifyHostname(conn, u'something.example.com'), 1608 None 1609 ) 1610 self.assertRaises( 1611 sslverify.SimpleVerificationError, 1612 sslverify.simpleVerifyHostname, conn, u'nonsense' 1613 ) 1614 1615 def test_surpriseFromInfoCallback(self): 1616 """ 1617 pyOpenSSL isn't always so great about reporting errors. If one occurs 1618 in the verification info callback, it should be logged and the 1619 connection should be shut down (if possible, anyway; the app_data could 1620 be clobbered but there's no point testing for that). 1621 """ 1622 cProto, sProto, pump = self.serviceIdentitySetup( 1623 u"correct-host.example.com", 1624 u"correct-host.example.com", 1625 buggyInfoCallback=True, 1626 ) 1627 self.assertEqual(cProto.wrappedProtocol.data, b'') 1628 self.assertEqual(sProto.wrappedProtocol.data, b'') 1629 1630 cErr = cProto.wrappedProtocol.lostReason.value 1631 sErr = sProto.wrappedProtocol.lostReason.value 1632 1633 self.assertIsInstance(cErr, ZeroDivisionError) 1634 self.assertIsInstance(sErr, ConnectionClosed) 1635 errors = self.flushLoggedErrors(ZeroDivisionError) 1636 self.assertTrue(errors) 1637 1638 1639 1640class _NotSSLTransport: 1641 def getHandle(self): 1642 return self 1643 1644 1645 1646class _MaybeSSLTransport: 1647 def getHandle(self): 1648 return self 1649 1650 def get_peer_certificate(self): 1651 return None 1652 1653 def get_host_certificate(self): 1654 return None 1655 1656 1657 1658class _ActualSSLTransport: 1659 def getHandle(self): 1660 return self 1661 1662 def get_host_certificate(self): 1663 return sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM).original 1664 1665 def get_peer_certificate(self): 1666 return sslverify.Certificate.loadPEM(A_PEER_CERTIFICATE_PEM).original 1667 1668 1669 1670class Constructors(unittest.TestCase): 1671 if skipSSL: 1672 skip = skipSSL 1673 1674 def test_peerFromNonSSLTransport(self): 1675 """ 1676 Verify that peerFromTransport raises an exception if the transport 1677 passed is not actually an SSL transport. 1678 """ 1679 x = self.assertRaises(CertificateError, 1680 sslverify.Certificate.peerFromTransport, 1681 _NotSSLTransport()) 1682 self.failUnless(str(x).startswith("non-TLS")) 1683 1684 1685 def test_peerFromBlankSSLTransport(self): 1686 """ 1687 Verify that peerFromTransport raises an exception if the transport 1688 passed is an SSL transport, but doesn't have a peer certificate. 1689 """ 1690 x = self.assertRaises(CertificateError, 1691 sslverify.Certificate.peerFromTransport, 1692 _MaybeSSLTransport()) 1693 self.failUnless(str(x).startswith("TLS")) 1694 1695 1696 def test_hostFromNonSSLTransport(self): 1697 """ 1698 Verify that hostFromTransport raises an exception if the transport 1699 passed is not actually an SSL transport. 1700 """ 1701 x = self.assertRaises(CertificateError, 1702 sslverify.Certificate.hostFromTransport, 1703 _NotSSLTransport()) 1704 self.failUnless(str(x).startswith("non-TLS")) 1705 1706 1707 def test_hostFromBlankSSLTransport(self): 1708 """ 1709 Verify that hostFromTransport raises an exception if the transport 1710 passed is an SSL transport, but doesn't have a host certificate. 1711 """ 1712 x = self.assertRaises(CertificateError, 1713 sslverify.Certificate.hostFromTransport, 1714 _MaybeSSLTransport()) 1715 self.failUnless(str(x).startswith("TLS")) 1716 1717 1718 def test_hostFromSSLTransport(self): 1719 """ 1720 Verify that hostFromTransport successfully creates the correct 1721 certificate if passed a valid SSL transport. 1722 """ 1723 self.assertEqual( 1724 sslverify.Certificate.hostFromTransport( 1725 _ActualSSLTransport()).serialNumber(), 1726 12345) 1727 1728 1729 def test_peerFromSSLTransport(self): 1730 """ 1731 Verify that peerFromTransport successfully creates the correct 1732 certificate if passed a valid SSL transport. 1733 """ 1734 self.assertEqual( 1735 sslverify.Certificate.peerFromTransport( 1736 _ActualSSLTransport()).serialNumber(), 1737 12346) 1738 1739 1740 1741class TestOpenSSLCipher(unittest.TestCase): 1742 """ 1743 Tests for twisted.internet._sslverify.OpenSSLCipher. 1744 """ 1745 if skipSSL: 1746 skip = skipSSL 1747 1748 cipherName = u'CIPHER-STRING' 1749 1750 def test_constructorSetsFullName(self): 1751 """ 1752 The first argument passed to the constructor becomes the full name. 1753 """ 1754 self.assertEqual( 1755 self.cipherName, 1756 sslverify.OpenSSLCipher(self.cipherName).fullName 1757 ) 1758 1759 1760 def test_repr(self): 1761 """ 1762 C{repr(cipher)} returns a valid constructor call. 1763 """ 1764 cipher = sslverify.OpenSSLCipher(self.cipherName) 1765 self.assertEqual( 1766 cipher, 1767 eval(repr(cipher), {'OpenSSLCipher': sslverify.OpenSSLCipher}) 1768 ) 1769 1770 1771 def test_eqSameClass(self): 1772 """ 1773 Equal type and C{fullName} means that the objects are equal. 1774 """ 1775 cipher1 = sslverify.OpenSSLCipher(self.cipherName) 1776 cipher2 = sslverify.OpenSSLCipher(self.cipherName) 1777 self.assertEqual(cipher1, cipher2) 1778 1779 1780 def test_eqSameNameDifferentType(self): 1781 """ 1782 If ciphers have the same name but different types, they're still 1783 different. 1784 """ 1785 class DifferentCipher(object): 1786 fullName = self.cipherName 1787 1788 self.assertNotEqual( 1789 sslverify.OpenSSLCipher(self.cipherName), 1790 DifferentCipher(), 1791 ) 1792 1793 1794 1795class TestExpandCipherString(unittest.TestCase): 1796 """ 1797 Tests for twisted.internet._sslverify._expandCipherString. 1798 """ 1799 if skipSSL: 1800 skip = skipSSL 1801 1802 def test_doesNotStumbleOverEmptyList(self): 1803 """ 1804 If the expanded cipher list is empty, an empty L{list} is returned. 1805 """ 1806 self.assertEqual( 1807 [], 1808 sslverify._expandCipherString(u'', SSL.SSLv23_METHOD, 0) 1809 ) 1810 1811 1812 def test_doesNotSwallowOtherSSLErrors(self): 1813 """ 1814 Only no cipher matches get swallowed, every other SSL error gets 1815 propagated. 1816 """ 1817 def raiser(_): 1818 # Unfortunately, there seems to be no way to trigger a real SSL 1819 # error artificially. 1820 raise SSL.Error([['', '', '']]) 1821 ctx = FakeContext(SSL.SSLv23_METHOD) 1822 ctx.set_cipher_list = raiser 1823 self.patch(sslverify.SSL, 'Context', lambda _: ctx) 1824 self.assertRaises( 1825 SSL.Error, 1826 sslverify._expandCipherString, u'ALL', SSL.SSLv23_METHOD, 0 1827 ) 1828 1829 1830 def test_returnsListOfICiphers(self): 1831 """ 1832 L{sslverify._expandCipherString} always returns a L{list} of 1833 L{interfaces.ICipher}. 1834 """ 1835 ciphers = sslverify._expandCipherString(u'ALL', SSL.SSLv23_METHOD, 0) 1836 self.assertIsInstance(ciphers, list) 1837 bogus = [] 1838 for c in ciphers: 1839 if not interfaces.ICipher.providedBy(c): 1840 bogus.append(c) 1841 1842 self.assertEqual([], bogus) 1843 1844 1845 1846class TestAcceptableCiphers(unittest.TestCase): 1847 """ 1848 Tests for twisted.internet._sslverify.OpenSSLAcceptableCiphers. 1849 """ 1850 if skipSSL: 1851 skip = skipSSL 1852 1853 def test_selectOnEmptyListReturnsEmptyList(self): 1854 """ 1855 If no ciphers are available, nothing can be selected. 1856 """ 1857 ac = sslverify.OpenSSLAcceptableCiphers([]) 1858 self.assertEqual([], ac.selectCiphers([])) 1859 1860 1861 def test_selectReturnsOnlyFromAvailable(self): 1862 """ 1863 Select only returns a cross section of what is available and what is 1864 desirable. 1865 """ 1866 ac = sslverify.OpenSSLAcceptableCiphers([ 1867 sslverify.OpenSSLCipher('A'), 1868 sslverify.OpenSSLCipher('B'), 1869 ]) 1870 self.assertEqual([sslverify.OpenSSLCipher('B')], 1871 ac.selectCiphers([sslverify.OpenSSLCipher('B'), 1872 sslverify.OpenSSLCipher('C')])) 1873 1874 1875 def test_fromOpenSSLCipherStringExpandsToListOfCiphers(self): 1876 """ 1877 If L{sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString} is 1878 called it expands the string to a list of ciphers. 1879 """ 1880 ac = sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString('ALL') 1881 self.assertIsInstance(ac._ciphers, list) 1882 self.assertTrue(all(sslverify.ICipher.providedBy(c) 1883 for c in ac._ciphers)) 1884 1885 1886 1887class TestDiffieHellmanParameters(unittest.TestCase): 1888 """ 1889 Tests for twisted.internet._sslverify.OpenSSLDHParameters. 1890 """ 1891 if skipSSL: 1892 skip = skipSSL 1893 filePath = FilePath(b'dh.params') 1894 1895 def test_fromFile(self): 1896 """ 1897 Calling C{fromFile} with a filename returns an instance with that file 1898 name saved. 1899 """ 1900 params = sslverify.OpenSSLDiffieHellmanParameters.fromFile( 1901 self.filePath 1902 ) 1903 self.assertEqual(self.filePath, params._dhFile) 1904 1905 1906 1907class FakeECKey(object): 1908 """ 1909 An introspectable fake of a key. 1910 1911 @ivar _nid: A free form nid. 1912 """ 1913 def __init__(self, nid): 1914 self._nid = nid 1915 1916 1917 1918class FakeNID(object): 1919 """ 1920 An introspectable fake of a NID. 1921 1922 @ivar _snName: A free form sn name. 1923 """ 1924 def __init__(self, snName): 1925 self._snName = snName 1926 1927 1928 1929class FakeLib(object): 1930 """ 1931 An introspectable fake of cryptography's lib object. 1932 1933 @ivar _createdKey: A set of keys that have been created by this instance. 1934 @type _createdKey: L{set} of L{FakeKey} 1935 1936 @cvar NID_undef: A symbolic constant for undefined NIDs. 1937 @type NID_undef: L{FakeNID} 1938 """ 1939 NID_undef = FakeNID("undef") 1940 1941 def __init__(self): 1942 self._createdKeys = set() 1943 1944 1945 def OBJ_sn2nid(self, snName): 1946 """ 1947 Create a L{FakeNID} with C{snName} and return it. 1948 1949 @param snName: a free form name that gets passed to the constructor 1950 of L{FakeNID}. 1951 1952 @return: a new L{FakeNID}. 1953 @rtype: L{FakeNID}. 1954 """ 1955 return FakeNID(snName) 1956 1957 1958 def EC_KEY_new_by_curve_name(self, nid): 1959 """ 1960 Create a L{FakeECKey}, save it to C{_createdKeys} and return it. 1961 1962 @param nid: an arbitrary object that is passed to the constructor of 1963 L{FakeECKey}. 1964 1965 @return: a new L{FakeECKey} 1966 @rtype: L{FakeECKey} 1967 """ 1968 key = FakeECKey(nid) 1969 self._createdKeys.add(key) 1970 return key 1971 1972 1973 def EC_KEY_free(self, key): 1974 """ 1975 Remove C{key} from C{_createdKey}. 1976 1977 @param key: a key object to be freed; i.e. removed from 1978 C{_createdKeys}. 1979 1980 @raises ValueError: If C{key} is not in C{_createdKeys} and thus not 1981 created by us. 1982 """ 1983 try: 1984 self._createdKeys.remove(key) 1985 except KeyError: 1986 raise ValueError("Unallocated EC key attempted to free.") 1987 1988 1989 def SSL_CTX_set_tmp_ecdh(self, ffiContext, key): 1990 """ 1991 Does not do anything. 1992 1993 @param ffiContext: ignored 1994 @param key: ignored 1995 """ 1996 1997 1998 1999class TestFakeLib(unittest.TestCase): 2000 """ 2001 Tests for FakeLib 2002 """ 2003 def test_objSn2Nid(self): 2004 """ 2005 Returns a L{FakeNID} with correct name. 2006 """ 2007 nid = FakeNID("test") 2008 self.assertEqual("test", nid._snName) 2009 2010 2011 def test_emptyKeys(self): 2012 """ 2013 A new L{FakeLib} has an empty set for created keys. 2014 """ 2015 self.assertEqual(set(), FakeLib()._createdKeys) 2016 2017 2018 def test_newKey(self): 2019 """ 2020 If a new key is created, it's added to C{_createdKeys}. 2021 """ 2022 lib = FakeLib() 2023 key = lib.EC_KEY_new_by_curve_name(FakeNID("name")) 2024 self.assertEqual(set([key]), lib._createdKeys) 2025 2026 2027 def test_freeUnknownKey(self): 2028 """ 2029 Raise L{ValueError} if an unknown key is attempted to be freed. 2030 """ 2031 key = FakeECKey(object()) 2032 self.assertRaises( 2033 ValueError, 2034 FakeLib().EC_KEY_free, key 2035 ) 2036 2037 2038 def test_freeKnownKey(self): 2039 """ 2040 Freeing an allocated key removes it from C{_createdKeys}. 2041 """ 2042 lib = FakeLib() 2043 key = lib.EC_KEY_new_by_curve_name(FakeNID("name")) 2044 lib.EC_KEY_free(key) 2045 self.assertEqual(set(), lib._createdKeys) 2046 2047 2048 2049class FakeFFI(object): 2050 """ 2051 A fake of a cryptography's ffi object. 2052 2053 @cvar NULL: Symbolic constant for CFFI's NULL objects. 2054 """ 2055 NULL = object() 2056 2057 2058 2059class FakeBinding(object): 2060 """ 2061 A fake of cryptography's binding object. 2062 2063 @type lib: L{FakeLib} 2064 @type ffi: L{FakeFFI} 2065 """ 2066 def __init__(self, lib=None, ffi=None): 2067 self.lib = lib or FakeLib() 2068 self.ffi = ffi or FakeFFI() 2069 2070 2071 2072class TestECCurve(unittest.TestCase): 2073 """ 2074 Tests for twisted.internet._sslverify.OpenSSLECCurve. 2075 """ 2076 if skipSSL: 2077 skip = skipSSL 2078 2079 def test_missingBinding(self): 2080 """ 2081 Raise L{NotImplementedError} if pyOpenSSL is not based on cryptography. 2082 """ 2083 def raiser(self): 2084 raise NotImplementedError 2085 self.patch(sslverify._OpenSSLECCurve, "_getBinding", raiser) 2086 self.assertRaises( 2087 NotImplementedError, 2088 sslverify._OpenSSLECCurve, sslverify._defaultCurveName, 2089 ) 2090 2091 2092 def test_nonECbinding(self): 2093 """ 2094 Raise L{NotImplementedError} if pyOpenSSL is based on cryptography but 2095 cryptography lacks required EC methods. 2096 """ 2097 def raiser(self): 2098 raise AttributeError 2099 lib = FakeLib() 2100 lib.OBJ_sn2nid = raiser 2101 self.patch(sslverify._OpenSSLECCurve, 2102 "_getBinding", 2103 lambda self: FakeBinding(lib=lib)) 2104 self.assertRaises( 2105 NotImplementedError, 2106 sslverify._OpenSSLECCurve, sslverify._defaultCurveName, 2107 ) 2108 2109 2110 def test_wrongName(self): 2111 """ 2112 Raise L{ValueError} on unknown sn names. 2113 """ 2114 lib = FakeLib() 2115 lib.OBJ_sn2nid = lambda self: FakeLib.NID_undef 2116 self.patch(sslverify._OpenSSLECCurve, 2117 "_getBinding", 2118 lambda self: FakeBinding(lib=lib)) 2119 self.assertRaises( 2120 ValueError, 2121 sslverify._OpenSSLECCurve, u"doesNotExist", 2122 ) 2123 2124 2125 def test_keyFails(self): 2126 """ 2127 Raise L{EnvironmentError} if key creation fails. 2128 """ 2129 lib = FakeLib() 2130 lib.EC_KEY_new_by_curve_name = lambda *a, **kw: FakeFFI.NULL 2131 self.patch(sslverify._OpenSSLECCurve, 2132 "_getBinding", 2133 lambda self: FakeBinding(lib=lib)) 2134 curve = sslverify._OpenSSLECCurve(sslverify._defaultCurveName) 2135 self.assertRaises( 2136 EnvironmentError, 2137 curve.addECKeyToContext, object() 2138 ) 2139 2140 2141 def test_keyGetsFreed(self): 2142 """ 2143 Don't leak a key when adding it to a context. 2144 """ 2145 lib = FakeLib() 2146 self.patch(sslverify._OpenSSLECCurve, 2147 "_getBinding", 2148 lambda self: FakeBinding(lib=lib)) 2149 curve = sslverify._OpenSSLECCurve(sslverify._defaultCurveName) 2150 ctx = FakeContext(None) 2151 ctx._context = None 2152 curve.addECKeyToContext(ctx) 2153 self.assertEqual(set(), lib._createdKeys) 2154