1# Copyright 2005 Divmod, Inc.  See LICENSE file for details
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6Tests for L{twisted.internet._sslverify}.
7"""
8
9from __future__ import division, absolute_import
10
11import itertools
12
13from zope.interface import implementer
14
15try:
16    from OpenSSL import SSL
17    from OpenSSL.crypto import PKey, X509
18    from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
19    skipSSL = False
20    if hasattr(SSL.Context, "set_tlsext_servername_callback"):
21        skipSNI = False
22    else:
23        skipSNI = "PyOpenSSL 0.13 or greater required for SNI support."
24except ImportError:
25    skipSSL = "OpenSSL is required for SSL tests."
26    skipSNI = skipSSL
27
28from twisted.test.iosim import connectedServerAndClient
29
30from twisted.internet.error import ConnectionClosed
31from twisted.python.compat import nativeString
32from twisted.python.constants import NamedConstant, Names
33from twisted.python.filepath import FilePath
34
35from twisted.trial import unittest
36from twisted.internet import protocol, defer, reactor
37
38from twisted.internet.error import CertificateError, ConnectionLost
39from twisted.internet import interfaces
40
41if not skipSSL:
42    from twisted.internet.ssl import platformTrust, VerificationError
43    from twisted.internet import _sslverify as sslverify
44    from twisted.protocols.tls import TLSMemoryBIOFactory
45
46# A couple of static PEM-format certificates to be used by various tests.
47A_HOST_CERTIFICATE_PEM = """
48-----BEGIN CERTIFICATE-----
49        MIIC2jCCAkMCAjA5MA0GCSqGSIb3DQEBBAUAMIG0MQswCQYDVQQGEwJVUzEiMCAG
50        A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u
51        MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo
52        dXNldHRzMScwJQYJKoZIhvcNAQkBFhhub2JvZHlAdHdpc3RlZG1hdHJpeC5jb20x
53        ETAPBgNVBAsTCFNlY3VyaXR5MB4XDTA2MDgxNjAxMDEwOFoXDTA3MDgxNjAxMDEw
54        OFowgbQxCzAJBgNVBAYTAlVTMSIwIAYDVQQDExlleGFtcGxlLnR3aXN0ZWRtYXRy
55        aXguY29tMQ8wDQYDVQQHEwZCb3N0b24xHDAaBgNVBAoTE1R3aXN0ZWQgTWF0cml4
56        IExhYnMxFjAUBgNVBAgTDU1hc3NhY2h1c2V0dHMxJzAlBgkqhkiG9w0BCQEWGG5v
57        Ym9keUB0d2lzdGVkbWF0cml4LmNvbTERMA8GA1UECxMIU2VjdXJpdHkwgZ8wDQYJ
58        KoZIhvcNAQEBBQADgY0AMIGJAoGBAMzH8CDF/U91y/bdbdbJKnLgnyvQ9Ig9ZNZp
59        8hpsu4huil60zF03+Lexg2l1FIfURScjBuaJMR6HiMYTMjhzLuByRZ17KW4wYkGi
60        KXstz03VIKy4Tjc+v4aXFI4XdRw10gGMGQlGGscXF/RSoN84VoDKBfOMWdXeConJ
61        VyC4w3iJAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAviMT4lBoxOgQy32LIgZ4lVCj
62        JNOiZYg8GMQ6y0ugp86X80UjOvkGtNf/R7YgED/giKRN/q/XJiLJDEhzknkocwmO
63        S+4b2XpiaZYxRyKWwL221O7CGmtWYyZl2+92YYmmCiNzWQPfP6BOMlfax0AGLHls
64        fXzCWdG0O/3Lk2SRM0I=
65-----END CERTIFICATE-----
66"""
67
68A_PEER_CERTIFICATE_PEM = """
69-----BEGIN CERTIFICATE-----
70        MIIC3jCCAkcCAjA6MA0GCSqGSIb3DQEBBAUAMIG2MQswCQYDVQQGEwJVUzEiMCAG
71        A1UEAxMZZXhhbXBsZS50d2lzdGVkbWF0cml4LmNvbTEPMA0GA1UEBxMGQm9zdG9u
72        MRwwGgYDVQQKExNUd2lzdGVkIE1hdHJpeCBMYWJzMRYwFAYDVQQIEw1NYXNzYWNo
73        dXNldHRzMSkwJwYJKoZIhvcNAQkBFhpzb21lYm9keUB0d2lzdGVkbWF0cml4LmNv
74        bTERMA8GA1UECxMIU2VjdXJpdHkwHhcNMDYwODE2MDEwMTU2WhcNMDcwODE2MDEw
75        MTU2WjCBtjELMAkGA1UEBhMCVVMxIjAgBgNVBAMTGWV4YW1wbGUudHdpc3RlZG1h
76        dHJpeC5jb20xDzANBgNVBAcTBkJvc3RvbjEcMBoGA1UEChMTVHdpc3RlZCBNYXRy
77        aXggTGFiczEWMBQGA1UECBMNTWFzc2FjaHVzZXR0czEpMCcGCSqGSIb3DQEJARYa
78        c29tZWJvZHlAdHdpc3RlZG1hdHJpeC5jb20xETAPBgNVBAsTCFNlY3VyaXR5MIGf
79        MA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCnm+WBlgFNbMlHehib9ePGGDXF+Nz4
80        CjGuUmVBaXCRCiVjg3kSDecwqfb0fqTksBZ+oQ1UBjMcSh7OcvFXJZnUesBikGWE
81        JE4V8Bjh+RmbJ1ZAlUPZ40bAkww0OpyIRAGMvKG+4yLFTO4WDxKmfDcrOb6ID8WJ
82        e1u+i3XGkIf/5QIDAQABMA0GCSqGSIb3DQEBBAUAA4GBAD4Oukm3YYkhedUepBEA
83        vvXIQhVDqL7mk6OqYdXmNj6R7ZMC8WWvGZxrzDI1bZuB+4aIxxd1FXC3UOHiR/xg
84        i9cDl1y8P/qRp4aEBNF6rI0D4AxTbfnHQx4ERDAOShJdYZs/2zifPJ6va6YvrEyr
85        yqDtGhklsWW3ZwBzEh5VEOUp
86-----END CERTIFICATE-----
87"""
88
89
90
91def counter(counter=itertools.count()):
92    """
93    Each time we're called, return the next integer in the natural numbers.
94    """
95    return next(counter)
96
97
98
99def makeCertificate(**kw):
100    keypair = PKey()
101    keypair.generate_key(TYPE_RSA, 768)
102
103    certificate = X509()
104    certificate.gmtime_adj_notBefore(0)
105    certificate.gmtime_adj_notAfter(60 * 60 * 24 * 365) # One year
106    for xname in certificate.get_issuer(), certificate.get_subject():
107        for (k, v) in kw.items():
108            setattr(xname, k, nativeString(v))
109
110    certificate.set_serial_number(counter())
111    certificate.set_pubkey(keypair)
112    certificate.sign(keypair, "md5")
113
114    return keypair, certificate
115
116
117
118def certificatesForAuthorityAndServer(commonName=b'example.com'):
119    """
120    Create a self-signed CA certificate and server certificate signed by the
121    CA.
122
123    @param commonName: The C{commonName} to embed in the certificate.
124    @type commonName: L{bytes}
125
126    @return: a 2-tuple of C{(certificate_authority_certificate,
127        server_certificate)}
128    @rtype: L{tuple} of (L{sslverify.Certificate},
129        L{sslverify.PrivateCertificate})
130    """
131    serverDN = sslverify.DistinguishedName(commonName=commonName)
132    serverKey = sslverify.KeyPair.generate()
133    serverCertReq = serverKey.certificateRequest(serverDN)
134
135    caDN = sslverify.DistinguishedName(commonName=b'CA')
136    caKey= sslverify.KeyPair.generate()
137    caCertReq = caKey.certificateRequest(caDN)
138    caSelfCertData = caKey.signCertificateRequest(
139            caDN, caCertReq, lambda dn: True, 516)
140    caSelfCert = caKey.newCertificate(caSelfCertData)
141
142    serverCertData = caKey.signCertificateRequest(
143            caDN, serverCertReq, lambda dn: True, 516)
144    serverCert = serverKey.newCertificate(serverCertData)
145    return caSelfCert, serverCert
146
147
148
149def loopbackTLSConnection(trustRoot, privateKeyFile, chainedCertFile=None):
150    """
151    Create a loopback TLS connection with the given trust and keys.
152
153    @param trustRoot: the C{trustRoot} argument for the client connection's
154        context.
155    @type trustRoot: L{sslverify.IOpenSSLTrustRoot}
156
157    @param privateKeyFile: The name of the file containing the private key.
158    @type privateKeyFile: L{str} (native string; file name)
159
160    @param chainedCertFile: The name of the chained certificate file.
161    @type chainedCertFile: L{str} (native string; file name)
162
163    @return: 3-tuple of server-protocol, client-protocol, and L{IOPump}
164    @rtype: L{tuple}
165    """
166    class ContextFactory(object):
167        def getContext(self):
168            """
169            Create a context for the server side of the connection.
170
171            @return: an SSL context using a certificate and key.
172            @rtype: C{OpenSSL.SSL.Context}
173            """
174            ctx = SSL.Context(SSL.TLSv1_METHOD)
175            if chainedCertFile is not None:
176                ctx.use_certificate_chain_file(chainedCertFile)
177            ctx.use_privatekey_file(privateKeyFile)
178            # Let the test author know if they screwed something up.
179            ctx.check_privatekey()
180            return ctx
181
182    class GreetingServer(protocol.Protocol):
183        greeting = b"greetings!"
184        def connectionMade(self):
185            self.transport.write(self.greeting)
186
187    class ListeningClient(protocol.Protocol):
188        data = b''
189        lostReason = None
190        def dataReceived(self, data):
191            self.data += data
192        def connectionLost(self, reason):
193            self.lostReason = reason
194
195    serverOpts = ContextFactory()
196    clientOpts = sslverify.OpenSSLCertificateOptions(trustRoot=trustRoot)
197
198    clientFactory = TLSMemoryBIOFactory(
199        clientOpts, isClient=True,
200        wrappedFactory=protocol.Factory.forProtocol(GreetingServer)
201    )
202    serverFactory = TLSMemoryBIOFactory(
203        serverOpts, isClient=False,
204        wrappedFactory=protocol.Factory.forProtocol(ListeningClient)
205    )
206
207    sProto, cProto, pump = connectedServerAndClient(
208        lambda: serverFactory.buildProtocol(None),
209        lambda: clientFactory.buildProtocol(None)
210    )
211    return sProto, cProto, pump
212
213
214
215def pathContainingDumpOf(testCase, *dumpables):
216    """
217    Create a temporary file to store some serializable-as-PEM objects in, and
218    return its name.
219
220    @param testCase: a test case to use for generating a temporary directory.
221    @type testCase: L{twisted.trial.unittest.TestCase}
222
223    @param dumpables: arguments are objects from pyOpenSSL with a C{dump}
224        method, taking a pyOpenSSL file-type constant, such as
225        L{OpenSSL.crypto.FILETYPE_PEM} or L{OpenSSL.crypto.FILETYPE_ASN1}.
226    @type dumpables: L{tuple} of L{object} with C{dump} method taking L{int}
227        returning L{bytes}
228
229    @return: the path to a file where all of the dumpables were dumped in PEM
230        format.
231    @rtype: L{str}
232    """
233    fname = testCase.mktemp()
234    with open(fname, "wb") as f:
235        for dumpable in dumpables:
236            f.write(dumpable.dump(FILETYPE_PEM))
237    return fname
238
239
240
241class DataCallbackProtocol(protocol.Protocol):
242    def dataReceived(self, data):
243        d, self.factory.onData = self.factory.onData, None
244        if d is not None:
245            d.callback(data)
246
247    def connectionLost(self, reason):
248        d, self.factory.onLost = self.factory.onLost, None
249        if d is not None:
250            d.errback(reason)
251
252class WritingProtocol(protocol.Protocol):
253    byte = b'x'
254    def connectionMade(self):
255        self.transport.write(self.byte)
256
257    def connectionLost(self, reason):
258        self.factory.onLost.errback(reason)
259
260
261
262class FakeContext(object):
263    """
264    Introspectable fake of an C{OpenSSL.SSL.Context}.
265
266    Saves call arguments for later introspection.
267
268    Necessary because C{Context} offers poor introspection.  cf. this
269    U{pyOpenSSL bug<https://bugs.launchpad.net/pyopenssl/+bug/1173899>}.
270
271    @ivar _method: See C{method} parameter of L{__init__}.
272
273    @ivar _options: C{int} of C{OR}ed values from calls of L{set_options}.
274
275    @ivar _certificate: Set by L{use_certificate}.
276
277    @ivar _privateKey: Set by L{use_privatekey}.
278
279    @ivar _verify: Set by L{set_verify}.
280
281    @ivar _verifyDepth: Set by L{set_verify_depth}.
282
283    @ivar _sessionID: Set by L{set_session_id}.
284
285    @ivar _extraCertChain: Accumulated C{list} of all extra certificates added
286        by L{add_extra_chain_cert}.
287
288    @ivar _cipherList: Set by L{set_cipher_list}.
289
290    @ivar _dhFilename: Set by L{load_tmp_dh}.
291
292    @ivar _defaultVerifyPathsSet: Set by L{set_default_verify_paths}
293    """
294    _options = 0
295
296    def __init__(self, method):
297        self._method = method
298        self._extraCertChain = []
299        self._defaultVerifyPathsSet = False
300
301
302    def set_options(self, options):
303        self._options |= options
304
305
306    def use_certificate(self, certificate):
307        self._certificate = certificate
308
309
310    def use_privatekey(self, privateKey):
311        self._privateKey = privateKey
312
313
314    def check_privatekey(self):
315        return None
316
317
318    def set_verify(self, flags, callback):
319        self._verify = flags, callback
320
321
322    def set_verify_depth(self, depth):
323        self._verifyDepth = depth
324
325
326    def set_session_id(self, sessionID):
327        self._sessionID = sessionID
328
329
330    def add_extra_chain_cert(self, cert):
331        self._extraCertChain.append(cert)
332
333
334    def set_cipher_list(self, cipherList):
335        self._cipherList = cipherList
336
337
338    def load_tmp_dh(self, dhfilename):
339        self._dhFilename = dhfilename
340
341
342    def set_default_verify_paths(self):
343        """
344        Set the default paths for the platform.
345        """
346        self._defaultVerifyPathsSet = True
347
348
349
350class ClientOptions(unittest.SynchronousTestCase):
351    """
352    Tests for L{sslverify.optionsForClientTLS}.
353    """
354    if skipSSL:
355        skip = skipSSL
356
357    def test_extraKeywords(self):
358        """
359        When passed a keyword parameter other than C{extraCertificateOptions},
360        L{sslverify.optionsForClientTLS} raises an exception just like a
361        normal Python function would.
362        """
363        error = self.assertRaises(
364            TypeError,
365            sslverify.optionsForClientTLS,
366            hostname=u'alpha', someRandomThing=u'beta',
367        )
368        self.assertEqual(
369            str(error),
370            "optionsForClientTLS() got an unexpected keyword argument "
371            "'someRandomThing'"
372        )
373
374
375    def test_bytesFailFast(self):
376        """
377        If you pass L{bytes} as the hostname to
378        L{sslverify.optionsForClientTLS} it immediately raises a L{TypeError}.
379        """
380        error = self.assertRaises(
381            TypeError,
382            sslverify.optionsForClientTLS, b'not-actually-a-hostname.com'
383        )
384        expectedText = (
385            "optionsForClientTLS requires text for host names, not " +
386            bytes.__name__
387        )
388        self.assertEqual(str(error), expectedText)
389
390
391
392class OpenSSLOptions(unittest.TestCase):
393    if skipSSL:
394        skip = skipSSL
395
396    serverPort = clientConn = None
397    onServerLost = onClientLost = None
398
399    sKey = None
400    sCert = None
401    cKey = None
402    cCert = None
403
404    def setUp(self):
405        """
406        Create class variables of client and server certificates.
407        """
408        self.sKey, self.sCert = makeCertificate(
409            O=b"Server Test Certificate",
410            CN=b"server")
411        self.cKey, self.cCert = makeCertificate(
412            O=b"Client Test Certificate",
413            CN=b"client")
414        self.caCert1 = makeCertificate(
415            O=b"CA Test Certificate 1",
416            CN=b"ca1")[1]
417        self.caCert2 = makeCertificate(
418            O=b"CA Test Certificate",
419            CN=b"ca2")[1]
420        self.caCerts = [self.caCert1, self.caCert2]
421        self.extraCertChain = self.caCerts
422
423
424    def tearDown(self):
425        if self.serverPort is not None:
426            self.serverPort.stopListening()
427        if self.clientConn is not None:
428            self.clientConn.disconnect()
429
430        L = []
431        if self.onServerLost is not None:
432            L.append(self.onServerLost)
433        if self.onClientLost is not None:
434            L.append(self.onClientLost)
435
436        return defer.DeferredList(L, consumeErrors=True)
437
438    def loopback(self, serverCertOpts, clientCertOpts,
439                 onServerLost=None, onClientLost=None, onData=None):
440        if onServerLost is None:
441            self.onServerLost = onServerLost = defer.Deferred()
442        if onClientLost is None:
443            self.onClientLost = onClientLost = defer.Deferred()
444        if onData is None:
445            onData = defer.Deferred()
446
447        serverFactory = protocol.ServerFactory()
448        serverFactory.protocol = DataCallbackProtocol
449        serverFactory.onLost = onServerLost
450        serverFactory.onData = onData
451
452        clientFactory = protocol.ClientFactory()
453        clientFactory.protocol = WritingProtocol
454        clientFactory.onLost = onClientLost
455
456        self.serverPort = reactor.listenSSL(0, serverFactory, serverCertOpts)
457        self.clientConn = reactor.connectSSL('127.0.0.1',
458                self.serverPort.getHost().port, clientFactory, clientCertOpts)
459
460
461    def test_constructorWithOnlyPrivateKey(self):
462        """
463        C{privateKey} and C{certificate} make only sense if both are set.
464        """
465        self.assertRaises(
466            ValueError,
467            sslverify.OpenSSLCertificateOptions, privateKey=self.sKey
468        )
469
470
471    def test_constructorWithOnlyCertificate(self):
472        """
473        C{privateKey} and C{certificate} make only sense if both are set.
474        """
475        self.assertRaises(
476            ValueError,
477            sslverify.OpenSSLCertificateOptions, certificate=self.sCert
478        )
479
480
481    def test_constructorWithCertificateAndPrivateKey(self):
482        """
483        Specifying C{privateKey} and C{certificate} initializes correctly.
484        """
485        opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
486                                                   certificate=self.sCert)
487        self.assertEqual(opts.privateKey, self.sKey)
488        self.assertEqual(opts.certificate, self.sCert)
489        self.assertEqual(opts.extraCertChain, [])
490
491
492    def test_constructorDoesNotAllowVerifyWithoutCACerts(self):
493        """
494        C{verify} must not be C{True} without specifying C{caCerts}.
495        """
496        self.assertRaises(
497            ValueError,
498            sslverify.OpenSSLCertificateOptions,
499            privateKey=self.sKey, certificate=self.sCert, verify=True
500        )
501
502
503    def test_constructorDoesNotAllowLegacyWithTrustRoot(self):
504        """
505        C{verify}, C{requireCertificate}, and C{caCerts} must not be specified
506        by the caller (to be I{any} value, even the default!) when specifying
507        C{trustRoot}.
508        """
509        self.assertRaises(
510            TypeError,
511            sslverify.OpenSSLCertificateOptions,
512            privateKey=self.sKey, certificate=self.sCert,
513            verify=True, trustRoot=None, caCerts=self.caCerts,
514        )
515        self.assertRaises(
516            TypeError,
517            sslverify.OpenSSLCertificateOptions,
518            privateKey=self.sKey, certificate=self.sCert,
519            trustRoot=None, requireCertificate=True,
520        )
521
522
523    def test_constructorAllowsCACertsWithoutVerify(self):
524        """
525        It's currently a NOP, but valid.
526        """
527        opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
528                                                   certificate=self.sCert,
529                                                   caCerts=self.caCerts)
530        self.assertFalse(opts.verify)
531        self.assertEqual(self.caCerts, opts.caCerts)
532
533
534    def test_constructorWithVerifyAndCACerts(self):
535        """
536        Specifying C{verify} and C{caCerts} initializes correctly.
537        """
538        opts = sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
539                                                   certificate=self.sCert,
540                                                   verify=True,
541                                                   caCerts=self.caCerts)
542        self.assertTrue(opts.verify)
543        self.assertEqual(self.caCerts, opts.caCerts)
544
545
546    def test_constructorSetsExtraChain(self):
547        """
548        Setting C{extraCertChain} works if C{certificate} and C{privateKey} are
549        set along with it.
550        """
551        opts = sslverify.OpenSSLCertificateOptions(
552            privateKey=self.sKey,
553            certificate=self.sCert,
554            extraCertChain=self.extraCertChain,
555        )
556        self.assertEqual(self.extraCertChain, opts.extraCertChain)
557
558
559    def test_constructorDoesNotAllowExtraChainWithoutPrivateKey(self):
560        """
561        A C{extraCertChain} without C{privateKey} doesn't make sense and is
562        thus rejected.
563        """
564        self.assertRaises(
565            ValueError,
566            sslverify.OpenSSLCertificateOptions,
567            certificate=self.sCert,
568            extraCertChain=self.extraCertChain,
569        )
570
571
572    def test_constructorDoesNotAllowExtraChainWithOutPrivateKey(self):
573        """
574        A C{extraCertChain} without C{certificate} doesn't make sense and is
575        thus rejected.
576        """
577        self.assertRaises(
578            ValueError,
579            sslverify.OpenSSLCertificateOptions,
580            privateKey=self.sKey,
581            extraCertChain=self.extraCertChain,
582        )
583
584
585    def test_extraChainFilesAreAddedIfSupplied(self):
586        """
587        If C{extraCertChain} is set and all prerequisites are met, the
588        specified chain certificates are added to C{Context}s that get
589        created.
590        """
591        opts = sslverify.OpenSSLCertificateOptions(
592            privateKey=self.sKey,
593            certificate=self.sCert,
594            extraCertChain=self.extraCertChain,
595        )
596        opts._contextFactory = FakeContext
597        ctx = opts.getContext()
598        self.assertEqual(self.sKey, ctx._privateKey)
599        self.assertEqual(self.sCert, ctx._certificate)
600        self.assertEqual(self.extraCertChain, ctx._extraCertChain)
601
602
603    def test_extraChainDoesNotBreakPyOpenSSL(self):
604        """
605        C{extraCertChain} doesn't break C{OpenSSL.SSL.Context} creation.
606        """
607        opts = sslverify.OpenSSLCertificateOptions(
608            privateKey=self.sKey,
609            certificate=self.sCert,
610            extraCertChain=self.extraCertChain,
611        )
612        ctx = opts.getContext()
613        self.assertIsInstance(ctx, SSL.Context)
614
615
616    def test_acceptableCiphersAreAlwaysSet(self):
617        """
618        If the user doesn't supply custom acceptable ciphers, a shipped secure
619        default is used.  We can't check directly for it because the effective
620        cipher string we set varies with platforms.
621        """
622        opts = sslverify.OpenSSLCertificateOptions(
623            privateKey=self.sKey,
624            certificate=self.sCert,
625        )
626        opts._contextFactory = FakeContext
627        ctx = opts.getContext()
628        self.assertEqual(opts._cipherString, ctx._cipherList)
629
630
631    def test_givesMeaningfulErrorMessageIfNoCipherMatches(self):
632        """
633        If there is no valid cipher that matches the user's wishes,
634        a L{ValueError} is raised.
635        """
636        self.assertRaises(
637            ValueError,
638            sslverify.OpenSSLCertificateOptions,
639            privateKey=self.sKey,
640            certificate=self.sCert,
641            acceptableCiphers=
642            sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString('')
643        )
644
645
646    def test_honorsAcceptableCiphersArgument(self):
647        """
648        If acceptable ciphers are passed, they are used.
649        """
650        @implementer(interfaces.IAcceptableCiphers)
651        class FakeAcceptableCiphers(object):
652            def selectCiphers(self, _):
653                return [sslverify.OpenSSLCipher(u'sentinel')]
654
655        opts = sslverify.OpenSSLCertificateOptions(
656            privateKey=self.sKey,
657            certificate=self.sCert,
658            acceptableCiphers=FakeAcceptableCiphers(),
659        )
660        opts._contextFactory = FakeContext
661        ctx = opts.getContext()
662        self.assertEqual(u'sentinel', ctx._cipherList)
663
664
665    def test_basicSecurityOptionsAreSet(self):
666        """
667        Every context must have C{OP_NO_SSLv2}, C{OP_NO_COMPRESSION}, and
668        C{OP_CIPHER_SERVER_PREFERENCE} set.
669        """
670        opts = sslverify.OpenSSLCertificateOptions(
671            privateKey=self.sKey,
672            certificate=self.sCert,
673        )
674        opts._contextFactory = FakeContext
675        ctx = opts.getContext()
676        options = (SSL.OP_NO_SSLv2 | opts._OP_NO_COMPRESSION |
677                   opts._OP_CIPHER_SERVER_PREFERENCE)
678        self.assertEqual(options, ctx._options & options)
679
680
681    def test_singleUseKeys(self):
682        """
683        If C{singleUseKeys} is set, every context must have
684        C{OP_SINGLE_DH_USE} and C{OP_SINGLE_ECDH_USE} set.
685        """
686        opts = sslverify.OpenSSLCertificateOptions(
687            privateKey=self.sKey,
688            certificate=self.sCert,
689            enableSingleUseKeys=True,
690        )
691        opts._contextFactory = FakeContext
692        ctx = opts.getContext()
693        options = SSL.OP_SINGLE_DH_USE | opts._OP_SINGLE_ECDH_USE
694        self.assertEqual(options, ctx._options & options)
695
696
697    def test_dhParams(self):
698        """
699        If C{dhParams} is set, they are loaded into each new context.
700        """
701        class FakeDiffieHellmanParameters(object):
702            _dhFile = FilePath(b'dh.params')
703
704        dhParams = FakeDiffieHellmanParameters()
705        opts = sslverify.OpenSSLCertificateOptions(
706            privateKey=self.sKey,
707            certificate=self.sCert,
708            dhParameters=dhParams,
709        )
710        opts._contextFactory = FakeContext
711        ctx = opts.getContext()
712        self.assertEqual(
713            FakeDiffieHellmanParameters._dhFile.path,
714            ctx._dhFilename
715        )
716
717
718    def test_ecDoesNotBreakConstructor(self):
719        """
720        Missing ECC does not break the constructor and sets C{_ecCurve} to
721        L{None}.
722        """
723        def raiser(self):
724            raise NotImplementedError
725        self.patch(sslverify._OpenSSLECCurve, "_getBinding", raiser)
726
727        opts = sslverify.OpenSSLCertificateOptions(
728            privateKey=self.sKey,
729            certificate=self.sCert,
730        )
731        self.assertIs(None, opts._ecCurve)
732
733
734    def test_ecNeverBreaksGetContext(self):
735        """
736        ECDHE support is best effort only and errors are ignored.
737        """
738        opts = sslverify.OpenSSLCertificateOptions(
739            privateKey=self.sKey,
740            certificate=self.sCert,
741        )
742        opts._ecCurve = object()
743        ctx = opts.getContext()
744        self.assertIsInstance(ctx, SSL.Context)
745
746
747    def test_ecSuccessWithRealBindings(self):
748        """
749        Integration test that checks the positive code path to ensure that we
750        use the API properly.
751        """
752        try:
753            defaultCurve = sslverify._OpenSSLECCurve(
754                sslverify._defaultCurveName
755            )
756        except NotImplementedError:
757            raise unittest.SkipTest(
758                "Underlying pyOpenSSL is not based on cryptography."
759            )
760        opts = sslverify.OpenSSLCertificateOptions(
761            privateKey=self.sKey,
762            certificate=self.sCert,
763        )
764        self.assertEqual(defaultCurve, opts._ecCurve)
765        # Exercise positive code path.  getContext swallows errors so we do it
766        # explicitly by hand.
767        opts._ecCurve.addECKeyToContext(opts.getContext())
768
769
770    def test_abbreviatingDistinguishedNames(self):
771        """
772        Check that abbreviations used in certificates correctly map to
773        complete names.
774        """
775        self.assertEqual(
776                sslverify.DN(CN=b'a', OU=b'hello'),
777                sslverify.DistinguishedName(commonName=b'a',
778                                            organizationalUnitName=b'hello'))
779        self.assertNotEquals(
780                sslverify.DN(CN=b'a', OU=b'hello'),
781                sslverify.DN(CN=b'a', OU=b'hello', emailAddress=b'xxx'))
782        dn = sslverify.DN(CN=b'abcdefg')
783        self.assertRaises(AttributeError, setattr, dn, 'Cn', b'x')
784        self.assertEqual(dn.CN, dn.commonName)
785        dn.CN = b'bcdefga'
786        self.assertEqual(dn.CN, dn.commonName)
787
788
789    def testInspectDistinguishedName(self):
790        n = sslverify.DN(commonName=b'common name',
791                         organizationName=b'organization name',
792                         organizationalUnitName=b'organizational unit name',
793                         localityName=b'locality name',
794                         stateOrProvinceName=b'state or province name',
795                         countryName=b'country name',
796                         emailAddress=b'email address')
797        s = n.inspect()
798        for k in [
799            'common name',
800            'organization name',
801            'organizational unit name',
802            'locality name',
803            'state or province name',
804            'country name',
805            'email address']:
806            self.assertIn(k, s, "%r was not in inspect output." % (k,))
807            self.assertIn(k.title(), s, "%r was not in inspect output." % (k,))
808
809
810    def testInspectDistinguishedNameWithoutAllFields(self):
811        n = sslverify.DN(localityName=b'locality name')
812        s = n.inspect()
813        for k in [
814            'common name',
815            'organization name',
816            'organizational unit name',
817            'state or province name',
818            'country name',
819            'email address']:
820            self.assertNotIn(k, s, "%r was in inspect output." % (k,))
821            self.assertNotIn(k.title(), s, "%r was in inspect output." % (k,))
822        self.assertIn('locality name', s)
823        self.assertIn('Locality Name', s)
824
825
826    def test_inspectCertificate(self):
827        """
828        Test that the C{inspect} method of L{sslverify.Certificate} returns
829        a human-readable string containing some basic information about the
830        certificate.
831        """
832        c = sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM)
833        self.assertEqual(
834            c.inspect().split('\n'),
835            ["Certificate For Subject:",
836             "               Common Name: example.twistedmatrix.com",
837             "              Country Name: US",
838             "             Email Address: nobody@twistedmatrix.com",
839             "             Locality Name: Boston",
840             "         Organization Name: Twisted Matrix Labs",
841             "  Organizational Unit Name: Security",
842             "    State Or Province Name: Massachusetts",
843             "",
844             "Issuer:",
845             "               Common Name: example.twistedmatrix.com",
846             "              Country Name: US",
847             "             Email Address: nobody@twistedmatrix.com",
848             "             Locality Name: Boston",
849             "         Organization Name: Twisted Matrix Labs",
850             "  Organizational Unit Name: Security",
851             "    State Or Province Name: Massachusetts",
852             "",
853             "Serial Number: 12345",
854             "Digest: C4:96:11:00:30:C3:EC:EE:A3:55:AA:ED:8C:84:85:18",
855             "Public Key with Hash: ff33994c80812aa95a79cdb85362d054"])
856
857
858    def test_certificateOptionsSerialization(self):
859        """
860        Test that __setstate__(__getstate__()) round-trips properly.
861        """
862        firstOpts = sslverify.OpenSSLCertificateOptions(
863            privateKey=self.sKey,
864            certificate=self.sCert,
865            method=SSL.SSLv3_METHOD,
866            verify=True,
867            caCerts=[self.sCert],
868            verifyDepth=2,
869            requireCertificate=False,
870            verifyOnce=False,
871            enableSingleUseKeys=False,
872            enableSessions=False,
873            fixBrokenPeers=True,
874            enableSessionTickets=True)
875        context = firstOpts.getContext()
876        self.assertIdentical(context, firstOpts._context)
877        self.assertNotIdentical(context, None)
878        state = firstOpts.__getstate__()
879        self.assertNotIn("_context", state)
880
881        opts = sslverify.OpenSSLCertificateOptions()
882        opts.__setstate__(state)
883        self.assertEqual(opts.privateKey, self.sKey)
884        self.assertEqual(opts.certificate, self.sCert)
885        self.assertEqual(opts.method, SSL.SSLv3_METHOD)
886        self.assertEqual(opts.verify, True)
887        self.assertEqual(opts.caCerts, [self.sCert])
888        self.assertEqual(opts.verifyDepth, 2)
889        self.assertEqual(opts.requireCertificate, False)
890        self.assertEqual(opts.verifyOnce, False)
891        self.assertEqual(opts.enableSingleUseKeys, False)
892        self.assertEqual(opts.enableSessions, False)
893        self.assertEqual(opts.fixBrokenPeers, True)
894        self.assertEqual(opts.enableSessionTickets, True)
895
896
897    def test_certificateOptionsSessionTickets(self):
898        """
899        Enabling session tickets should not set the OP_NO_TICKET option.
900        """
901        opts = sslverify.OpenSSLCertificateOptions(enableSessionTickets=True)
902        ctx = opts.getContext()
903        self.assertEqual(0, ctx.set_options(0) & 0x00004000)
904
905
906    def test_certificateOptionsSessionTicketsDisabled(self):
907        """
908        Enabling session tickets should set the OP_NO_TICKET option.
909        """
910        opts = sslverify.OpenSSLCertificateOptions(enableSessionTickets=False)
911        ctx = opts.getContext()
912        self.assertEqual(0x00004000, ctx.set_options(0) & 0x00004000)
913
914
915    def test_allowedAnonymousClientConnection(self):
916        """
917        Check that anonymous connections are allowed when certificates aren't
918        required on the server.
919        """
920        onData = defer.Deferred()
921        self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
922                            certificate=self.sCert, requireCertificate=False),
923                      sslverify.OpenSSLCertificateOptions(
924                          requireCertificate=False),
925                      onData=onData)
926
927        return onData.addCallback(
928            lambda result: self.assertEqual(result, WritingProtocol.byte))
929
930
931    def test_refusedAnonymousClientConnection(self):
932        """
933        Check that anonymous connections are refused when certificates are
934        required on the server.
935        """
936        onServerLost = defer.Deferred()
937        onClientLost = defer.Deferred()
938        self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
939                            certificate=self.sCert, verify=True,
940                            caCerts=[self.sCert], requireCertificate=True),
941                      sslverify.OpenSSLCertificateOptions(
942                          requireCertificate=False),
943                      onServerLost=onServerLost,
944                      onClientLost=onClientLost)
945
946        d = defer.DeferredList([onClientLost, onServerLost],
947                               consumeErrors=True)
948
949
950        def afterLost(result):
951            ((cSuccess, cResult), (sSuccess, sResult)) = result
952            self.failIf(cSuccess)
953            self.failIf(sSuccess)
954            # Win32 fails to report the SSL Error, and report a connection lost
955            # instead: there is a race condition so that's not totally
956            # surprising (see ticket #2877 in the tracker)
957            self.assertIsInstance(cResult.value, (SSL.Error, ConnectionLost))
958            self.assertIsInstance(sResult.value, SSL.Error)
959
960        return d.addCallback(afterLost)
961
962    def test_failedCertificateVerification(self):
963        """
964        Check that connecting with a certificate not accepted by the server CA
965        fails.
966        """
967        onServerLost = defer.Deferred()
968        onClientLost = defer.Deferred()
969        self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
970                            certificate=self.sCert, verify=False,
971                            requireCertificate=False),
972                      sslverify.OpenSSLCertificateOptions(verify=True,
973                            requireCertificate=False, caCerts=[self.cCert]),
974                      onServerLost=onServerLost,
975                      onClientLost=onClientLost)
976
977        d = defer.DeferredList([onClientLost, onServerLost],
978                               consumeErrors=True)
979        def afterLost(result):
980            ((cSuccess, cResult), (sSuccess, sResult)) = result
981            self.failIf(cSuccess)
982            self.failIf(sSuccess)
983
984        return d.addCallback(afterLost)
985
986    def test_successfulCertificateVerification(self):
987        """
988        Test a successful connection with client certificate validation on
989        server side.
990        """
991        onData = defer.Deferred()
992        self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
993                            certificate=self.sCert, verify=False,
994                            requireCertificate=False),
995                      sslverify.OpenSSLCertificateOptions(verify=True,
996                            requireCertificate=True, caCerts=[self.sCert]),
997                      onData=onData)
998
999        return onData.addCallback(
1000                lambda result: self.assertEqual(result, WritingProtocol.byte))
1001
1002    def test_successfulSymmetricSelfSignedCertificateVerification(self):
1003        """
1004        Test a successful connection with validation on both server and client
1005        sides.
1006        """
1007        onData = defer.Deferred()
1008        self.loopback(sslverify.OpenSSLCertificateOptions(privateKey=self.sKey,
1009                            certificate=self.sCert, verify=True,
1010                            requireCertificate=True, caCerts=[self.cCert]),
1011                      sslverify.OpenSSLCertificateOptions(privateKey=self.cKey,
1012                            certificate=self.cCert, verify=True,
1013                            requireCertificate=True, caCerts=[self.sCert]),
1014                      onData=onData)
1015
1016        return onData.addCallback(
1017                lambda result: self.assertEqual(result, WritingProtocol.byte))
1018
1019    def test_verification(self):
1020        """
1021        Check certificates verification building custom certificates data.
1022        """
1023        clientDN = sslverify.DistinguishedName(commonName='client')
1024        clientKey = sslverify.KeyPair.generate()
1025        clientCertReq = clientKey.certificateRequest(clientDN)
1026
1027        serverDN = sslverify.DistinguishedName(commonName='server')
1028        serverKey = sslverify.KeyPair.generate()
1029        serverCertReq = serverKey.certificateRequest(serverDN)
1030
1031        clientSelfCertReq = clientKey.certificateRequest(clientDN)
1032        clientSelfCertData = clientKey.signCertificateRequest(
1033                clientDN, clientSelfCertReq, lambda dn: True, 132)
1034        clientSelfCert = clientKey.newCertificate(clientSelfCertData)
1035
1036        serverSelfCertReq = serverKey.certificateRequest(serverDN)
1037        serverSelfCertData = serverKey.signCertificateRequest(
1038                serverDN, serverSelfCertReq, lambda dn: True, 516)
1039        serverSelfCert = serverKey.newCertificate(serverSelfCertData)
1040
1041        clientCertData = serverKey.signCertificateRequest(
1042                serverDN, clientCertReq, lambda dn: True, 7)
1043        clientCert = clientKey.newCertificate(clientCertData)
1044
1045        serverCertData = clientKey.signCertificateRequest(
1046                clientDN, serverCertReq, lambda dn: True, 42)
1047        serverCert = serverKey.newCertificate(serverCertData)
1048
1049        onData = defer.Deferred()
1050
1051        serverOpts = serverCert.options(serverSelfCert)
1052        clientOpts = clientCert.options(clientSelfCert)
1053
1054        self.loopback(serverOpts,
1055                      clientOpts,
1056                      onData=onData)
1057
1058        return onData.addCallback(
1059                lambda result: self.assertEqual(result, WritingProtocol.byte))
1060
1061
1062
1063class ProtocolVersion(Names):
1064    """
1065    L{ProtocolVersion} provides constants representing each version of the
1066    SSL/TLS protocol.
1067    """
1068    SSLv2 = NamedConstant()
1069    SSLv3 = NamedConstant()
1070    TLSv1_0 = NamedConstant()
1071    TLSv1_1 = NamedConstant()
1072    TLSv1_2 = NamedConstant()
1073
1074
1075
1076class ProtocolVersionTests(unittest.TestCase):
1077    """
1078    Tests for L{sslverify.OpenSSLCertificateOptions}'s SSL/TLS version
1079    selection features.
1080    """
1081    if skipSSL:
1082        skip = skipSSL
1083    else:
1084        _METHOD_TO_PROTOCOL = {
1085            SSL.SSLv2_METHOD: set([ProtocolVersion.SSLv2]),
1086            SSL.SSLv3_METHOD: set([ProtocolVersion.SSLv3]),
1087            SSL.TLSv1_METHOD: set([ProtocolVersion.TLSv1_0]),
1088            getattr(SSL, "TLSv1_1_METHOD", object()):
1089                set([ProtocolVersion.TLSv1_1]),
1090            getattr(SSL, "TLSv1_2_METHOD", object()):
1091                set([ProtocolVersion.TLSv1_2]),
1092
1093            # Presently, SSLv23_METHOD means (SSLv2, SSLv3, TLSv1.0, TLSv1.1,
1094            # TLSv1.2) (excluding any protocol versions not implemented by the
1095            # underlying version of OpenSSL).
1096            SSL.SSLv23_METHOD: set(ProtocolVersion.iterconstants()),
1097            }
1098
1099        _EXCLUSION_OPS = {
1100            SSL.OP_NO_SSLv2: ProtocolVersion.SSLv2,
1101            SSL.OP_NO_SSLv3: ProtocolVersion.SSLv3,
1102            SSL.OP_NO_TLSv1: ProtocolVersion.TLSv1_0,
1103            getattr(SSL, "OP_NO_TLSv1_1", 0): ProtocolVersion.TLSv1_1,
1104            getattr(SSL, "OP_NO_TLSv1_2", 0): ProtocolVersion.TLSv1_2,
1105            }
1106
1107
1108    def _protocols(self, opts):
1109        """
1110        Determine which SSL/TLS protocol versions are allowed by C{opts}.
1111
1112        @param opts: An L{sslverify.OpenSSLCertificateOptions} instance to
1113            inspect.
1114
1115        @return: A L{set} of L{NamedConstant}s from L{ProtocolVersion}
1116            indicating which SSL/TLS protocol versions connections negotiated
1117            using C{opts} will allow.
1118        """
1119        protocols = self._METHOD_TO_PROTOCOL[opts.method].copy()
1120        context = opts.getContext()
1121        options = context.set_options(0)
1122        if opts.method == SSL.SSLv23_METHOD:
1123            # Exclusions apply only to SSLv23_METHOD and no others.
1124            for opt, exclude in self._EXCLUSION_OPS.items():
1125                if options & opt:
1126                    protocols.discard(exclude)
1127        return protocols
1128
1129
1130    def test_default(self):
1131        """
1132        When L{sslverify.OpenSSLCertificateOptions} is initialized with no
1133        specific protocol versions all versions of TLS are allowed and no
1134        versions of SSL are allowed.
1135        """
1136        self.assertEqual(
1137            set([ProtocolVersion.TLSv1_0,
1138                 ProtocolVersion.TLSv1_1,
1139                 ProtocolVersion.TLSv1_2]),
1140            self._protocols(sslverify.OpenSSLCertificateOptions()))
1141
1142
1143    def test_SSLv23(self):
1144        """
1145        When L{sslverify.OpenSSLCertificateOptions} is initialized with
1146        C{SSLv23_METHOD} all versions of TLS and SSLv3 are allowed.
1147        """
1148        self.assertEqual(
1149            set([ProtocolVersion.SSLv3,
1150                 ProtocolVersion.TLSv1_0,
1151                 ProtocolVersion.TLSv1_1,
1152                 ProtocolVersion.TLSv1_2]),
1153            self._protocols(sslverify.OpenSSLCertificateOptions(
1154                    method=SSL.SSLv23_METHOD)))
1155
1156
1157
1158class TrustRootTests(unittest.TestCase):
1159    """
1160    Tests for L{sslverify.OpenSSLCertificateOptions}' C{trustRoot} argument,
1161    L{sslverify.platformTrust}, and their interactions.
1162    """
1163    if skipSSL:
1164        skip = skipSSL
1165
1166    def test_caCertsPlatformDefaults(self):
1167        """
1168        Specifying a C{trustRoot} of L{sslverify.OpenSSLDefaultPaths} when
1169        initializing L{sslverify.OpenSSLCertificateOptions} loads the
1170        platform-provided trusted certificates via C{set_default_verify_paths}.
1171        """
1172        opts = sslverify.OpenSSLCertificateOptions(
1173            trustRoot=sslverify.OpenSSLDefaultPaths(),
1174        )
1175        fc = FakeContext(SSL.TLSv1_METHOD)
1176        opts._contextFactory = lambda method: fc
1177        opts.getContext()
1178        self.assertTrue(fc._defaultVerifyPathsSet)
1179
1180
1181    def test_trustRootPlatformRejectsUntrustedCA(self):
1182        """
1183        Specifying a C{trustRoot} of L{platformTrust} when initializing
1184        L{sslverify.OpenSSLCertificateOptions} causes certificates issued by a
1185        newly created CA to be rejected by an SSL connection using these
1186        options.
1187
1188        Note that this test should I{always} pass, even on platforms where the
1189        CA certificates are not installed, as long as L{platformTrust} rejects
1190        completely invalid / unknown root CA certificates.  This is simply a
1191        smoke test to make sure that verification is happening at all.
1192        """
1193        caSelfCert, serverCert = certificatesForAuthorityAndServer()
1194        chainedCert = pathContainingDumpOf(self, serverCert, caSelfCert)
1195        privateKey = pathContainingDumpOf(self, serverCert.privateKey)
1196
1197        sProto, cProto, pump = loopbackTLSConnection(
1198            trustRoot=platformTrust(),
1199            privateKeyFile=privateKey,
1200            chainedCertFile=chainedCert,
1201        )
1202        # No data was received.
1203        self.assertEqual(cProto.wrappedProtocol.data, b'')
1204
1205        # It was an L{SSL.Error}.
1206        self.assertEqual(cProto.wrappedProtocol.lostReason.type, SSL.Error)
1207
1208        # Some combination of OpenSSL and PyOpenSSL is bad at reporting errors.
1209        err = cProto.wrappedProtocol.lostReason.value
1210        self.assertEqual(err.args[0][0][2], 'tlsv1 alert unknown ca')
1211
1212
1213    def test_trustRootSpecificCertificate(self):
1214        """
1215        Specifying a L{Certificate} object for L{trustRoot} will result in that
1216        certificate being the only trust root for a client.
1217        """
1218        caCert, serverCert = certificatesForAuthorityAndServer()
1219        otherCa, otherServer = certificatesForAuthorityAndServer()
1220        sProto, cProto, pump = loopbackTLSConnection(
1221            trustRoot=caCert,
1222            privateKeyFile=pathContainingDumpOf(self, serverCert.privateKey),
1223            chainedCertFile=pathContainingDumpOf(self, serverCert),
1224        )
1225        pump.flush()
1226        self.assertIs(cProto.wrappedProtocol.lostReason, None)
1227        self.assertEqual(cProto.wrappedProtocol.data,
1228                         sProto.wrappedProtocol.greeting)
1229
1230
1231
1232class ServiceIdentityTests(unittest.SynchronousTestCase):
1233    """
1234    Tests for the verification of the peer's service's identity via the
1235    C{hostname} argument to L{sslverify.OpenSSLCertificateOptions}.
1236    """
1237
1238    if skipSSL:
1239        skip = skipSSL
1240
1241    def serviceIdentitySetup(self, clientHostname, serverHostname,
1242                             serverContextSetup=lambda ctx: None,
1243                             validCertificate=True,
1244                             clientPresentsCertificate=False,
1245                             validClientCertificate=True,
1246                             serverVerifies=False,
1247                             buggyInfoCallback=False,
1248                             fakePlatformTrust=False,
1249                             useDefaultTrust=False):
1250        """
1251        Connect a server and a client.
1252
1253        @param clientHostname: The I{client's idea} of the server's hostname;
1254            passed as the C{hostname} to the
1255            L{sslverify.OpenSSLCertificateOptions} instance.
1256        @type clientHostname: L{unicode}
1257
1258        @param serverHostname: The I{server's own idea} of the server's
1259            hostname; present in the certificate presented by the server.
1260        @type serverHostname: L{unicode}
1261
1262        @param serverContextSetup: a 1-argument callable invoked with the
1263            L{OpenSSL.SSL.Context} after it's produced.
1264        @type serverContextSetup: L{callable} taking L{OpenSSL.SSL.Context}
1265            returning L{NoneType}.
1266
1267        @param validCertificate: Is the server's certificate valid?  L{True} if
1268            so, L{False} otherwise.
1269        @type validCertificate: L{bool}
1270
1271        @param clientPresentsCertificate: Should the client present a
1272            certificate to the server?  Defaults to 'no'.
1273        @type clientPresentsCertificate: L{bool}
1274
1275        @param validClientCertificate: If the client presents a certificate,
1276            should it actually be a valid one, i.e. signed by the same CA that
1277            the server is checking?  Defaults to 'yes'.
1278        @type validClientCertificate: L{bool}
1279
1280        @param serverVerifies: Should the server verify the client's
1281            certificate?  Defaults to 'no'.
1282        @type serverVerifies: L{bool}
1283
1284        @param buggyInfoCallback: Should we patch the implementation so that
1285            the C{info_callback} passed to OpenSSL to have a bug and raise an
1286            exception (L{ZeroDivisionError})?  Defaults to 'no'.
1287        @type buggyInfoCallback: L{bool}
1288
1289        @param fakePlatformTrust: Should we fake the platformTrust to be the
1290            same as our fake server certificate authority, so that we can test
1291            it's being used?  Defaults to 'no' and we just pass platform trust.
1292        @type fakePlatformTrust: L{bool}
1293
1294        @param useDefaultTrust: Should we avoid passing the C{trustRoot} to
1295            L{ssl.optionsForClientTLS}?  Defaults to 'no'.
1296        @type useDefaultTrust: L{bool}
1297
1298        @return: see L{connectedServerAndClient}.
1299        @rtype: see L{connectedServerAndClient}.
1300        """
1301        serverIDNA = sslverify._idnaBytes(serverHostname)
1302        serverCA, serverCert = certificatesForAuthorityAndServer(serverIDNA)
1303        other = {}
1304        passClientCert = None
1305        clientCA, clientCert = certificatesForAuthorityAndServer(u'client')
1306        if serverVerifies:
1307            other.update(trustRoot=clientCA)
1308
1309        if clientPresentsCertificate:
1310            if validClientCertificate:
1311                passClientCert = clientCert
1312            else:
1313                bogusCA, bogus = certificatesForAuthorityAndServer(u'client')
1314                passClientCert = bogus
1315
1316        serverOpts = sslverify.OpenSSLCertificateOptions(
1317            privateKey=serverCert.privateKey.original,
1318            certificate=serverCert.original,
1319            **other
1320        )
1321        serverContextSetup(serverOpts.getContext())
1322        if not validCertificate:
1323            serverCA, otherServer = certificatesForAuthorityAndServer(
1324                serverIDNA
1325            )
1326        if buggyInfoCallback:
1327            def broken(*a, **k):
1328                """
1329                Raise an exception.
1330
1331                @param a: Arguments for an C{info_callback}
1332
1333                @param k: Keyword arguments for an C{info_callback}
1334                """
1335                1 / 0
1336            self.patch(
1337                sslverify.ClientTLSOptions, "_identityVerifyingInfoCallback",
1338                broken,
1339            )
1340
1341        signature = {'hostname': clientHostname}
1342        if passClientCert:
1343            signature.update(clientCertificate=passClientCert)
1344        if not useDefaultTrust:
1345            signature.update(trustRoot=serverCA)
1346        if fakePlatformTrust:
1347            self.patch(sslverify, "platformTrust", lambda: serverCA)
1348
1349        clientOpts = sslverify.optionsForClientTLS(**signature)
1350
1351        class GreetingServer(protocol.Protocol):
1352            greeting = b"greetings!"
1353            lostReason = None
1354            data = b''
1355            def connectionMade(self):
1356                self.transport.write(self.greeting)
1357            def dataReceived(self, data):
1358                self.data += data
1359            def connectionLost(self, reason):
1360                self.lostReason = reason
1361
1362        class GreetingClient(protocol.Protocol):
1363            greeting = b'cheerio!'
1364            data = b''
1365            lostReason = None
1366            def connectionMade(self):
1367                self.transport.write(self.greeting)
1368            def dataReceived(self, data):
1369                self.data += data
1370            def connectionLost(self, reason):
1371                self.lostReason = reason
1372
1373        self.serverOpts = serverOpts
1374        self.clientOpts = clientOpts
1375
1376        clientFactory = TLSMemoryBIOFactory(
1377            clientOpts, isClient=True,
1378            wrappedFactory=protocol.Factory.forProtocol(GreetingClient)
1379        )
1380        serverFactory = TLSMemoryBIOFactory(
1381            serverOpts, isClient=False,
1382            wrappedFactory=protocol.Factory.forProtocol(GreetingServer)
1383        )
1384        return connectedServerAndClient(
1385            lambda: serverFactory.buildProtocol(None),
1386            lambda: clientFactory.buildProtocol(None),
1387        )
1388
1389
1390    def test_invalidHostname(self):
1391        """
1392        When a certificate containing an invalid hostname is received from the
1393        server, the connection is immediately dropped.
1394        """
1395        cProto, sProto, pump = self.serviceIdentitySetup(
1396            u"wrong-host.example.com",
1397            u"correct-host.example.com",
1398        )
1399        self.assertEqual(cProto.wrappedProtocol.data, b'')
1400        self.assertEqual(sProto.wrappedProtocol.data, b'')
1401
1402        cErr = cProto.wrappedProtocol.lostReason.value
1403        sErr = sProto.wrappedProtocol.lostReason.value
1404
1405        self.assertIsInstance(cErr, VerificationError)
1406        self.assertIsInstance(sErr, ConnectionClosed)
1407
1408
1409    def test_validHostname(self):
1410        """
1411        Whenever a valid certificate containing a valid hostname is received,
1412        connection proceeds normally.
1413        """
1414        cProto, sProto, pump = self.serviceIdentitySetup(
1415            u"valid.example.com",
1416            u"valid.example.com",
1417        )
1418        self.assertEqual(cProto.wrappedProtocol.data,
1419                         b'greetings!')
1420
1421        cErr = cProto.wrappedProtocol.lostReason
1422        sErr = sProto.wrappedProtocol.lostReason
1423        self.assertIdentical(cErr, None)
1424        self.assertIdentical(sErr, None)
1425
1426
1427    def test_validHostnameInvalidCertificate(self):
1428        """
1429        When an invalid certificate containing a perfectly valid hostname is
1430        received, the connection is aborted with an OpenSSL error.
1431        """
1432        cProto, sProto, pump = self.serviceIdentitySetup(
1433            u"valid.example.com",
1434            u"valid.example.com",
1435            validCertificate=False,
1436        )
1437
1438        self.assertEqual(cProto.wrappedProtocol.data, b'')
1439        self.assertEqual(sProto.wrappedProtocol.data, b'')
1440
1441        cErr = cProto.wrappedProtocol.lostReason.value
1442        sErr = sProto.wrappedProtocol.lostReason.value
1443
1444        self.assertIsInstance(cErr, SSL.Error)
1445        self.assertIsInstance(sErr, SSL.Error)
1446
1447
1448    def test_realCAsBetterNotSignOurBogusTestCerts(self):
1449        """
1450        If we use the default trust from the platform, our dinky certificate
1451        should I{really} fail.
1452        """
1453        cProto, sProto, pump = self.serviceIdentitySetup(
1454            u"valid.example.com",
1455            u"valid.example.com",
1456            validCertificate=False,
1457            useDefaultTrust=True,
1458        )
1459
1460        self.assertEqual(cProto.wrappedProtocol.data, b'')
1461        self.assertEqual(sProto.wrappedProtocol.data, b'')
1462
1463        cErr = cProto.wrappedProtocol.lostReason.value
1464        sErr = sProto.wrappedProtocol.lostReason.value
1465
1466        self.assertIsInstance(cErr, SSL.Error)
1467        self.assertIsInstance(sErr, SSL.Error)
1468
1469
1470    def test_butIfTheyDidItWouldWork(self):
1471        """
1472        L{ssl.optionsForClientTLS} should be using L{ssl.platformTrust} by
1473        default, so if we fake that out then it should trust ourselves again.
1474        """
1475        cProto, sProto, pump = self.serviceIdentitySetup(
1476            u"valid.example.com",
1477            u"valid.example.com",
1478            useDefaultTrust=True,
1479            fakePlatformTrust=True,
1480        )
1481        self.assertEqual(cProto.wrappedProtocol.data,
1482                         b'greetings!')
1483
1484        cErr = cProto.wrappedProtocol.lostReason
1485        sErr = sProto.wrappedProtocol.lostReason
1486        self.assertIdentical(cErr, None)
1487        self.assertIdentical(sErr, None)
1488
1489
1490    def test_clientPresentsCertificate(self):
1491        """
1492        When the server verifies and the client presents a valid certificate
1493        for that verification by passing it to
1494        L{sslverify.optionsForClientTLS}, communication proceeds.
1495        """
1496        cProto, sProto, pump = self.serviceIdentitySetup(
1497            u"valid.example.com",
1498            u"valid.example.com",
1499            validCertificate=True,
1500            serverVerifies=True,
1501            clientPresentsCertificate=True,
1502        )
1503
1504        self.assertEqual(cProto.wrappedProtocol.data,
1505                         b'greetings!')
1506
1507        cErr = cProto.wrappedProtocol.lostReason
1508        sErr = sProto.wrappedProtocol.lostReason
1509        self.assertIdentical(cErr, None)
1510        self.assertIdentical(sErr, None)
1511
1512
1513    def test_clientPresentsBadCertificate(self):
1514        """
1515        When the server verifies and the client presents an invalid certificate
1516        for that verification by passing it to
1517        L{sslverify.optionsForClientTLS}, the connection cannot be established
1518        with an SSL error.
1519        """
1520        cProto, sProto, pump = self.serviceIdentitySetup(
1521            u"valid.example.com",
1522            u"valid.example.com",
1523            validCertificate=True,
1524            serverVerifies=True,
1525            validClientCertificate=False,
1526            clientPresentsCertificate=True,
1527        )
1528
1529        self.assertEqual(cProto.wrappedProtocol.data,
1530                         b'')
1531
1532        cErr = cProto.wrappedProtocol.lostReason.value
1533        sErr = sProto.wrappedProtocol.lostReason.value
1534
1535        self.assertIsInstance(cErr, SSL.Error)
1536        self.assertIsInstance(sErr, SSL.Error)
1537
1538
1539    def test_hostnameIsIndicated(self):
1540        """
1541        Specifying the C{hostname} argument to L{CertificateOptions} also sets
1542        the U{Server Name Extension
1543        <https://en.wikipedia.org/wiki/Server_Name_Indication>} TLS indication
1544        field to the correct value.
1545        """
1546        names = []
1547        def setupServerContext(ctx):
1548            def servername_received(conn):
1549                names.append(conn.get_servername().decode("ascii"))
1550            ctx.set_tlsext_servername_callback(servername_received)
1551        cProto, sProto, pump = self.serviceIdentitySetup(
1552            u"valid.example.com",
1553            u"valid.example.com",
1554            setupServerContext
1555        )
1556        self.assertEqual(names, [u"valid.example.com"])
1557
1558    test_hostnameIsIndicated.skip = skipSNI
1559
1560
1561    def test_hostnameEncoding(self):
1562        """
1563        Hostnames are encoded as IDNA.
1564        """
1565        names = []
1566        hello = u"h\N{LATIN SMALL LETTER A WITH ACUTE}llo.example.com"
1567        def setupServerContext(ctx):
1568            def servername_received(conn):
1569                serverIDNA = sslverify._idnaText(conn.get_servername())
1570                names.append(serverIDNA)
1571            ctx.set_tlsext_servername_callback(servername_received)
1572        cProto, sProto, pump = self.serviceIdentitySetup(
1573            hello, hello, setupServerContext
1574        )
1575        self.assertEqual(names, [hello])
1576        self.assertEqual(cProto.wrappedProtocol.data,
1577                         b'greetings!')
1578
1579        cErr = cProto.wrappedProtocol.lostReason
1580        sErr = sProto.wrappedProtocol.lostReason
1581        self.assertIdentical(cErr, None)
1582        self.assertIdentical(sErr, None)
1583
1584    test_hostnameEncoding.skip = skipSNI
1585
1586
1587    def test_fallback(self):
1588        """
1589        L{sslverify.simpleVerifyHostname} checks string equality on the
1590        commonName of a connection's certificate's subject, doing nothing if it
1591        matches and raising L{VerificationError} if it doesn't.
1592        """
1593        name = 'something.example.com'
1594        class Connection(object):
1595            def get_peer_certificate(self):
1596                """
1597                Fake of L{OpenSSL.SSL.Connection.get_peer_certificate}.
1598
1599                @return: A certificate with a known common name.
1600                @rtype: L{OpenSSL.crypto.X509}
1601                """
1602                cert = X509()
1603                cert.get_subject().commonName = name
1604                return cert
1605        conn = Connection()
1606        self.assertIdentical(
1607            sslverify.simpleVerifyHostname(conn, u'something.example.com'),
1608            None
1609        )
1610        self.assertRaises(
1611            sslverify.SimpleVerificationError,
1612            sslverify.simpleVerifyHostname, conn, u'nonsense'
1613        )
1614
1615    def test_surpriseFromInfoCallback(self):
1616        """
1617        pyOpenSSL isn't always so great about reporting errors.  If one occurs
1618        in the verification info callback, it should be logged and the
1619        connection should be shut down (if possible, anyway; the app_data could
1620        be clobbered but there's no point testing for that).
1621        """
1622        cProto, sProto, pump = self.serviceIdentitySetup(
1623            u"correct-host.example.com",
1624            u"correct-host.example.com",
1625            buggyInfoCallback=True,
1626        )
1627        self.assertEqual(cProto.wrappedProtocol.data, b'')
1628        self.assertEqual(sProto.wrappedProtocol.data, b'')
1629
1630        cErr = cProto.wrappedProtocol.lostReason.value
1631        sErr = sProto.wrappedProtocol.lostReason.value
1632
1633        self.assertIsInstance(cErr, ZeroDivisionError)
1634        self.assertIsInstance(sErr, ConnectionClosed)
1635        errors = self.flushLoggedErrors(ZeroDivisionError)
1636        self.assertTrue(errors)
1637
1638
1639
1640class _NotSSLTransport:
1641    def getHandle(self):
1642        return self
1643
1644
1645
1646class _MaybeSSLTransport:
1647    def getHandle(self):
1648        return self
1649
1650    def get_peer_certificate(self):
1651        return None
1652
1653    def get_host_certificate(self):
1654        return None
1655
1656
1657
1658class _ActualSSLTransport:
1659    def getHandle(self):
1660        return self
1661
1662    def get_host_certificate(self):
1663        return sslverify.Certificate.loadPEM(A_HOST_CERTIFICATE_PEM).original
1664
1665    def get_peer_certificate(self):
1666        return sslverify.Certificate.loadPEM(A_PEER_CERTIFICATE_PEM).original
1667
1668
1669
1670class Constructors(unittest.TestCase):
1671    if skipSSL:
1672        skip = skipSSL
1673
1674    def test_peerFromNonSSLTransport(self):
1675        """
1676        Verify that peerFromTransport raises an exception if the transport
1677        passed is not actually an SSL transport.
1678        """
1679        x = self.assertRaises(CertificateError,
1680                              sslverify.Certificate.peerFromTransport,
1681                              _NotSSLTransport())
1682        self.failUnless(str(x).startswith("non-TLS"))
1683
1684
1685    def test_peerFromBlankSSLTransport(self):
1686        """
1687        Verify that peerFromTransport raises an exception if the transport
1688        passed is an SSL transport, but doesn't have a peer certificate.
1689        """
1690        x = self.assertRaises(CertificateError,
1691                              sslverify.Certificate.peerFromTransport,
1692                              _MaybeSSLTransport())
1693        self.failUnless(str(x).startswith("TLS"))
1694
1695
1696    def test_hostFromNonSSLTransport(self):
1697        """
1698        Verify that hostFromTransport raises an exception if the transport
1699        passed is not actually an SSL transport.
1700        """
1701        x = self.assertRaises(CertificateError,
1702                              sslverify.Certificate.hostFromTransport,
1703                              _NotSSLTransport())
1704        self.failUnless(str(x).startswith("non-TLS"))
1705
1706
1707    def test_hostFromBlankSSLTransport(self):
1708        """
1709        Verify that hostFromTransport raises an exception if the transport
1710        passed is an SSL transport, but doesn't have a host certificate.
1711        """
1712        x = self.assertRaises(CertificateError,
1713                              sslverify.Certificate.hostFromTransport,
1714                              _MaybeSSLTransport())
1715        self.failUnless(str(x).startswith("TLS"))
1716
1717
1718    def test_hostFromSSLTransport(self):
1719        """
1720        Verify that hostFromTransport successfully creates the correct
1721        certificate if passed a valid SSL transport.
1722        """
1723        self.assertEqual(
1724            sslverify.Certificate.hostFromTransport(
1725                _ActualSSLTransport()).serialNumber(),
1726            12345)
1727
1728
1729    def test_peerFromSSLTransport(self):
1730        """
1731        Verify that peerFromTransport successfully creates the correct
1732        certificate if passed a valid SSL transport.
1733        """
1734        self.assertEqual(
1735            sslverify.Certificate.peerFromTransport(
1736                _ActualSSLTransport()).serialNumber(),
1737            12346)
1738
1739
1740
1741class TestOpenSSLCipher(unittest.TestCase):
1742    """
1743    Tests for twisted.internet._sslverify.OpenSSLCipher.
1744    """
1745    if skipSSL:
1746        skip = skipSSL
1747
1748    cipherName = u'CIPHER-STRING'
1749
1750    def test_constructorSetsFullName(self):
1751        """
1752        The first argument passed to the constructor becomes the full name.
1753        """
1754        self.assertEqual(
1755            self.cipherName,
1756            sslverify.OpenSSLCipher(self.cipherName).fullName
1757        )
1758
1759
1760    def test_repr(self):
1761        """
1762        C{repr(cipher)} returns a valid constructor call.
1763        """
1764        cipher = sslverify.OpenSSLCipher(self.cipherName)
1765        self.assertEqual(
1766            cipher,
1767            eval(repr(cipher), {'OpenSSLCipher': sslverify.OpenSSLCipher})
1768        )
1769
1770
1771    def test_eqSameClass(self):
1772        """
1773        Equal type and C{fullName} means that the objects are equal.
1774        """
1775        cipher1 = sslverify.OpenSSLCipher(self.cipherName)
1776        cipher2 = sslverify.OpenSSLCipher(self.cipherName)
1777        self.assertEqual(cipher1, cipher2)
1778
1779
1780    def test_eqSameNameDifferentType(self):
1781        """
1782        If ciphers have the same name but different types, they're still
1783        different.
1784        """
1785        class DifferentCipher(object):
1786            fullName = self.cipherName
1787
1788        self.assertNotEqual(
1789            sslverify.OpenSSLCipher(self.cipherName),
1790            DifferentCipher(),
1791        )
1792
1793
1794
1795class TestExpandCipherString(unittest.TestCase):
1796    """
1797    Tests for twisted.internet._sslverify._expandCipherString.
1798    """
1799    if skipSSL:
1800        skip = skipSSL
1801
1802    def test_doesNotStumbleOverEmptyList(self):
1803        """
1804        If the expanded cipher list is empty, an empty L{list} is returned.
1805        """
1806        self.assertEqual(
1807            [],
1808            sslverify._expandCipherString(u'', SSL.SSLv23_METHOD, 0)
1809        )
1810
1811
1812    def test_doesNotSwallowOtherSSLErrors(self):
1813        """
1814        Only no cipher matches get swallowed, every other SSL error gets
1815        propagated.
1816        """
1817        def raiser(_):
1818            # Unfortunately, there seems to be no way to trigger a real SSL
1819            # error artificially.
1820            raise SSL.Error([['', '', '']])
1821        ctx = FakeContext(SSL.SSLv23_METHOD)
1822        ctx.set_cipher_list = raiser
1823        self.patch(sslverify.SSL, 'Context', lambda _: ctx)
1824        self.assertRaises(
1825            SSL.Error,
1826            sslverify._expandCipherString, u'ALL', SSL.SSLv23_METHOD, 0
1827        )
1828
1829
1830    def test_returnsListOfICiphers(self):
1831        """
1832        L{sslverify._expandCipherString} always returns a L{list} of
1833        L{interfaces.ICipher}.
1834        """
1835        ciphers = sslverify._expandCipherString(u'ALL', SSL.SSLv23_METHOD, 0)
1836        self.assertIsInstance(ciphers, list)
1837        bogus = []
1838        for c in ciphers:
1839            if not interfaces.ICipher.providedBy(c):
1840                bogus.append(c)
1841
1842        self.assertEqual([], bogus)
1843
1844
1845
1846class TestAcceptableCiphers(unittest.TestCase):
1847    """
1848    Tests for twisted.internet._sslverify.OpenSSLAcceptableCiphers.
1849    """
1850    if skipSSL:
1851        skip = skipSSL
1852
1853    def test_selectOnEmptyListReturnsEmptyList(self):
1854        """
1855        If no ciphers are available, nothing can be selected.
1856        """
1857        ac = sslverify.OpenSSLAcceptableCiphers([])
1858        self.assertEqual([], ac.selectCiphers([]))
1859
1860
1861    def test_selectReturnsOnlyFromAvailable(self):
1862        """
1863        Select only returns a cross section of what is available and what is
1864        desirable.
1865        """
1866        ac = sslverify.OpenSSLAcceptableCiphers([
1867            sslverify.OpenSSLCipher('A'),
1868            sslverify.OpenSSLCipher('B'),
1869        ])
1870        self.assertEqual([sslverify.OpenSSLCipher('B')],
1871                         ac.selectCiphers([sslverify.OpenSSLCipher('B'),
1872                                           sslverify.OpenSSLCipher('C')]))
1873
1874
1875    def test_fromOpenSSLCipherStringExpandsToListOfCiphers(self):
1876        """
1877        If L{sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString} is
1878        called it expands the string to a list of ciphers.
1879        """
1880        ac = sslverify.OpenSSLAcceptableCiphers.fromOpenSSLCipherString('ALL')
1881        self.assertIsInstance(ac._ciphers, list)
1882        self.assertTrue(all(sslverify.ICipher.providedBy(c)
1883                            for c in ac._ciphers))
1884
1885
1886
1887class TestDiffieHellmanParameters(unittest.TestCase):
1888    """
1889    Tests for twisted.internet._sslverify.OpenSSLDHParameters.
1890    """
1891    if skipSSL:
1892        skip = skipSSL
1893    filePath = FilePath(b'dh.params')
1894
1895    def test_fromFile(self):
1896        """
1897        Calling C{fromFile} with a filename returns an instance with that file
1898        name saved.
1899        """
1900        params = sslverify.OpenSSLDiffieHellmanParameters.fromFile(
1901            self.filePath
1902        )
1903        self.assertEqual(self.filePath, params._dhFile)
1904
1905
1906
1907class FakeECKey(object):
1908    """
1909    An introspectable fake of a key.
1910
1911    @ivar _nid: A free form nid.
1912    """
1913    def __init__(self, nid):
1914        self._nid = nid
1915
1916
1917
1918class FakeNID(object):
1919    """
1920    An introspectable fake of a NID.
1921
1922    @ivar _snName: A free form sn name.
1923    """
1924    def __init__(self, snName):
1925        self._snName = snName
1926
1927
1928
1929class FakeLib(object):
1930    """
1931    An introspectable fake of cryptography's lib object.
1932
1933    @ivar _createdKey: A set of keys that have been created by this instance.
1934    @type _createdKey: L{set} of L{FakeKey}
1935
1936    @cvar NID_undef: A symbolic constant for undefined NIDs.
1937    @type NID_undef: L{FakeNID}
1938    """
1939    NID_undef = FakeNID("undef")
1940
1941    def __init__(self):
1942        self._createdKeys = set()
1943
1944
1945    def OBJ_sn2nid(self, snName):
1946        """
1947        Create a L{FakeNID} with C{snName} and return it.
1948
1949        @param snName: a free form name that gets passed to the constructor
1950            of L{FakeNID}.
1951
1952        @return: a new L{FakeNID}.
1953        @rtype: L{FakeNID}.
1954        """
1955        return FakeNID(snName)
1956
1957
1958    def EC_KEY_new_by_curve_name(self, nid):
1959        """
1960        Create a L{FakeECKey}, save it to C{_createdKeys} and return it.
1961
1962        @param nid: an arbitrary object that is passed to the constructor of
1963            L{FakeECKey}.
1964
1965        @return: a new L{FakeECKey}
1966        @rtype: L{FakeECKey}
1967        """
1968        key = FakeECKey(nid)
1969        self._createdKeys.add(key)
1970        return key
1971
1972
1973    def EC_KEY_free(self, key):
1974        """
1975        Remove C{key} from C{_createdKey}.
1976
1977        @param key: a key object to be freed; i.e. removed from
1978            C{_createdKeys}.
1979
1980        @raises ValueError: If C{key} is not in C{_createdKeys} and thus not
1981            created by us.
1982        """
1983        try:
1984            self._createdKeys.remove(key)
1985        except KeyError:
1986            raise ValueError("Unallocated EC key attempted to free.")
1987
1988
1989    def SSL_CTX_set_tmp_ecdh(self, ffiContext, key):
1990        """
1991        Does not do anything.
1992
1993        @param ffiContext: ignored
1994        @param key: ignored
1995        """
1996
1997
1998
1999class TestFakeLib(unittest.TestCase):
2000    """
2001    Tests for FakeLib
2002    """
2003    def test_objSn2Nid(self):
2004        """
2005        Returns a L{FakeNID} with correct name.
2006        """
2007        nid = FakeNID("test")
2008        self.assertEqual("test", nid._snName)
2009
2010
2011    def test_emptyKeys(self):
2012        """
2013        A new L{FakeLib} has an empty set for created keys.
2014        """
2015        self.assertEqual(set(), FakeLib()._createdKeys)
2016
2017
2018    def test_newKey(self):
2019        """
2020        If a new key is created, it's added to C{_createdKeys}.
2021        """
2022        lib = FakeLib()
2023        key = lib.EC_KEY_new_by_curve_name(FakeNID("name"))
2024        self.assertEqual(set([key]), lib._createdKeys)
2025
2026
2027    def test_freeUnknownKey(self):
2028        """
2029        Raise L{ValueError} if an unknown key is attempted to be freed.
2030        """
2031        key = FakeECKey(object())
2032        self.assertRaises(
2033            ValueError,
2034            FakeLib().EC_KEY_free, key
2035        )
2036
2037
2038    def test_freeKnownKey(self):
2039        """
2040        Freeing an allocated key removes it from C{_createdKeys}.
2041        """
2042        lib = FakeLib()
2043        key = lib.EC_KEY_new_by_curve_name(FakeNID("name"))
2044        lib.EC_KEY_free(key)
2045        self.assertEqual(set(), lib._createdKeys)
2046
2047
2048
2049class FakeFFI(object):
2050    """
2051    A fake of a cryptography's ffi object.
2052
2053    @cvar NULL: Symbolic constant for CFFI's NULL objects.
2054    """
2055    NULL = object()
2056
2057
2058
2059class FakeBinding(object):
2060    """
2061    A fake of cryptography's binding object.
2062
2063    @type lib: L{FakeLib}
2064    @type ffi: L{FakeFFI}
2065    """
2066    def __init__(self, lib=None, ffi=None):
2067        self.lib = lib or FakeLib()
2068        self.ffi = ffi or FakeFFI()
2069
2070
2071
2072class TestECCurve(unittest.TestCase):
2073    """
2074    Tests for twisted.internet._sslverify.OpenSSLECCurve.
2075    """
2076    if skipSSL:
2077        skip = skipSSL
2078
2079    def test_missingBinding(self):
2080        """
2081        Raise L{NotImplementedError} if pyOpenSSL is not based on cryptography.
2082        """
2083        def raiser(self):
2084            raise NotImplementedError
2085        self.patch(sslverify._OpenSSLECCurve, "_getBinding", raiser)
2086        self.assertRaises(
2087            NotImplementedError,
2088            sslverify._OpenSSLECCurve, sslverify._defaultCurveName,
2089        )
2090
2091
2092    def test_nonECbinding(self):
2093        """
2094        Raise L{NotImplementedError} if pyOpenSSL is based on cryptography but
2095        cryptography lacks required EC methods.
2096        """
2097        def raiser(self):
2098            raise AttributeError
2099        lib = FakeLib()
2100        lib.OBJ_sn2nid = raiser
2101        self.patch(sslverify._OpenSSLECCurve,
2102                   "_getBinding",
2103                   lambda self: FakeBinding(lib=lib))
2104        self.assertRaises(
2105            NotImplementedError,
2106            sslverify._OpenSSLECCurve, sslverify._defaultCurveName,
2107        )
2108
2109
2110    def test_wrongName(self):
2111        """
2112        Raise L{ValueError} on unknown sn names.
2113        """
2114        lib = FakeLib()
2115        lib.OBJ_sn2nid = lambda self: FakeLib.NID_undef
2116        self.patch(sslverify._OpenSSLECCurve,
2117                   "_getBinding",
2118                   lambda self: FakeBinding(lib=lib))
2119        self.assertRaises(
2120            ValueError,
2121            sslverify._OpenSSLECCurve, u"doesNotExist",
2122        )
2123
2124
2125    def test_keyFails(self):
2126        """
2127        Raise L{EnvironmentError} if key creation fails.
2128        """
2129        lib = FakeLib()
2130        lib.EC_KEY_new_by_curve_name = lambda *a, **kw: FakeFFI.NULL
2131        self.patch(sslverify._OpenSSLECCurve,
2132                   "_getBinding",
2133                   lambda self: FakeBinding(lib=lib))
2134        curve = sslverify._OpenSSLECCurve(sslverify._defaultCurveName)
2135        self.assertRaises(
2136            EnvironmentError,
2137            curve.addECKeyToContext, object()
2138        )
2139
2140
2141    def test_keyGetsFreed(self):
2142        """
2143        Don't leak a key when adding it to a context.
2144        """
2145        lib = FakeLib()
2146        self.patch(sslverify._OpenSSLECCurve,
2147                   "_getBinding",
2148                   lambda self: FakeBinding(lib=lib))
2149        curve = sslverify._OpenSSLECCurve(sslverify._defaultCurveName)
2150        ctx = FakeContext(None)
2151        ctx._context = None
2152        curve.addECKeyToContext(ctx)
2153        self.assertEqual(set(), lib._createdKeys)
2154