1from openid.test import datadriven
2
3import unittest
4
5from openid.message import Message, BARE_NS, OPENID_NS, OPENID2_NS
6from openid import association
7import time
8from openid import cryptutil
9import warnings
10
11
12class AssociationSerializationTest(unittest.TestCase):
13    def test_roundTrip(self):
14        issued = int(time.time())
15        lifetime = 600
16        handle = 'a-QoU6tM*#!*R\'q\\w<W>X`90>tj7d{[t~Wv@(j(V9(jcx:ZeGYbT0;N]"C}bxQ$aDjf{)"z6@+W<Wb$Vm`k9j0/tZ=\\J[0Qmp35ex[H9g<nUC9UGj4.Hlq7"Q]`w:w6Q'
17        assoc = association.Association(handle, 'secret', issued, lifetime,
18                                        'HMAC-SHA1')
19        s = assoc.serialize()
20        assoc2 = association.Association.deserialize(s)
21        self.assertEqual(assoc.handle, assoc2.handle)
22        self.assertEqual(assoc.issued, assoc2.issued)
23        self.assertEqual(assoc.secret, assoc2.secret)
24        self.assertEqual(assoc.lifetime, assoc2.lifetime)
25        self.assertEqual(assoc.assoc_type, assoc2.assoc_type)
26
27from openid.server.server import \
28     DiffieHellmanSHA1ServerSession, \
29     DiffieHellmanSHA256ServerSession, \
30     PlainTextServerSession
31
32from openid.consumer.consumer import \
33     DiffieHellmanSHA1ConsumerSession, \
34     DiffieHellmanSHA256ConsumerSession, \
35     PlainTextConsumerSession
36
37from openid.dh import DiffieHellman
38
39
40def createNonstandardConsumerDH():
41    nonstandard_dh = DiffieHellman(1315291, 2)
42    return DiffieHellmanSHA1ConsumerSession(nonstandard_dh)
43
44
45class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase):
46    secrets = [
47        '\x00' * 20,
48        '\xff' * 20,
49        ' ' * 20,
50        'This is a secret....',
51    ]
52
53    session_factories = [
54        (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA1ServerSession),
55        (createNonstandardConsumerDH, DiffieHellmanSHA1ServerSession),
56        (PlainTextConsumerSession, PlainTextServerSession),
57    ]
58
59    def generateCases(cls):
60        return [(c, s, sec)
61                for c, s in cls.session_factories for sec in cls.secrets]
62
63    generateCases = classmethod(generateCases)
64
65    def __init__(self, csess_fact, ssess_fact, secret):
66        datadriven.DataDrivenTestCase.__init__(self, csess_fact.__name__)
67        self.secret = secret
68        self.csess_fact = csess_fact
69        self.ssess_fact = ssess_fact
70
71    def runOneTest(self):
72        csess = self.csess_fact()
73        msg = Message.fromOpenIDArgs(csess.getRequest())
74        ssess = self.ssess_fact.fromMessage(msg)
75        check_secret = csess.extractSecret(
76            Message.fromOpenIDArgs(ssess.answer(self.secret)))
77        self.assertEqual(self.secret, check_secret)
78
79
80class TestMakePairs(unittest.TestCase):
81    """Check the key-value formatting methods of associations.
82    """
83
84    def setUp(self):
85        self.message = m = Message(OPENID2_NS)
86        m.updateArgs(OPENID2_NS, {
87            'mode': 'id_res',
88            'identifier': '=example',
89            'signed': 'identifier,mode',
90            'sig': 'cephalopod',
91        })
92        m.updateArgs(BARE_NS, {'xey': 'value'})
93        self.assoc = association.Association.fromExpiresIn(
94            3600, '{sha1}', 'very_secret', "HMAC-SHA1")
95
96    def testMakePairs(self):
97        """Make pairs using the OpenID 1.x type signed list."""
98        pairs = self.assoc._makePairs(self.message)
99        expected = [
100            ('identifier', '=example'),
101            ('mode', 'id_res'),
102        ]
103        self.assertEqual(pairs, expected)
104
105
106class TestMac(unittest.TestCase):
107    def setUp(self):
108        self.pairs = [('key1', 'value1'), ('key2', 'value2')]
109
110    def test_sha1(self):
111        assoc = association.Association.fromExpiresIn(
112            3600, '{sha1}', 'very_secret', "HMAC-SHA1")
113        expected = (b'\xe0\x1bv\x04\xf1G\xc0\xbb\x7f\x9a\x8b'
114                    b'\xe9\xbc\xee}\\\xe5\xbb7*')
115        sig = assoc.sign(self.pairs)
116        self.assertEqual(sig, expected)
117
118    if cryptutil.SHA256_AVAILABLE:
119
120        def test_sha256(self):
121            assoc = association.Association.fromExpiresIn(
122                3600, '{sha256SA}', 'very_secret', "HMAC-SHA256")
123            expected = (b'\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy'
124                        b'\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&')
125            sig = assoc.sign(self.pairs)
126            self.assertEqual(sig, expected)
127
128
129class TestMessageSigning(unittest.TestCase):
130    def setUp(self):
131        self.message = m = Message(OPENID2_NS)
132        m.updateArgs(OPENID2_NS, {'mode': 'id_res', 'identifier': '=example'})
133        m.updateArgs(BARE_NS, {'xey': 'value'})
134        self.args = {
135            'openid.mode': 'id_res',
136            'openid.identifier': '=example',
137            'xey': 'value'
138        }
139
140    def test_signSHA1(self):
141        assoc = association.Association.fromExpiresIn(
142            3600, '{sha1}', 'very_secret', "HMAC-SHA1")
143        signed = assoc.signMessage(self.message)
144        self.assertTrue(signed.getArg(OPENID_NS, "sig"))
145        self.assertEqual(
146            signed.getArg(OPENID_NS, "signed"),
147            "assoc_handle,identifier,mode,ns,signed")
148        self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", signed)
149
150    if cryptutil.SHA256_AVAILABLE:
151
152        def test_signSHA256(self):
153            assoc = association.Association.fromExpiresIn(
154                3600, '{sha1}', 'very_secret', "HMAC-SHA256")
155            signed = assoc.signMessage(self.message)
156            self.assertTrue(signed.getArg(OPENID_NS, "sig"))
157            self.assertEqual(
158                signed.getArg(OPENID_NS, "signed"),
159                "assoc_handle,identifier,mode,ns,signed")
160            self.assertEqual(signed.getArg(BARE_NS, "xey"), "value", signed)
161
162
163class TestCheckMessageSignature(unittest.TestCase):
164    def test_aintGotSignedList(self):
165        m = Message(OPENID2_NS)
166        m.updateArgs(OPENID2_NS, {
167            'mode': 'id_res',
168            'identifier': '=example',
169            'sig': 'coyote',
170        })
171        m.updateArgs(BARE_NS, {'xey': 'value'})
172        assoc = association.Association.fromExpiresIn(
173            3600, '{sha1}', 'very_secret', "HMAC-SHA1")
174        self.assertRaises(ValueError, assoc.checkMessageSignature, m)
175
176
177def pyUnitTests():
178    return datadriven.loadTests(__name__)
179
180
181if __name__ == '__main__':
182    suite = pyUnitTests()
183    runner = unittest.TextTestRunner()
184    runner.run(suite)
185