1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for I{AMP} commands related to identity.
6"""
7
8from twisted.trial import unittest
9
10from twisted.protocols import amp
11from twisted.internet.defer import succeed
12from twisted.internet.ssl import DN, KeyPair, CertificateRequest
13
14from vertex.ivertex import IQ2QUser
15from vertex.q2q import Q2Q, Q2QAddress, Identify, Sign
16
17from vertex.test.amphelpers import callResponder
18
19
20def makeCert(cn):
21    """
22    Create a self-signed certificate with the given common name.
23
24    @param cn: Common Name to use in certificate.
25    @type cn: L{bytes}
26
27    @return: Self-signed certificate.
28    @rtype: L{Certificate<twisted.internet.ssl.Certificate>}
29    """
30    sharedDN = DN(CN=cn)
31    key = KeyPair.generate()
32    cr = key.certificateRequest(sharedDN)
33    sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1)
34    return key.newCertificate(sscrd)
35
36
37def makeCertRequest(cn):
38    """
39    Create a certificate request with the given common name.
40
41    @param cn: Common Name to use in certificate request.
42    @type cn: L{bytes}
43
44    @return: Certificate request.
45    @rtype: L{CertificateRequest}
46    """
47    key = KeyPair.generate()
48    return key.certificateRequest(DN(CN=cn))
49
50
51
52class IdentityTests(unittest.TestCase):
53    """
54    Tests for L{Identify}.
55    """
56
57    def test_identify(self):
58        """
59        A presence server responds to Identify messages with the cert
60        stored for the requested domain.
61        """
62        target = "example.com"
63        fakeCert = makeCert("fake certificate")
64
65        class FakeStorage(object):
66            def getPrivateCertificate(cs, subject):
67                self.assertEqual(subject, target)
68                return fakeCert
69        class FakeService(object):
70            certificateStorage = FakeStorage()
71
72        q = Q2Q()
73        q.service = FakeService()
74
75        d = callResponder(q, Identify, subject=Q2QAddress(target))
76        response = self.successResultOf(d)
77        self.assertEqual(response, {'certificate': fakeCert})
78        self.assertFalse(hasattr(response['certificate'], 'privateKey'))
79
80
81
82class SignTests(unittest.TestCase):
83    """
84    Tests for L{Sign}.
85    """
86
87    def test_cannotSign(self):
88        """
89        Vertex nodes with no portal will not sign cert requests.
90        """
91        cr = CertificateRequest.load(makeCertRequest("example.com"))
92        class FakeService(object):
93            portal = None
94
95        q = Q2Q()
96        q.service = FakeService()
97
98        d = callResponder(q, Sign,
99                          certificate_request=cr,
100                          password='hunter2')
101        self.failureResultOf(d, amp.RemoteAmpError)
102
103
104    def test_sign(self):
105        """
106        'Sign' messages with a cert request result in a cred login with
107        the given password. The avatar returned is then asked to sign
108        the cert request with the presence server's certificate. The
109        resulting certificate is returned as a response.
110        """
111        user = 'jethro@example.com'
112        passwd = 'hunter2'
113
114        issuerName = "fake certificate"
115        domainCert = makeCert(issuerName)
116
117        class FakeAvatar(object):
118            def signCertificateRequest(fa, certificateRequest, hostcert,
119                                       suggestedSerial):
120                self.assertEqual(hostcert, domainCert)
121                return hostcert.signRequestObject(certificateRequest,
122                                                  suggestedSerial)
123
124        class FakeStorage(object):
125            def getPrivateCertificate(cs, subject):
126                return domainCert
127
128            def genSerial(cs, domain):
129                return 1
130
131        cr = CertificateRequest.load(makeCertRequest(user))
132        class FakePortal(object):
133            def login(fp, creds, proto, iface):
134                self.assertEqual(iface, IQ2QUser)
135                self.assertEqual(creds.username, user)
136                self.assertEqual(creds.password, passwd)
137                return succeed([None, FakeAvatar(), None])
138
139        class FakeService(object):
140            portal = FakePortal()
141            certificateStorage = FakeStorage()
142
143        q = Q2Q()
144        q.service = FakeService()
145
146        d = callResponder(q, Sign,
147                          certificate_request=cr,
148                          password=passwd)
149        response = self.successResultOf(d)
150        self.assertEqual(response['certificate'].getIssuer().commonName,
151                         issuerName)
152