1import unittest
2from util import *
3
4
5class ECDHTests(unittest.TestCase):
6
7    def priv_to_pub(self, priv):
8        pub, _ = make_cbuffer('00'*33)
9        ret = wally_ec_public_key_from_private_key(priv, len(priv), pub, len(pub))
10        self.assertEqual(ret, WALLY_OK)
11        return pub
12
13    def test_ecdh(self):
14        """Tests for ECDH"""
15
16        priv1, _ = make_cbuffer('aa'*32)
17        priv2, _ = make_cbuffer('bb'*32)
18        pub1 = self.priv_to_pub(priv1)
19        pub2 = self.priv_to_pub(priv2)
20
21        out12, _ = make_cbuffer('00'*32)
22        out21, _ = make_cbuffer('00'*32)
23        ret = wally_ecdh(pub1, len(pub1), priv2, len(priv2), out12, len(out12))
24        self.assertEqual(ret, WALLY_OK)
25        ret = wally_ecdh(pub2, len(pub2), priv1, len(priv1), out21, len(out21))
26        self.assertEqual(ret, WALLY_OK)
27
28        self.assertEqual(out12, out21)
29
30        out, _ = make_cbuffer('00'*32)
31        priv_bad, _ = make_cbuffer('00'*32)
32        pub_bad, _ = make_cbuffer('02' + '00'*32)
33
34        for args in [
35            (None, 32, pub1, 32, out, 32),      # Missing private key
36            (priv_bad, 32, pub1, 33, out, 32),  # Invalid private key
37            (priv1, 31, pub1, 32, out, 32),     # Invalid private key length
38            (priv1, 32, None, 33, out, 32),     # Missing public key
39            (priv1, 32, pub_bad, 33, out, 32),  # Invalid public key
40            (priv1, 32, pub1, 32, out, 32),     # Invalid public key length
41            (priv1, 32, pub1, 32, None, 32),    # Missing output
42            (priv1, 32, pub1, 33, out, 31),     # Invalid output length
43        ]:
44            self.assertEqual(WALLY_EINVAL, wally_ecdh(*args))
45            self.assertEqual(h(out), utf8('00'*32))
46
47
48if __name__ == '__main__':
49    unittest.main()
50