1import unittest
2from util import *
3
4ROUND_TRIP_CASES = [
5    # RFC 4648
6    ('f', 'Zg=='),
7    ('fo', 'Zm8='),
8    ('foo', 'Zm9v'),
9    ('foob', 'Zm9vYg=='),
10    ('fooba', 'Zm9vYmE='),
11    ('foobar', 'Zm9vYmFy'),
12    # Cases from https://commons.apache.org/proper/commons-codec/xref-test/org/apache/commons/codec/binary/Base64Test.html
13    ('Hello World', 'SGVsbG8gV29ybGQ='),
14    ('A', 'QQ=='),
15    ('AA', 'QUE='),
16    ('AAA', 'QUFB'),
17    ('The quick brown fox jumped over the lazy dogs.',
18        'VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2dzLg=='),
19    ('It was the best of times, it was the worst of times.',
20        'SXQgd2FzIHRoZSBiZXN0IG9mIHRpbWVzLCBpdCB3YXMgdGhlIHdvcnN0IG9mIHRpbWVzLg=='),
21    ('http://jakarta.apache.org/commmons', 'aHR0cDovL2pha2FydGEuYXBhY2hlLm9yZy9jb21tbW9ucw=='),
22    ('AaBbCcDdEeFfGgHhIiJjKkLlMmNnOoPpQqRrSsTtUuVvWwXxYyZz',
23        'QWFCYkNjRGRFZUZmR2dIaElpSmpLa0xsTW1Obk9vUHBRcVJyU3NUdFV1VnZXd1h4WXlaeg=='),
24    ('xyzzy!', 'eHl6enkh'),
25]
26
27class Base64Tests(unittest.TestCase):
28
29    def test_vectors(self):
30        """Tests for encoding and decoding a base 64 string"""
31
32        buf, buf_len = make_cbuffer('00' * 1024)
33        for str_in, b64_in in ROUND_TRIP_CASES:
34            ret, max_len = wally_base64_get_maximum_length(b64_in, 0)
35            self.assertEqual(ret, WALLY_OK)
36            self.assertTrue(max_len >= len(str_in))
37
38            ret, b64_out = wally_base64_from_bytes(utf8(str_in), len(str_in), 0)
39            self.assertEqual((ret, b64_out), (WALLY_OK, b64_in))
40
41            ret, written = wally_base64_to_bytes(utf8(b64_in), 0, buf, max_len)
42            self.assertEqual((ret, buf[:written]), (WALLY_OK, utf8(str_in)))
43
44    def test_get_maximum_length(self):
45        # Invalid args
46        valid_b64 = utf8(ROUND_TRIP_CASES[0][1])
47
48        for args in [(None,      0), # Null base64 string
49                     (bytes(),   0), # Zero-length base64 string
50                     (valid_b64, 1), # Invalid flags
51            ]:
52            ret, max_len = wally_base64_get_maximum_length(*args)
53            self.assertEqual((ret, max_len), (WALLY_EINVAL, 0))
54
55    def test_base64_from_bytes(self):
56        # Invalid args
57        valid_str = utf8(ROUND_TRIP_CASES[0][0])
58        valid_str_len = len(valid_str)
59
60        for args in [
61            (None,      valid_str_len, 0), # Null input bytes
62            (valid_str, 0,             0), # Zero-length input bytes
63            (valid_str, valid_str_len, 1), # Invalid flags
64            ]:
65            ret, b64_out = wally_base64_from_bytes(*args)
66            self.assertEqual((ret, b64_out), (WALLY_EINVAL, None))
67
68    def test_base64_to_bytes(self):
69        # Invalid args
70        buf, buf_len = make_cbuffer('00' * 1024)
71        valid_b64 = utf8(ROUND_TRIP_CASES[0][1])
72        _, max_len = wally_base64_get_maximum_length(valid_b64, 0)
73
74        for args in [
75            (None,      0, buf,  max_len),   # Null base64 string
76            (valid_b64, 1, buf,  max_len),   # Invalid flags
77            (valid_b64, 0, None, max_len),   # Null output buffer
78            (valid_b64, 0, buf,  0),         # Zero output length
79            ]:
80            ret, written = wally_base64_to_bytes(*args)
81            self.assertEqual((ret, written), (WALLY_EINVAL, 0))
82
83        # Too short output length returns the number of bytes needed
84        ret, written = wally_base64_to_bytes(valid_b64, 0, buf,  max_len-1)
85        self.assertEqual((ret, written), (WALLY_OK, max_len))
86
87if __name__ == '__main__':
88    unittest.main()
89